Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4885c38a
Commit
4885c38a
authored
Sep 03, 2024
by
aska-0096
Browse files
Merge branch 'transpose_opt' of
https://github.com/ROCm/composable_kernel
into rowwise_opt
parents
cbf14ee1
7c8e92fa
Changes
83
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
990 additions
and
139 deletions
+990
-139
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+13
-3
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+37
-21
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
...fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
+37
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp
...on_instance/gpu/grouped_convolution_forward_convscale.hpp
+83
-3
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp
...stance/gpu/grouped_convolution_forward_convscale_relu.hpp
+83
-1
library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp
...k/library/tensor_operation_instance/gpu/permute_scale.hpp
+13
-0
library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp
...ance/gpu/permute_scale/device_permute_scale_instances.hpp
+52
-6
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp
...uce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp
+18
-9
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+3
-2
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt
..._instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt
+2
-1
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
...combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
+61
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt
...ance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt
+2
-1
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
...onvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
+61
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
...d_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
+1
-4
library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt
...ensor_operation_instance/gpu/permute_scale/CMakeLists.txt
+3
-2
library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp32_fp8_instances.cpp
...mute_scale/device_permute_scale_6d_fp32_fp8_instances.cpp
+28
-0
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.cpp
...uce/device_reduce_instance_blockwise_f32_f32_f32_amax.cpp
+18
-9
profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp
...include/profiler/profile_grouped_conv_bwd_weight_impl.hpp
+88
-75
profiler/src/profile_grouped_conv_bwd_weight.cpp
profiler/src/profile_grouped_conv_bwd_weight.cpp
+1
-2
script/convert_miopen_driver_to_profiler.py
script/convert_miopen_driver_to_profiler.py
+386
-0
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
4885c38a
...
...
@@ -707,16 +707,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{
if
constexpr
(
AsyncCopyK
)
{
return
GetSmemSizeKV
<
Problem
>
()
+
GetSmemSizeDropout
<
Problem
>
();
return
GetSmemSizeKV
<
Problem
>
()
+
GetSmemSizeDropout
<
Problem
>
(
0
);
}
else
{
return
ck_tile
::
max
(
GetSmemSizeKV
<
Problem
>
(),
GetSmemSizeDropout
<
Problem
>
());
return
ck_tile
::
max
(
GetSmemSizeKV
<
Problem
>
(),
GetSmemSizeDropout
<
Problem
>
(
0
));
}
}
// this method is only available when Problem::kHasDropout is present
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeDropout
()
CK_TILE_HOST_DEVICE
static
constexpr
std
::
enable_if_t
<
std
::
is_convertible_v
<
decltype
(
Problem
::
kHasDropout
),
bool
>
,
ck_tile
::
index_t
>
GetSmemSizeDropout
(
int
)
{
if
constexpr
(
Problem
::
kHasDropout
)
{
...
...
@@ -736,6 +739,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
}
}
// fallback version if Problem::kHasDropout is not exist
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeDropout
(...)
{
return
0
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKDramTileDistribution
()
{
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
4885c38a
...
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
namespace
ck_tile
{
...
...
@@ -32,30 +33,31 @@ struct TileFmhaTraits
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
template
<
bool
kPadSeqLenQ
/* padding for seqlen_q */
,
bool
kPadSeqLenK
/* padding for seqlen_k */
,
bool
kPadHeadDimQ
/* paddding for hdim_q */
,
bool
kPadHeadDimV
/* paddding for hdim_v */
,
BlockAttentionBiasEnum
BiasEnum
,
bool
kHasBiasGrad
,
bool
kStoreLSE
,
bool
kHasDropout
,
bool
kDoFp8StaticQuant
,
bool
kHasUnevenSplits_
=
true
,
index_t
kBlockPerCu
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaFwdSplitKVTraits
:
TileFmhaTraits
<
kPadSeqLenQ
,
kPadSeqLenK
,
kPadHeadDimQ
,
kPadHeadDimV
,
BiasEnum
,
kHasBiasGrad
,
kStoreLSE
,
kHasDropout
,
kDoFp8StaticQuant
,
kBlockPerCu
>
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
bool
kPadSeqLenK_
/* padding for seqlen_k */
,
bool
kPadHeadDimQ_
/* paddding for hdim_q */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
BlockAttentionBiasEnum
BiasEnum_
,
bool
kHasBiasGrad_
,
bool
kStoreLSE_
,
bool
kDoFp8StaticQuant_
,
bool
kIsPagedKV_
,
bool
kHasUnevenSplits_
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaFwdSplitKVTraits
{
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadSeqLenK
=
kPadSeqLenK_
;
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kHasBiasGrad
=
kHasBiasGrad_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
// determine if some split (length) is not divisible by tile size
static
constexpr
bool
kHasUnevenSplits
=
kHasUnevenSplits_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
...
...
@@ -76,6 +78,20 @@ struct TileFmhaFwdSplitKVCombineTraits
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
bool
kPadSeqLenK_
/* padding for seqlen_k */
,
bool
kPadHeadDimQ_
/* paddding for hdim_q */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaFwdAppendKVTraits
{
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadSeqLenK
=
kPadSeqLenK_
;
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
index_t
kBlockPerCu_
=
2
/* hint to occupancy */
>
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
View file @
4885c38a
...
...
@@ -184,6 +184,43 @@ using device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances = std::tuple<
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
,
typename
OutElementOp
>
using
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef CK_ENABLE_FP8
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F8
,
F8
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
,
F8
>
#endif
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp
View file @
4885c38a
...
...
@@ -8,9 +8,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
...
...
@@ -177,6 +175,88 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
};
using
CombConvScale
=
ck
::
tensor_operation
::
element_wise
::
ScaleScalePass
;
#ifdef CK_ENABLE_FP8
void
add_device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
CombConvScale
,
F8
,
F8
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DLayouts
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DDataTypes
,
typename
OutDataType
,
typename
AComputeType
,
typename
BComputeType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
CombConvScale
,
AComputeType
,
BComputeType
>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
CombConvScale
,
AComputeType
,
BComputeType
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
{
#ifdef CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
InDataType
,
f8_t
>
&&
is_same_v
<
WeiDataType
,
f8_t
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
AComputeType
,
f8_t
>
&&
is_same_v
<
BComputeType
,
f8_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp
View file @
4885c38a
...
...
@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/
unary
_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/
combined
_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
...
...
@@ -99,6 +99,88 @@ struct DeviceOperationInstanceFactory<
}
};
using
CombConvScaleRelu
=
ck
::
tensor_operation
::
element_wise
::
ScaleScaleRelu
;
#ifdef CK_ENABLE_FP8
void
add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
CombConvScaleRelu
,
F8
,
F8
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DLayouts
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DDataTypes
,
typename
OutDataType
,
typename
AComputeType
,
typename
BComputeType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
CombConvScaleRelu
,
AComputeType
,
BComputeType
>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
CombConvScaleRelu
,
AComputeType
,
BComputeType
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
{
#ifdef CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
InDataType
,
f8_t
>
&&
is_same_v
<
WeiDataType
,
f8_t
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
AComputeType
,
f8_t
>
&&
is_same_v
<
BComputeType
,
f8_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp
View file @
4885c38a
...
...
@@ -70,6 +70,12 @@ void add_device_permute_scale_6d_f32_instances(
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
element_wise
::
Scale
,
6
>>>&
);
#endif
#ifdef CK_ENABLE_FP8
void
add_device_permute_scale_6d_f32_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
element_wise
::
Scale
,
6
>>>&
);
#endif
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
ElementwiseOperation
,
...
...
@@ -184,6 +190,13 @@ struct DeviceOperationInstanceFactory<
{
add_device_permute_scale_6d_f16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
InDataTypeTuple
,
ck
::
Tuple
<
F32
>>
&&
is_same_v
<
OutDataTypeTuple
,
ck
::
Tuple
<
F8
>>
)
{
add_device_permute_scale_6d_f32_f8_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp
View file @
4885c38a
...
...
@@ -10,6 +10,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
...
...
@@ -46,7 +47,7 @@ using device_permute_scale_f16_instances =
#if 0
// Disabled instances to improve compilation time
// They listed here to show other possible combinations of parameters
// They listed here to show other possible combinations of parameters
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
...
...
@@ -57,7 +58,7 @@ using device_permute_scale_f16_instances =
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
...
...
@@ -97,7 +98,7 @@ using device_permute_scale_f16_instances =
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<
F16
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
>
;
template
<
index_t
NDims
,
...
...
@@ -131,7 +132,7 @@ using device_permute_scale_f32_instances = std::tuple<
#if 0
// Disabled instances to improve compilation time
// They listed here to show other possible combinations of parameters
// They listed here to show other possible combinations of parameters
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
...
...
@@ -142,7 +143,7 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
...
...
@@ -168,7 +169,7 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
#endif
#endif
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
...
...
@@ -183,6 +184,51 @@ using device_permute_scale_f32_instances = std::tuple<
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F32
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
>
;
#ifdef CK_ENABLE_FP8
template
<
index_t
NDims
,
typename
ElementwiseOp
>
using
device_permute_scale_f32_f8_instances
=
std
::
tuple
<
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
32
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
64
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
32
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
16
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
128
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
32
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
16
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
4
>
,
ck
::
Sequence
<
4
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
128
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
256
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
256
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
128
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
64
,
128
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
32
,
256
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
256
,
32
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
32
,
128
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
128
,
32
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
64
,
32
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
64
,
8
,
8
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
64
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
128
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
256
,
32
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
64
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
32
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
16
,
128
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
128
,
128
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
32
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
16
,
64
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
64
,
64
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
32
,
16
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
,
DeviceElementwiseImpl
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
ElementwiseOp
,
NDims
,
32
,
16
,
32
,
4
,
4
,
ck
::
Sequence
<
1
,
0
>
,
ck
::
Sequence
<
1
>
,
ck
::
Sequence
<
1
>>
>
;
#endif
// clang-format on
}
// namespace instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.hpp
View file @
4885c38a
...
...
@@ -14,15 +14,24 @@ namespace device {
namespace
instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
6
,
6
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
6
,
6
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
5
,
5
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
5
,
5
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
6
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
6
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
5
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
5
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
3
,
3
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
3
,
3
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
2
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
2
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>>&
);
extern
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
1
,
1
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
1
,
1
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>>&
);
// clang-format on
}
// namespace instance
...
...
library/include/ck/library/utility/check_err.hpp
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -272,7 +272,8 @@ check_err(const Range& out,
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
" number of errors: "
<<
err_count
<<
std
::
endl
;
}
return
res
;
}
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt
View file @
4885c38a
...
...
@@ -3,6 +3,7 @@ set(GROUPED_CONV3D_FWD_CONVSCALE
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp
)
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
)
add_instance_library
(
device_grouped_conv3d_fwd_convscale_instance
${
GROUPED_CONV3D_FWD_CONVSCALE
}
)
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_combconvscale_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
CombConvScale
,
F8
,
F8
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwdDefault
,
CombConvScale
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1P0
,
CombConvScale
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1S1P0
,
CombConvScale
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt
View file @
4885c38a
# ONLY XDL_KERNELS
set
(
GROUPED_CONV3D_FWD_CONVSCALE_RELU
xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
)
xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
)
add_instance_library
(
device_grouped_conv3d_fwd_convscale_relu_instance
${
GROUPED_CONV3D_FWD_CONVSCALE_RELU
}
)
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instance.cpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_combconvscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_f8_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
CombConvScaleRelu
,
F8
,
F8
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwdDefault
,
CombConvScaleRelu
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1P0
,
CombConvScaleRelu
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_f8_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1S1P0
,
CombConvScaleRelu
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
View file @
4885c38a
...
...
@@ -3,15 +3,13 @@
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation
/gpu/element/unary_element_wise_operation
.hpp"
#include "ck/
library/
tensor_operation
_instance/gpu/grouped_convolution_forward_convscale_relu
.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
ConvScaleRelu
=
ck
::
tensor_operation
::
element_wise
::
ConvScaleRelu
;
void
add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
...
...
@@ -56,7 +54,6 @@ void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_in
ConvFwd1x1S1P0
,
ConvScaleRelu
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt
View file @
4885c38a
add_instance_library
(
device_permute_scale_instance
add_instance_library
(
device_permute_scale_instance
device_permute_scale_1d_fp16_instances.cpp
device_permute_scale_2d_fp16_instances.cpp
device_permute_scale_3d_fp16_instances.cpp
...
...
@@ -10,4 +10,5 @@ add_instance_library(device_permute_scale_instance
device_permute_scale_3d_fp32_instances.cpp
device_permute_scale_4d_fp32_instances.cpp
device_permute_scale_5d_fp32_instances.cpp
device_permute_scale_6d_fp32_instances.cpp
)
device_permute_scale_6d_fp32_instances.cpp
device_permute_scale_6d_fp32_fp8_instances.cpp
)
library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp32_fp8_instances.cpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
Scale
=
element_wise
::
Scale
;
void
add_device_permute_scale_6d_f32_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwise
<
ck
::
Tuple
<
F32
>
,
ck
::
Tuple
<
F8
>
,
Scale
,
6
>>>&
instances
)
{
#ifdef CK_ENABLE_FP8
add_device_operation_instances
(
instances
,
device_permute_scale_f32_f8_instances
<
6
,
Scale
>
{});
#else
ignore
=
instances
;
#endif
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.cpp
View file @
4885c38a
...
...
@@ -10,15 +10,24 @@ namespace device {
namespace
instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
// InDataType | AccDataType | OutDataType | Rank | NumReduceDim | ReduceOperation | InElementwiseOp | AccElementwiseOp | PropagateNan | UseIndex
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
1
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
false
,
true
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
6
,
6
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
6
,
6
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
5
,
5
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
5
,
5
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
4
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
6
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
6
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
5
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
5
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
4
,
3
,
ReduceAMax
,
UnaryAbs
,
PassThrough
,
true
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
3
,
3
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
3
,
3
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
2
,
2
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
2
,
2
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>>&
);
template
void
add_device_reduce_instance_blockwise
<
F32
,
F32
,
F32
,
1
,
1
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>(
std
::
vector
<
DeviceReducePtr
<
F32
,
F32
,
F32
,
1
,
1
,
ReduceAMax
,
PassThrough
,
PassThrough
,
true
,
false
>>&
);
// clang-format on
}
// namespace instance
...
...
profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp
View file @
4885c38a
...
...
@@ -136,9 +136,10 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
float
best_avg_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_avg_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
ck
::
index_t
best_split_k
=
1
;
// profile device Conv instances
bool
all_pass
=
true
;
...
...
@@ -167,99 +168,111 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
range_copy
(
conv_param
.
input_left_pads_
,
begin
(
input_left_pads
));
range_copy
(
conv_param
.
input_right_pads_
,
begin
(
input_right_pads
));
std
::
vector
<
ck
::
index_t
>
split_k_list
=
{
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
};
if
(
split_k
>
0
)
{
split_k_list
=
{
split_k
};
}
for
(
auto
&
op_ptr
:
op_ptrs
)
{
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
,
split_k
);
const
std
::
size_t
workspace_sz
=
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
());
DeviceMem
workspace_dev
(
workspace_sz
);
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
workspace_dev
.
GetDeviceBuffer
());
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
for
(
std
::
size_t
split_k_id
=
0
;
split_k_id
<
split_k_list
.
size
();
split_k_id
++
)
{
// using atomic add, so need to reset input
wei_device_buf
.
SetZero
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
,
split_k_list
[
split_k_id
]);
const
std
::
size_t
workspace_sz
=
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
());
DeviceMem
workspace_dev
(
workspace_sz
);
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
workspace_dev
.
GetDeviceBuffer
());
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
// using atomic add, so need to reset input
wei_device_buf
.
SetZero
();
std
::
size_t
flop
=
conv_param
.
GetFlops
();
std
::
size_t
num_btype
=
conv_param
.
GetByte
<
InDataType
,
WeiDataType
,
OutDataType
>
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
avg_time
;
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
})
;
if
(
tflops
>
best_tflops
)
{
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_avg_time
=
avg_time
;
best_gb_per_sec
=
gb_per_sec
;
}
std
::
size_t
flop
=
conv_param
.
GetFlops
();
std
::
size_t
num_btype
=
conv_param
.
GetByte
<
InDataType
,
WeiDataType
,
OutDataType
>
();
if
(
do_verification
)
{
wei_device_buf
.
FromDevice
(
weight_device_result
.
mData
.
data
());
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
avg_time
;
bool
pass
=
ck
::
utils
::
check_err
(
weight_device_result
,
weight_host_result
);
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", SplitK "
<<
split_k_list
[
split_k_id
]
<<
std
::
endl
;
if
(
!
pas
s
)
if
(
tflops
>
best_tflop
s
)
{
std
::
cout
<<
"Fail info: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_avg_time
=
avg_time
;
best_gb_per_sec
=
gb_per_sec
;
best_split_k
=
split_k_list
[
split_k_id
];
}
all_pass
&=
pass
;
if
(
do_log
)
if
(
do_verification
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"output : "
,
output
.
mData
,
","
)
<<
std
::
endl
;
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"weight (device): "
,
weight_device_result
.
mData
,
","
)
<<
std
::
endl
;
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"weight (host): "
,
weight_host_result
.
mData
,
","
)
<<
std
::
endl
;
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"input: "
,
input
.
mData
,
","
)
<<
std
::
endl
;
;
wei_device_buf
.
FromDevice
(
weight_device_result
.
mData
.
data
());
bool
pass
=
ck
::
utils
::
check_err
(
weight_device_result
,
weight_host_result
);
if
(
!
pass
)
{
std
::
cout
<<
"Fail info: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
}
all_pass
&=
pass
;
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"output : "
,
output
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"weight (device): "
,
weight_device_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"weight (host): "
,
weight_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"input: "
,
input
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
}
else
{
std
::
cout
<<
op_ptr
->
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
else
{
std
::
cout
<<
op_ptr
->
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
}
}
}
std
::
cout
<<
"Best configuration parameters:"
<<
"
\n
name: "
<<
best_op_name
<<
"
\n
avg_time: "
<<
best_avg_time
<<
"
\n
tflops: "
<<
best_tflops
<<
"
\n
GB/s: "
<<
best_gb_per_sec
<<
std
::
endl
;
<<
"
\n
tflops: "
<<
best_tflops
<<
"
\n
GB/s: "
<<
best_gb_per_sec
<<
", SplitK "
<<
best_split_k
<<
std
::
endl
;
return
all_pass
;
}
...
...
profiler/src/profile_grouped_conv_bwd_weight.cpp
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
...
...
@@ -81,7 +81,6 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
const
auto
params
=
ck
::
utils
::
conv
::
parse_conv_param
(
num_dim_spatial
,
9
,
argv
);
ck
::
index_t
split_k
=
std
::
stoi
(
argv
[
8
+
1
+
4
+
6
*
num_dim_spatial
]);
split_k
=
std
::
max
(
1
,
split_k
);
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
...
...
script/convert_miopen_driver_to_profiler.py
0 → 100644
View file @
4885c38a
# SPDX-License-Identifier: MIT
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
# Convert miopen driver command to ck Profiler
# Example: python3 ../script/convert_miopen_driver_to_profiler.py
# /opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 -k 64 -y 3 -x 3
# -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -m conv -g 32 -F 1 -t 1
import
argparse
import
subprocess
def
init_const_args
(
args
):
args
.
ck_profiler_cmd
=
'../build/bin/ckProfiler'
# use decimal values
args
.
init_method
=
2
# don't print tensor values
args
.
log_value
=
0
def
run_ck_profiler_cmd
(
cmd
):
print
(
"ckProfiler command:"
)
print
(
cmd
)
subprocess
.
run
(
cmd
)
def
parse_data_type
(
args
):
if
args
.
data_type
==
"fp32"
:
if
args
.
ck_profier_op
==
"grouped_conv_bwd_weight"
or
\
args
.
ck_profier_op
==
"grouped_conv_bwd_data"
or
\
args
.
ck_profier_op
==
"grouped_conv_fwd"
:
args
.
data_type
=
0
if
args
.
data_type
==
"fp16"
:
if
args
.
ck_profier_op
==
"grouped_conv_bwd_weight"
or
\
args
.
ck_profier_op
==
"grouped_conv_bwd_data"
or
\
args
.
ck_profier_op
==
"grouped_conv_fwd"
:
args
.
data_type
=
1
if
args
.
data_type
==
"int8"
:
if
args
.
ck_profier_op
==
"grouped_conv_bwd_weight"
:
args
.
data_type
=
4
if
args
.
ck_profier_op
==
"grouped_conv_bwd_data"
:
print
(
'Not supported data type for grouped_conv_bwd_data'
)
exit
(
1
)
if
args
.
ck_profier_op
==
"grouped_conv_fwd"
:
args
.
data_type
=
3
if
args
.
data_type
==
"bfp16"
:
if
args
.
ck_profier_op
==
"grouped_conv_bwd_weight"
or
\
args
.
ck_profier_op
==
"grouped_conv_bwd_data"
or
\
args
.
ck_profier_op
==
"grouped_conv_fwd"
:
args
.
data_type
=
2
def
add_conv_params_to_cmd
(
args
,
cmd
):
if
args
.
spatial_dim
==
1
:
cmd
+=
[
str
(
args
.
fil_w
),
str
(
args
.
in_w
)]
cmd
+=
[
str
(
args
.
conv_stride_w
),
str
(
args
.
dilation_w
)]
cmd
+=
[
str
(
args
.
pad_w
),
str
(
args
.
pad_w
)]
elif
args
.
spatial_dim
==
2
:
cmd
+=
[
str
(
args
.
fil_h
),
str
(
args
.
fil_w
)]
cmd
+=
[
str
(
args
.
in_h
),
str
(
args
.
in_w
)]
cmd
+=
[
str
(
args
.
conv_stride_h
),
str
(
args
.
conv_stride_w
)]
cmd
+=
[
str
(
args
.
dilation_h
),
str
(
args
.
dilation_w
)]
cmd
+=
[
str
(
args
.
pad_h
),
str
(
args
.
pad_w
)]
cmd
+=
[
str
(
args
.
pad_h
),
str
(
args
.
pad_w
)]
elif
args
.
spatial_dim
==
3
:
cmd
+=
[
str
(
args
.
fil_d
),
str
(
args
.
fil_h
),
str
(
args
.
fil_w
)]
cmd
+=
[
str
(
args
.
in_d
),
str
(
args
.
in_h
),
str
(
args
.
in_w
)]
cmd
+=
[
str
(
args
.
conv_stride_d
),
str
(
args
.
conv_stride_h
)]
cmd
+=
[
str
(
args
.
conv_stride_w
)]
cmd
+=
[
str
(
args
.
dilation_d
),
str
(
args
.
dilation_h
),
str
(
args
.
dilation_w
)]
cmd
+=
[
str
(
args
.
pad_d
),
str
(
args
.
pad_h
),
str
(
args
.
pad_w
)]
cmd
+=
[
str
(
args
.
pad_d
),
str
(
args
.
pad_h
),
str
(
args
.
pad_w
)]
else
:
print
(
'Not supported spatial dim (supported: 1, 2, 3)'
)
exit
(
1
)
def
run_ck_grouped_conv_fwd
(
args
):
args
.
ck_profier_op
=
"grouped_conv_fwd"
parse_data_type
(
args
)
# default for MIOpen NHWGC
args
.
layout
=
1
# use int32 by default
args
.
index_type
=
0
cmd
=
[
str
(
args
.
ck_profiler_cmd
),
str
(
args
.
ck_profier_op
)]
cmd
+=
[
str
(
args
.
data_type
),
str
(
args
.
layout
),
str
(
args
.
index_type
)]
cmd
+=
[
str
(
args
.
verify
),
str
(
args
.
init_method
)]
cmd
+=
[
str
(
args
.
log_value
),
str
(
args
.
time
)]
cmd
+=
[
str
(
args
.
spatial_dim
),
str
(
args
.
group_count
)]
cmd
+=
[
str
(
args
.
batchsize
),
str
(
args
.
out_channels
)]
cmd
+=
[
str
(
args
.
in_channels
)]
add_conv_params_to_cmd
(
args
,
cmd
)
run_ck_profiler_cmd
(
cmd
)
def
run_ck_grouped_conv_bwd_data
(
args
):
args
.
ck_profier_op
=
"grouped_conv_bwd_data"
parse_data_type
(
args
)
# default for MIOpen NHWGC
args
.
layout
=
1
cmd
=
[
str
(
args
.
ck_profiler_cmd
),
str
(
args
.
ck_profier_op
)]
cmd
+=
[
str
(
args
.
data_type
),
str
(
args
.
layout
)]
cmd
+=
[
str
(
args
.
verify
),
str
(
args
.
init_method
)]
cmd
+=
[
str
(
args
.
log_value
),
str
(
args
.
time
)]
cmd
+=
[
str
(
args
.
spatial_dim
),
str
(
args
.
group_count
)]
cmd
+=
[
str
(
args
.
batchsize
),
str
(
args
.
out_channels
)]
cmd
+=
[
str
(
args
.
in_channels
)]
add_conv_params_to_cmd
(
args
,
cmd
)
run_ck_profiler_cmd
(
cmd
)
def
run_ck_grouped_conv_bwd_weight
(
args
):
args
.
ck_profier_op
=
"grouped_conv_bwd_weight"
parse_data_type
(
args
)
# default for MIOpen NHWGC
args
.
layout
=
2
# Test all split K value from the list {1, 2, 4, 8, 32, 64, 128}
args
.
split_k_value
=
-
1
cmd
=
[
str
(
args
.
ck_profiler_cmd
),
str
(
args
.
ck_profier_op
)]
cmd
+=
[
str
(
args
.
data_type
),
str
(
args
.
layout
)]
cmd
+=
[
str
(
args
.
verify
),
str
(
args
.
init_method
)]
cmd
+=
[
str
(
args
.
log_value
),
str
(
args
.
time
)]
cmd
+=
[
str
(
args
.
spatial_dim
),
str
(
args
.
group_count
)]
cmd
+=
[
str
(
args
.
batchsize
),
str
(
args
.
out_channels
)]
cmd
+=
[
str
(
args
.
in_channels
)]
add_conv_params_to_cmd
(
args
,
cmd
)
cmd
+=
[
str
(
args
.
split_k_value
)]
run_ck_profiler_cmd
(
cmd
)
# Get name of miopen driver, remove it from unknown
def
process_miopen_driver_name
(
args
,
unknown
):
if
"convint8"
in
unknown
:
args
.
data_type
=
'int8'
unknown
.
remove
(
"convint8"
)
elif
"convbfp16"
in
unknown
:
args
.
data_type
=
'bfp16'
unknown
.
remove
(
"convbfp16"
)
elif
"convfp16"
in
unknown
:
args
.
data_type
=
'fp16'
unknown
.
remove
(
"convfp16"
)
elif
"conv"
in
unknown
:
args
.
data_type
=
'fp32'
unknown
.
remove
(
"conv"
)
else
:
print
(
'Not supported driver (supported: conv, convfp16, convint8,'
' convbfp16).'
)
exit
(
1
)
def
run_ck_profiler
(
args
):
# MIOpen get number of channel per all groups, CK profiler get number of
# channel per group
args
.
in_channels
=
int
(
args
.
in_channels
/
args
.
group_count
)
args
.
out_channels
=
int
(
args
.
out_channels
/
args
.
group_count
)
if
args
.
forw
==
0
or
args
.
forw
==
1
or
args
.
forw
==
3
or
args
.
forw
==
5
:
run_ck_grouped_conv_fwd
(
args
)
if
args
.
forw
==
0
or
args
.
forw
==
2
or
args
.
forw
==
3
or
args
.
forw
==
6
:
run_ck_grouped_conv_bwd_data
(
args
)
if
args
.
forw
==
0
or
args
.
forw
==
4
or
args
.
forw
==
5
or
args
.
forw
==
6
:
run_ck_grouped_conv_bwd_weight
(
args
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
prog
=
"converter"
,
description
=
"Convert miopen driver command to ck Profiler"
"
\n
Example: python3 "
"../script/convert_miopen_driver_to_profiler.py "
"/opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 "
"-k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g "
"32 -F 1 -t 1"
,
)
parser
.
add_argument
(
"-in_layout"
,
"-I"
,
default
=-
1
,
type
=
int
,
required
=
False
,
help
=
"Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)"
)
parser
.
add_argument
(
"-forw"
,
"-F"
,
default
=
0
,
type
=
int
,
required
=
False
,
help
=
"Flag enables fwd, bwd, wrw convolutions"
"
\n
0 fwd+bwd+wrw (default)"
"
\n
1 fwd only"
"
\n
2 bwd only"
"
\n
4 wrw only"
"
\n
3 fwd+bwd"
"
\n
5 fwd+wrw"
"
\n
6 bwd+wrw"
)
parser
.
add_argument
(
"-spatial_dim"
,
"-_"
,
default
=
2
,
type
=
int
,
required
=
False
,
help
=
"convolution spatial dimension (Default-2)"
)
parser
.
add_argument
(
"-batchsize"
,
"-n"
,
default
=
100
,
type
=
int
,
required
=
False
,
help
=
"Mini-batch size (Default=100)"
)
parser
.
add_argument
(
"-in_channels"
,
"-c"
,
default
=
3
,
type
=
int
,
required
=
False
,
help
=
"Number of Input Channels (Default=3)"
)
parser
.
add_argument
(
"-in_d"
,
"-!"
,
default
=
32
,
type
=
int
,
required
=
False
,
help
=
"Input Depth (Default=32)"
)
parser
.
add_argument
(
"-in_h"
,
"-H"
,
default
=
32
,
type
=
int
,
required
=
False
,
help
=
"Input Height (Default=32)"
)
parser
.
add_argument
(
"-in_w"
,
"-W"
,
default
=
32
,
type
=
int
,
required
=
False
,
help
=
"Input Width (Default=32)"
)
parser
.
add_argument
(
"-out_channels"
,
"-k"
,
default
=
32
,
type
=
int
,
required
=
False
,
help
=
"Number of Output Channels (Default=32)"
)
parser
.
add_argument
(
"-fil_d"
,
"-@"
,
default
=
3
,
type
=
int
,
required
=
False
,
help
=
"Filter Depth (Default=3)"
)
parser
.
add_argument
(
"-fil_h"
,
"-y"
,
default
=
3
,
type
=
int
,
required
=
False
,
help
=
"Filter Height (Default=3)"
)
parser
.
add_argument
(
"-fil_w"
,
"-x"
,
default
=
3
,
type
=
int
,
required
=
False
,
help
=
"Filter Width (Default=3)"
)
parser
.
add_argument
(
"-conv_stride_d"
,
"-#"
,
default
=
1
,
type
=
int
,
required
=
False
,
help
=
"Convolution Stride for Depth (Default=1)"
)
parser
.
add_argument
(
"-conv_stride_h"
,
"-u"
,
default
=
1
,
type
=
int
,
required
=
False
,
help
=
"Convolution Stride for Height (Default=1)"
)
parser
.
add_argument
(
"-conv_stride_w"
,
"-v"
,
default
=
1
,
type
=
int
,
required
=
False
,
help
=
"Convolution Stride for Width (Default=1)"
)
parser
.
add_argument
(
"-pad_d"
,
"-$"
,
default
=
1
,
type
=
int
,
required
=
False
,
help
=
"Zero Padding for Depth (Default=0)"
)
parser
.
add_argument
(
"-pad_h"
,
"-p"
,
default
=
1
,
type
=
int
,
required
=
False
,
help
=
"Zero Padding for Height (Default=0)"
)
parser
.
add_argument
(
"-pad_w"
,
"-q"
,
default
=
1
,
type
=
int
,
required
=
False
,
help
=
"Zero Padding for Width (Default=0)"
)
parser
.
add_argument
(
"-verify"
,
"-V"
,
default
=
1
,
type
=
int
,
required
=
False
,
help
=
"Verify Each Layer (Default=1)"
)
parser
.
add_argument
(
"-time"
,
"-t"
,
default
=
0
,
type
=
int
,
required
=
False
,
help
=
"Time Each Layer (Default=0)"
)
parser
.
add_argument
(
"-dilation_d"
,
"-^"
,
default
=
1
,
type
=
int
,
required
=
False
,
help
=
"Dilation of Filter Depth (Default=1)"
)
parser
.
add_argument
(
"-dilation_h"
,
"-l"
,
default
=
1
,
type
=
int
,
required
=
False
,
help
=
"Dilation of Filter Height (Default=1)"
)
parser
.
add_argument
(
"-dilation_w"
,
"-j"
,
default
=
1
,
type
=
int
,
required
=
False
,
help
=
"Dilation of Filter Width (Default=1)"
)
parser
.
add_argument
(
"-group_count"
,
"-g"
,
type
=
int
,
default
=
1
,
required
=
False
,
help
=
"Number of Groups (Default=1)"
)
args
,
unknown
=
parser
.
parse_known_args
()
init_const_args
(
args
)
process_miopen_driver_name
(
args
,
unknown
)
print
(
"Ignored args:"
)
print
(
unknown
)
run_ck_profiler
(
args
)
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment