Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c9013009
Commit
c9013009
authored
Sep 25, 2023
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
114c2646
84dcf5d0
Changes
267
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
184 additions
and
182 deletions
+184
-182
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
+29
-14
example/17_convnd_bwd_data/CMakeLists.txt
example/17_convnd_bwd_data/CMakeLists.txt
+7
-6
example/18_batched_gemm_reduce/CMakeLists.txt
example/18_batched_gemm_reduce/CMakeLists.txt
+0
-2
example/20_grouped_conv_bwd_weight/CMakeLists.txt
example/20_grouped_conv_bwd_weight/CMakeLists.txt
+9
-11
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
...ouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
+50
-39
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
...d_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
+1
-13
example/21_gemm_layernorm/CMakeLists.txt
example/21_gemm_layernorm/CMakeLists.txt
+1
-2
example/22_cgemm/CMakeLists.txt
example/22_cgemm/CMakeLists.txt
+8
-8
example/24_batched_gemm/CMakeLists.txt
example/24_batched_gemm/CMakeLists.txt
+12
-10
example/24_batched_gemm/batched_gemm_xdl_bf16.cpp
example/24_batched_gemm/batched_gemm_xdl_bf16.cpp
+0
-0
example/25_gemm_bias_e_permute/CMakeLists.txt
example/25_gemm_bias_e_permute/CMakeLists.txt
+2
-4
example/26_contraction/CMakeLists.txt
example/26_contraction/CMakeLists.txt
+4
-8
example/27_layernorm/CMakeLists.txt
example/27_layernorm/CMakeLists.txt
+2
-4
example/28_grouped_gemm_bias_e_permute/CMakeLists.txt
example/28_grouped_gemm_bias_e_permute/CMakeLists.txt
+1
-3
example/29_batched_gemm_bias_e_permute/CMakeLists.txt
example/29_batched_gemm_bias_e_permute/CMakeLists.txt
+3
-5
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
+16
-16
example/31_batched_gemm_gemm/CMakeLists.txt
example/31_batched_gemm_gemm/CMakeLists.txt
+5
-13
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+21
-14
example/35_splitK_gemm/CMakeLists.txt
example/35_splitK_gemm/CMakeLists.txt
+13
-10
example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp
example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp
+0
-0
No files found.
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
View file @
c9013009
...
@@ -6,30 +6,43 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -6,30 +6,43 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_custom_target
(
example_gemm_reduce_xdl_max
)
add_custom_target
(
example_gemm_reduce_xdl_max
)
add_custom_target
(
example_gemm_reduce_xdl_mean_meansquare
)
add_custom_target
(
example_gemm_reduce_xdl_mean_meansquare
)
add_custom_target
(
example_gemm_add_add_mean_meansquare_xdl
)
add_custom_target
(
example_gemm_add_add_mean_meansquare_xdl
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_executable
(
example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16
)
endif
()
add_example_executable
(
example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16
)
add_dependencies
(
example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16
)
endif
()
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16
)
endif
()
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp
)
add_example_executable
(
example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp
)
add_example_executable
(
example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_int8
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_int8
)
endif
()
add_example_executable
(
example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp
)
add_example_executable
(
example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32
)
endif
()
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32
)
endif
()
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16
)
endif
()
add_example_executable
(
example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16
)
endif
()
endif
()
...
@@ -40,7 +53,9 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -40,7 +53,9 @@ foreach(gpu IN LISTS GPU_TARGETS)
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp
)
add_example_executable
(
example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_int4
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_int4
)
endif
()
endif
()
endif
()
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
...
...
example/17_convnd_bwd_data/CMakeLists.txt
View file @
c9013009
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_example_executable
(
example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp
)
add_example_executable
(
example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp
)
target_link_libraries
(
example_convnd_bwd_data_xdl_fp16 PRIVATE utility
)
if
(
result EQUAL 0
)
target_link_libraries
(
example_convnd_bwd_data_xdl_fp16 PRIVATE utility
)
endif
()
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
if
(
DL_KERNELS
)
add_example_executable
(
example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp
)
add_example_executable
(
example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp
)
target_link_libraries
(
example_convnd_bwd_data_dl_fp16 PRIVATE utility
)
if
(
result EQUAL 0
)
endif
(
)
target_link_libraries
(
example_convnd_bwd_data_dl_fp16 PRIVATE utility
)
endif
()
endif
()
example/18_batched_gemm_reduce/CMakeLists.txt
View file @
c9013009
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
@@ -7,4 +6,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -7,4 +6,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
endif
()
example/20_grouped_conv_bwd_weight/CMakeLists.txt
View file @
c9013009
...
@@ -3,22 +3,20 @@ set(target 0)
...
@@ -3,22 +3,20 @@ set(target 0)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_grouped_conv_bwd_weight
)
add_custom_target
(
example_grouped_conv_bwd_weight
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
)
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
)
endif
()
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16
)
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16
)
endif
()
endif
()
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_custom_target
(
example_grouped_conv_bwd_weight_dl
)
if
(
DL_KERNELS
)
add_example_executable
(
example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp
)
add_custom_target
(
example_grouped_conv_bwd_weight_dl
)
if
(
result EQUAL 0
)
add_example_executable
(
example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp
)
add_dependencies
(
example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16
)
add_dependencies
(
example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16
)
endif
()
endif
()
endif
()
\ No newline at end of file
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
View file @
c9013009
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "common.hpp"
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_
gnwc_gkxc_gnwk_
dl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
using
InDataType
=
F16
;
using
InDataType
=
F16
;
using
WeiDataType
=
F16
;
using
WeiDataType
=
F16
;
...
@@ -15,44 +15,55 @@ using WeiElementOp = PassThrough;
...
@@ -15,44 +15,55 @@ using WeiElementOp = PassThrough;
using
OutElementOp
=
PassThrough
;
using
OutElementOp
=
PassThrough
;
template
<
ck
::
index_t
NDimSpatial
>
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight_Dl
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
<
NDimSpatial
,
// NDimSpatial
NDimSpatial
,
// NDimSpatial
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
InDataType
,
// InDataType
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
WeiDataType
,
// WeiDataType
ck
::
tensor_layout
::
convolution
::
GNHWC
,
OutDataType
,
// OutDataType
ck
::
tensor_layout
::
convolution
::
GNDHWC
>>
,
// InLayout
AccDataType
,
// AccDataType
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
InElementOp
,
// InElementwiseOperation
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GKXC
,
WeiElementOp
,
// WeiElementwiseOperation
ck
::
tensor_layout
::
convolution
::
GKYXC
,
OutElementOp
,
// OutElementwiseOperation
ck
::
tensor_layout
::
convolution
::
GKZYXC
>>
,
// WeiLayout
ConvBwdWeightDefault
,
// ConvBackwardWeightSpecialization
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
256
,
// BlockSize
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
128
,
// MPerBlock
ck
::
tensor_layout
::
convolution
::
GNHWK
,
128
,
// NPerBlock
ck
::
tensor_layout
::
convolution
::
GNDHWK
>>
,
// OutLayout
16
,
// K0PerBlock
InDataType
,
// InDataType
2
,
// K1
WeiDataType
,
// WeiDataType
4
,
// M1PerThread
OutDataType
,
// OutDataType
4
,
// N1PerThread
AccDataType
,
// AccDataType
1
,
// KPerThread
InElementOp
,
// InElementwiseOperation
S
<
8
,
2
>
,
// M1N1ThreadClusterM1Xs
WeiElementOp
,
// WeiElementwiseOperation
S
<
8
,
2
>
,
// M1N1ThreadClusterN1Xs
OutElementOp
,
// OutElementwiseOperation
S
<
1
,
8
,
1
,
1
,
2
>
,
// ABlockTransferThreadSliceLengths_K0_M0_M1_K1
ConvBwdWeightDefault
,
// ConvBackwardWeightSpecialization
S
<
1
,
2
,
1
,
128
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M0_M1_K1
256
,
// BlockSize
S
<
0
,
2
,
3
,
1
,
4
>
,
// ABlockTransferThreadClusterArrangeOrder
128
,
// MPerBlock
S
<
0
,
2
,
3
,
1
,
4
>
,
// ABlockTransferSrcAccessOrder
128
,
// NPerBlock
S
<
1
,
1
,
1
,
1
,
1
>
,
// ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
16
,
// K0PerBlock
S
<
0
,
2
,
3
,
1
,
4
>
,
// ABlockTransferSrcVectorTensorContiguousDimOrder
2
,
// K1
S
<
1
,
1
,
1
,
1
,
1
>
,
// ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
4
,
// M1PerThread
S
<
1
,
1
,
1
,
8
,
2
>
,
// BBlockTransferThreadSliceLengths_K0_N0_N1_K1
4
,
// N1PerThread
S
<
1
,
16
,
1
,
16
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N0_N1_K1
1
,
// KPerThread
S
<
0
,
1
,
4
,
2
,
3
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
8
,
2
>
,
// M1N1ThreadClusterM1Xs
S
<
0
,
1
,
4
,
2
,
3
>
,
// BBlockTransferSrcAccessOrder
S
<
8
,
2
>
,
// M1N1ThreadClusterN1Xs
S
<
1
,
1
,
1
,
8
,
1
>
,
// BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
S
<
1
,
8
,
1
,
1
,
2
>
,
// ABlockTransferThreadSliceLengths_K0_M0_M1_K1
S
<
0
,
1
,
4
,
2
,
3
>
,
// BBlockTransferSrcVectorTensorContiguousDimOrder
S
<
1
,
2
,
1
,
128
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M0_M1_K1
S
<
1
,
1
,
1
,
1
,
2
>
,
// BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
S
<
0
,
2
,
3
,
1
,
4
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// CThreadTransferSrcDstAccessOrder
S
<
0
,
2
,
3
,
1
,
4
>
,
// ABlockTransferSrcAccessOrder
5
,
// CThreadTransferSrcDstVectorDim
S
<
1
,
1
,
1
,
1
,
1
>
,
// ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
4
>
;
// CThreadTransferDstScalarPerVector
S
<
0
,
2
,
3
,
1
,
4
>
,
// ABlockTransferSrcVectorTensorContiguousDimOrder
S
<
1
,
1
,
1
,
1
,
1
>
,
// ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
S
<
1
,
1
,
1
,
8
,
2
>
,
// BBlockTransferThreadSliceLengths_K0_N0_N1_K1
S
<
1
,
16
,
1
,
16
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N0_N1_K1
S
<
0
,
1
,
4
,
2
,
3
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
1
,
4
,
2
,
3
>
,
// BBlockTransferSrcAccessOrder
S
<
1
,
1
,
1
,
8
,
1
>
,
// BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
S
<
0
,
1
,
4
,
2
,
3
>
,
// BBlockTransferSrcVectorTensorContiguousDimOrder
S
<
1
,
1
,
1
,
1
,
2
>
,
// BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// CThreadTransferSrcDstAccessOrder
5
,
// CThreadTransferSrcDstVectorDim
4
>
;
// CThreadTransferDstScalarPerVector
#include "run_grouped_conv_bwd_weight_example.inc"
#include "run_grouped_conv_bwd_weight_example.inc"
...
...
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
View file @
c9013009
...
@@ -14,20 +14,8 @@ template <ck::index_t NDimSpatial>
...
@@ -14,20 +14,8 @@ template <ck::index_t NDimSpatial>
bool
run_grouped_conv_bwd_weight
(
const
ExecutionConfig
&
config
,
bool
run_grouped_conv_bwd_weight
(
const
ExecutionConfig
&
config
,
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
{
{
ck
::
index_t
split_k
;
// Set split_k = 2 for xdl op, split_k = 1 for dl
// Dl op doesn't support split_k > 1
// Dl op doesn't support split_k > 1
// TODO: Add Dl op split_k > 1 support
constexpr
ck
::
index_t
split_k
=
1
;
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
||
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
))
{
split_k
=
2
;
}
else
{
split_k
=
1
;
}
const
auto
in_g_n_c_wis_desc
=
const
auto
in_g_n_c_wis_desc
=
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
...
...
example/21_gemm_layernorm/CMakeLists.txt
View file @
c9013009
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
@@ -10,4 +9,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -10,4 +9,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
endif
()
example/22_cgemm/CMakeLists.txt
View file @
c9013009
add_custom_target
(
example_cgemm_xdl
)
add_custom_target
(
example_cgemm_xdl
)
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp
)
add_example_executable
(
example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_bf16
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_bf16
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp
)
add_example_executable
(
example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_fp16
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_fp16
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp
)
add_example_executable
(
example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_fp32
)
if
(
result EQUAL 0
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_fp32
)
endif
()
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_cgemm_xdl_int8 cgemm_xdl_int8.cpp
)
add_example_executable
(
example_cgemm_xdl_int8 cgemm_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_int8
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_int8
)
endif
()
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
...
...
example/24_batched_gemm/CMakeLists.txt
View file @
c9013009
add_custom_target
(
example_batched_gemm_xdl
)
add_custom_target
(
example_batched_gemm_xdl
)
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp
)
add_example_executable
(
example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_fp32
)
add_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_fp32
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_fp16
)
add_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_fp16
)
endif
()
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp
)
add_example_executable
(
example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_bf
p
16
)
add_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_bf16
)
endif
()
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp
)
add_example_executable
(
example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_int8
)
add_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_int8
)
endif
()
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp
)
add_example_executable
(
example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp
)
add_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_int4
)
if
(
result EQUAL 0
)
add_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_int4
)
endif
()
endif
()
endif
()
example/24_batched_gemm/batched_gemm_xdl_bf
p
16.cpp
→
example/24_batched_gemm/batched_gemm_xdl_bf16.cpp
View file @
c9013009
File moved
example/25_gemm_bias_e_permute/CMakeLists.txt
View file @
c9013009
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
)
endif
()
example/26_contraction/CMakeLists.txt
View file @
c9013009
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp
)
add_example_executable
(
example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp
)
add_example_executable
(
example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp
)
add_example_executable
(
example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp
)
add_example_executable
(
example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp
)
endif
()
add_example_executable
(
example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp
)
if
(
DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp
)
add_example_executable
(
example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp
)
endif
()
example/27_layernorm/CMakeLists.txt
View file @
c9013009
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_layernorm_fp16 layernorm_fp16.cpp
)
add_example_executable
(
example_layernorm_fp16 layernorm_fp16.cpp
)
add_example_executable
(
example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp
)
add_example_executable
(
example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp
)
endif
()
example/28_grouped_gemm_bias_e_permute/CMakeLists.txt
View file @
c9013009
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp
)
endif
()
example/29_batched_gemm_bias_e_permute/CMakeLists.txt
View file @
c9013009
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102"
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102"
)
add_example_executable
(
example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp
)
add_example_executable
(
example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp
)
endif
()
endif
()
endif
()
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
View file @
c9013009
...
@@ -5,27 +5,31 @@ set(target 0)
...
@@ -5,27 +5,31 @@ set(target 0)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list1 AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list1 AND target EQUAL 0
)
add_custom_target
(
example_grouped_conv_fwd_multiple_d
)
add_custom_target
(
example_grouped_conv_fwd_multiple_d
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16
)
add_example_executable
(
example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp
)
endif
()
add_example_executable
(
example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32
)
endif
()
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16
)
endif
()
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8
)
endif
()
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4
)
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4
)
endif
()
endif
()
# USE_BITINT_EXTENSION_INT4
endif
()
# USE_BITINT_EXTENSION_INT4
set
(
target 1
)
set
(
target 1
)
...
@@ -35,12 +39,8 @@ endforeach()
...
@@ -35,12 +39,8 @@ endforeach()
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list2 AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list2 AND target EQUAL 0
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp
)
endif
()
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
example/31_batched_gemm_gemm/CMakeLists.txt
View file @
c9013009
list
(
APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list2 gfx908 gfx90a
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list1 AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list1 AND target EQUAL 0
)
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp
)
add_example_executable
(
example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp
)
add_example_executable
(
example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp
)
endif
()
add_example_executable
(
example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp
)
add_example_executable
(
example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp
)
endif
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
...
@@ -20,7 +14,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -20,7 +14,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
endforeach
()
endforeach
()
if
(
NOT GPU_TARGETS MATCHES
"gfx94"
AND NOT GPU_TARGETS MATCHES
"gfx1"
)
if
(
NOT GPU_TARGETS MATCHES
"gfx94"
AND NOT GPU_TARGETS MATCHES
"gfx1"
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp
)
add_example_executable
(
example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp
)
endif
()
endif
()
endif
()
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
c9013009
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp
)
endif
()
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
endif
()
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16
)
endif
()
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16
)
endif
()
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16
)
endif
()
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16
)
endif
()
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16
)
endif
()
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
)
endif
()
endif
()
example/35_splitK_gemm/CMakeLists.txt
View file @
c9013009
...
@@ -3,25 +3,28 @@ set(target 0)
...
@@ -3,25 +3,28 @@ set(target 0)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_splitK_gemm_xdl
)
add_custom_target
(
example_splitK_gemm_xdl
)
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp
)
add_example_executable
(
example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32
)
add_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16
)
add_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16
)
endif
()
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp
)
add_example_executable
(
example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_bf
p
16
)
add_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16
)
endif
()
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp
)
add_example_executable
(
example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_int8
)
add_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_int8
)
endif
()
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp
)
add_example_executable
(
example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp
)
add_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_int4
)
if
(
result EQUAL 0
)
add_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_int4
)
endif
()
endif
()
endif
()
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
...
...
example/35_splitK_gemm/splitK_gemm_xdl_bf
p
16.cpp
→
example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp
View file @
c9013009
File moved
Prev
1
2
3
4
5
6
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment