Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0629870d
Commit
0629870d
authored
Aug 23, 2023
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
48fe8532
d4ad52d6
Changes
151
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
247 additions
and
182 deletions
+247
-182
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+9
-12
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+12
-16
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
...gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp
.../grid/normalization/gridwise_normalization_splitk_1st.hpp
+3
-3
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+3
-0
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
...operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
+114
-86
include/ck/version.h.in
include/ck/version.h.in
+40
-0
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+1
-0
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp
...ck/library/tensor_operation_instance/gpu/batched_gemm.hpp
+8
-8
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp
...operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_permute.hpp
...nsor_operation_instance/gpu/batched_gemm_bias_permute.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp
...n_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
...brary/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp
...ry/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
...nsor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
...ration_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp
...ry/tensor_operation_instance/gpu/contraction_bilinear.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp
...brary/tensor_operation_instance/gpu/contraction_scale.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
...nsor_operation_instance/gpu/convolution_backward_data.hpp
+24
-24
library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp
...ary/tensor_operation_instance/gpu/convolution_forward.hpp
+7
-7
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
0629870d
...
...
@@ -142,8 +142,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
...
...
@@ -323,13 +323,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
}
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
CGridDesc_M_N
{}))
>
;
using
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
C0GridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
...
...
@@ -654,12 +654,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
FloatC
,
// typename Src0Data,
FloatC
,
// typename Src1Data,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename DimAccessOrder,
5
,
// index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
// index_t ScalarPerVector,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
0629870d
...
...
@@ -151,8 +151,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
...
...
@@ -331,18 +331,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
c_grid_desc_m_n
);
}
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
CGridDesc_M_N
{}))
>
;
using
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
C0GridDesc_M_N
{}))
>
;
using
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
C1GridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
...
...
@@ -674,14 +674,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
FloatC
,
// typename Src1Data,
FloatC
,
// typename Src2Data,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
decltype
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
),
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename DimAccessOrder,
5
,
// index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
// index_t ScalarPerVector,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
View file @
0629870d
...
...
@@ -78,8 +78,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
using
ThreadwiseWolfordDesc2D
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
DimSubBlocks
*
DimThreadSize
>
{},
Number
<
RowSubBlocks
*
RowVectorSize
>
{})));
using
ThreadwiseWolfordDescReduce
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
DimSubBlocks
*
DimThreadSize
>
{})));
using
ThreadwiseWolfordDescReduce
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
DimSubBlocks
*
DimThreadSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelford
<
AccDataType
,
ThreadwiseWolfordDesc2D
,
ThreadwiseWolfordDescReduce
>
;
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp
View file @
0629870d
...
...
@@ -87,9 +87,9 @@ struct GridwiseNormalizationSplitK1st
int
left_kPerBlock
=
math
::
integer_divide_ceil
(
k
,
kGridSize
);
int
kRightmostBlock
=
kRaw
-
left_kPerBlock
*
(
kGridSize
-
1
);
int
kPerThread
=
kRightmostBlock
<
K_BlockTileSize
?
0
:
KThreadSliceSize
*
(
kRightmostBlock
/
K_BlockTileSize
);
int
kPerBlockTail
=
kRightmostBlock
-
kPerThread
*
KThreadClusterSize
;
?
0
:
KThreadSliceSize
*
(
kRightmostBlock
/
K_BlockTileSize
);
int
kPerBlockTail
=
kRightmostBlock
-
kPerThread
*
KThreadClusterSize
;
if
(
kPerBlockTail
>
0
)
{
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
0629870d
...
...
@@ -129,6 +129,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
static_assert
(
SliceLengths
::
At
(
SrcVectorDim
)
%
SrcScalarPerVector
==
0
,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"
);
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_src_access_lengths
=
...
...
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
View file @
0629870d
...
...
@@ -236,8 +236,6 @@ struct TransformConvBwdDataToGemm_v1
const
index_t
ConvDilationH
=
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
AK0
=
K
/
AK1
;
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
const
auto
out_grid_desc
=
make_out_grid_desc
<
NDimSpatial
,
ALayout
,
ConvBwdDataSpecialization
>
(
...
...
@@ -247,6 +245,8 @@ struct TransformConvBwdDataToGemm_v1
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
AK0
=
math
::
integer_divide_ceil
(
K
,
AK1
);
// A: output tensor
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
...
...
@@ -308,6 +308,9 @@ struct TransformConvBwdDataToGemm_v1
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
index_t
AK0
=
math
::
integer_divide_ceil
(
ZDotSlice
*
YDotSlice
*
XDotSlice
*
K
,
AK1
);
if
constexpr
(
NDimSpatial
==
2
)
{
// A: output tensor
...
...
@@ -332,7 +335,7 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_
ak0_ak1
_grid_desc
=
const
auto
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_
k
_grid_desc
=
transform_tensor_descriptor
(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
...
...
@@ -340,7 +343,7 @@ struct TransformConvBwdDataToGemm_v1
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_
unmerge_transform
(
make_tuple
(
AK0
,
AK1
)
)),
make_
pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
...
...
@@ -352,21 +355,28 @@ struct TransformConvBwdDataToGemm_v1
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{}));
Sequence
<
5
>
{}));
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
AK0
)),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
AK1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
out_gemmk_gemmmraw_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K
)),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemm
ak0
_gemmm_
gemmak1
_grid_desc
=
const
auto
out_gemm
k
_gemmm_
padded
_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmak0_gemmmraw_gemmak1_grid_desc
,
make_tuple
(
AK0
,
GemmMPerBlock
,
AK1
),
Sequence
<
false
,
DoPadGemmM
,
false
>
{});
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
AK1
,
GemmMPerBlock
),
Sequence
<
true
,
DoPadGemmM
>
{});
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
out_gemmak0_gemmm_gemmak1_grid_desc
;
}
...
...
@@ -411,7 +421,7 @@ struct TransformConvBwdDataToGemm_v1
Sequence
<
7
>
{}));
const
auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_
ak0_ak1
_grid_desc
=
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_
k
_grid_desc
=
transform_tensor_descriptor
(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
...
...
@@ -421,7 +431,7 @@ struct TransformConvBwdDataToGemm_v1
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_
unmerge_transform
(
make_tuple
(
AK0
,
AK1
)
)),
make_
pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
...
...
@@ -437,22 +447,29 @@ struct TransformConvBwdDataToGemm_v1
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
,
8
>
{}));
Sequence
<
7
>
{}));
const
auto
out_gemm
ak0
_gemmmraw_
gemmak1_
grid_desc
=
transform_tensor_descriptor
(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_
ak0_ak1
_grid_desc
,
const
auto
out_gemm
k
_gemmmraw_grid_desc
=
transform_tensor_descriptor
(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_
k
_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
AK0
)),
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
AK1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
8
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K
)),
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemm
ak0
_gemmm_
gemmak1
_grid_desc
=
const
auto
out_gemm
k
_gemmm_
padded
_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmak0_gemmmraw_gemmak1_grid_desc
,
make_tuple
(
AK0
,
GemmMPerBlock
,
AK1
),
Sequence
<
false
,
DoPadGemmM
,
false
>
{});
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
AK1
,
GemmMPerBlock
),
Sequence
<
true
,
DoPadGemmM
>
{});
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
out_gemmak0_gemmm_gemmak1_grid_desc
;
}
...
...
@@ -505,8 +522,6 @@ struct TransformConvBwdDataToGemm_v1
const
index_t
ConvDilationH
=
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
BK0
=
K
/
BK1
;
// assume packed
// k_y_x_c for 2d or k_z_y_x_c for 3d
const
auto
wei_grid_desc
=
make_wei_grid_desc
<
BLayout
>
(
K
,
Z
,
Y
,
X
,
C
);
...
...
@@ -515,6 +530,8 @@ struct TransformConvBwdDataToGemm_v1
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
BK0
=
math
::
integer_divide_ceil
(
K
,
BK1
);
// B: weight tensor
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
...
...
@@ -551,6 +568,9 @@ struct TransformConvBwdDataToGemm_v1
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
index_t
BK0
=
math
::
integer_divide_ceil
(
ZDotSlice
*
YDotSlice
*
XDotSlice
*
K
,
BK1
);
// B weight tensor
if
constexpr
(
NDimSpatial
==
2
)
{
...
...
@@ -566,43 +586,47 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ytilde
),
make_freeze_transform
(
i_xtilde
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
const
auto
wei_k_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ytilde
),
make_freeze_transform
(
i_xtilde
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
3
>
{}));
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
BK0
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
BK1
)),
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
wei_gemmk_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
,
2
,
0
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemm
bk0
_gemmn_
gemmbk1
_grid_desc
=
const
auto
wei_gemm
k
_gemmn_
padded
_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
,
make_tuple
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
.
GetLength
(
I0
),
GemmNPerBlock
,
BK1
),
Sequence
<
false
,
DoPadGemmN
,
false
>
{});
wei_gemmk_gemmnraw_grid_desc
,
make_tuple
(
BK1
,
GemmNPerBlock
),
Sequence
<
true
,
DoPadGemmN
>
{});
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmn_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
wei_gemmk_gemmn_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
wei_gemmbk0_gemmn_gemmbk1_grid_desc
;
}
...
...
@@ -631,10 +655,10 @@ struct TransformConvBwdDataToGemm_v1
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
wei_
bk0_bk1
_zdotslice_ydotslice_xdotslice_c_grid_desc
=
const
auto
wei_
gemmk
_zdotslice_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_
unmerge_transform
(
make_tuple
(
BK0
,
BK1
)
),
make_tuple
(
make_
pass_through_transform
(
K
),
make_slice_transform
(
ZDot
,
I0
,
ZDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
...
...
@@ -650,33 +674,37 @@ struct TransformConvBwdDataToGemm_v1
Sequence
<
4
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
5
>
{}));
Sequence
<
4
>
{}));
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_bk0_bk1_zdotslice_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
BK0
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
BK1
)),
make_tuple
(
Sequence
<
2
,
3
,
4
,
0
>
{},
Sequence
<
5
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
wei_gemmk_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
,
2
,
3
,
0
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemm
bk0
_gemm
n_gemmbk1
_grid_desc
=
const
auto
wei_gemm
k
_gemm
_padded
_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
,
make_tuple
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
.
GetLength
(
I0
),
GemmNPerBlock
,
BK1
),
Sequence
<
false
,
DoPadGemmN
,
false
>
{});
wei_gemmk_gemmnraw_grid_desc
,
make_tuple
(
BK1
,
GemmNPerBlock
),
Sequence
<
true
,
DoPadGemmN
>
{});
return
wei_gemmbk0_gemmn_gemmbk1_grid_desc
;
const
auto
wei_gemmbk0_gemm_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
wei_gemmk_gemm_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
wei_gemmbk0_gemm_gemmbk1_grid_desc
;
}
else
{
...
...
include/ck/version.h.in
0 → 100644
View file @
0629870d
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2023 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
/* the configured version and settings for miopen- Composable Kernel */
#ifndef CK_VERSION_H_
#define CK_VERSION_H_
// clang-format off
#define CK_VERSION @CMAKE_PROJECT_VERSION@
#define CK_VERSION_MAJOR @CMAKE_PROJECT_VERSION_MAJOR@
#define CK_VERSION_MINOR @CMAKE_PROJECT_VERSION_MINOR@
#define CK_VERSION_PATCH @CMAKE_PROJECT_VERSION_PATCH@
#define CK_COMMIT_ID @COMMIT_ID@
// clang-format on
#endif
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
0629870d
...
...
@@ -17,6 +17,7 @@ namespace instance {
using
F64
=
double
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F8
=
ck
::
f8_t
;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp
View file @
0629870d
...
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef
__bf16__
#ifdef
CK_ENABLE_BF16
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -37,7 +37,7 @@ void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -58,7 +58,7 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__fp32__
#ifdef
CK_ENABLE_FP32
void
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -79,7 +79,7 @@ void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__int8__
#ifdef
CK_ENABLE_INT8
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
...
...
@@ -154,7 +154,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef
__fp32__
#ifdef
CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
is_same_v
<
CDataType
,
float
>
)
{
...
...
@@ -180,7 +180,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
}
}
#endif
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
...
...
@@ -206,7 +206,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
}
}
#endif
#ifdef
__bf16__
#ifdef
CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
...
...
@@ -232,7 +232,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
}
}
#endif
#ifdef
__int8__
#ifdef
CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
int8_t
>
)
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp
View file @
0629870d
...
...
@@ -14,7 +14,7 @@
using
CDE0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
using
CDE1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_permute.hpp
View file @
0629870d
...
...
@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp
View file @
0629870d
...
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
void
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
...
...
@@ -59,7 +59,7 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
#endif
#ifdef
__bf16__
#ifdef
CK_ENABLE_BF16
void
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
...
...
@@ -148,7 +148,7 @@ struct DeviceOperationInstanceFactory<
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
&&
Acc0BiasDataType
::
Size
()
==
1
&&
...
...
@@ -166,7 +166,7 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
#ifdef
__bf16__
#ifdef
CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
ADataType
,
BF16
>
&&
is_same_v
<
B0DataType
,
BF16
>
&&
is_same_v
<
B1DataType
,
BF16
>
&&
is_same_v
<
CDataType
,
BF16
>
&&
Acc0BiasDataType
::
Size
()
==
1
&&
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
View file @
0629870d
...
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
void
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmGemm
<
Row
,
Col
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp
View file @
0629870d
...
...
@@ -19,7 +19,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
void
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmMultiD
<
Col
,
Row
,
...
...
@@ -124,7 +124,7 @@ void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instan
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__int8__
#ifdef
CK_ENABLE_INT8
void
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmMultiD
<
Col
,
Row
,
...
...
@@ -263,7 +263,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
...
...
@@ -297,7 +297,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
}
}
#endif
#ifdef
__int8__
#ifdef
CK_ENABLE_INT8
else
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
EDataType
,
int8_t
>
)
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
View file @
0629870d
...
...
@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
View file @
0629870d
...
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
void
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
...
...
@@ -59,7 +59,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
#endif
#ifdef
__bf16__
#ifdef
CK_ENABLE_BF16
void
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
...
...
@@ -148,7 +148,7 @@ struct DeviceOperationInstanceFactory<
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
...
...
@@ -164,7 +164,7 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
#ifdef
__bf16__
#ifdef
CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
ADataType
,
BF16
>
&&
is_same_v
<
B0DataType
,
BF16
>
&&
is_same_v
<
B1DataType
,
BF16
>
&&
is_same_v
<
CDataType
,
BF16
>
)
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp
View file @
0629870d
...
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef
__fp32__
#ifdef
CK_ENABLE_FP32
// float
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -66,7 +66,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
PassThrough
,
Bilinear
>>>&
instances
);
#endif
#ifdef
__fp64__
#ifdef
CK_ENABLE_FP64
// double
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -150,7 +150,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef
__fp32__
#ifdef
CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
is_same_v
<
DDataType
,
float
>
&&
is_same_v
<
EDataType
,
float
>
)
{
...
...
@@ -167,7 +167,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
}
}
#endif
#ifdef
__fp64__
#ifdef
CK_ENABLE_FP64
if
constexpr
(
is_same_v
<
ADataType
,
double
>
&&
is_same_v
<
BDataType
,
double
>
&&
is_same_v
<
DDataType
,
double
>
&&
is_same_v
<
EDataType
,
double
>
)
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp
View file @
0629870d
...
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef
__fp32__
#ifdef
CK_ENABLE_FP32
// float
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -66,7 +66,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
PassThrough
,
Scale
>>>&
instances
);
#endif
#ifdef
__fp64__
#ifdef
CK_ENABLE_FP64
// double
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -149,7 +149,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef
__fp32__
#ifdef
CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
is_same_v
<
EDataType
,
float
>
)
{
...
...
@@ -166,7 +166,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
}
}
#endif
#ifdef
__fp64__
#ifdef
CK_ENABLE_FP64
if
constexpr
(
is_same_v
<
ADataType
,
double
>
&&
is_same_v
<
BDataType
,
double
>
&&
is_same_v
<
EDataType
,
double
>
)
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
View file @
0629870d
...
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef
__bf16__
#ifdef
CK_ENABLE_BF16
// conv1d backward data
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
...
...
@@ -30,19 +30,19 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
NWC
,
KXC
,
NWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__fp32__
#ifdef
CK_ENABLE_FP32
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
NWC
,
KXC
,
NWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__int8__
#ifdef
CK_ENABLE_INT8
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
NWC
,
...
...
@@ -55,7 +55,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__bf16__
#ifdef
CK_ENABLE_BF16
// conv2d backward data
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
...
...
@@ -69,7 +69,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
...
...
@@ -82,7 +82,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__fp32__
#ifdef
CK_ENABLE_FP32
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
...
...
@@ -95,7 +95,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__int8__
#ifdef
CK_ENABLE_INT8
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
...
...
@@ -109,7 +109,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough
>>>&
instances
);
#endif
#ifdef DL_KERNELS
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
// conv2d dl
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
...
...
@@ -123,7 +123,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__fp32__
#ifdef
CK_ENABLE_FP32
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
...
...
@@ -136,7 +136,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__int8__
#ifdef
CK_ENABLE_INT8
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
...
...
@@ -150,7 +150,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef
__bf16__
#ifdef
CK_ENABLE_BF16
// conv3d backward data
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
...
...
@@ -164,7 +164,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
NDHWC
,
...
...
@@ -177,7 +177,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__fp32__
#ifdef
CK_ENABLE_FP32
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
NDHWC
,
...
...
@@ -190,7 +190,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__int8__
#ifdef
CK_ENABLE_INT8
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
NDHWC
,
...
...
@@ -245,21 +245,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
(
op_ptrs
);
}
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances
(
op_ptrs
);
}
#endif
#ifdef
__bf16__
#ifdef
CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
(
op_ptrs
);
}
#endif
#ifdef
__int8__
#ifdef
CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
...
...
@@ -278,7 +278,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
#endif
}
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
...
...
@@ -288,14 +288,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
#endif
}
#endif
#ifdef
__bf16__
#ifdef
CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
op_ptrs
);
}
#endif
#ifdef
__int8__
#ifdef
CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
...
...
@@ -314,21 +314,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
op_ptrs
);
}
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
op_ptrs
);
}
#endif
#ifdef
__bf16__
#ifdef
CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
op_ptrs
);
}
#endif
#ifdef
__int8__
#ifdef
CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp
View file @
0629870d
...
...
@@ -18,7 +18,7 @@ namespace device {
namespace
instance
{
// conv2d forward
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
void
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -28,7 +28,7 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
DeviceConvFwd
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__bf16__
#ifdef
CK_ENABLE_BF16
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
NHWC
,
...
...
@@ -41,13 +41,13 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__fp32__
#ifdef
CK_ENABLE_FP32
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
NHWC
,
KYXC
,
NHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef
__int8__
#ifdef
CK_ENABLE_INT8
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
NHWC
,
...
...
@@ -103,7 +103,7 @@ struct DeviceOperationInstanceFactory<
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
}
#ifdef
__fp16__
#ifdef
CK_ENABLE_FP16
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
...
...
@@ -111,7 +111,7 @@ struct DeviceOperationInstanceFactory<
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
#endif
#ifdef
__bf16__
#ifdef
CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
...
...
@@ -119,7 +119,7 @@ struct DeviceOperationInstanceFactory<
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
op_ptrs
);
}
#endif
#ifdef
__int8__
#ifdef
CK_ENABLE_INT8
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment