Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ebb5522c
Commit
ebb5522c
authored
Nov 12, 2024
by
Mateusz Ozga
Committed by
root
Dec 16, 2024
Browse files
Apply cshuffle to bwd_weight_cshuffle operator
parent
fdfe2102
Changes
85
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2437 additions
and
536 deletions
+2437
-536
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
+25
-25
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
+10
-10
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
..._weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
+26
-23
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+866
-217
include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp
...erator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp
+269
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp
...wd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp
+102
-125
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+128
-18
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
..._instance/gpu/grouped_convolution_backward_weight_xdl.inc
+676
-78
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt
...ion_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt
+13
-3
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev2_instance.cpp
..._gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev2_instance.cpp
+5
-13
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev5_instance.cpp
..._gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev5_instance.cpp
+39
-0
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev2_instance.cpp
...xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev2_instance.cpp
+39
-0
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev5_instance.cpp
...xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev5_instance.cpp
+39
-0
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev2_instance.cpp
...weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev2_instance.cpp
+38
-0
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev5_instance.cpp
...weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev5_instance.cpp
+38
-0
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev2_instance.cpp
...wd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev2_instance.cpp
+5
-12
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev5_instance.cpp
...wd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev5_instance.cpp
+38
-0
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev2_instance.cpp
...weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev2_instance.cpp
+38
-0
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev5_instance.cpp
...weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev5_instance.cpp
+38
-0
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_pad0_pipev2_instance.cpp
...wd_weight_xdl_gnwc_gkxc_gnwk_f32_pad0_pipev2_instance.cpp
+5
-12
No files found.
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
View file @
ebb5522c
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "common.hpp"
...
@@ -39,33 +39,33 @@ using DeviceConvBwdWeightInstance =
...
@@ -39,33 +39,33 @@ using DeviceConvBwdWeightInstance =
WeiElementOp
,
// WeiElementwiseOperation
WeiElementOp
,
// WeiElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
ConvBwdWeightDefault
,
// ConvolutionBackwardWeightSpecialization
ConvBwdWeightDefault
,
// ConvolutionBackwardWeightSpecialization
256
,
// BlockSize
64
,
// BlockSize
1
28
,
// MPerBlock
1
6
,
// MPerBlock
1
28
,
// NPerBlock
1
6
,
// NPerBlock
4
,
// K0PerBlock
32
,
// K0PerBlock
8
,
// K1
8
,
// K1
32
,
// MPerXdl
16
,
// MPerXdl
32
,
// NPerXdl
16
,
// NPerXdl
2
,
// MXdlPerWave
1
,
// MXdlPerWave
2
,
// NXdlPerWave
1
,
// NXdlPerWave
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
4
,
16
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
2
,
0
,
1
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
1
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
1
,
// ABlockTransferSrcScalarPerVector
2
,
// ABlockTransferDstScalarPerVector_K1
4
,
// ABlockTransferDstScalarPerVector_K1
tru
e
,
// ABlockLdsAddExtraM
fals
e
,
// ABlockLdsAddExtraM
S
<
1
,
4
,
16
,
4
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
4
,
16
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
2
,
0
,
1
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// BBlockTransferSrcAccessOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
1
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
1
,
// BBlockTransferSrcScalarPerVector
2
,
// BBlockTransferDstScalarPerVector_K1
4
,
// BBlockTransferDstScalarPerVector_K1
tru
e
,
// BBlockLdsAddExtraN
fals
e
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
4
>
,
// CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
8
,
1
,
8
>
,
// CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1
28
/
(
sizeof
(
WeiDataType
)
*
CHAR_BIT
)
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
1
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
template
<
ck
::
index_t
NDimSpatial
>
template
<
ck
::
index_t
NDimSpatial
>
using
HostConvBwdWeightInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWeight
<
NDimSpatial
,
using
HostConvBwdWeightInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWeight
<
NDimSpatial
,
...
...
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
View file @
ebb5522c
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "common.hpp"
...
@@ -41,26 +41,26 @@ using DeviceConvBwdWeightInstance =
...
@@ -41,26 +41,26 @@ using DeviceConvBwdWeightInstance =
256
,
// BlockSize
256
,
// BlockSize
128
,
// MPerBlock
128
,
// MPerBlock
128
,
// NPerBlock
128
,
// NPerBlock
4
,
// K0PerBlock
32
,
// K0PerBlock
8
,
// K1
8
,
// K1
32
,
// MPerXdl
32
,
// MPerXdl
32
,
// NPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
2
,
// MXdlPerWave
2
,
// NXdlPerWave
2
,
// NXdlPerWave
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
4
,
16
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
2
,
0
,
1
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferSrcScalarPerVector
2
,
// ABlockTransferDstScalarPerVector_K1
2
,
// ABlockTransferDstScalarPerVector_K1
tru
e
,
// ABlockLdsAddExtraM
fals
e
,
// ABlockLdsAddExtraM
S
<
1
,
4
,
16
,
4
>
,
//
B
BlockTransferThreadClusterLengths_K0_
N
_K1
S
<
4
,
16
,
1
>
,
//
A
BlockTransferThreadClusterLengths_K0_
M
_K1
S
<
0
,
3
,
1
,
2
>
,
//
B
BlockTransferThreadClusterArrangeOrder
S
<
2
,
0
,
1
>
,
//
A
BlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
//
B
BlockTransferSrcAccessOrder
S
<
1
,
0
,
2
>
,
//
A
BlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferSrcScalarPerVector
2
,
// BBlockTransferDstScalarPerVector_K1
2
,
// BBlockTransferDstScalarPerVector_K1
tru
e
,
// BBlockLdsAddExtraN
fals
e
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
4
>
,
// CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
4
>
,
// CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
...
...
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
View file @
ebb5522c
...
@@ -40,35 +40,38 @@ using DeviceConvBwdWeightInstance =
...
@@ -40,35 +40,38 @@ using DeviceConvBwdWeightInstance =
WeiElementOp
,
// WeiElementwiseOperation
WeiElementOp
,
// WeiElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
ConvBwdWeightDefault
,
// ConvolutionBackwardWeightSpecialization
ConvBwdWeightDefault
,
// ConvolutionBackwardWeightSpecialization
256
,
// BlockSize
64
,
// BlockSize
1
28
,
// MPerBlock
1
6
,
// MPerBlock
1
28
,
// NPerBlock
1
6
,
// NPerBlock
4
,
// K0PerBlock
32
,
// K0PerBlock
8
,
// K1
8
,
// K1
32
,
// MPerXdl
16
,
// MPerXdl
32
,
// NPerXdl
16
,
// NPerXdl
2
,
// MXdlPerWave
1
,
// MXdlPerWave
2
,
// NXdlPerWave
1
,
// NXdlPerWave
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
4
,
16
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
2
,
0
,
1
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
1
,
// ABlockTransferSrcVectorDim
1
,
// ABlockTransferSrcScalarPerVector
1
,
// ABlockTransferSrcScalarPerVector
1
,
// ABlockTransferDstScalarPerVector_K1
4
,
// ABlockTransferDstScalarPerVector_K1
tru
e
,
// ABlockLdsAddExtraM
fals
e
,
// ABlockLdsAddExtraM
S
<
1
,
4
,
16
,
4
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
4
,
16
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
2
,
0
,
1
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// BBlockTransferSrcAccessOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
1
,
// BBlockTransferSrcVectorDim
1
,
// BBlockTransferSrcScalarPerVector
1
,
// BBlockTransferSrcScalarPerVector
1
,
// BBlockTransferDstScalarPerVector_K1
4
,
// BBlockTransferDstScalarPerVector_K1
tru
e
,
// BBlockLdsAddExtraN
fals
e
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
4
>
,
// CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
8
,
1
,
8
>
,
// CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
2
,
// CBlockTransferScalarPerVector_NWaveNPerXdl
2
,
// CBlockTransferScalarPerVector_NWaveNPerXdl
ComputeTypeA
,
// ComputeTypeA
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
// BlkGemmPipeSched
ComputeTypeB
>
;
// ComputeTypeB
ck
::
BlockGemmPipelineVersion
::
v1
,
// BlkGemmPipelineVer
1
,
// NumGroupsToMerge
ComputeTypeA
,
// ComputeTypeA
ComputeTypeB
>
;
// ComputeTypeB
template
<
ck
::
index_t
NDimSpatial
>
template
<
ck
::
index_t
NDimSpatial
>
using
HostConvBwdWeightInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWeight
<
NDimSpatial
,
using
HostConvBwdWeightInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWeight
<
NDimSpatial
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
ebb5522c
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp
View file @
ebb5522c
...
@@ -34,6 +34,94 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -34,6 +34,94 @@ struct TransformConvBwdWeightToGemmV2
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_out_grid_desc
(
const
index_t
N
,
const
index_t
Wo
,
const
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_strides
)
{
const
index_t
BatchStride
=
output_strides
[
0
];
const
index_t
WoStride
=
output_strides
[
3
];
const
auto
KStride
=
Number
<
1
>
{};
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Wo
,
NumGroupsToMerge
,
K
),
make_tuple
(
WoStride
,
BatchStride
,
KStride
));
}
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_in_grid_desc
(
const
index_t
N
,
const
index_t
Wi
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_strides
)
{
const
index_t
BatchStride
=
input_strides
[
0
];
const
index_t
NStride
=
input_strides
[
1
];
const
index_t
WiStride
=
input_strides
[
3
];
const
auto
CStride
=
input_strides
[
2
];
if
constexpr
(
ConvBackwardWeightSpecialization
==
device
::
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Wi
,
NumGroupsToMerge
,
C
),
make_tuple
(
WiStride
,
BatchStride
,
CStride
));
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
NumGroupsToMerge
,
C
),
make_tuple
(
NStride
,
WiStride
,
BatchStride
,
CStride
));
}
}
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_wei_grid_desc
(
const
index_t
K
,
const
index_t
X
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weights_strides
)
{
const
auto
CStride
=
Number
<
1
>
{};
const
auto
KStride
=
weights_strides
[
1
];
const
auto
XStride
=
weights_strides
[
3
];
const
auto
BatchStride
=
weights_strides
[
0
];
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placehorder
// for Batch+N dimension
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NumGroupsToMerge
,
K
,
X
,
1
,
C
),
make_tuple
(
BatchStride
,
KStride
,
XStride
,
BatchStride
,
CStride
));
// Padd 1 to NumGroupsToMerge
const
auto
padded_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
X
),
make_pad_transform
(
1
,
0
,
NumGroupsToMerge
-
1
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert
(
NumGroupsToMerge
==
1
||
NumGroupsToMerge
==
2
||
NumGroupsToMerge
==
4
||
NumGroupsToMerge
==
8
||
NumGroupsToMerge
==
16
||
NumGroupsToMerge
==
32
||
NumGroupsToMerge
==
64
);
const
auto
unmerged_padded_desc
=
transform_tensor_descriptor
(
padded_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
NumGroupsToMerge
,
NumGroupsToMerge
)),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
X
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{}));
// Merge To M, N
return
transform_tensor_descriptor
(
unmerged_padded_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
NumGroupsToMerge
,
K
)),
make_merge_transform
(
make_tuple
(
X
,
NumGroupsToMerge
,
C
))),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
constexpr
static
auto
constexpr
static
auto
make_out_grid_desc
(
const
index_t
N
,
make_out_grid_desc
(
const
index_t
N
,
...
@@ -221,6 +309,187 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -221,6 +309,187 @@ struct TransformConvBwdWeightToGemmV2
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
const
index_t
N
,
const
index_t
K
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
batch_k
)
{
using
namespace
ck
;
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
const
index_t
GemmKTotal
=
N
*
Wo
;
const
index_t
GemmM
=
K
*
NumGroupsToMerge
;
const
index_t
GemmN
=
C
*
X
*
NumGroupsToMerge
;
const
auto
PadGemmM
=
MPerBlock
-
GemmM
%
MPerBlock
;
const
auto
PadGemmN
=
NPerBlock
-
GemmN
%
NPerBlock
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
const
auto
out_grid_desc
=
make_out_grid_desc
<
NDim
>
(
N
,
Wo
,
K
,
output_strides
);
const
auto
in_grid_desc
=
make_in_grid_desc
<
NDim
>
(
N
,
Wi
,
C
,
input_strides
);
const
auto
wei_grid_desc
=
make_wei_grid_desc
<
NDim
>
(
K
,
X
,
C
,
weights_strides
);
if
constexpr
(
ConvBackwardWeightSpecialization
==
device
::
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
NumGroupsToMerge
,
GemmM
/
NumGroupsToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
*
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
NumGroupsToMerge
,
GemmN
/
NumGroupsToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
*
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_grid_desc
);
}
else
{
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
NumGroupsToMerge
,
GemmM
/
NumGroupsToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
*
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
X
,
NumGroupsToMerge
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
4
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
*
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// Padd
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc
=
transform_tensor_descriptor
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmKBatch
*
GemmK0
),
make_right_pad_transform
(
GemmM
,
PadGemmM
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc
=
transform_tensor_descriptor
(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmKBatch
*
GemmK0
),
make_right_pad_transform
(
GemmN
,
PadGemmN
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
wei_gemmm_gemmn_pad_grid_desc
=
transform_tensor_descriptor
(
wei_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmM
,
PadGemmM
),
make_right_pad_transform
(
GemmN
,
PadGemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc
,
wei_gemmm_gemmn_pad_grid_desc
);
}
}
// function end
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
const
index_t
N
,
const
index_t
N
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp
View file @
ebb5522c
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
View file @
ebb5522c
...
@@ -276,7 +276,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -276,7 +276,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
float
>
&&
is_same_v
<
ComputeTypeA
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
&&
is_same_v
<
ComputeTypeA
,
float
>
&&
is_same_v
<
ComputeTypeB
,
float
>
)
is_same_v
<
ComputeTypeB
,
float
>
)
{
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances
(
op_ptrs
);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_pad0_pipev5_instances
(
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
...
@@ -284,7 +291,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -284,7 +291,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
ComputeTypeB
,
half_t
>
)
is_same_v
<
ComputeTypeB
,
half_t
>
)
{
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances
(
op_ptrs
);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev5_instances
(
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
...
@@ -293,7 +307,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -293,7 +307,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
{
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances
(
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev5_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -309,7 +329,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -309,7 +329,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
float
>
&&
is_same_v
<
ComputeTypeA
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
&&
is_same_v
<
ComputeTypeA
,
float
>
&&
is_same_v
<
ComputeTypeB
,
float
>
)
is_same_v
<
ComputeTypeB
,
float
>
)
{
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances
(
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev5_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -318,7 +344,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -318,7 +344,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
ComputeTypeB
,
half_t
>
)
is_same_v
<
ComputeTypeB
,
half_t
>
)
{
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev2_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -328,7 +360,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -328,7 +360,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
{
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances
(
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_pad0_pipev5_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -341,7 +379,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -341,7 +379,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
float
>
&&
is_same_v
<
ComputeTypeA
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
&&
is_same_v
<
ComputeTypeA
,
float
>
&&
is_same_v
<
ComputeTypeB
,
float
>
)
is_same_v
<
ComputeTypeB
,
float
>
)
{
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -350,7 +394,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -350,7 +394,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
ComputeTypeB
,
half_t
>
)
is_same_v
<
ComputeTypeB
,
half_t
>
)
{
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instances
(
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -366,7 +416,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -366,7 +416,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
{
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances
(
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_pad0_pipev5_instances
(
op_ptrs
);
op_ptrs
);
}
}
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
...
@@ -375,7 +431,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -375,7 +431,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
{
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances
(
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -429,7 +491,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -429,7 +491,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
float
>
&&
is_same_v
<
ComputeTypeA
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
&&
is_same_v
<
ComputeTypeA
,
float
>
&&
is_same_v
<
ComputeTypeB
,
float
>
)
is_same_v
<
ComputeTypeB
,
float
>
)
{
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_pad0_pipev5_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -438,7 +506,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -438,7 +506,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
ComputeTypeB
,
half_t
>
)
is_same_v
<
ComputeTypeB
,
half_t
>
)
{
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_pad0_pipev5_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -448,7 +522,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -448,7 +522,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
{
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_pad0_pipev5_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -461,7 +541,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -461,7 +541,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
float
>
&&
is_same_v
<
ComputeTypeA
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
&&
is_same_v
<
ComputeTypeA
,
float
>
&&
is_same_v
<
ComputeTypeB
,
float
>
)
is_same_v
<
ComputeTypeB
,
float
>
)
{
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -470,7 +556,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -470,7 +556,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
ComputeTypeB
,
half_t
>
)
is_same_v
<
ComputeTypeB
,
half_t
>
)
{
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances
(
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -486,7 +578,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -486,7 +578,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
{
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_pad0_pipev5_instances
(
op_ptrs
);
op_ptrs
);
}
}
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
...
@@ -495,7 +593,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -495,7 +593,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
{
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances
(
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -510,7 +614,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -510,7 +614,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
bf8_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
bf8_t
>
&&
is_same_v
<
ComputeTypeB
,
f8_t
>
)
is_same_v
<
ComputeTypeB
,
f8_t
>
)
{
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_default_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_default_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_pad0_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_pad0_pipev5_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
View file @
ebb5522c
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt
View file @
ebb5522c
# ONLY XDL_AND_DL_KERNELS
# ONLY XDL_AND_DL_KERNELS
set
(
GROUPED_CONV1D_BWD_WEIGHT
set
(
GROUPED_CONV1D_BWD_WEIGHT
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev2_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev5_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instance.cpp
)
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev2_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev5_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev2_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev5_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev2_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev5_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev2_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev5_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_pad0_pipev2_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_pad0_pipev5_instance.cpp
)
if
(
DL_KERNELS
)
if
(
DL_KERNELS
)
list
(
APPEND GROUPED_CONV1D_BWD_WEIGHT
list
(
APPEND GROUPED_CONV1D_BWD_WEIGHT
...
@@ -15,3 +24,4 @@ if(DL_KERNELS)
...
@@ -15,3 +24,4 @@ if(DL_KERNELS)
endif
()
endif
()
add_instance_library
(
device_grouped_conv1d_bwd_weight_instance
${
GROUPED_CONV1D_BWD_WEIGHT
}
)
add_instance_library
(
device_grouped_conv1d_bwd_weight_instance
${
GROUPED_CONV1D_BWD_WEIGHT
}
)
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_
default_pipev2_
instance.cpp
View file @
ebb5522c
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c)
2018-
2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
...
@@ -9,7 +9,7 @@ namespace tensor_operation {
...
@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances
(
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_
default_pipev2_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GNWC
,
GKXC
,
GKXC
,
...
@@ -21,7 +21,6 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_insta
...
@@ -21,7 +21,6 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_insta
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
PassThrough
>>>&
instances
)
{
{
// 1. Default
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances
<
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances
<
...
@@ -29,16 +28,9 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_insta
...
@@ -29,16 +28,9 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_insta
GNWC
,
GNWC
,
GKXC
,
GKXC
,
GNWK
,
GNWK
,
ConvBwdWeightDefault
>
{});
ConvBwdWeightDefault
,
// 2. Filter1x1Stride1Pad0
BlockGemmPipelineScheduler
::
Intrawave
,
add_device_operation_instances
(
BlockGemmPipelineVersion
::
v2
>
{});
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances
<
1
,
GNWC
,
GKXC
,
GNWK
,
ConvBwdWeightFilter1x1Stride1Pad0
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev5_instance.cpp
0 → 100644
View file @
ebb5522c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances
<
1
,
GNWC
,
GKXC
,
GNWK
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev2_instance.cpp
0 → 100644
View file @
ebb5522c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances
<
1
,
GNWC
,
GKXC
,
GNWK
,
ConvBwdWeightFilter1x1Stride1Pad0
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev5_instance.cpp
0 → 100644
View file @
ebb5522c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances
<
1
,
GNWC
,
GKXC
,
GNWK
,
ConvBwdWeightFilter1x1Stride1Pad0
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev2_instance.cpp
0 → 100644
View file @
ebb5522c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances
<
1
,
GNWC
,
GKXC
,
GNWK
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev5_instance.cpp
0 → 100644
View file @
ebb5522c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances
<
1
,
GNWC
,
GKXC
,
GNWK
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_
pad0_pipev2_
instance.cpp
View file @
ebb5522c
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
...
@@ -9,7 +9,7 @@ namespace tensor_operation {
...
@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances
(
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_
pad0_pipev2_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GNWC
,
GKXC
,
GKXC
,
...
@@ -21,22 +21,15 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
...
@@ -21,22 +21,15 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
PassThrough
>>>&
instances
)
{
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances
<
1
,
GNWC
,
GKXC
,
GNWK
,
ConvBwdWeightDefault
>
{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
instances
,
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances
<
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances
<
1
,
1
,
GNWC
,
GNWC
,
GKXC
,
GKXC
,
GNWK
,
GNWK
,
ConvBwdWeightFilter1x1Stride1Pad0
>
{});
ConvBwdWeightFilter1x1Stride1Pad0
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev5_instance.cpp
0 → 100644
View file @
ebb5522c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances
<
1
,
GNWC
,
GKXC
,
GNWK
,
ConvBwdWeightFilter1x1Stride1Pad0
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev2_instance.cpp
0 → 100644
View file @
ebb5522c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances
<
1
,
GNWC
,
GKXC
,
GNWK
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev5_instance.cpp
0 → 100644
View file @
ebb5522c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances
<
1
,
GNWC
,
GKXC
,
GNWK
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_
pad0_pipev2_
instance.cpp
View file @
ebb5522c
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
...
@@ -9,7 +9,7 @@ namespace tensor_operation {
...
@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances
(
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_
pad0_pipev2_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GNWC
,
GKXC
,
GKXC
,
...
@@ -21,22 +21,15 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
...
@@ -21,22 +21,15 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
PassThrough
>>>&
instances
)
{
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances
<
1
,
GNWC
,
GKXC
,
GNWK
,
ConvBwdWeightDefault
>
{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
instances
,
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances
<
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances
<
1
,
1
,
GNWC
,
GNWC
,
GKXC
,
GKXC
,
GNWK
,
GNWK
,
ConvBwdWeightFilter1x1Stride1Pad0
>
{});
ConvBwdWeightFilter1x1Stride1Pad0
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment