Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
40fabdcd
Unverified
Commit
40fabdcd
authored
Feb 22, 2023
by
zjing14
Committed by
GitHub
Feb 22, 2023
Browse files
Merge branch 'develop' into builtin_fastgelu
parents
9cc99065
246ceee4
Changes
9
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1923 additions
and
45 deletions
+1923
-45
CHANGELOG.md
CHANGELOG.md
+1
-0
example/20_grouped_conv_bwd_weight/CMakeLists.txt
example/20_grouped_conv_bwd_weight/CMakeLists.txt
+6
-0
example/20_grouped_conv_bwd_weight/common.hpp
example/20_grouped_conv_bwd_weight/common.hpp
+0
-1
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
...ouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
+59
-0
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
+42
-0
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
+42
-0
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
...d_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
+14
-43
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
...impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
+1216
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
...de/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
+543
-1
No files found.
CHANGELOG.md
View file @
40fabdcd
...
@@ -18,6 +18,7 @@ Full documentation for Composable Kernel is not yet available.
...
@@ -18,6 +18,7 @@ Full documentation for Composable Kernel is not yet available.
-
Added multi-D GEMM client APIs (#534).
-
Added multi-D GEMM client APIs (#534).
-
Added multi-embeddings support (#542).
-
Added multi-embeddings support (#542).
-
Added Navi3x blockwise GEMM and real GEMM support (#541).
-
Added Navi3x blockwise GEMM and real GEMM support (#541).
-
Added Navi grouped ConvBwdWeight support (#505).
### Changed
### Changed
-
Changed ...
-
Changed ...
example/20_grouped_conv_bwd_weight/CMakeLists.txt
View file @
40fabdcd
...
@@ -6,3 +6,9 @@ add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd
...
@@ -6,3 +6,9 @@ add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
example_grouped_conv_bwd_weight_xdl_bf16
)
example_grouped_conv_bwd_weight_xdl_bf16
)
add_custom_target
(
example_grouped_conv_bwd_weight_dl
)
add_example_executable
(
example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp
)
add_dependencies
(
example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16
)
example/20_grouped_conv_bwd_weight/common.hpp
View file @
40fabdcd
...
@@ -9,7 +9,6 @@
...
@@ -9,7 +9,6 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...
...
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
0 → 100644
View file @
40fabdcd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp"
using
InDataType
=
F16
;
using
WeiDataType
=
F16
;
using
OutDataType
=
F16
;
using
AccDataType
=
F32
;
using
InElementOp
=
PassThrough
;
using
WeiElementOp
=
PassThrough
;
using
OutElementOp
=
PassThrough
;
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
<
NDimSpatial
,
// NDimSpatial
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
AccDataType
,
// AccDataType
InElementOp
,
// InElementwiseOperation
WeiElementOp
,
// WeiElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
ConvBwdWeightDefault
,
// ConvBackwardWeightSpecialization
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
16
,
// K0PerBlock
2
,
// K1
4
,
// M1PerThread
4
,
// N1PerThread
1
,
// KPerThread
S
<
8
,
2
>
,
// M1N1ThreadClusterM1Xs
S
<
8
,
2
>
,
// M1N1ThreadClusterN1Xs
S
<
1
,
8
,
1
,
1
,
2
>
,
// ABlockTransferThreadSliceLengths_K0_M0_M1_K1
S
<
1
,
2
,
1
,
128
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M0_M1_K1
S
<
0
,
2
,
3
,
1
,
4
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
3
,
1
,
4
>
,
// ABlockTransferSrcAccessOrder
S
<
1
,
1
,
1
,
1
,
1
>
,
// ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
S
<
0
,
2
,
3
,
1
,
4
>
,
// ABlockTransferSrcVectorTensorContiguousDimOrder
S
<
1
,
1
,
1
,
1
,
1
>
,
// ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
S
<
1
,
1
,
1
,
8
,
2
>
,
// BBlockTransferThreadSliceLengths_K0_N0_N1_K1
S
<
1
,
16
,
1
,
16
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N0_N1_K1
S
<
0
,
1
,
4
,
2
,
3
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
1
,
4
,
2
,
3
>
,
// BBlockTransferSrcAccessOrder
S
<
1
,
1
,
1
,
8
,
1
>
,
// BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
S
<
0
,
1
,
4
,
2
,
3
>
,
// BBlockTransferSrcVectorTensorContiguousDimOrder
S
<
1
,
1
,
1
,
1
,
2
>
,
// BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// CThreadTransferSrcDstAccessOrder
5
,
// CThreadTransferSrcDstVectorDim
4
>
;
// CThreadTransferDstScalarPerVector
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
View file @
40fabdcd
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
#include "common.hpp"
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
using
InDataType
=
BF16
;
using
InDataType
=
BF16
;
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
using
WeiDataType
=
F32
;
using
WeiDataType
=
F32
;
...
@@ -13,6 +15,46 @@ using InElementOp = PassThrough;
...
@@ -13,6 +15,46 @@ using InElementOp = PassThrough;
using
WeiElementOp
=
PassThrough
;
using
WeiElementOp
=
PassThrough
;
using
OutElementOp
=
PassThrough
;
using
OutElementOp
=
PassThrough
;
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<
NDimSpatial
,
// NDimSpatial
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
AccDataType
,
// AccDataType
InElementOp
,
// InElementwiseOperation
WeiElementOp
,
// WeiElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
ConvBwdWeightDefault
,
// ConvolutionBackwardWeightSpecialization
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
4
,
// K0PerBlock
8
,
// K1
32
,
// MPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
2
,
// NXdlPerWave
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
2
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
S
<
1
,
4
,
16
,
4
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
2
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
4
>
,
// CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
128
/
(
sizeof
(
WeiDataType
)
*
CHAR_BIT
)
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
#include "run_grouped_conv_bwd_weight_example.inc"
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
View file @
40fabdcd
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
#include "common.hpp"
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
using
InDataType
=
F16
;
using
InDataType
=
F16
;
using
WeiDataType
=
F16
;
using
WeiDataType
=
F16
;
using
OutDataType
=
F16
;
using
OutDataType
=
F16
;
...
@@ -12,6 +14,46 @@ using InElementOp = PassThrough;
...
@@ -12,6 +14,46 @@ using InElementOp = PassThrough;
using
WeiElementOp
=
PassThrough
;
using
WeiElementOp
=
PassThrough
;
using
OutElementOp
=
PassThrough
;
using
OutElementOp
=
PassThrough
;
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<
NDimSpatial
,
// NDimSpatial
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
AccDataType
,
// AccDataType
InElementOp
,
// InElementwiseOperation
WeiElementOp
,
// WeiElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
ConvBwdWeightDefault
,
// ConvolutionBackwardWeightSpecialization
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
4
,
// K0PerBlock
8
,
// K1
32
,
// MPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
2
,
// NXdlPerWave
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
2
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
S
<
1
,
4
,
16
,
4
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
2
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
4
>
,
// CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
128
/
(
sizeof
(
WeiDataType
)
*
CHAR_BIT
)
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
#include "run_grouped_conv_bwd_weight_example.inc"
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
View file @
40fabdcd
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<
NDimSpatial
,
// NDimSpatial
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
AccDataType
,
// AccDataType
InElementOp
,
// InElementwiseOperation
WeiElementOp
,
// WeiElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
ConvBwdWeightDefault
,
// ConvolutionBackwardWeightSpecialization
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
4
,
// K0PerBlock
8
,
// K1
32
,
// MPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
2
,
// NXdlPerWave
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
2
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
S
<
1
,
4
,
16
,
4
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
2
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
4
>
,
// CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
128
/
(
sizeof
(
WeiDataType
)
*
CHAR_BIT
)
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
template
<
ck
::
index_t
NDimSpatial
>
template
<
ck
::
index_t
NDimSpatial
>
using
HostConvBwdWeightInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWeight
<
NDimSpatial
,
using
HostConvBwdWeightInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWeight
<
NDimSpatial
,
InDataType
,
InDataType
,
...
@@ -54,7 +14,18 @@ template <ck::index_t NDimSpatial>
...
@@ -54,7 +14,18 @@ template <ck::index_t NDimSpatial>
bool
run_grouped_conv_bwd_weight
(
const
ExecutionConfig
&
config
,
bool
run_grouped_conv_bwd_weight
(
const
ExecutionConfig
&
config
,
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
{
{
constexpr
ck
::
index_t
split_k
=
2
;
ck
::
index_t
split_k
;
// Set split_k = 2 for xdl op, split_k = 1 for dl
// Dl op doesn't support split_k > 1
// TODO: Add Dl op split_k > 1 support
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
))
{
split_k
=
2
;
}
else
{
split_k
=
1
;
}
const
auto
in_g_n_c_wis_desc
=
const
auto
in_g_n_c_wis_desc
=
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
...
@@ -144,7 +115,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
...
@@ -144,7 +115,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
std
::
cerr
<<
"wrong! device_conv with the specified compilation parameters does "
std
::
cerr
<<
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
"not support this Conv problem"
<<
std
::
endl
;
<<
std
::
endl
;
return
fals
e
;
return
tru
e
;
}
}
float
avg_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
float
avg_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
0 → 100644
View file @
40fabdcd
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
View file @
40fabdcd
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment