Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ac76519a
Unverified
Commit
ac76519a
authored
Aug 10, 2023
by
Adam Osewski
Committed by
GitHub
Aug 10, 2023
Browse files
Merge branch 'develop' into aosewski/gemm_tile_loop
parents
a70c6283
578142db
Changes
174
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
317 additions
and
235 deletions
+317
-235
CMakeLists.txt
CMakeLists.txt
+8
-0
CONTRIBUTORS.md
CONTRIBUTORS.md
+4
-2
Jenkinsfile
Jenkinsfile
+2
-2
client_example/11_grouped_conv_bwd_weight/common.hpp
client_example/11_grouped_conv_bwd_weight/common.hpp
+41
-62
client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp
...rouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp
+10
-12
client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp
...rouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp
+11
-12
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp
...rouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp
+11
-12
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp
...rouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp
+17
-19
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+46
-35
example/01_gemm/gemm_xdl_fp16_f8.cpp
example/01_gemm/gemm_xdl_fp16_f8.cpp
+41
-0
example/02_gemm_bilinear/CMakeLists.txt
example/02_gemm_bilinear/CMakeLists.txt
+2
-0
example/03_gemm_bias_relu/CMakeLists.txt
example/03_gemm_bias_relu/CMakeLists.txt
+2
-0
example/04_gemm_add_add_fastgelu/CMakeLists.txt
example/04_gemm_add_add_fastgelu/CMakeLists.txt
+17
-13
example/09_convnd_fwd/CMakeLists.txt
example/09_convnd_fwd/CMakeLists.txt
+26
-8
example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt
...e/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt
+16
-8
example/13_pool2d_fwd/CMakeLists.txt
example/13_pool2d_fwd/CMakeLists.txt
+6
-3
example/14_gemm_quantization/CMakeLists.txt
example/14_gemm_quantization/CMakeLists.txt
+3
-1
example/15_grouped_gemm/CMakeLists.txt
example/15_grouped_gemm/CMakeLists.txt
+21
-17
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
+27
-27
example/17_convnd_bwd_data/CMakeLists.txt
example/17_convnd_bwd_data/CMakeLists.txt
+6
-2
No files found.
CMakeLists.txt
View file @
ac76519a
...
...
@@ -89,6 +89,14 @@ else()
message
(
"Building CK for the following targets:
${
AMDGPU_TARGETS
}
"
)
endif
()
find_package
(
hip
)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
math
(
EXPR hip_VERSION_FLAT
"(
${
hip_VERSION_MAJOR
}
* 1000 +
${
hip_VERSION_MINOR
}
) * 100000 +
${
hip_VERSION_PATCH
}
"
)
message
(
"hip_version_flat=
${
hip_VERSION_FLAT
}
"
)
if
(
${
hip_VERSION_FLAT
}
GREATER 500723302
)
message
(
"Adding the fno-offload-uniform-block compiler flag"
)
add_compile_options
(
-fno-offload-uniform-block
)
endif
()
option
(
USE_BITINT_EXTENSION_INT4,
"Whether to enable clang's BitInt extension to provide int4 data type."
OFF
)
option
(
USE_OPT_NAVI3X,
"Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons."
OFF
)
...
...
CONTRIBUTORS.md
View file @
ac76519a
...
...
@@ -6,9 +6,11 @@ This is the list of developers and contributors to Composable Kernel library
## Developers
[
Chao Liu
](
https://github.com/asroy
)
,
[
Jing Zhang
](
https://github.com/zjing14
)
, 2018-2023
[
Letao Qin
](
https://github.com/ltqin
)
,
[
Qianfeng Zhang
](
https://github.com/qianfengz
)
,
[
Liang Huang
](
https://github.com/carlushuang
)
,
[
Shaojie Wang
](
https://github.com/shaojiewang
)
, 2019-202
2
[
Letao Qin
](
https://github.com/ltqin
)
,
[
Qianfeng Zhang
](
https://github.com/qianfengz
)
,
[
Liang Huang
](
https://github.com/carlushuang
)
,
[
Shaojie Wang
](
https://github.com/shaojiewang
)
, 2019-202
3
[
Anthony Chang
](
https://github.com/rosenrodt
)
,
[
Chunyu Lai
](
https://github.com/rocking5566
)
,
[
Illia Silin
](
https://github.com/illsilin
)
,
[
Adam Osewski
](
https://github.com/aosewski
)
,
[
Poyen Chen
](
https://github.com/poyenc
)
,
[
Rosty Geyyer
](
https://github.com/geyyer
)
, 2022
[
Anthony Chang
](
https://github.com/rosenrodt
)
,
[
Chunyu Lai
](
https://github.com/rocking5566
)
,
[
Illia Silin
](
https://github.com/illsilin
)
,
[
Adam Osewski
](
https://github.com/aosewski
)
,
[
Poyen Chen
](
https://github.com/poyenc
)
,
[
Rosty Geyyer
](
https://github.com/geyyer
)
,
[
Astha Rai
](
https://github.com/arai713
)
,
[
Shi YanXing
](
https://github.com/Yanxing-Shi
)
, 2022-2023
[
Hari Sadasivan
](
https://github.com/hsadasiv
)
,
[
Bartlomiej Kocot
](
https://github.com/bartekxk
)
,
[
Bartlomiej Wroblewski
](
https://github.com/bwroblew
)
, 2023
Hanwen Chang, 2019-2021,
...
...
Jenkinsfile
View file @
ac76519a
...
...
@@ -710,8 +710,8 @@ pipeline {
}
agent
{
label
rocmnode
(
"gfx908 || gfx90a"
)
}
environment
{
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940" """
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940
;gfx941
" """
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940
;gfx941
" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
}
steps
{
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
"install"
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
,
prefixpath:
'/usr/local'
)
...
...
client_example/11_grouped_conv_bwd_weight/common.hpp
View file @
ac76519a
...
...
@@ -32,63 +32,49 @@ struct SimpleDeviceMem
};
template
<
ck
::
index_t
NumDimSpatial
>
std
::
size_t
GetFlops
(
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_spatial_lengths
)
std
::
size_t
GetFlops
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_lengths
)
{
constexpr
ck
::
index_t
spatial_offset
=
3
;
const
auto
C
=
filter_lengths
[
2
];
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return
static_cast
<
std
::
size_t
>
(
2
)
*
G
*
N
*
K
*
C
*
std
::
accumulate
(
std
::
begin
(
output_
spatial_
lengths
),
std
::
end
(
output_
spatial_
lengths
),
return
static_cast
<
std
::
size_t
>
(
2
)
*
C
*
std
::
accumulate
(
std
::
begin
(
output_lengths
),
std
::
end
(
output_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
())
*
std
::
accumulate
(
std
::
begin
(
filter_
spatial_
lengths
),
std
::
end
(
filter_
spatial_
lengths
),
std
::
accumulate
(
std
::
begin
(
filter_lengths
)
+
spatial_offset
,
std
::
end
(
filter_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
());
}
template
<
typename
InDataType
,
ck
::
index_t
NumDimSpatial
>
std
::
size_t
GetInputByte
(
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
input_spatial_lengths
)
std
::
size_t
GetInputByte
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
input_lengths
)
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return
sizeof
(
InDataType
)
*
(
G
*
N
*
C
*
std
::
accumulate
(
std
::
begin
(
input_spatial_lengths
),
std
::
end
(
input_spatial_lengths
),
return
sizeof
(
InDataType
)
*
(
std
::
accumulate
(
std
::
begin
(
input_lengths
),
std
::
end
(
input_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
()));
}
template
<
typename
WeiDataType
,
ck
::
index_t
NumDimSpatial
>
std
::
size_t
GetWeightByte
(
ck
::
index_t
G
,
ck
::
index_t
K
,
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_spatial_lengths
)
std
::
size_t
GetWeightByte
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_lengths
)
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return
sizeof
(
WeiDataType
)
*
(
G
*
K
*
C
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths
),
std
::
end
(
filter_spatial_lengths
),
return
sizeof
(
WeiDataType
)
*
(
std
::
accumulate
(
std
::
begin
(
filter_lengths
),
std
::
end
(
filter_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
()));
}
template
<
typename
OutDataType
,
ck
::
index_t
NumDimSpatial
>
std
::
size_t
GetOutputByte
(
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_spatial_lengths
)
std
::
size_t
GetOutputByte
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_lengths
)
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return
sizeof
(
OutDataType
)
*
(
G
*
N
*
K
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths
),
std
::
end
(
output_spatial_lengths
),
return
sizeof
(
OutDataType
)
*
(
std
::
accumulate
(
std
::
begin
(
output_lengths
),
std
::
end
(
output_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
()));
}
...
...
@@ -101,14 +87,11 @@ template <ck::index_t NumDimSpatial,
typename
WeiLayout
,
typename
OutLayout
>
bool
run_grouped_conv_bwd_weight
(
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
input_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
filter_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
output_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
conv_filter_dilations
,
...
...
@@ -117,9 +100,9 @@ bool run_grouped_conv_bwd_weight(
{
ck
::
index_t
split_k
=
2
;
SimpleDeviceMem
in
(
GetInputByte
<
InDataType
,
NumDimSpatial
>
(
G
,
N
,
C
,
input_spatial
_lengths
));
SimpleDeviceMem
wei
(
GetWeightByte
<
WeiDataType
,
NumDimSpatial
>
(
G
,
K
,
C
,
filter_spatial
_lengths
));
SimpleDeviceMem
out
(
GetOutputByte
<
OutDataType
,
NumDimSpatial
>
(
G
,
N
,
K
,
output_spatial
_lengths
));
SimpleDeviceMem
in
(
GetInputByte
<
InDataType
,
NumDimSpatial
+
3
>
(
input
_lengths
));
SimpleDeviceMem
wei
(
GetWeightByte
<
WeiDataType
,
NumDimSpatial
+
3
>
(
filter
_lengths
));
SimpleDeviceMem
out
(
GetOutputByte
<
OutDataType
,
NumDimSpatial
+
3
>
(
output
_lengths
));
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight
<
NumDimSpatial
,
InLayout
,
...
...
@@ -143,6 +126,10 @@ bool run_grouped_conv_bwd_weight(
float
best_gb_per_sec
=
0
;
float
best_tflops
=
0
;
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
a_g_n_c_wis_lengths
{};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
a_g_n_c_wis_strides
{};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
b_g_k_c_xs_lengths
{};
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
...
...
@@ -152,14 +139,11 @@ bool run_grouped_conv_bwd_weight(
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
out
.
GetDeviceBuffer
(),
G
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
...
...
@@ -176,12 +160,10 @@ bool run_grouped_conv_bwd_weight(
{
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
GetFlops
<
NumDimSpatial
>
(
G
,
N
,
K
,
C
,
output_spatial_lengths
,
filter_spatial_lengths
);
std
::
size_t
num_bytes
=
GetInputByte
<
InDataType
,
NumDimSpatial
>
(
G
,
N
,
C
,
input_spatial_lengths
)
+
GetWeightByte
<
WeiDataType
,
NumDimSpatial
>
(
G
,
K
,
C
,
filter_spatial_lengths
)
+
GetOutputByte
<
OutDataType
,
NumDimSpatial
>
(
G
,
N
,
K
,
output_spatial_lengths
);
std
::
size_t
flop
=
GetFlops
<
NumDimSpatial
+
3
>
(
output_lengths
,
filter_lengths
);
std
::
size_t
num_bytes
=
GetInputByte
<
InDataType
,
NumDimSpatial
+
3
>
(
input_lengths
)
+
GetWeightByte
<
WeiDataType
,
NumDimSpatial
+
3
>
(
filter_lengths
)
+
GetOutputByte
<
OutDataType
,
NumDimSpatial
+
3
>
(
output_lengths
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
...
...
@@ -221,14 +203,11 @@ bool run_grouped_conv_bwd_weight(
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
out
.
GetDeviceBuffer
(),
G
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
...
...
client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp
View file @
ac76519a
...
...
@@ -22,11 +22,12 @@ static constexpr ck::index_t C = 192;
static
constexpr
ck
::
index_t
X
=
3
;
static
constexpr
ck
::
index_t
Wi
=
28
;
static
constexpr
ck
::
index_t
Wo
=
28
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_spatial_lengths
{
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_spatial_lengths
{
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_spatial_lengths
{
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_lengths
{
G
,
N
,
C
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
filter_lengths
{
G
,
K
,
C
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_lengths
{
G
,
N
,
K
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Wi
*
C
,
Wi
*
C
,
1
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
weights_strides
{
K
*
X
*
C
,
X
*
C
,
1
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Wo
*
K
,
Wo
*
K
,
1
,
K
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
};
...
...
@@ -40,14 +41,11 @@ int main()
OutDataType
,
InLayout
,
WeiLayout
,
OutLayout
>
(
G
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
OutLayout
>
(
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
...
...
client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp
View file @
ac76519a
...
...
@@ -25,13 +25,15 @@ static constexpr ck::index_t Hi = 28;
static
constexpr
ck
::
index_t
Wi
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Wo
=
28
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_
spatial_
lengths
{
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_
spatial_
lengths
{
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_
spatial_
lengths
{
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_lengths
{
G
,
N
,
C
,
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
filter_lengths
{
G
,
K
,
C
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_lengths
{
G
,
N
,
K
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
N
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
1
,
Wi
*
C
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
weights_strides
{
K
*
Y
*
X
*
C
,
Y
*
X
*
C
,
1
,
X
*
C
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
N
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
1
,
Wo
*
K
,
K
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
,
1
};
...
...
@@ -45,14 +47,11 @@ int main()
OutDataType
,
InLayout
,
WeiLayout
,
OutLayout
>
(
G
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
OutLayout
>
(
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
...
...
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp
View file @
ac76519a
...
...
@@ -28,13 +28,15 @@ static constexpr ck::index_t Wi = 3;
static
constexpr
ck
::
index_t
Do
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Wo
=
3
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_
spatial_
lengths
{
Di
,
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_
spatial_
lengths
{
Z
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_
spatial_
lengths
{
Do
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_lengths
{
G
,
N
,
C
,
Di
,
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
filter_lengths
{
G
,
K
,
C
,
Z
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_lengths
{
G
,
N
,
K
,
Do
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
1
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
weights_strides
{
K
*
Z
*
Y
*
X
*
C
,
Z
*
Y
*
X
*
C
,
1
,
Y
*
X
*
C
,
X
*
C
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
1
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
,
1
,
1
};
...
...
@@ -48,14 +50,11 @@ int main()
OutDataType
,
InLayout
,
WeiLayout
,
OutLayout
>
(
G
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
OutLayout
>
(
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
...
...
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp
View file @
ac76519a
...
...
@@ -28,13 +28,15 @@ static constexpr ck::index_t Wi = 3;
static
constexpr
ck
::
index_t
Do
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Wo
=
3
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_
spatial_
lengths
{
Di
,
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_
spatial_
lengths
{
Z
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_
spatial_
lengths
{
Do
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_lengths
{
G
,
N
,
C
,
Di
,
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
filter_lengths
{
G
,
K
,
C
,
Z
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_lengths
{
G
,
N
,
K
,
Do
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
1
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
weights_strides
{
K
*
Z
*
Y
*
X
*
C
,
Z
*
Y
*
X
*
C
,
1
,
Y
*
X
*
C
,
X
*
C
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
1
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
,
1
,
1
};
...
...
@@ -48,20 +50,16 @@ int main()
OutDataType
,
InLayout
,
WeiLayout
,
OutLayout
>
(
G
,
N
,
K
,
C
,
{
Di
,
Hi
,
Wi
},
{
Z
,
Y
,
X
},
{
Do
,
Ho
,
Wo
},
{
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
},
{
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
})
OutLayout
>
(
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)
?
EXIT_SUCCESS
:
EXIT_FAILURE
;
}
example/01_gemm/CMakeLists.txt
View file @
ac76519a
add_custom_target
(
example_gemm_dl
)
if
(
DL_KERNELS
)
add_custom_target
(
example_gemm_dl
)
add_example_executable
(
example_gemm_dl_fp32 gemm_dl_fp32.cpp
)
add_
example_executable
(
example_gemm_dl
_fp16
gemm_dl_fp
16.cpp
)
add_dependencies
(
example_gemm_dl
example_
gemm_dl_fp
32
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp16
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_dl_fp32 gemm_dl_fp32.cpp
)
add_
dependencies
(
example_gemm_dl
example_
gemm_dl_fp
32
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_dl
_fp16
gemm_dl_fp
16.cpp
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp16
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_dl_int8 gemm_dl_int8.cpp
)
add_dependencies
(
example_gemm_dl example_gemm_dl_int8
)
endif
()
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_dl_int4 gemm_dl_int4.cpp
)
add_dependencies
(
example_gemm_dl example_gemm_dl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
endif
()
add_custom_target
(
example_gemm_xdl
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_xdl_fp16 gemm_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_wavelet_fp16
)
add_example_executable
(
example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102"
)
add_custom_target
(
example_gemm_wmma
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
endif
()
add_example_executable
(
example_gemm_xdl_fp16 gemm_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
endif
()
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_wavelet_fp16
)
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_xdl_int8 gemm_xdl_int8.cpp
)
...
...
@@ -37,22 +50,20 @@ if(USE_BITINT_EXTENSION_INT4)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp
)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing
(
example_gemm_xdl_fp64 gemm_xdl_fp64.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp64
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102"
)
add_custom_target
(
example_gemm_wmma
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
if
(
DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing
(
example_gemm_xdl_fp64 gemm_xdl_fp64.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp64
)
endif
()
add_example_executable
(
example_gemm_xdl_streamk gemm_xdl_streamk.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx940"
OR GPU_TARGETS MATCHES
"gfx941"
OR GPU_TARGETS MATCHES
"gfx942"
)
if
(
DTYPES MATCHES
"fp8"
OR NOT DEFINED DTYPES
)
if
(
GPU_TARGETS MATCHES
"gfx940"
OR GPU_TARGETS MATCHES
"gfx941"
OR GPU_TARGETS MATCHES
"gfx942"
)
add_example_executable
(
example_gemm_xdl_f8 gemm_xdl_f8.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_f8
)
endif
()
endif
()
add_example_executable
(
example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_f8
)
example/01_gemm/gemm_xdl_fp16_f8.cpp
0 → 100644
View file @
ac76519a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
using
ADataType
=
ck
::
f8_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
LoopSched
=
ck
::
make_default_loop_scheduler
();
static
constexpr
auto
PipelineVer
=
ck
::
PipelineVersion
::
v1
;
using
ComputeType
=
ck
::
half_t
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| ComputeType|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| |
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopSched
,
PipelineVer
,
ComputeType
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/02_gemm_bilinear/CMakeLists.txt
View file @
ac76519a
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list1 gfx1100 gfx1101 gfx1102
)
list
(
APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
...
...
@@ -15,3 +16,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
endif
()
endforeach
()
endif
()
example/03_gemm_bias_relu/CMakeLists.txt
View file @
ac76519a
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
...
@@ -6,3 +7,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
endif
()
endforeach
()
endif
()
example/04_gemm_add_add_fastgelu/CMakeLists.txt
View file @
ac76519a
...
...
@@ -3,22 +3,26 @@ set(target 0)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_gemm_add_add_fastgelu_xdl
)
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp
)
endif
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8
)
endif
()
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
example/09_convnd_fwd/CMakeLists.txt
View file @
ac76519a
...
...
@@ -2,16 +2,34 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp
)
endif
()
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
if
(
DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
add_example_executable_no_testing
(
example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp
)
endif
()
set
(
target 1
)
endif
()
endforeach
()
add_example_executable
(
example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp
)
add_example_executable
(
example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp
)
if
(
DL_KERNELS
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp
)
endif
()
endif
()
example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt
View file @
ac76519a
...
...
@@ -3,14 +3,22 @@ set(target 0)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_convnd_fwd_reduce_xdl
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp
)
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp
)
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4
)
...
...
example/13_pool2d_fwd/CMakeLists.txt
View file @
ac76519a
add_example_executable
(
example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp
)
add_example_executable
(
example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp
)
endif
()
example/14_gemm_quantization/CMakeLists.txt
View file @
ac76519a
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
# dlops
add_example_executable
(
example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp
)
if
(
DL_KERNELS
)
add_example_executable
(
example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp
)
endif
()
# xdlops
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
...
...
example/15_grouped_gemm/CMakeLists.txt
View file @
ac76519a
add_custom_target
(
example_grouped_gemm_xdl
)
add_example_executable
(
example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp
)
add_example_executable
(
example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp
)
add_dependencies
(
example_grouped_gemm_xdl
example_grouped_gemm_xdl_fp32
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp
)
add_dependencies
(
example_grouped_gemm_xdl
example_grouped_gemm_xdl_fp16
example_grouped_gemm_xdl_bfp16
example_grouped_gemm_xdl_int8
example_grouped_gemm_multiple_d_dl_fp16
example_grouped_gemm_xdl_splitk_fp16
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_bfp16
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int4
)
...
...
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
View file @
ac76519a
...
...
@@ -6,32 +6,32 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_custom_target
(
example_gemm_reduce_xdl_max
)
add_custom_target
(
example_gemm_reduce_xdl_mean_meansquare
)
add_custom_target
(
example_gemm_add_add_mean_meansquare_xdl
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp
)
add_example_executable
(
example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp
)
add_example_executable
(
example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16
)
add_dependencies
(
example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp
)
add_example_executable
(
example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_int8
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp
)
add_dependencies
(
example_gemm_reduce_xdl_max
example_gemm_max_xdl_bf16
example_gemm_max_xdl_fp16
example_gemm_max_xdl_fp32
example_gemm_max_xdl_int8
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare
example_gemm_mean_meansquare_xdl_fp16
example_gemm_mean_meansquare_xdl_fp32
example_gemm_mean_meansquare_xdl_bf16
example_gemm_add_addsquare_xdl_int8
)
add_dependencies
(
example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16
)
endif
()
add_dependencies
(
example_gemm_reduce_xdl
example_gemm_reduce_xdl_mean_meansquare
...
...
example/17_convnd_bwd_data/CMakeLists.txt
View file @
ac76519a
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
...
@@ -7,5 +8,8 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
endif
()
endforeach
()
add_example_executable
(
example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp
)
target_link_libraries
(
example_convnd_bwd_data_dl_fp16 PRIVATE utility
)
if
(
DL_KERNELS
)
add_example_executable
(
example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp
)
target_link_libraries
(
example_convnd_bwd_data_dl_fp16 PRIVATE utility
)
endif
()
endif
()
Prev
1
2
3
4
5
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment