Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
83074f4c
Commit
83074f4c
authored
Nov 11, 2024
by
Astha Rai
Browse files
Merge branch 'develop' into codegen_hiprtc
parents
7c56cd01
5fb150db
Changes
102
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
581 additions
and
19 deletions
+581
-19
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp
...ed_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp
+25
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp
...v_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp
+12
-8
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp
...fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp
+19
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp
...ped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp
+37
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
...wd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
+19
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+16
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
..._instance/gpu/grouped_convolution_backward_weight_xdl.inc
+100
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+43
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc
...ion_instance/gpu/grouped_convolution_forward_comp_xdl.inc
+32
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc
...nstance/gpu/grouped_convolution_forward_mem_inter_xdl.inc
+32
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
...nstance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
+32
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc
...peration_instance/gpu/grouped_convolution_forward_xdl.inc
+32
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc
...ance/gpu/grouped_convolution_forward_xdl_large_tensor.inc
+16
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
...nce/gpu/grouped_convolution_forward_xdl_merged_groups.inc
+30
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+4
-4
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+5
-5
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
...ion_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
+4
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp
..._two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp
+41
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp
...t_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp
+41
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp
..._two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp
+41
-0
No files found.
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp
View file @
83074f4c
...
@@ -131,6 +131,31 @@ using device_grouped_conv_fwd_xdl_f32_comp_instances = std::tuple<
...
@@ -131,6 +131,31 @@ using device_grouped_conv_fwd_xdl_f32_comp_instances = std::tuple<
// clang-format on
// clang-format on
>
;
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_int8_comp_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
16
,
16
,
8
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
// AGPR Spill when use permuted lds layout. so, use padding for these two.
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
// clang-format on
>
;
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp
View file @
83074f4c
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp
View file @
83074f4c
...
@@ -87,6 +87,25 @@ using device_grouped_conv_fwd_xdl_large_tensor_f32_instances = std::tuple<
...
@@ -87,6 +87,25 @@ using device_grouped_conv_fwd_xdl_large_tensor_f32_instances = std::tuple<
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
// clang-format on
// clang-format on
>
;
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_large_tensor_int8_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp
View file @
83074f4c
...
@@ -154,6 +154,43 @@ using device_grouped_conv_fwd_xdl_f32_mem_instances = std::tuple<
...
@@ -154,6 +154,43 @@ using device_grouped_conv_fwd_xdl_f32_mem_instances = std::tuple<
// clang-format on
// clang-format on
>
;
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
>
using
device_grouped_conv_fwd_xdl_int8_mem_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
32
,
16
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
64
,
16
,
16
,
128
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
64
,
16
,
16
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
16
,
32
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
// Memory friendly
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
16
,
64
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
128
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
128
,
16
,
64
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
64
,
32
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
64
,
16
,
64
,
8
,
8
,
16
,
16
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
32
,
16
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
64
,
16
,
16
,
128
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
64
,
16
,
16
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
16
,
32
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
16
,
64
,
64
,
8
,
8
,
16
,
16
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
32
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
16
,
128
,
64
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
32
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
16
,
256
,
64
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
32
,
256
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
// clang-format on
>
;
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
View file @
83074f4c
...
@@ -90,6 +90,25 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple<
...
@@ -90,6 +90,25 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple<
// clang-format on
// clang-format on
>
;
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_int8_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
int8_t
,
int8_t
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
int8_t
,
int8_t
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
int8_t
,
int8_t
,
LoopScheduler
::
Default
,
32
>
// clang-format on
>
;
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
View file @
83074f4c
...
@@ -352,6 +352,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -352,6 +352,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances
(
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances
(
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances
(
...
@@ -375,6 +377,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -375,6 +377,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instances
(
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instances
(
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instances
(
...
@@ -390,6 +394,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -390,6 +394,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
ComputeTypeB
,
half_t
>
)
is_same_v
<
ComputeTypeB
,
half_t
>
)
{
{
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances
(
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances
(
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances
(
...
@@ -403,6 +409,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -403,6 +409,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
{
{
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances
(
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instances
(
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instances
(
...
@@ -464,6 +472,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -464,6 +472,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances
(
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances
(
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances
(
...
@@ -487,6 +497,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -487,6 +497,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instances
(
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instances
(
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instances
(
...
@@ -511,6 +523,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -511,6 +523,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
ComputeTypeB
,
half_t
>
)
is_same_v
<
ComputeTypeB
,
half_t
>
)
{
{
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances
(
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances
(
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances
(
...
@@ -524,6 +538,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -524,6 +538,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
{
{
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances
(
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances
(
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
View file @
83074f4c
...
@@ -113,6 +113,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_in
...
@@ -113,6 +113,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_in
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instances
(
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
NHWGC
,
...
@@ -136,6 +148,19 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_p
...
@@ -136,6 +148,19 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_p
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances
(
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
NGCHW
,
...
@@ -173,6 +198,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
...
@@ -173,6 +198,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances
(
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
NHWGC
,
...
@@ -196,6 +233,19 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi
...
@@ -196,6 +233,19 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances
(
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
NGCHW
,
...
@@ -298,6 +348,18 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16
...
@@ -298,6 +348,18 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instances
(
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
NDHWGC
,
...
@@ -321,6 +383,19 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf1
...
@@ -321,6 +383,19 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf1
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances
(
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NGCDHW
,
NGCDHW
,
...
@@ -358,6 +433,18 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
...
@@ -358,6 +433,18 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances
(
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
NDHWGC
,
...
@@ -381,6 +468,19 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
...
@@ -381,6 +468,19 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances
(
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NGCDHW
,
NGCDHW
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
83074f4c
...
@@ -122,6 +122,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -122,6 +122,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif // DL_KERNELS
#endif // DL_KERNELS
#ifdef CK_USE_XDL
#ifdef CK_USE_XDL
// 1D
// layout GNWC/GKXC/GNWK
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
GNWC
>
&&
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
GNWC
>
&&
is_same_v
<
WeiLayout
,
GKXC
>
&&
is_same_v
<
OutLayout
,
GNWK
>
)
is_same_v
<
WeiLayout
,
GKXC
>
&&
is_same_v
<
OutLayout
,
GNWK
>
)
{
{
...
@@ -160,7 +162,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -160,7 +162,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
}
}
// 2D
// layout GNHWC/GKYXC/GNHWK
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
{
...
@@ -191,7 +194,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -191,7 +194,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
}
}
// layout NHWGC/GKYXC/NHWGK
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWGC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWGC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NHWGK
>
)
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NHWGK
>
)
{
{
...
@@ -247,8 +250,27 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -247,8 +250,27 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances
(
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
&&
is_same_v
<
AComputeType
,
int8_t
>
&&
is_same_v
<
BComputeType
,
int8_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_int8_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_int8_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances
(
op_ptrs
);
}
#endif
#endif
}
}
// layout NGCHW/GKYXC/NGKHW
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NGCHW
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NGCHW
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NGKHW
>
)
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NGKHW
>
)
{
{
...
@@ -282,8 +304,26 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -282,8 +304,26 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
&&
is_same_v
<
AComputeType
,
int8_t
>
&&
is_same_v
<
BComputeType
,
int8_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instances
(
op_ptrs
);
}
#endif
}
}
// 3D
// layout GNDHWC/GKZYXC/GNDHWK
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
{
{
...
@@ -323,6 +363,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -323,6 +363,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
#endif
}
}
// layout NDHWGC/GKZYXC/NDHWGK
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
{
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc
View file @
83074f4c
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances
(
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances
(
...
@@ -90,6 +106,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
...
@@ -90,6 +106,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
(
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc
View file @
83074f4c
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances
(
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances
(
...
@@ -90,6 +106,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances
...
@@ -90,6 +106,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
(
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
View file @
83074f4c
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances
(
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances
(
...
@@ -90,6 +106,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances
...
@@ -90,6 +106,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc
View file @
83074f4c
...
@@ -171,6 +171,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
...
@@ -171,6 +171,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances
(
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances
(
...
@@ -204,6 +220,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(
...
@@ -204,6 +220,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances
(
void
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc
View file @
83074f4c
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instan
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instan
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
void
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
View file @
83074f4c
...
@@ -85,6 +85,36 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_insta
...
@@ -85,6 +85,36 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_insta
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
...
...
library/include/ck/library/utility/check_err.hpp
View file @
83074f4c
...
@@ -24,7 +24,7 @@ namespace ck {
...
@@ -24,7 +24,7 @@ namespace ck {
namespace
utils
{
namespace
utils
{
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_relative_threshold
(
const
int
number
OfA
ccumulations
=
1
)
double
get_relative_threshold
(
const
int
number
_of_a
ccumulations
=
1
)
{
{
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
...
@@ -79,13 +79,13 @@ double get_relative_threshold(const int numberOfAccumulations = 1)
...
@@ -79,13 +79,13 @@ double get_relative_threshold(const int numberOfAccumulations = 1)
}
}
else
else
{
{
acc_error
=
std
::
pow
(
2
,
-
NumericUtils
<
AccDataType
>::
mant
)
*
0.5
*
number
OfA
ccumulations
;
acc_error
=
std
::
pow
(
2
,
-
NumericUtils
<
AccDataType
>::
mant
)
*
0.5
*
number
_of_a
ccumulations
;
}
}
return
std
::
max
(
acc_error
,
midway_error
);
return
std
::
max
(
acc_error
,
midway_error
);
}
}
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number
OfA
ccumulations
=
1
)
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number
_of_a
ccumulations
=
1
)
{
{
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
...
@@ -142,7 +142,7 @@ double get_absolute_threshold(const double max_possible_num, const int numberOfA
...
@@ -142,7 +142,7 @@ double get_absolute_threshold(const double max_possible_num, const int numberOfA
else
else
{
{
acc_error
=
acc_error
=
std
::
pow
(
2
,
expo
-
NumericUtils
<
AccDataType
>::
mant
)
*
0.5
*
number
OfA
ccumulations
;
std
::
pow
(
2
,
expo
-
NumericUtils
<
AccDataType
>::
mant
)
*
0.5
*
number
_of_a
ccumulations
;
}
}
return
std
::
max
(
acc_error
,
midway_error
);
return
std
::
max
(
acc_error
,
midway_error
);
}
}
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
83074f4c
...
@@ -88,19 +88,19 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -88,19 +88,19 @@ function(add_instance_library INSTANCE_NAME)
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
set
(
INST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
set
(
INST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
if
(
source MATCHES
"_xdl"
)
if
(
source MATCHES
"_xdl"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906
gfx906:xnack-
gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
gfx10.3-generic gfx11-generic gfx12-generic
)
elseif
(
source MATCHES
"_wmma"
)
elseif
(
source MATCHES
"_wmma"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906
gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack-
gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
source MATCHES
"mha"
)
elseif
(
source MATCHES
"mha"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906
gfx906:xnack- gfx908:xnack- gfx908:xnack+
gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
gfx10.3-generic gfx11-generic gfx12-generic
)
endif
()
endif
()
#only build the fp8 gemm instances for gfx908/90a if the build argument is set
#only build the fp8 gemm instances for gfx908/90a if the build argument is set
if
(
NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH
)
if
(
NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH
)
if
(
source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"f8"
)
if
(
source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"f8"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906
gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack-
gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
gfx10.3-generic gfx11-generic gfx12-generic
)
endif
()
endif
()
if
(
source MATCHES
"gemm_multiply_multiply_f8"
)
if
(
source MATCHES
"gemm_multiply_multiply_f8"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906
gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack-
gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
gfx10.3-generic gfx11-generic gfx12-generic
)
endif
()
endif
()
endif
()
endif
()
set
(
offload_targets
)
set
(
offload_targets
)
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
View file @
83074f4c
...
@@ -15,6 +15,10 @@ set(GROUPED_CONV2D_BWD_WEIGHT
...
@@ -15,6 +15,10 @@ set(GROUPED_CONV2D_BWD_WEIGHT
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp
)
)
if
(
DL_KERNELS
)
if
(
DL_KERNELS
)
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp
0 → 100644
View file @
83074f4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v1
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp
0 → 100644
View file @
83074f4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v1
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp
0 → 100644
View file @
83074f4c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_generic_instances
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v1
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment