Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
478df149
Commit
478df149
authored
Jan 18, 2023
by
fsx950223
Browse files
Merge remote-tracking branch 'origin/develop' into embeddings
parents
8941136f
80e05267
Changes
211
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3260 additions
and
37 deletions
+3260
-37
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+2
-0
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp
...ion/gpu/device/impl/device_multiple_reduce_multiblock.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
...or_operation/gpu/device/impl/device_reduce_multiblock.hpp
+12
-4
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
...or_operation/gpu/device/impl/device_reduce_threadwise.hpp
+11
-2
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
...dwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
+1111
-0
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
...mm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
+394
-0
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
.../grid/gridwise_elementwise_layernorm_welford_variance.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+641
-0
include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp
...tion/gpu/grid/gridwise_normalization_welford_variance.hpp
+1
-1
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+507
-0
include/ck/utility/amd_inline_asm.hpp
include/ck/utility/amd_inline_asm.hpp
+6
-0
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+100
-3
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+2
-0
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+5
-5
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
...ry/reference_tensor_operation/cpu/reference_layernorm.hpp
+4
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_reduce.hpp
...brary/reference_tensor_operation/cpu/reference_reduce.hpp
+435
-0
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp
..._instance/gpu/reduce/device_reduce_instance_blockwise.hpp
+10
-2
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp
...duce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp
...uce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp
+8
-8
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp
...duce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp
+4
-4
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
478df149
...
...
@@ -500,6 +500,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
#if DEBUG_LOG
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_ak0_m_ak1_{"
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
...
...
@@ -520,6 +521,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
e_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_m_k_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
b_grid_desc_n_k_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp
View file @
478df149
...
...
@@ -73,8 +73,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
static_for
<
0
,
NumReduction
,
1
>
{}([
&
](
auto
I
)
{
using
OutDataType
=
remove_cvref_t
<
decltype
(
OutDataTypeTuple
{}[
I
])
>
;
flag
=
flag
&&
ck
::
reduce
::
InMemoryDataOperatonSupportedOnDataType
<
OutMemoryDataOperation
,
OutDataType
>::
value
;
flag
&&
ck
::
reduce
::
InMemoryDataOperat
i
onSupportedOnDataType
<
OutMemoryDataOperation
,
OutDataType
>::
value
;
});
return
flag
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
View file @
478df149
...
...
@@ -40,8 +40,16 @@ template <typename InDataType,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceReduceMultiBlock
:
public
DeviceReduce
<
Rank
,
NumReduceDim
,
InElementwiseOperation
,
AccElementwiseOperation
>
struct
DeviceReduceMultiBlock
:
public
DeviceReduce
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
OutputIndex
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
...
...
@@ -67,8 +75,8 @@ struct DeviceReduceMultiBlock
static
constexpr
bool
use_multiblock
=
(
OutMemoryDataOperation
==
InMemoryDataOperationEnum
::
AtomicAdd
);
static_assert
(
ck
::
reduce
::
InMemoryDataOperatonSupportedOnDataType
<
OutMemoryDataOperation
,
OutDataType
>::
value
,
static_assert
(
ck
::
reduce
::
InMemoryDataOperat
i
onSupportedOnDataType
<
OutMemoryDataOperation
,
OutDataType
>::
value
,
"The OutDataType must support the specified OutMemoryDataOperation!"
);
static_assert
(
!
use_multiblock
||
(
use_multiblock
&&
!
OutputIndex
),
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
View file @
478df149
...
...
@@ -35,8 +35,17 @@ template <typename InDataType,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceReduceThreadWise
:
public
DeviceReduce
<
Rank
,
NumReduceDim
,
InElementwiseOperation
,
AccElementwiseOperation
>
struct
DeviceReduceThreadWise
:
public
DeviceReduce
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
OutputIndex
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
0 → 100644
View file @
478df149
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
namespace
ck
{
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// output : F[M, N0], where N0 is number of blocks along N dimension
// output : G[M, N0], where N0 is number of blocks along N dimension
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// F, G = welford(E)
// Assume:
// D0, D1, ... and E have the same layout
// Calculate mean & variance along N dimension for E
template
<
typename
ABDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EMeanVarDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
typename
AGridDesc_M_K
,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
typename
EGridDesc_M_N
,
typename
MeanVarGridDesc_M_NBlock
,
typename
CountGridDesc_M_NBlock
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
PostShuffleThreadClusterSize_M_N
,
index_t
PostShuffleScalarPerVector
,
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
AK0PerBlock
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0PerBlock
=
Number
<
KPerBlock
/
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0PerBlock
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0PerBlock
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static
constexpr
auto
MakeDsGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
return
static_cast
<
const
DDataType
*>
(
nullptr
);
},
Number
<
NumDTensor
>
{});
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
ABDataType
),
c_block_size
*
sizeof
(
CShuffleDataType
));
}
// A desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// B desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// E desc for destination in blockwise copy
template
<
typename
EGridDescriptor_M_N
>
__host__
__device__
static
constexpr
auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDescriptor_M_N
&
e_grid_desc_m_n
)
{
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
e_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
e_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// Ds desc for source in blockwise copy
template
<
typename
DsGridDescriptor_M_N
>
__host__
__device__
static
constexpr
auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
DsGridDescriptor_M_N
&
ds_grid_desc_m_n
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
i
]);
},
Number
<
NumDTensor
>
{});
}
template
<
typename
GridDescriptor_M_N
>
__host__
__device__
static
constexpr
auto
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
const
GridDescriptor_M_N
&
grid_desc_m_n
)
{
const
auto
M
=
grid_desc_m_n
.
GetLength
(
I0
);
const
auto
NBlock
=
grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
grid_desc_mblock_mperblock_nblock
=
transform_tensor_descriptor
(
grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_pass_through_transform
(
NBlock
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}));
return
grid_desc_mblock_mperblock_nblock
;
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
(
e_grid_desc_m_n
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2ETileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
// check consistency of desc
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)))
{
return
false
;
}
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
// check tile size
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
{
return
false
;
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
// check block-to-E-tile
if
(
!
block_2_etile_map
.
CheckValidity
(
e_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_m_k
.
GetElementSpaceSize
()
*
sizeof
(
ABDataType
)
<=
TwoGB
&&
b_grid_desc_n_k
.
GetElementSpaceSize
()
*
sizeof
(
ABDataType
)
<=
TwoGB
&&
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
EMeanVarDataType
)
<=
TwoGB
))
{
return
false
;
}
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
using
DefaultAGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
DefaultBGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
MeanVarGridDesc_M_NBlock
{}))
>
;
using
CountGridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
CountGridDesc_M_NBlock
{}))
>
;
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
DefaultBlock2ETileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
template
<
bool
HasMainKBlockLoop
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
DsGridPointer
p_ds_grid
,
EMeanVarDataType
*
__restrict__
p_e_grid
,
EMeanVarDataType
*
__restrict__
p_welford_mean_grid
,
EMeanVarDataType
*
__restrict__
p_welford_var_grid
,
int32_t
*
__restrict__
p_welford_count
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
&
mean_var_grid_desc_mblock_mperblock_nblock
,
const
CountGridDescriptor_MBlock_MPerBlock_NBlock
&
count_grid_desc_mblock_mperblock_nblock
,
const
Block2ETileMap
&
block_2_etile_map
,
index_t
NRaw
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor
>
{});
auto
e_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_grid
,
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
mean_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_mean_grid
,
mean_var_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
auto
var_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_var_grid
,
mean_var_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
auto
welford_count_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_count
,
count_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_etile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_etile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABDataType
,
ABDataType
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
ABDataType
,
ABDataType
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
ABDataType
,
AccDataType
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
,
LoopSched
>
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ABDataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ABDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C, Welford and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CShuffleDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
false
>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_der_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
false
>
{};
// LDS c_shuffle_block_desc_mperblock_nperblock
constexpr
auto
c_shuffle_block_desc_mperblock_nperblock
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_pass_through_transform
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I1
)),
make_freeze_transform
(
I0
),
make_pass_through_transform
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I3
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{}));
static_assert
(
PostShuffleThreadClusterSize_M_N
::
At
(
I0
)
*
PostShuffleThreadClusterSize_M_N
::
At
(
I1
)
==
BlockSize
,
"wrong!"
);
static_assert
((
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
)
%
PostShuffleThreadClusterSize_M_N
::
At
(
I0
)
==
0
&&
(
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
)
%
PostShuffleThreadClusterSize_M_N
::
At
(
I1
)
==
0
,
"wrong!"
);
constexpr
index_t
PostShuffleThreadSliceSize_M
=
(
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
)
/
PostShuffleThreadClusterSize_M_N
::
At
(
I0
);
constexpr
index_t
PostShuffleThreadSliceSize_N
=
(
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
)
/
PostShuffleThreadClusterSize_M_N
::
At
(
I1
);
constexpr
auto
PostShuffleThreadSliceSize_M_N
=
Sequence
<
PostShuffleThreadSliceSize_M
,
PostShuffleThreadSliceSize_N
>
{};
// VGPR post_shuffle_thread_desc_m_n
constexpr
auto
post_shuffle_thread_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
PostShuffleThreadSliceSize_M
>
{},
Number
<
PostShuffleThreadSliceSize_N
>
{}));
auto
e_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
>
(
post_shuffle_thread_desc_m_n
.
GetElementSpaceSize
());
// To apply D0, D1, ... and Welford.
// threadwise copy from LDS to VGPR
constexpr
auto
post_shuffle_thread_cluster_desc
=
make_cluster_descriptor
(
PostShuffleThreadClusterSize_M_N
{},
Sequence
<
0
,
1
>
{});
const
auto
post_shuffle_thread_cluster_idx
=
post_shuffle_thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
post_shuffle_thread_data_idx_begin
=
post_shuffle_thread_cluster_idx
*
PostShuffleThreadSliceSize_M_N
;
// To apply D0, D1, ... and Welford.
// Copy c shuffle from LDS back to VGPR
auto
post_shuffle_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
CShuffleDataType
,
AccDataType
,
decltype
(
c_shuffle_block_desc_mperblock_nperblock
),
decltype
(
post_shuffle_thread_desc_m_n
),
decltype
(
PostShuffleThreadSliceSize_M_N
),
Sequence
<
0
,
1
>
,
1
,
PostShuffleScalarPerVector
,
1
,
true
>
{
c_shuffle_block_desc_mperblock_nperblock
,
post_shuffle_thread_data_idx_begin
};
// D0, D1, ..., Dn
constexpr
auto
post_shuffle_thread_desc_I1_mperblock_I1_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
PostShuffleThreadSliceSize_M
>
{},
I1
,
Number
<
PostShuffleThreadSliceSize_N
>
{}));
// FIXME: Decrease usage of VGPR
// Apply pointwise lambda function from multi-source (Global and LDS) into VGPR
auto
ds_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
CShuffleDataType
>
(
post_shuffle_thread_desc_I1_mperblock_I1_nperblock
.
GetElementSpaceSize
());
},
Number
<
NumDTensor
>
{});
// Copy D0, D1, ..., Dn from global to VGPR
auto
ds_thread_copy_global_to_vgpr
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
I
.
value
,
DsDataType
>>
;
return
ThreadwiseTensorSliceTransfer_v2
<
DDataType
,
AccDataType
,
decltype
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I
]),
decltype
(
post_shuffle_thread_desc_I1_mperblock_I1_nperblock
),
Sequence
<
I1
,
PostShuffleThreadSliceSize_M
,
I1
,
PostShuffleThreadSliceSize_N
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
PostShuffleScalarPerVector
,
1
,
true
>
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I
],
make_multi_index
(
I0
,
m_block_data_idx_on_grid
+
post_shuffle_thread_data_idx_begin
[
I0
],
I0
,
n_block_data_idx_on_grid
+
post_shuffle_thread_data_idx_begin
[
I1
]));
},
Number
<
NumDTensor
>
{});
auto
e_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
EMeanVarDataType
,
decltype
(
post_shuffle_thread_desc_I1_mperblock_I1_nperblock
),
decltype
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
PostShuffleThreadSliceSize_M
,
I1
,
PostShuffleThreadSliceSize_N
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder
3
,
// DstVectorDim
PostShuffleScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
e_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
I0
,
m_block_data_idx_on_grid
+
post_shuffle_thread_data_idx_begin
[
I0
],
I0
,
n_block_data_idx_on_grid
+
post_shuffle_thread_data_idx_begin
[
I1
]),
tensor_operation
::
element_wise
::
PassThrough
{}};
// Welford
constexpr
auto
thread_welford_src_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
PostShuffleThreadSliceSize_M
>
{},
Number
<
PostShuffleThreadSliceSize_N
>
{}));
constexpr
auto
thread_welford_dst_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
PostShuffleThreadSliceSize_M
>
{}));
using
ThreadwiseWelford
=
ThreadwiseWelford
<
AccDataType
,
decltype
(
thread_welford_src_desc_m_k
),
decltype
(
thread_welford_dst_desc_m
)
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
PostShuffleThreadClusterSize_M_N
,
Sequence
<
0
,
1
>
,
false
>
;
constexpr
int
num_shuffleM
=
MPerBlock
/
(
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
);
constexpr
int
num_shuffleN
=
NPerBlock
/
(
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
);
using
mean_var_vgpr_type
=
decltype
(
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
>
(
thread_welford_dst_desc_m
.
GetElementSpaceSize
()));
using
welford_count_vgpr_type
=
decltype
(
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
>
(
thread_welford_dst_desc_m
.
GetElementSpaceSize
()));
Array
<
ThreadwiseWelford
,
num_shuffleM
>
threadwise_welfords
;
Array
<
mean_var_vgpr_type
,
num_shuffleM
>
mean_thread_bufs
;
Array
<
mean_var_vgpr_type
,
num_shuffleM
>
var_thread_bufs
;
Array
<
welford_count_vgpr_type
,
num_shuffleM
>
welford_count_thread_bufs
;
int
max_count
=
PostShuffleThreadSliceSize_N
*
num_shuffleN
;
const
auto
nblock
=
mean_var_grid_desc_mblock_mperblock_nblock
.
GetLength
(
I2
);
// tail block
if
(
block_work_idx
[
I1
]
%
nblock
==
nblock
-
1
)
{
constexpr
index_t
NPerShuffleBlock
=
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
;
int
NPerBlockTail
=
NRaw
-
NPerBlock
*
(
nblock
-
1
);
int
thread_max_len
=
PostShuffleThreadSliceSize_N
*
(
post_shuffle_thread_cluster_idx
[
I1
]
+
1
);
int
shuffle_step
=
0
;
while
(
thread_max_len
<=
NPerBlockTail
&&
shuffle_step
<
num_shuffleN
)
{
++
shuffle_step
;
thread_max_len
+=
NPerShuffleBlock
;
}
int
delta
=
0
;
if
(
thread_max_len
-
NPerBlockTail
>
PostShuffleThreadSliceSize_N
)
delta
=
0
;
else
if
(
NPerBlockTail
>
thread_max_len
)
delta
=
PostShuffleThreadSliceSize_N
;
else
delta
=
PostShuffleThreadSliceSize_N
-
thread_max_len
+
NPerBlockTail
;
max_count
=
shuffle_step
*
PostShuffleThreadSliceSize_N
+
delta
;
}
static_for
<
0
,
num_shuffleM
,
1
>
{}([
&
](
auto
i
)
{
threadwise_welfords
(
i
).
max_count_
=
max_count
;
mean_thread_bufs
(
i
)
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
>
(
thread_welford_dst_desc_m
.
GetElementSpaceSize
());
var_thread_bufs
(
i
)
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
>
(
thread_welford_dst_desc_m
.
GetElementSpaceSize
());
welford_count_thread_bufs
(
i
)
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
>
(
thread_welford_dst_desc_m
.
GetElementSpaceSize
());
static_for
<
0
,
PostShuffleThreadSliceSize_M
,
1
>
{}([
&
](
auto
j
)
{
mean_thread_bufs
(
i
)(
j
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_bufs
(
i
)(
j
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
welford_count_thread_bufs
(
i
)(
j
)
=
0
;
});
});
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_der_global
.
GetNumOfAccess
(),
"wrong!"
);
int
shuffleM_index
=
__builtin_amdgcn_readfirstlane
(
0
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to read from LDS
block_sync_lds
();
// each thread shuffle data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to write to LDS
block_sync_lds
();
// Get shuffle data from LDS to VGPR
post_shuffle_thread_copy_lds_to_vgpr
.
Run
(
c_shuffle_block_desc_mperblock_nperblock
,
c_shuffle_block_buf
,
post_shuffle_thread_desc_m_n
,
make_tuple
(
I0
,
I0
),
e_thread_buf
);
// Global read D0, D1, ...
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
Id
)
{
auto
&
d_thread_copy_global_to_vgpr
=
ds_thread_copy_global_to_vgpr
(
Id
);
d_thread_copy_global_to_vgpr
.
Run
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
Id
],
ds_grid_buf
[
Id
],
post_shuffle_thread_desc_I1_mperblock_I1_nperblock
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
ds_thread_buf
(
Id
));
if
constexpr
(
access_id
<
num_access
-
1
)
{
// move on D0, D1, ...
constexpr
auto
de_global_step
=
sfc_der_global
.
GetForwardStep
(
access_id
);
d_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
Id
],
de_global_step
);
}
});
// cde_element_op(e, c, d0, d1, ...);
static_for
<
0
,
post_shuffle_thread_desc_m_n
.
GetElementSize
(),
1
>
{}([
&
](
auto
i
)
{
const
auto
c_ds_src_data_refs
=
concat_tuple_of_reference
(
tie
(
e_thread_buf
[
i
]),
generate_tie
(
[
&
](
auto
Id
)
->
const
auto
&
{
return
ds_thread_buf
[
Id
][
i
];
},
Number
<
NumDTensor
>
{}));
auto
e_dst_data_refs
=
tie
(
e_thread_buf
(
i
));
unpack2
(
cde_element_op
,
e_dst_data_refs
,
c_ds_src_data_refs
);
});
// Global write E
e_thread_copy_vgpr_to_global
.
Run
(
post_shuffle_thread_desc_I1_mperblock_I1_nperblock
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
e_thread_buf
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
// move on E
constexpr
auto
de_global_step
=
sfc_der_global
.
GetForwardStep
(
access_id
);
e_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
e_grid_desc_mblock_mperblock_nblock_nperblock
,
de_global_step
);
}
// Threadwise welford
auto
&
threadwise_welford
=
threadwise_welfords
(
shuffleM_index
);
auto
&
mean_thread_buf
=
mean_thread_bufs
(
shuffleM_index
);
auto
&
var_thread_buf
=
var_thread_bufs
(
shuffleM_index
);
threadwise_welford
.
Run
(
e_thread_buf
,
mean_thread_buf
,
var_thread_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
de_global_step
=
sfc_der_global
.
GetForwardStep
(
access_id
);
constexpr
int
shuffleMInc
=
de_global_step
[
I1
]
/
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I1
);
shuffleM_index
=
__builtin_amdgcn_readfirstlane
(
shuffleM_index
+
shuffleMInc
);
}
});
// copy c, d, e + welford
// Blockwise welford and write out
static_for
<
0
,
num_shuffleM
,
1
>
{}([
&
](
auto
i
)
{
auto
&
mean_thread_buf
=
mean_thread_bufs
(
i
);
auto
&
var_thread_buf
=
var_thread_bufs
(
i
);
auto
&
count_thread_buf
=
welford_count_thread_bufs
(
i
);
static_for
<
0
,
PostShuffleThreadSliceSize_M
,
1
>
{}([
&
](
auto
j
)
{
block_sync_lds
();
count_thread_buf
(
j
)
=
threadwise_welfords
(
i
).
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
j
),
var_thread_buf
(
j
),
count_thread_buf
(
j
));
});
if
(
post_shuffle_thread_cluster_idx
[
I1
]
==
0
)
{
constexpr
auto
thread_welford_desc_I_m_I
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
PostShuffleThreadSliceSize_M
>
{},
I1
));
constexpr
int
shuffleMPerBlock
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I1
);
auto
mean_var_count_thread_copy_index
=
make_multi_index
(
block_work_idx
[
I0
],
// mblock
shuffleMPerBlock
*
i
+
post_shuffle_thread_data_idx_begin
[
I0
],
// mperblock
block_work_idx
[
I1
]);
// nblock
auto
mean_var_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
EMeanVarDataType
,
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
mean_var_grid_desc_mblock_mperblock_nblock
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
mean_var_grid_desc_mblock_mperblock_nblock
,
mean_var_count_thread_copy_index
,
tensor_operation
::
element_wise
::
PassThrough
{}};
mean_var_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
make_tuple
(
I0
,
I0
,
I0
),
mean_thread_buf
,
mean_var_grid_desc_mblock_mperblock_nblock
,
mean_grid_buf
);
// write mean
mean_var_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
make_tuple
(
I0
,
I0
,
I0
),
var_thread_buf
,
mean_var_grid_desc_mblock_mperblock_nblock
,
var_grid_buf
);
// write variance
// Stride of count is [0, 1]. Only the first row in count[0, 0:nblock] need
// to be written.
if
(
i
==
0
&&
block_work_idx
[
I0
]
==
0
&&
post_shuffle_thread_cluster_idx
[
I0
]
==
0
)
{
auto
count_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
int32_t
,
int32_t
,
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
count_grid_desc_mblock_mperblock_nblock
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
{
count_grid_desc_mblock_mperblock_nblock
,
mean_var_count_thread_copy_index
,
tensor_operation
::
element_wise
::
PassThrough
{}};
count_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
make_tuple
(
I0
,
I0
,
I0
),
count_thread_buf
,
count_grid_desc_mblock_mperblock_nblock
,
welford_count_grid_buf
);
// write count
}
}
});
}
// shuffle C + Ds + welford + write out
}
// run
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
0 → 100644
View file @
478df149
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
namespace
ck
{
template
<
typename
EMeanVarDataType
,
typename
HDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
EHGridDesc_M_N
,
typename
MeanVarGridDesc_M_NBlock
,
typename
CountGridDesc_M_NBlock
,
typename
GammaBetaGridDesc_N
,
typename
HElementwiseOperation
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
NThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
NThreadSliceSize
,
index_t
ESrcVectorSize
,
index_t
HDstVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorSize
>
struct
GridwiseWelfordSecondHalfLayernorm2d
{
static_assert
(
NThreadSliceSize
%
ESrcVectorSize
==
0
&&
NThreadSliceSize
%
GammaSrcVectorSize
==
0
&&
NThreadSliceSize
%
BetaSrcVectorSize
==
0
,
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
NThreadSliceSize
%
HDstVectorSize
==
0
,
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
using
ThreadClusterLengths_M_N
=
Sequence
<
MThreadClusterSize
,
NThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
Sequence
<
0
,
1
>
;
using
ThreadClusterArrangeOrder
=
Sequence
<
0
,
1
>
;
static
constexpr
auto
thread_cluster_desc_m_n
=
make_cluster_descriptor
(
ThreadClusterLengths_M_N
{},
ThreadClusterArrangeOrder
{});
using
ThreadBufferLengths_M_N
=
Sequence
<
MThreadSliceSize
,
NThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
NThreadSliceSize
>
{}));
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
static
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
using
ThreadBufferLengths_N
=
Sequence
<
NThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NThreadSliceSize
>
{}));
using
ThreadWelfordSrcDesc_M_1
=
decltype
(
thread_buffer_desc_m_1
);
using
ThreadWelfordDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelfordMerge
<
ComputeDataType
,
ThreadWelfordSrcDesc_M_1
,
ThreadWelfordDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
ComputeDataType
,
BlockSize
,
ThreadClusterLengths_M_N
,
ThreadClusterArrangeOrder
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
N_BlockTileSize
=
NThreadClusterSize
*
NThreadSliceSize
;
__device__
static
void
Run
(
const
EMeanVarDataType
*
__restrict__
p_e_grid
,
const
EMeanVarDataType
*
__restrict__
p_in_welford_mean_grid
,
const
EMeanVarDataType
*
__restrict__
p_in_welford_var_grid
,
const
int32_t
*
__restrict__
p_in_welford_count_grid
,
const
GammaDataType
*
__restrict__
p_gamma_grid
,
const
BetaDataType
*
__restrict__
p_beta_grid
,
HDataType
*
__restrict__
p_h_grid
,
const
EHGridDesc_M_N
&
e_grid_desc_m_n
,
const
EHGridDesc_M_N
&
h_grid_desc_m_n
,
const
MeanVarGridDesc_M_NBlock
&
mean_var_grid_desc_m_nblock
,
const
CountGridDesc_M_NBlock
&
count_grid_desc_m_nblock
,
const
GammaBetaGridDesc_N
&
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
&
beta_grid_desc_n
,
index_t
numMeanVarCountBlockTileIteration_N
,
index_t
NBlockClusterLength
,
ComputeDataType
epsilon
,
HElementwiseOperation
h_element_op
)
{
// Thread/Block id
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
auto
block_work_idx
=
make_tuple
(
block_global_id
/
NBlockClusterLength
,
block_global_id
%
NBlockClusterLength
);
const
auto
thread_cluster_idx
=
thread_cluster_desc_m_n
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
I1
];
// Global Memory
const
auto
e_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_grid
,
e_grid_desc_m_n
.
GetElementSpaceSize
());
const
auto
welford_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_mean_grid
,
mean_var_grid_desc_m_nblock
.
GetElementSpaceSize
());
const
auto
welford_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_var_grid
,
mean_var_grid_desc_m_nblock
.
GetElementSpaceSize
());
const
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_count_grid
,
count_grid_desc_m_nblock
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_grid
,
gamma_grid_desc_n
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_grid
,
beta_grid_desc_n
.
GetElementSpaceSize
());
auto
h_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_h_grid
,
h_grid_desc_m_n
.
GetElementSpaceSize
());
// VGPR
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
in_welford_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
in_welford_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
in_welford_count_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
welford_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
welford_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
welford_count_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
NThreadSliceSize
,
true
>
e_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
NThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
NThreadSliceSize
,
true
>
beta_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
NThreadSliceSize
,
true
>
h_thread_buf
;
// IO
auto
threadwise_mean_load_m_nblock
=
ThreadwiseTensorSliceTransfer_v2
<
EMeanVarDataType
,
ComputeDataType
,
MeanVarGridDesc_M_NBlock
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
ThreadBufferDimAccessOrder
,
1
,
1
,
1
,
true
>
(
mean_var_grid_desc_m_nblock
,
make_multi_index
(
block_work_idx
[
I0
]
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
auto
threadwise_var_load_m_nblock
=
ThreadwiseTensorSliceTransfer_v2
<
EMeanVarDataType
,
ComputeDataType
,
MeanVarGridDesc_M_NBlock
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
ThreadBufferDimAccessOrder
,
1
,
1
,
1
,
true
>
(
mean_var_grid_desc_m_nblock
,
make_multi_index
(
block_work_idx
[
I0
]
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
auto
threadwise_count_load_m_nblock
=
ThreadwiseTensorSliceTransfer_v2
<
int32_t
,
int32_t
,
CountGridDesc_M_NBlock
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
ThreadBufferDimAccessOrder
,
1
,
1
,
1
,
true
>
(
count_grid_desc_m_nblock
,
make_multi_index
(
block_work_idx
[
I0
]
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
auto
threadwise_e_load_m_n
=
ThreadwiseTensorSliceTransfer_v2
<
EMeanVarDataType
,
ComputeDataType
,
decltype
(
e_grid_desc_m_n
),
decltype
(
thread_buffer_desc_m_n
),
ThreadBufferLengths_M_N
,
ThreadBufferDimAccessOrder
,
1
,
// SrcVectorDim
ESrcVectorSize
,
1
,
true
>
(
e_grid_desc_m_n
,
make_multi_index
(
block_work_idx
[
I0
]
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_work_idx
[
I1
]
*
N_BlockTileSize
+
thread_n_cluster_id
*
NThreadSliceSize
));
auto
threadwise_gamma_load_n
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
ComputeDataType
,
decltype
(
gamma_grid_desc_n
),
decltype
(
thread_buffer_desc_n
),
ThreadBufferLengths_N
,
Sequence
<
0
>
,
// DimAccessOrder,
0
,
// SrcVectorDim,
GammaSrcVectorSize
,
1
,
true
>
(
gamma_grid_desc_n
,
make_multi_index
(
block_work_idx
[
I1
]
*
N_BlockTileSize
+
thread_n_cluster_id
*
NThreadSliceSize
));
auto
threadwise_beta_load_n
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
ComputeDataType
,
decltype
(
beta_grid_desc_n
),
decltype
(
thread_buffer_desc_n
),
ThreadBufferLengths_N
,
Sequence
<
0
>
,
// DimAccessOrder,
0
,
// SrcVectorDim,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_n
,
make_multi_index
(
block_work_idx
[
I1
]
*
N_BlockTileSize
+
thread_n_cluster_id
*
NThreadSliceSize
));
auto
threadwise_h_store_m_n
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
HDataType
,
decltype
(
thread_buffer_desc_m_n
),
decltype
(
h_grid_desc_m_n
),
HElementwiseOperation
,
ThreadBufferLengths_M_N
,
ThreadBufferDimAccessOrder
,
1
,
// DstVectorDim
HDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
h_grid_desc_m_n
,
make_multi_index
(
block_work_idx
[
I0
]
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_work_idx
[
I1
]
*
N_BlockTileSize
+
thread_n_cluster_id
*
NThreadSliceSize
),
h_element_op
);
// step1: Merge mean and variance
constexpr
auto
mean_var_count_thread_copy_step_I0_n
=
make_multi_index
(
I0
,
NThreadClusterSize
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
welford_mean_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
welford_var_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
welford_count_thread_buf
(
I
)
=
0
;
});
for
(
index_t
n
=
0
;
n
<
numMeanVarCountBlockTileIteration_N
;
++
n
)
{
threadwise_mean_load_m_nblock
.
Run
(
mean_var_grid_desc_m_nblock
,
welford_mean_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_welford_mean_thread_buf
);
threadwise_var_load_m_nblock
.
Run
(
mean_var_grid_desc_m_nblock
,
welford_var_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_welford_var_thread_buf
);
threadwise_count_load_m_nblock
.
Run
(
count_grid_desc_m_nblock
,
welford_count_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_welford_count_thread_buf
);
ThreadwiseWelford
::
Run
(
in_welford_mean_thread_buf
,
in_welford_var_thread_buf
,
in_welford_count_thread_buf
,
welford_mean_thread_buf
,
welford_var_thread_buf
,
welford_count_thread_buf
);
threadwise_mean_load_m_nblock
.
MoveSrcSliceWindow
(
mean_var_grid_desc_m_nblock
,
mean_var_count_thread_copy_step_I0_n
);
threadwise_var_load_m_nblock
.
MoveSrcSliceWindow
(
mean_var_grid_desc_m_nblock
,
mean_var_count_thread_copy_step_I0_n
);
threadwise_count_load_m_nblock
.
MoveSrcSliceWindow
(
count_grid_desc_m_nblock
,
mean_var_count_thread_copy_step_I0_n
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseWelford
::
Run
(
welford_mean_thread_buf
(
I
),
welford_var_thread_buf
(
I
),
welford_count_thread_buf
(
I
));
});
// step2: normalization
// h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n]
threadwise_e_load_m_n
.
Run
(
e_grid_desc_m_n
,
e_global_val_buf
,
thread_buffer_desc_m_n
,
make_tuple
(
I0
,
I0
),
e_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
m
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
welford_var_thread_buf
(
m
)
+
epsilon
);
static_for
<
0
,
NThreadSliceSize
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
m_n
=
thread_buffer_desc_m_n
.
CalculateOffset
(
make_tuple
(
m
,
n
));
h_thread_buf
(
Number
<
m_n
>
{})
=
(
e_thread_buf
(
Number
<
m_n
>
{})
-
welford_mean_thread_buf
(
m
))
*
divisor
;
});
});
threadwise_gamma_load_n
.
Run
(
gamma_grid_desc_n
,
gamma_global_val_buf
,
thread_buffer_desc_n
,
make_tuple
(
I0
),
gamma_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
NThreadSliceSize
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
m_n
=
thread_buffer_desc_m_n
.
CalculateOffset
(
make_tuple
(
m
,
n
));
h_thread_buf
(
Number
<
m_n
>
{})
=
h_thread_buf
(
Number
<
m_n
>
{})
*
gamma_thread_buf
(
n
);
});
});
threadwise_beta_load_n
.
Run
(
beta_grid_desc_n
,
beta_global_val_buf
,
thread_buffer_desc_n
,
make_tuple
(
I0
),
beta_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
NThreadSliceSize
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
m_n
=
thread_buffer_desc_m_n
.
CalculateOffset
(
make_tuple
(
m
,
n
));
h_thread_buf
(
Number
<
m_n
>
{})
=
h_thread_buf
(
Number
<
m_n
>
{})
+
beta_thread_buf
(
n
);
});
});
threadwise_h_store_m_n
.
Run
(
thread_buffer_desc_m_n
,
make_tuple
(
I0
,
I0
),
h_thread_buf
,
h_grid_desc_m_n
,
h_global_val_buf
);
}
// run
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
View file @
478df149
...
...
@@ -434,7 +434,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
__builtin_amdgcn_
sqrt
f
(
var_thread_buf
(
iM
)
+
epsilon
);
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
0 → 100644
View file @
478df149
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_wmma
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
// const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx1100__))
}
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
FloatCShuffle
,
typename
FloatC
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_wmma
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
a_block_desc_k0perblock_mperblock_k1
;
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
b_block_desc_k0perblock_nperblock_k1
;
}
__host__
__device__
static
constexpr
auto
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMRepeatPerShuffle
*
MWave
*
MPerWmma
>
{},
I1
,
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
{}));
return
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
b_block_space_size_aligned
*
sizeof
(
FloatB
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerWmma
*
MRepeat
)
==
0
)
&&
(
NPerBlock
%
(
NRepeat
*
NPerWmma
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
/* M01 */
,
index_t
/* N01 */
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
// clang-format off
/*******************************************************************************/
// Memory buffer zone.
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
/*******************************************************************************/
// BlockIdx.x -> [BlockId.m, BlockId.n]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
// Store BlockId into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
/* typename SrcElementwiseOperation, */
AElementwiseOperation
,
/* typename DstElementwiseOperation, */
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
/* InMemoryDataOperationEnum DstInMemOp, */
InMemoryDataOperationEnum
::
Set
,
/* typename BlockSliceLengths, */
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename SrcData, */
FloatA
,
/* typename DstData, */
FloatA
,
/* typename SrcDesc, */
decltype
(
a_grid_desc_k0_m_k1
),
/* typename DstDesc, */
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
/* typename SrcDimAccessOrder, */
ABlockTransferSrcAccessOrder
,
/* typename DstDimAccessOrder, */
Sequence
<
0
,
1
,
2
>
,
/* index_t SrcVectorDim, */
ABlockTransferSrcVectorDim
,
/* index_t DstVectorDim, */
2
,
/* index_t SrcScalarPerVector, */
ABlockTransferSrcScalarPerVector
,
/* index_t DstScalarPerVector, */
ABlockTransferDstScalarPerVector_K1
,
/* index_t SrcScalarStrideInVector, */
1
,
/* index_t DstScalarStrideInVector, */
1
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
AThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
>
(
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_k0perblock_mperblock_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatB
,
FloatB
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_k0perblock_nperblock_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
/*******************************************************************************/
// GEMM
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
auto
blockwise_gemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
<
BlockSize
,
FloatA
,
FloatB
,
FloatAcc
,
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
MPerWmma
,
NPerWmma
,
MRepeat
,
NRepeat
,
KPack
>
{};
// Prepare Register for C matrix
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
/*******************************************************************************/
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
*>
(
p_shared
),
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
());
// Shift Per SUB_K
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// gridwise GEMM pipeline
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_k0_m_k1
,
a_block_desc_k0perblock_mperblock_k1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_k0_n_k1
,
b_block_desc_k0perblock_nperblock_k1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
K0BlockMainLoop
);
/*******************************************************************************/
// write out to C, implement shuffle
{
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
blockwise_gemm
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
// This API Provide All dimension (size) you need
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
=
blockwise_gemm
.
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
constexpr
auto
MWave
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I1
);
constexpr
auto
MSubGroup
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I2
);
constexpr
auto
NWave
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I4
);
constexpr
auto
NThreadPerSubGroup
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I5
);
constexpr
auto
MAccVgprs
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I6
);
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatCShuffle
*>
(
p_shared
),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMRepeatPerShuffle
>
{},
// MRepeat per shuffle repeat
MWave
,
// MWave
MSubGroup
,
// MSubGroup * MAccVgprs = MPerWmma
MAccVgprs
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNRepeatPerShuffle
>
{},
// NRepeat per shuffle repeat
NWave
,
// NWave
NThreadPerSubGroup
))),
// NThreadPerSubGroup = NPerWmma
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
1
,
2
,
6
>
{},
Sequence
<>
{},
Sequence
<
3
,
4
,
5
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MRepeat
,
MWave
,
MSubGroup
,
MAccVgprs
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
NRepeat
,
NWave
,
NThreadPerSubGroup
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatCShuffle
,
decltype
(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
decltype
(
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMRepeatPerShuffle
,
I1
,
I1
,
CShuffleNRepeatPerShuffle
,
I1
,
I1
,
MAccVgprs
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
1
,
// vector write pixel
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
,
make_multi_index
(
0
,
m_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
0
,
n_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerWmma
,
1
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// space filling curve for local reg & global memory
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MRepeat
,
1
,
1
,
NRepeat
,
1
,
1
,
MAccVgprs
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
Sequence
<
CShuffleMRepeatPerShuffle
,
1
,
1
,
CShuffleNRepeatPerShuffle
,
1
,
1
,
MAccVgprs
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerWmma
,
1
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
}
// clang-format on
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp
View file @
478df149
...
...
@@ -319,7 +319,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
__builtin_amdgcn_
sqrt
f
(
var_thread_buf
(
iM
)
+
epsilon
);
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
0 → 100644
View file @
478df149
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/amd_wmma.hpp"
namespace
ck
{
enum
struct
WmmaInstr
{
wmma_f32_16x16x16_f16
=
0
,
wmma_f32_16x16x16_bf16
,
wmma_f16_16x16x16_f16
,
wmma_bf16_16x16x16_bf16
,
wmma_i32_16x16x16_iu8
,
wmma_i32_16x16x16_iu4
};
/*
* WMMA Wave Tile Always MxNxK = 16x16x16
* WAVE32
-----------------------------------
|RC0| | | | | | | | | | | | | | | | SubGroup 0
|RC1| | | | | | | | | | | | | | | |
|RC2| | | | | | | | | | | | | | | |
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
|RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|RC6| | | | | | | | | | | | | | | |
|RC7| | | | | | | | | | | | | | | |
-----------------------------------
| | | | | | | | | | | | | | | | | SubGroup 1
| | | | | | | | | | | | | | | | |
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
| | | | | | | | | | | | | | | | |
| | | | | | | | | | | | | | | | |
| | | | | | | | | | | | | | | | |
-----------------------------------
* WAVE64
-----------------------------------
|RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0
|RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
| | | | | | | | | | | | | | | | |
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2
| 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4|
| 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7|
| | | | | | | | | | | | | | | | |
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3
| 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6|
| 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3|
| | | | | | | | | | | | | | | | |
-----------------------------------
* RC = Register for storing accumalted result
* T = Thread ID
*/
template
<
WmmaInstr
Instr
,
index_t
WaveSize
,
typename
=
void
>
struct
wmma_type
{
};
// A-swizzled
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_f16
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
// * Data Pixel
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_f32_16x16x16_f16_w32
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
else
if
constexpr
(
wave_size
==
64
)
{
intrin_wmma_f32_16x16x16_f16_w64
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_bf16
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_f32_16x16x16_bf16_w32
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
else
if
constexpr
(
wave_size
==
64
)
{
intrin_wmma_f32_16x16x16_bf16_w64
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
#ifdef CK_UNPACKED_ACC_DESC_LOGIC
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f16_16x16x16_f16
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
2
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
Opsel
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_f16_16x16x16_f16_w32
<
MPerWmma
,
NPerWmma
,
Opsel
>::
Run
(
a
,
b
,
reg_c
);
}
else
if
constexpr
(
wave_size
==
64
)
{
intrin_wmma_f16_16x16x16_f16_w64
<
MPerWmma
,
NPerWmma
,
Opsel
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_bf16_16x16x16_bf16
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
2
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
Opsel
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_bf16_16x16x16_bf16_w32
<
MPerWmma
,
NPerWmma
,
Opsel
>::
Run
(
a
,
b
,
reg_c
);
}
else
if
constexpr
(
wave_size
==
64
)
{
intrin_wmma_bf16_16x16x16_bf16_w64
<
MPerWmma
,
NPerWmma
,
Opsel
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
#endif
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_i32_16x16x16_iu8
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
bool
neg_a
,
bool
neg_b
,
bool
clamp
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_i32_16x16x16_iu8_w32
<
MPerWmma
,
NPerWmma
,
neg_a
,
neg_b
,
clamp
>::
Run
(
a
,
b
,
reg_c
);
}
else
if
constexpr
(
wave_size
==
64
)
{
intrin_wmma_i32_16x16x16_iu8_w64
<
MPerWmma
,
NPerWmma
,
neg_a
,
neg_b
,
clamp
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
typename
src_type_a
,
typename
src_type_b
,
typename
dst_type
,
index_t
MPerWmma
,
index_t
NPerWmma
>
struct
WmmaSelector
{
template
<
typename
src_type_a_
,
typename
src_type_b_
,
typename
dst_type_
,
index_t
MPerWmma_
,
index_t
NPerWmma_
>
static
constexpr
auto
GetWmma
();
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_f32_16x16x16_f16
;
}
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_f32_16x16x16_bf16
;
}
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
half_t
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_f16_16x16x16_f16
;
}
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
bhalf_t
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_bf16_16x16x16_bf16
;
}
template
<
>
static
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_i32_16x16x16_iu8
;
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
static
constexpr
auto
GetWmma
<
int4_t
,
int
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_i32_16x16x16_iu4
;
}
#endif
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
static
constexpr
auto
selected_wmma
=
wmma_type
<
GetWmma
<
src_type_a
,
src_type_b
,
dst_type
,
MPerWmma
,
NPerWmma
>
(),
Number
<
32
>
{}
>
{};
__host__
__device__
constexpr
WmmaSelector
()
{
static_assert
(
selected_wmma
.
m_per_wmma
==
16
,
"WRONG! WMMA_M must equal to 16"
);
static_assert
(
selected_wmma
.
m_per_wmma
==
16
,
"WRONG! WMMA_M must equal to 16"
);
static_assert
(
selected_wmma
.
k_per_wmma
==
16
,
"WRONG! WMMA_M must equal to 16"
);
static_assert
(
selected_wmma
.
wave_size
*
selected_wmma
.
num_acc_vgprs_per_wave
*
selected_wmma
.
acc_data_size
==
selected_wmma
.
m_per_wmma
*
selected_wmma
.
n_per_wmma
*
4
,
"WRONG! Invalid Number of Accumulator Register"
);
}
};
template
<
typename
src_type_a
,
typename
src_type_b
,
typename
dst_type
,
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
KPack
,
bool
TransposeC
=
false
>
struct
WmmaGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex4D
=
MultiIndex
<
4
>
;
__host__
__device__
constexpr
WmmaGemm
()
{
static_assert
(
NPerWmma
==
16
&&
MPerWmma
==
16
,
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"
);
static_assert
(
KPack
==
wmma_instr
.
k_per_wmma
,
"KPack should be k_per_wmma"
);
}
// WMMA output supporting C = A * B
// Vector Write
// MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
template
<
typename
CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA
>
__host__
__device__
static
constexpr
auto
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA
&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
)
{
const
auto
MBlockxRepeat
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I0
);
const
auto
NBlockxRepeat
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I3
);
const
auto
MWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I1
);
const
auto
NWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I4
);
return
transform_tensor_descriptor
(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
,
make_tuple
(
make_pass_through_transform
(
MBlockxRepeat
),
make_pass_through_transform
(
MWave
),
make_unmerge_transform
(
make_tuple
(
Number
<
wmma_instr
.
num_subgroups
>
{},
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{})),
make_pass_through_transform
(
NBlockxRepeat
),
make_pass_through_transform
(
NWave
),
make_pass_through_transform
(
Number
<
wmma_instr
.
num_thread_per_subgroups
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
6
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
}
__device__
static
constexpr
index_t
GetRegSizePerWmma
()
{
return
wmma_instr
.
num_acc_vgprs_per_wave
;
}
__device__
static
constexpr
index_t
GetWaveSize
()
{
return
wmma_instr
.
wave_size
;
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
static_assert
(
(
is_same
<
src_type_a
,
half_t
>::
value
&&
is_same
<
src_type_b
,
half_t
>::
value
&&
is_same
<
dst_type
,
float
>::
value
)
||
(
is_same
<
src_type_a
,
bhalf_t
>::
value
&&
is_same
<
src_type_b
,
bhalf_t
>::
value
&&
is_same
<
dst_type
,
float
>::
value
)
||
(
is_same
<
src_type_a
,
half_t
>::
value
&&
is_same
<
src_type_b
,
half_t
>::
value
&&
is_same
<
dst_type
,
half_t
>::
value
)
||
(
is_same
<
src_type_a
,
bhalf_t
>::
value
&&
is_same
<
src_type_b
,
bhalf_t
>::
value
&&
is_same
<
dst_type
,
bhalf_t
>::
value
)
||
(
is_same
<
src_type_a
,
int8_t
>::
value
&&
is_same
<
src_type_b
,
int8_t
>::
value
&&
is_same
<
dst_type
,
int32_t
>::
value
)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
(
is_same
<
src_type_a
,
int4_t
>::
value
&&
is_same
<
src_type_b
,
int4_t
>::
value
&&
is_same
<
dst_type
,
int32_t
>::
value
)
#endif
,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"(int8, int32) or (int4, int32)!"
);
if
constexpr
(
!
TransposeC
)
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_a_wave
,
p_b_wave
,
p_c_thread
);
}
else
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_b_wave
,
p_a_wave
,
p_c_thread
);
}
}
__device__
static
auto
GetLaneId
()
{
return
get_thread_local_1d_id
()
%
wmma_instr
.
wave_size
;
}
__device__
static
auto
GetSubGroupId
()
{
return
(
GetLaneId
()
/
wmma_instr
.
num_thread_per_subgroups
)
%
wmma_instr
.
num_subgroups
;
}
__device__
static
auto
GetLaneIdUnderSubGroup
()
{
return
GetLaneId
()
%
wmma_instr
.
num_thread_per_subgroups
;
}
__device__
static
auto
GetSwizzledLaneIdLow
()
{
return
((
GetLaneIdUnderSubGroup
()
&
1
)
<<
3
)
|
(
GetLaneIdUnderSubGroup
()
>>
1
);
}
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
return
GetSwizzledLaneIdLow
();
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
return
GetLaneIdUnderSubGroup
();
}
__device__
static
CIndex
GetBeginOfThreadBlk
()
{
index_t
n_offset
=
GetLaneIdUnderSubGroup
();
index_t
m_offset
=
GetSubGroupId
()
*
wmma_instr
.
num_acc_vgprs_per_wave
;
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
}
static
constexpr
auto
wmma
=
WmmaSelector
<
src_type_a
,
src_type_b
,
dst_type
,
MPerWmma
,
NPerWmma
>
{};
static
constexpr
auto
wmma_instr
=
wmma
.
selected_wmma
;
__host__
__device__
static
constexpr
auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
()
{
return
make_tuple
(
I1
,
I1
,
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{});
}
};
}
// namespace ck
include/ck/utility/amd_inline_asm.hpp
View file @
478df149
...
...
@@ -355,5 +355,11 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
c3
);
}
// Ranged input operand
__device__
void
amd_assembly_wmma_f32_16x16x16_f16_w32
(
half16_t
a
,
half16_t
b
,
float8_t
&
c
)
{
asm
volatile
(
"v_wmma_f32_16x16x16_f16 %0, %1, %2, %0"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
}
}
// namespace ck
#endif
include/ck/utility/amd_wmma.hpp
View file @
478df149
...
...
@@ -4,11 +4,13 @@
#ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP
#include "ck/utility/amd_inline_asm.hpp"
#include "data_type.hpp"
// TODO: Add arch limitation
namespace
ck
{
// wave32 only
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w32
;
...
...
@@ -19,8 +21,13 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) =
// __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template
// AsType<float8_t>()[Number<0>{}]);
}
};
...
...
@@ -98,5 +105,95 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
}
};
/********************************WAVE64 MODE***********************************************/
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w64
;
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w64
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
}
};
// src: bf16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_bf16_w64
;
template
<
>
struct
intrin_wmma_f32_16x16x16_bf16_w64
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
}
};
// src: fp16, dst: fp16
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
Opsel
>
struct
intrin_wmma_f16_16x16x16_f16_w64
;
template
<
index_t
Opsel
>
struct
intrin_wmma_f16_16x16x16_f16_w64
<
16
,
16
,
Opsel
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
half8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}],
Opsel
);
}
};
// src: bf16, dst: bf16
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
Opsel
>
struct
intrin_wmma_bf16_16x16x16_bf16_w64
;
template
<
index_t
Opsel
>
struct
intrin_wmma_bf16_16x16x16_bf16_w64
<
16
,
16
,
Opsel
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
bhalf8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}],
Opsel
);
}
};
// src: iu8, dst: i32
template
<
index_t
MPerWave
,
index_t
NPerWave
,
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w64
;
template
<
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w64
<
16
,
16
,
neg_a
,
neg_b
,
clamp
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64
(
neg_a
,
bit_cast
<
int32x4_t
>
(
reg_a
),
neg_b
,
bit_cast
<
int32x4_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
clamp
);
}
};
}
// namespace ck
#endif
include/ck/utility/math_v2.hpp
View file @
478df149
...
...
@@ -3,7 +3,9 @@
#pragma once
#ifndef __HIP_DEVICE_COMPILE__
#include <cmath>
#endif
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
...
...
include/ck/utility/reduction_operator.hpp
View file @
478df149
...
...
@@ -251,27 +251,27 @@ constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum o
};
template
<
InMemoryDataOperationEnum
Operation
,
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
struct
InMemoryDataOperat
i
onSupportedOnDataType
{
static
constexpr
bool
value
=
false
;
};
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
AtomicAdd
,
DataType
>
struct
InMemoryDataOperat
i
onSupportedOnDataType
<
InMemoryDataOperationEnum
::
AtomicAdd
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
;
};
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
AtomicMax
,
DataType
>
struct
InMemoryDataOperat
i
onSupportedOnDataType
<
InMemoryDataOperationEnum
::
AtomicMax
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
;
};
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
Set
,
DataType
>
struct
InMemoryDataOperat
i
onSupportedOnDataType
<
InMemoryDataOperationEnum
::
Set
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
...
...
@@ -280,7 +280,7 @@ struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Set, D
};
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
Add
,
DataType
>
struct
InMemoryDataOperat
i
onSupportedOnDataType
<
InMemoryDataOperationEnum
::
Add
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
View file @
478df149
...
...
@@ -90,10 +90,13 @@ struct ReferenceLayernorm : public device::BaseOperator
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
AccDataType
divisor
=
static_cast
<
AccDataType
>
(
1
)
/
ck
::
math
::
sqrt
(
var
(
m
)
+
arg
.
epsilon_
);
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
auto
x_val
=
ck
::
type_convert
<
AccDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
auto
y_val
=
(
x_val
-
mean
(
m
))
/
sqrt
(
var
(
m
)
+
arg
.
epsilon_
)
;
auto
y_val
=
(
x_val
-
mean
(
m
))
*
divisor
;
y_val
=
(
y_val
*
arg
.
gamma_n_
(
n
))
+
arg
.
beta_n_
(
n
);
arg
.
acc_elementwise_op_
(
y_val
,
y_val
);
arg
.
y_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
YDataType
>
(
y_val
);
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_reduce.hpp
0 → 100644
View file @
478df149
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <array>
#include <algorithm>
#include <thread>
#include "ck/ck.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
,
bool
PropagateNan
,
bool
OutputIndex
>
struct
ReferenceReduce
:
public
device
::
DeviceReduce
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
OutputIndex
>
{
using
IndexDataType
=
int32_t
;
static
constexpr
int
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
NumSrcDim
=
Rank
;
static
constexpr
index_t
NumDstDim
=
(
NumInvariantDim
==
0
)
?
1
:
NumInvariantDim
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
Rank
>
inLengths
,
const
std
::
array
<
index_t
,
Rank
>
inStrides
,
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
const
InDataType
*
in_host
,
OutDataType
*
out_host
,
IndexDataType
*
out_index_host
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
)
:
reduceDims_
(
reduceDims
),
outLengths_
(
outLengths
),
outStrides_
(
outStrides
),
in_host_
(
in_host
),
out_host_
(
out_host
),
out_index_host_
(
out_index_host
),
in_elementwise_op_
(
in_elementwise_op
),
acc_elementwise_op_
(
acc_elementwise_op
)
{
using
ck
::
host_common
::
get_index_set
;
if
(
std
::
any_of
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
[](
int
d
)
{
return
d
<
0
||
d
>=
Rank
;
}))
throw
std
::
runtime_error
(
"Invalid reduce dimensions!"
);
if
constexpr
(
NumInvariantDim
>
0
)
{
// get invariant_dims[] and invariant_lengths[]
for
(
int
dim
=
0
,
i
=
0
;
dim
<
Rank
;
dim
++
)
if
(
std
::
none_of
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
[
&
](
int
d
)
{
return
d
==
dim
;
}))
{
invariantDims_
[
i
]
=
dim
;
invariant_lengths_
[
i
]
=
inLengths
[
dim
];
i
++
;
};
};
// get reduce_lengths_[]
for
(
int
j
=
0
,
i
=
0
;
j
<
NumReduceDim
;
j
++
)
{
int
dim
=
reduceDims
[
j
];
reduce_lengths_
[
i
++
]
=
inLengths
[
dim
];
};
if
constexpr
(
NumInvariantDim
>
0
)
{
// check invariant_lengths_ and outLengths
for
(
int
i
=
0
;
i
<
NumInvariantDim
;
i
++
)
if
(
invariant_lengths_
[
i
]
!=
outLengths_
[
i
])
throw
std
::
runtime_error
(
"Invalid lengths parameters!"
);
}
if
constexpr
(
NumInvariantDim
>
0
)
{
for
(
int
j
=
0
,
i
=
0
;
j
<
NumInvariantDim
;
j
++
)
{
int
dim
=
invariantDims_
[
j
];
in_invariant_strides_
[
i
]
=
inStrides
[
dim
];
i
++
;
};
};
for
(
int
j
=
0
,
i
=
0
;
j
<
NumReduceDim
;
j
++
)
{
int
dim
=
reduceDims_
[
j
];
in_reduce_strides_
[
i
]
=
inStrides
[
dim
];
i
++
;
};
if
constexpr
(
NumInvariantDim
>
0
)
invariant_index_set_
=
get_index_set
<
NumInvariantDim
>
(
invariant_lengths_
);
reduce_index_set_
=
get_index_set
<
NumReduceDim
>
(
reduce_lengths_
);
alpha_
=
type_convert
<
AccDataType
>
(
alpha
);
beta_
=
type_convert
<
AccDataType
>
(
beta
);
};
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims_
;
std
::
array
<
int
,
NumInvariantDim
>
invariantDims_
;
std
::
array
<
index_t
,
NumInvariantDim
>
invariant_lengths_
;
std
::
array
<
index_t
,
NumReduceDim
>
reduce_lengths_
;
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths_
;
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
in_invariant_strides_
;
std
::
array
<
index_t
,
NumReduceDim
>
in_reduce_strides_
;
const
InDataType
*
in_host_
;
OutDataType
*
out_host_
;
IndexDataType
*
out_index_host_
;
const
InElementwiseOperation
in_elementwise_op_
;
const
AccElementwiseOperation
acc_elementwise_op_
;
AccDataType
alpha_
;
AccDataType
beta_
;
std
::
vector
<
std
::
array
<
index_t
,
NumInvariantDim
>>
invariant_index_set_
;
std
::
vector
<
std
::
array
<
index_t
,
NumReduceDim
>>
reduce_index_set_
;
};
struct
Invoker
:
public
device
::
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
ignore
=
stream_config
;
using
ck
::
float_equal_one
;
using
ck
::
float_equal_zero
;
using
ck
::
type_convert
;
using
ck
::
host_common
::
get_index_set
;
using
ck
::
host_common
::
get_offset_from_index
;
if
constexpr
(
OutputIndex
)
{
using
Accumulation
=
ck
::
detail
::
AccumulateWithIndexAndNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
,
IndexDataType
>
;
if
constexpr
(
NumInvariantDim
==
0
)
{
AccDataType
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>();
IndexDataType
accuIndex
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
reduce_index_set_
.
size
();
i
++
)
{
auto
in_offset
=
get_offset_from_index
<
NumReduceDim
>
(
arg
.
in_reduce_strides_
,
arg
.
reduce_index_set_
[
i
]);
auto
currVal
=
type_convert
<
AccDataType
>
(
arg
.
in_host_
[
in_offset
]);
arg
.
in_elementwise_op_
(
currVal
,
currVal
);
auto
currIndex
=
static_cast
<
IndexDataType
>
(
i
);
Accumulation
::
Calculate
(
accuVal
,
currVal
,
accuIndex
,
currIndex
);
};
arg
.
acc_elementwise_op_
(
accuVal
,
accuVal
);
if
(
!
float_equal_one
{}(
arg
.
alpha_
))
accuVal
*=
type_convert
<
AccDataType
>
(
arg
.
alpha_
);
if
(
!
float_equal_zero
{}(
arg
.
beta_
))
accuVal
+=
type_convert
<
AccDataType
>
(
arg
.
out_host_
[
0
])
*
type_convert
<
AccDataType
>
(
arg
.
beta_
);
arg
.
out_host_
[
0
]
=
type_convert
<
OutDataType
>
(
accuVal
);
arg
.
out_index_host_
[
0
]
=
accuIndex
;
}
else
{
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
AccDataType
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>();
IndexDataType
accuIndex
=
0
;
auto
in_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
in_invariant_strides_
,
invariant_index
);
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
reduce_index_set_
.
size
();
i
++
)
{
auto
in_reduce_offset
=
get_offset_from_index
<
NumReduceDim
>
(
arg
.
in_reduce_strides_
,
arg
.
reduce_index_set_
[
i
]);
auto
currVal
=
type_convert
<
AccDataType
>
(
arg
.
in_host_
[
in_invariant_offset
+
in_reduce_offset
]);
arg
.
in_elementwise_op_
(
currVal
,
currVal
);
auto
currIndex
=
static_cast
<
IndexDataType
>
(
i
);
Accumulation
::
Calculate
(
accuVal
,
currVal
,
accuIndex
,
currIndex
);
};
arg
.
acc_elementwise_op_
(
accuVal
,
accuVal
);
if
(
!
float_equal_one
{}(
arg
.
alpha_
))
accuVal
*=
type_convert
<
AccDataType
>
(
arg
.
alpha_
);
auto
dst_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
outStrides_
,
invariant_index
);
if
(
!
float_equal_zero
{}(
arg
.
beta_
))
accuVal
+=
type_convert
<
AccDataType
>
(
arg
.
out_host_
[
dst_offset
])
*
type_convert
<
AccDataType
>
(
arg
.
beta_
);
arg
.
out_host_
[
dst_offset
]
=
type_convert
<
OutDataType
>
(
accuVal
);
arg
.
out_index_host_
[
dst_offset
]
=
accuIndex
;
};
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
work_per_thread
=
(
arg
.
invariant_index_set_
.
size
()
+
num_thread
-
1
)
/
num_thread
;
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
i_begin
=
it
*
work_per_thread
;
std
::
size_t
i_end
=
std
::
min
((
it
+
1
)
*
work_per_thread
,
arg
.
invariant_index_set_
.
size
());
auto
f
=
[
=
]
{
for
(
std
::
size_t
i
=
i_begin
;
i
<
i_end
;
i
++
)
{
thread_reduce_func
(
arg
.
invariant_index_set_
[
i
]);
}
};
threads
[
it
]
=
joinable_thread
(
f
);
}
};
}
else
{
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
>
;
if
constexpr
(
NumInvariantDim
==
0
)
{
AccDataType
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>();
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
auto
in_offset
=
get_offset_from_index
<
NumReduceDim
>
(
arg
.
in_reduce_strides_
,
reduce_index
);
auto
currVal
=
type_convert
<
AccDataType
>
(
arg
.
in_host_
[
in_offset
]);
arg
.
in_elementwise_op_
(
currVal
,
currVal
);
Accumulation
::
Calculate
(
accuVal
,
currVal
);
};
arg
.
acc_elementwise_op_
(
accuVal
,
accuVal
);
if
(
!
float_equal_one
{}(
arg
.
alpha_
))
accuVal
*=
type_convert
<
AccDataType
>
(
arg
.
alpha_
);
if
(
!
float_equal_zero
{}(
arg
.
beta_
))
accuVal
+=
type_convert
<
AccDataType
>
(
arg
.
out_host_
[
0
])
*
type_convert
<
AccDataType
>
(
arg
.
beta_
);
arg
.
out_host_
[
0
]
=
type_convert
<
OutDataType
>
(
accuVal
);
}
else
{
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
AccDataType
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>();
auto
in_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
in_invariant_strides_
,
invariant_index
);
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
auto
in_reduce_offset
=
get_offset_from_index
<
NumReduceDim
>
(
arg
.
in_reduce_strides_
,
reduce_index
);
auto
currVal
=
type_convert
<
AccDataType
>
(
arg
.
in_host_
[
in_invariant_offset
+
in_reduce_offset
]);
arg
.
in_elementwise_op_
(
currVal
,
currVal
);
Accumulation
::
Calculate
(
accuVal
,
currVal
);
};
arg
.
acc_elementwise_op_
(
accuVal
,
accuVal
);
if
(
!
float_equal_one
{}(
arg
.
alpha_
))
accuVal
*=
type_convert
<
AccDataType
>
(
arg
.
alpha_
);
auto
dst_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
outStrides_
,
invariant_index
);
if
(
!
float_equal_zero
{}(
arg
.
beta_
))
accuVal
+=
type_convert
<
AccDataType
>
(
arg
.
out_host_
[
dst_offset
])
*
type_convert
<
AccDataType
>
(
arg
.
beta_
);
arg
.
out_host_
[
dst_offset
]
=
type_convert
<
OutDataType
>
(
accuVal
);
};
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
work_per_thread
=
(
arg
.
invariant_index_set_
.
size
()
+
num_thread
-
1
)
/
num_thread
;
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
i_begin
=
it
*
work_per_thread
;
std
::
size_t
i_end
=
std
::
min
((
it
+
1
)
*
work_per_thread
,
arg
.
invariant_index_set_
.
size
());
auto
f
=
[
=
]
{
for
(
std
::
size_t
i
=
i_begin
;
i
<
i_end
;
i
++
)
{
thread_reduce_func
(
arg
.
invariant_index_set_
[
i
]);
}
};
threads
[
it
]
=
joinable_thread
(
f
);
}
};
};
return
(
0.0
f
);
};
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
};
};
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
p_arg
)
override
{
ignore
=
p_arg
;
return
true
;
};
std
::
unique_ptr
<
device
::
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
inLengths
,
const
std
::
array
<
index_t
,
Rank
>
inStrides
,
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
const
void
*
in_host
,
const
void
*
in_index_host
,
void
*
out_host
,
void
*
out_index_host
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
)
override
{
ignore
=
in_index_host
;
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
outLengths
,
outStrides
,
reduceDims
,
alpha
,
beta
,
static_cast
<
const
InDataType
*>
(
in_host
),
static_cast
<
OutDataType
*>
(
out_host
),
static_cast
<
IndexDataType
*>
(
out_index_host
),
in_elementwise_op
,
acc_elementwise_op
);
};
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"Reference_Reduce<"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp
View file @
478df149
...
...
@@ -76,8 +76,16 @@ template <typename InDataType,
bool
PropagateNan
,
bool
OutputIndex
>
void
add_device_reduce_instance_blockwise
(
std
::
vector
<
DeviceReducePtr
<
Rank
,
NumReduceDim
,
InElementwiseOp
,
AccElementwiseOp
>>&
device_op_instances
)
std
::
vector
<
DeviceReducePtr
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOp
,
AccElementwiseOp
,
PropagateNan
,
OutputIndex
>>&
device_op_instances
)
{
static_for
<
0
,
std
::
tuple_size
<
reduce_configuration_1_instances_blockwise
>::
value
,
1
>
{}(
[
&
](
auto
i
)
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp
View file @
478df149
...
...
@@ -15,10 +15,10 @@ namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
3
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
4
,
3
,
PassThrough
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
4
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
4
,
4
,
PassThrough
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
1
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
4
,
1
,
PassThrough
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
2
,
1
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
2
,
1
,
PassThrough
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
3
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
4
,
3
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
4
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
4
,
4
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
1
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
4
,
1
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
2
,
1
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
2
,
1
,
ReduceAdd
,
PassThrough
,
PassThrough
,
false
,
false
>>&
);
// clang-format on
}
// namespace instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp
View file @
478df149
...
...
@@ -15,14 +15,14 @@ namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
4
,
3
,
UnaryAbs
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
4
,
4
,
UnaryAbs
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
4
,
1
,
UnaryAbs
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
2
,
1
,
UnaryAbs
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
4
,
3
,
UnaryAbs
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
4
,
4
,
UnaryAbs
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
4
,
1
,
UnaryAbs
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
2
,
1
,
UnaryAbs
,
PassThrough
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
// clang-format on
}
// namespace instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp
View file @
478df149
...
...
@@ -15,10 +15,10 @@ namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
3
,
ReduceAdd
,
PassThrough
,
UnaryDivide
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
4
,
3
,
PassThrough
,
UnaryDivide
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
4
,
ReduceAdd
,
PassThrough
,
UnaryDivide
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
4
,
4
,
PassThrough
,
UnaryDivide
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
1
,
ReduceAdd
,
PassThrough
,
UnaryDivide
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
4
,
1
,
PassThrough
,
UnaryDivide
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
2
,
1
,
ReduceAdd
,
PassThrough
,
UnaryDivide
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
2
,
1
,
PassThrough
,
UnaryDivide
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
3
,
ReduceAdd
,
PassThrough
,
UnaryDivide
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
4
,
3
,
ReduceAdd
,
PassThrough
,
UnaryDivide
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
4
,
ReduceAdd
,
PassThrough
,
UnaryDivide
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
4
,
4
,
ReduceAdd
,
PassThrough
,
UnaryDivide
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
4
,
1
,
ReduceAdd
,
PassThrough
,
UnaryDivide
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
4
,
1
,
ReduceAdd
,
PassThrough
,
UnaryDivide
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
BF16
,
F32
,
BF16
,
2
,
1
,
ReduceAdd
,
PassThrough
,
UnaryDivide
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
BF16
,
F32
,
BF16
,
2
,
1
,
ReduceAdd
,
PassThrough
,
UnaryDivide
,
false
,
false
>>&
);
// clang-format on
}
// namespace instance
...
...
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment