Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8f9c0243
Commit
8f9c0243
authored
Sep 22, 2023
by
Alan Turner
Browse files
Merge branch 'develop' into migx-jit-lib
parents
181ea79a
c8a8385f
Changes
609
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
267 additions
and
170 deletions
+267
-170
example/02_gemm_bilinear/CMakeLists.txt
example/02_gemm_bilinear/CMakeLists.txt
+2
-0
example/03_gemm_bias_relu/CMakeLists.txt
example/03_gemm_bias_relu/CMakeLists.txt
+2
-0
example/04_gemm_add_add_fastgelu/CMakeLists.txt
example/04_gemm_add_add_fastgelu/CMakeLists.txt
+17
-13
example/09_convnd_fwd/CMakeLists.txt
example/09_convnd_fwd/CMakeLists.txt
+26
-8
example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp
example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp
+1
-1
example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp
example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp
+1
-1
example/09_convnd_fwd/convnd_fwd_dl_int8.cpp
example/09_convnd_fwd/convnd_fwd_dl_int8.cpp
+1
-1
example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt
...e/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt
+16
-8
example/13_pool2d_fwd/CMakeLists.txt
example/13_pool2d_fwd/CMakeLists.txt
+6
-3
example/13_pool2d_fwd/pool2d_fwd_common.hpp
example/13_pool2d_fwd/pool2d_fwd_common.hpp
+24
-18
example/13_pool2d_fwd/pool2d_fwd_fp16.cpp
example/13_pool2d_fwd/pool2d_fwd_fp16.cpp
+32
-26
example/13_pool2d_fwd/pool2d_fwd_fp32.cpp
example/13_pool2d_fwd/pool2d_fwd_fp32.cpp
+32
-26
example/14_gemm_quantization/CMakeLists.txt
example/14_gemm_quantization/CMakeLists.txt
+6
-2
example/15_grouped_gemm/CMakeLists.txt
example/15_grouped_gemm/CMakeLists.txt
+21
-17
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
+27
-27
example/17_convnd_bwd_data/CMakeLists.txt
example/17_convnd_bwd_data/CMakeLists.txt
+6
-2
example/18_batched_gemm_reduce/CMakeLists.txt
example/18_batched_gemm_reduce/CMakeLists.txt
+2
-0
example/20_grouped_conv_bwd_weight/CMakeLists.txt
example/20_grouped_conv_bwd_weight/CMakeLists.txt
+15
-11
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
+15
-3
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
+15
-3
No files found.
example/02_gemm_bilinear/CMakeLists.txt
View file @
8f9c0243
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list1 gfx1100 gfx1101 gfx1102
)
list
(
APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
...
...
@@ -15,3 +16,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
endif
()
endforeach
()
endif
()
example/03_gemm_bias_relu/CMakeLists.txt
View file @
8f9c0243
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
...
@@ -6,3 +7,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
endif
()
endforeach
()
endif
()
example/04_gemm_add_add_fastgelu/CMakeLists.txt
View file @
8f9c0243
...
...
@@ -3,22 +3,26 @@ set(target 0)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_gemm_add_add_fastgelu_xdl
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp
)
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8
)
endif
()
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
example/09_convnd_fwd/CMakeLists.txt
View file @
8f9c0243
...
...
@@ -2,16 +2,34 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_example_executable
(
example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp
)
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp
)
endif
()
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing
(
example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp
)
if
(
DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
add_example_executable_no_testing
(
example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp
)
endif
()
set
(
target 1
)
endif
()
endforeach
()
add_example_executable
(
example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp
)
add_example_executable
(
example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp
)
if
(
DL_KERNELS
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp
)
endif
()
endif
()
example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp
View file @
8f9c0243
...
...
@@ -3,7 +3,7 @@
#include "convnd_fwd_dl_common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
...
...
example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp
View file @
8f9c0243
...
...
@@ -3,7 +3,7 @@
#include "convnd_fwd_dl_common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
...
...
example/09_convnd_fwd/convnd_fwd_dl_int8.cpp
View file @
8f9c0243
...
...
@@ -3,7 +3,7 @@
#include "convnd_fwd_dl_common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
...
...
example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt
View file @
8f9c0243
...
...
@@ -3,14 +3,22 @@ set(target 0)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_convnd_fwd_reduce_xdl
)
add_example_executable
(
example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp
)
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp
)
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4
)
...
...
example/13_pool2d_fwd/CMakeLists.txt
View file @
8f9c0243
add_example_executable
(
example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp
)
add_example_executable
(
example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp
)
endif
()
example/13_pool2d_fwd/pool2d_fwd_common.hpp
View file @
8f9c0243
...
...
@@ -39,31 +39,35 @@ bool pool_test(bool do_verification,
ck
::
index_t
Wi
,
ck
::
index_t
window_stride_h
,
ck
::
index_t
window_stride_w
,
ck
::
index_t
window_dilation_h
,
ck
::
index_t
window_dilation_w
,
ck
::
index_t
in_left_pad_h
,
ck
::
index_t
in_left_pad_w
,
ck
::
index_t
in_right_pad_h
,
ck
::
index_t
in_right_pad_w
)
{
using
DevicePoolFwdInstance
=
ck
::
tensor_operation
::
device
::
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
<
InDataType
,
// InDataType
OutDataType
,
// OutDataType
IndexDataType
,
// IndexDataType
ComputeDataType
,
// ComputeDataType
ReduceOpId
,
OutputIndex
,
64
,
// BlockSize
64
,
// ReduceMThreadClusterSize
1
,
// ReduceKThreadClusterSize
4
,
// ReduceMThreadSliceSize
1
,
// ReduceKThreadSliceSize
4
>
;
// InSrcOutDstVectorSize
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Y
)
/
window_stride_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
X
)
/
window_stride_w
+
1
;
ck
::
tensor_operation
::
device
::
DevicePool2dFwd_NHWC_NHWC
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
64
,
// BlockSize
64
,
// ReduceMThreadClusterSize
1
,
// ReduceKThreadClusterSize
4
,
// ReduceMThreadSliceSize
1
,
// ReduceKThreadSliceSize
1
>
;
// InSrcOutDstVectorSize
const
ck
::
index_t
Ys
=
(
Y
-
1
)
*
window_dilation_h
+
1
;
const
ck
::
index_t
Xs
=
(
X
-
1
)
*
window_dilation_w
+
1
;
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Ys
)
/
window_stride_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
Xs
)
/
window_stride_w
+
1
;
const
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
{
Y
,
X
};
const
std
::
vector
<
ck
::
index_t
>
window_strides
{
window_stride_h
,
window_stride_w
};
const
std
::
vector
<
ck
::
index_t
>
window_dilations
{
window_dilation_h
,
window_dilation_w
};
const
std
::
vector
<
ck
::
index_t
>
input_left_pads
{
in_left_pad_h
,
in_left_pad_w
};
const
std
::
vector
<
ck
::
index_t
>
input_right_pads
{
in_right_pad_h
,
in_right_pad_w
};
...
...
@@ -123,6 +127,7 @@ bool pool_test(bool do_verification,
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
window_strides
,
window_dilations
,
input_left_pads
,
input_right_pads
,
{
2
,
3
});
...
...
@@ -144,8 +149,8 @@ bool pool_test(bool do_verification,
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB / s "
<<
std
::
endl
;
bool
pass
=
true
;
...
...
@@ -169,6 +174,7 @@ bool pool_test(bool do_verification,
out_indices_n_c_ho_wo_host
,
window_spatial_lengths
,
window_strides
,
window_dilations
,
input_left_pads
,
input_right_pads
);
...
...
example/13_pool2d_fwd/pool2d_fwd_fp16.cpp
View file @
8f9c0243
...
...
@@ -34,18 +34,20 @@ int main(int argc, char* argv[])
bool
time_kernel
;
// Pool shape
ck
::
index_t
N
=
128
;
ck
::
index_t
C
=
192
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Wi
=
71
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
N
=
128
;
ck
::
index_t
C
=
192
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Wi
=
71
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
window_dilation_h
=
1
;
ck
::
index_t
window_dilation_w
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
if
(
argc
==
1
)
{
...
...
@@ -59,31 +61,33 @@ int main(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
static_cast
<
bool
>
(
std
::
stoi
(
argv
[
3
]));
}
else
if
(
argc
==
1
6
)
else
if
(
argc
==
1
8
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
static_cast
<
bool
>
(
std
::
stoi
(
argv
[
3
]));
N
=
std
::
stoi
(
argv
[
4
]);
C
=
std
::
stoi
(
argv
[
5
]);
Y
=
std
::
stoi
(
argv
[
6
]);
X
=
std
::
stoi
(
argv
[
7
]);
Hi
=
std
::
stoi
(
argv
[
8
]);
Wi
=
std
::
stoi
(
argv
[
9
]);
window_stride_h
=
std
::
stoi
(
argv
[
10
]);
window_stride_w
=
std
::
stoi
(
argv
[
11
]);
in_left_pad_h
=
std
::
stoi
(
argv
[
12
]);
in_left_pad_w
=
std
::
stoi
(
argv
[
13
]);
in_right_pad_h
=
std
::
stoi
(
argv
[
14
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
15
]);
N
=
std
::
stoi
(
argv
[
4
]);
C
=
std
::
stoi
(
argv
[
5
]);
Y
=
std
::
stoi
(
argv
[
6
]);
X
=
std
::
stoi
(
argv
[
7
]);
Hi
=
std
::
stoi
(
argv
[
8
]);
Wi
=
std
::
stoi
(
argv
[
9
]);
window_stride_h
=
std
::
stoi
(
argv
[
10
]);
window_stride_w
=
std
::
stoi
(
argv
[
11
]);
window_dilation_h
=
std
::
stoi
(
argv
[
12
]);
window_dilation_w
=
std
::
stoi
(
argv
[
13
]);
in_left_pad_h
=
std
::
stoi
(
argv
[
14
]);
in_left_pad_w
=
std
::
stoi
(
argv
[
15
]);
in_right_pad_h
=
std
::
stoi
(
argv
[
16
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
17
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
printf
(
"arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx,
Dy, Dx,
LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
exit
(
0
);
}
...
...
@@ -107,6 +111,8 @@ int main(int argc, char* argv[])
Wi
,
window_stride_h
,
window_stride_w
,
window_dilation_h
,
window_dilation_w
,
in_left_pad_h
,
in_left_pad_w
,
in_right_pad_h
,
...
...
example/13_pool2d_fwd/pool2d_fwd_fp32.cpp
View file @
8f9c0243
...
...
@@ -34,18 +34,20 @@ int main(int argc, char* argv[])
bool
time_kernel
;
// Pool shape
ck
::
index_t
N
=
128
;
ck
::
index_t
C
=
192
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Wi
=
71
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
N
=
128
;
ck
::
index_t
C
=
192
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Wi
=
71
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
window_dilation_h
=
1
;
ck
::
index_t
window_dilation_w
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
if
(
argc
==
1
)
{
...
...
@@ -59,31 +61,33 @@ int main(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
static_cast
<
bool
>
(
std
::
stoi
(
argv
[
3
]));
}
else
if
(
argc
==
1
6
)
else
if
(
argc
==
1
8
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
static_cast
<
bool
>
(
std
::
stoi
(
argv
[
3
]));
N
=
std
::
stoi
(
argv
[
4
]);
C
=
std
::
stoi
(
argv
[
5
]);
Y
=
std
::
stoi
(
argv
[
6
]);
X
=
std
::
stoi
(
argv
[
7
]);
Hi
=
std
::
stoi
(
argv
[
8
]);
Wi
=
std
::
stoi
(
argv
[
9
]);
window_stride_h
=
std
::
stoi
(
argv
[
10
]);
window_stride_w
=
std
::
stoi
(
argv
[
11
]);
in_left_pad_h
=
std
::
stoi
(
argv
[
12
]);
in_left_pad_w
=
std
::
stoi
(
argv
[
13
]);
in_right_pad_h
=
std
::
stoi
(
argv
[
14
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
15
]);
N
=
std
::
stoi
(
argv
[
4
]);
C
=
std
::
stoi
(
argv
[
5
]);
Y
=
std
::
stoi
(
argv
[
6
]);
X
=
std
::
stoi
(
argv
[
7
]);
Hi
=
std
::
stoi
(
argv
[
8
]);
Wi
=
std
::
stoi
(
argv
[
9
]);
window_stride_h
=
std
::
stoi
(
argv
[
10
]);
window_stride_w
=
std
::
stoi
(
argv
[
11
]);
window_dilation_h
=
std
::
stoi
(
argv
[
12
]);
window_dilation_w
=
std
::
stoi
(
argv
[
13
]);
in_left_pad_h
=
std
::
stoi
(
argv
[
14
]);
in_left_pad_w
=
std
::
stoi
(
argv
[
15
]);
in_right_pad_h
=
std
::
stoi
(
argv
[
16
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
17
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
printf
(
"arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx,
Dy, Dx,
LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
exit
(
0
);
}
...
...
@@ -107,6 +111,8 @@ int main(int argc, char* argv[])
Wi
,
window_stride_h
,
window_stride_w
,
window_dilation_h
,
window_dilation_w
,
in_left_pad_h
,
in_left_pad_w
,
in_right_pad_h
,
...
...
example/14_gemm_quantization/CMakeLists.txt
View file @
8f9c0243
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
# dlops
add_example_executable
(
example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp
)
if
(
DL_KERNELS
)
add_example_executable
(
example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp
)
endif
()
# xdlops
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
...
...
@@ -10,4 +13,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable
(
example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp
)
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
endforeach
()
endif
()
\ No newline at end of file
example/15_grouped_gemm/CMakeLists.txt
View file @
8f9c0243
add_custom_target
(
example_grouped_gemm_xdl
)
add_example_executable
(
example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp
)
add_example_executable
(
example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp
)
add_dependencies
(
example_grouped_gemm_xdl
example_grouped_gemm_xdl_fp32
example_grouped_gemm_xdl_fp16
example_grouped_gemm_xdl_bfp16
example_grouped_gemm_xdl_int8
example_grouped_gemm_multiple_d_dl_fp16
example_grouped_gemm_xdl_splitk_fp16
)
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp
)
add_dependencies
(
example_grouped_gemm_xdl
example_grouped_gemm_xdl_fp16
example_grouped_gemm_multiple_d_dl_fp16
example_grouped_gemm_xdl_splitk_fp16
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_bfp16
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int4
)
...
...
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
View file @
8f9c0243
...
...
@@ -6,33 +6,33 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_custom_target
(
example_gemm_reduce_xdl_max
)
add_custom_target
(
example_gemm_reduce_xdl_mean_meansquare
)
add_custom_target
(
example_gemm_add_add_mean_meansquare_xdl
)
add_example_executable
(
example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_
max
_xdl_
int8
gemm_
max
_xdl_
int8
.cpp
)
add_example_executable
(
example_gemm_m
ax
_xdl_fp
32
gemm_m
ax
_xdl_fp
32
.cpp
)
add_
example_executable
(
example_gemm_max_xdl_bf16
gemm_max_xdl_
bf16.cpp
)
add_
example_executable
(
example_gemm_
add_add
_mean_meansquare
_xdl_fp16 gemm_add_add
_mean_meansquare_xdl_fp16
.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_m
ean_meansquare
_xdl_
fp32
gemm_m
ean_meansquare
_xdl_
fp32
.cpp
)
add_example_executable
(
example_gemm_
mean_mean
square_xdl_
bf16
gemm_
mean_mean
square_xdl_
bf16
.cpp
)
add_
example_executable
(
example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare
_xdl_int8
.cpp
)
add_dependencies
(
example_gemm_reduce_xdl_max
example_gemm_max_xdl_bf16
example_gemm_max_xdl_fp16
example_gemm_max
_xdl_fp32
example_gemm_max_xdl_
int8
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare
example_gemm_mean_meansquare_xdl_fp16
example_gemm_mean_meansquare_xdl_fp32
example_gemm_mean_meansquare_xdl_bf16
example_gemm_add_addsquare
_xdl_
int8
)
add_dependencies
(
example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_
add_add_mean_meansquare
_xdl_
fp16
gemm_
add_add_mean_meansquare
_xdl_
fp16
.cpp
)
add_example_executable
(
example_gemm_m
ean_meansquare
_xdl_fp
16
gemm_m
ean_meansquare
_xdl_fp
16
.cpp
)
add_
dependencies
(
example_gemm_reduce_xdl_max example_
gemm_max_xdl_
fp16
)
add_dependencies
(
example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16
)
add_
dependencies
(
example_gemm_
reduce_xdl
_mean_meansquare
example_gemm
_mean_meansquare_xdl_fp16
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_m
ax
_xdl_
int8
gemm_m
ax
_xdl_
int8
.cpp
)
add_example_executable
(
example_gemm_
add_add
square_xdl_
int8
gemm_
add_add
square_xdl_
int8
.cpp
)
add_
dependencies
(
example_gemm_reduce_xdl_max example_gemm_max
_xdl_int8
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare
_xdl_fp32
.cpp
)
add_dependencies
(
example_gemm_reduce_xdl_max
example_gemm_max_xdl_
fp32
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_bf16
gemm_mean_meansquare_xdl_bf16.cpp
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max
_xdl_
bf16
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16
)
endif
(
)
add_dependencies
(
example_gemm_reduce_xdl
example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
...
...
example/17_convnd_bwd_data/CMakeLists.txt
View file @
8f9c0243
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
...
@@ -7,5 +8,8 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
endif
()
endforeach
()
add_example_executable
(
example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp
)
target_link_libraries
(
example_convnd_bwd_data_dl_fp16 PRIVATE utility
)
if
(
DL_KERNELS
)
add_example_executable
(
example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp
)
target_link_libraries
(
example_convnd_bwd_data_dl_fp16 PRIVATE utility
)
endif
()
endif
()
example/18_batched_gemm_reduce/CMakeLists.txt
View file @
8f9c0243
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
...
@@ -6,3 +7,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
endif
()
endforeach
()
endif
()
example/20_grouped_conv_bwd_weight/CMakeLists.txt
View file @
8f9c0243
...
...
@@ -3,18 +3,22 @@ set(target 0)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_grouped_conv_bwd_weight
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp
)
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
example_grouped_conv_bwd_weight_xdl_bf16
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp
)
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp
)
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16
)
endif
()
set
(
target 1
)
endif
()
endforeach
()
add_custom_target
(
example_grouped_conv_bwd_weight_dl
)
add_example_executable
(
example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp
)
add_dependencies
(
example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
if
(
DL_KERNELS
)
add_custom_target
(
example_grouped_conv_bwd_weight_dl
)
add_example_executable
(
example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp
)
add_dependencies
(
example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16
)
endif
()
endif
()
\ No newline at end of file
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
View file @
8f9c0243
...
...
@@ -3,7 +3,7 @@
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_
gnwc_gkxc_gnwk_
xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
using
InDataType
=
BF16
;
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
...
...
@@ -17,8 +17,20 @@ using OutElementOp = PassThrough;
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<
NDimSpatial
,
// NDimSpatial
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
ck
::
tensor_layout
::
convolution
::
GNHWC
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GKXC
,
ck
::
tensor_layout
::
convolution
::
GKYXC
,
ck
::
tensor_layout
::
convolution
::
GKZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
ck
::
tensor_layout
::
convolution
::
GNHWK
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
>>
,
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
...
...
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
View file @
8f9c0243
...
...
@@ -3,7 +3,7 @@
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_
gnwc_gkxc_gnwk_
xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
using
InDataType
=
F16
;
using
WeiDataType
=
F16
;
...
...
@@ -16,8 +16,20 @@ using OutElementOp = PassThrough;
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<
NDimSpatial
,
// NDimSpatial
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
ck
::
tensor_layout
::
convolution
::
GNHWC
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GKXC
,
ck
::
tensor_layout
::
convolution
::
GKYXC
,
ck
::
tensor_layout
::
convolution
::
GKZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
ck
::
tensor_layout
::
convolution
::
GNHWK
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
>>
,
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
...
...
Prev
1
2
3
4
5
6
7
…
31
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment