Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
de1afb7b
Commit
de1afb7b
authored
Oct 19, 2023
by
Rostyslav Geyyer
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/composable_kernel
into lwpck-977
parents
ce562aa6
f7331c60
Changes
107
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
402 additions
and
380 deletions
+402
-380
CHANGELOG.md
CHANGELOG.md
+1
-1
CMakeLists.txt
CMakeLists.txt
+3
-3
Jenkinsfile
Jenkinsfile
+2
-2
client_example/05_layernorm/layernorm2d.cpp
client_example/05_layernorm/layernorm2d.cpp
+33
-9
client_example/18_groupnorm/groupnorm_swish.cpp
client_example/18_groupnorm/groupnorm_swish.cpp
+78
-44
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-0
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+23
-45
example/04_gemm_add_add_fastgelu/CMakeLists.txt
example/04_gemm_add_add_fastgelu/CMakeLists.txt
+20
-24
example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt
...e/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt
+22
-25
example/12_reduce/README.md
example/12_reduce/README.md
+2
-2
example/15_grouped_gemm/CMakeLists.txt
example/15_grouped_gemm/CMakeLists.txt
+19
-31
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
+44
-58
example/20_grouped_conv_bwd_weight/CMakeLists.txt
example/20_grouped_conv_bwd_weight/CMakeLists.txt
+21
-27
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp
...yernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp
+4
-1
example/22_cgemm/CMakeLists.txt
example/22_cgemm/CMakeLists.txt
+10
-14
example/24_batched_gemm/CMakeLists.txt
example/24_batched_gemm/CMakeLists.txt
+11
-16
example/27_layernorm/layernorm_fp16.cpp
example/27_layernorm/layernorm_fp16.cpp
+12
-7
example/27_layernorm/layernorm_splitk_fp16.cpp
example/27_layernorm/layernorm_splitk_fp16.cpp
+12
-7
example/27_layernorm/run_layernorm_example.inc
example/27_layernorm/run_layernorm_example.inc
+54
-28
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
+30
-36
No files found.
CHANGELOG.md
View file @
de1afb7b
...
@@ -17,7 +17,7 @@ None
...
@@ -17,7 +17,7 @@ None
-
Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985)
-
Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985)
-
Grouped convolution support for small K and C (#822 #879 #897)
-
Grouped convolution support for small K and C (#822 #879 #897)
-
Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
-
Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
-
Support for bf16/f32/f16 and NHWGC (2D and 3
d
) grouped convolution backward data (#757 #799)
-
Support for bf16/f32/f16 and NHWGC (2D and 3
D
) grouped convolution backward data (#757 #799)
-
Support for Batched Gemm DL (#732)
-
Support for Batched Gemm DL (#732)
### Changes
### Changes
...
...
CMakeLists.txt
View file @
de1afb7b
...
@@ -32,12 +32,10 @@ if (DTYPES)
...
@@ -32,12 +32,10 @@ if (DTYPES)
if
(
DTYPES MATCHES
"fp8"
)
if
(
DTYPES MATCHES
"fp8"
)
add_definitions
(
-DCK_ENABLE_FP8
)
add_definitions
(
-DCK_ENABLE_FP8
)
set
(
CK_ENABLE_FP8
"ON"
)
set
(
CK_ENABLE_FP8
"ON"
)
add_compile_options
(
-Wno-bit-int-extension
)
endif
()
endif
()
if
(
DTYPES MATCHES
"bf8"
)
if
(
DTYPES MATCHES
"bf8"
)
add_definitions
(
-DCK_ENABLE_BF8
)
add_definitions
(
-DCK_ENABLE_BF8
)
set
(
CK_ENABLE_BF8
"ON"
)
set
(
CK_ENABLE_BF8
"ON"
)
add_compile_options
(
-Wno-bit-int-extension
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp16"
)
if
(
DTYPES MATCHES
"fp16"
)
add_definitions
(
-DCK_ENABLE_FP16
)
add_definitions
(
-DCK_ENABLE_FP16
)
...
@@ -59,9 +57,11 @@ if (DTYPES)
...
@@ -59,9 +57,11 @@ if (DTYPES)
else
()
else
()
add_definitions
(
-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16
)
add_definitions
(
-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16
)
set
(
CK_ENABLE_ALL_DTYPES
"ON"
)
set
(
CK_ENABLE_ALL_DTYPES
"ON"
)
add_compile_options
(
-Wno-bit-int-extension
)
# enable fp8 and bf8
endif
()
endif
()
#for f8/bf8_t type
add_compile_options
(
-Wno-bit-int-extension
)
if
(
DL_KERNELS
)
if
(
DL_KERNELS
)
add_definitions
(
-DDL_KERNELS
)
add_definitions
(
-DDL_KERNELS
)
set
(
CK_ENABLE_DL_KERNELS
"ON"
)
set
(
CK_ENABLE_DL_KERNELS
"ON"
)
...
...
Jenkinsfile
View file @
de1afb7b
...
@@ -790,8 +790,8 @@ pipeline {
...
@@ -790,8 +790,8 @@ pipeline {
}
}
agent
{
label
rocmnode
(
"navi32"
)
}
agent
{
label
rocmnode
(
"navi32"
)
}
environment
{
environment
{
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" """
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101"
-DDL_KERNELS=ON
"""
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101"
-DDL_KERNELS=ON
-D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
}
}
steps
{
steps
{
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
"install"
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
,
prefixpath:
'/usr/local'
)
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
"install"
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
,
prefixpath:
'/usr/local'
)
...
...
client_example/05_layernorm/layernorm2d.cpp
View file @
de1afb7b
...
@@ -12,12 +12,14 @@
...
@@ -12,12 +12,14 @@
#include "ck/library/tensor_operation_instance/gpu/normalization.hpp"
#include "ck/library/tensor_operation_instance/gpu/normalization.hpp"
using
XDataType
=
ck
::
half_t
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
using
SaveMeanInvStdDataType
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
#define SAVE_MEAN_INV_STD
constexpr
int
Rank
=
2
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
constexpr
int
NumReduceDim
=
1
;
...
@@ -50,12 +52,16 @@ int main(int argc, char* argv[])
...
@@ -50,12 +52,16 @@ int main(int argc, char* argv[])
SimpleDeviceMem
gamma_device_buf
(
sizeof
(
GammaDataType
)
*
N
);
SimpleDeviceMem
gamma_device_buf
(
sizeof
(
GammaDataType
)
*
N
);
SimpleDeviceMem
beta_device_buf
(
sizeof
(
BetaDataType
)
*
N
);
SimpleDeviceMem
beta_device_buf
(
sizeof
(
BetaDataType
)
*
N
);
SimpleDeviceMem
y_device_buf
(
sizeof
(
YDataType
)
*
xy_size
);
SimpleDeviceMem
y_device_buf
(
sizeof
(
YDataType
)
*
xy_size
);
#ifdef SAVE_MEAN_INV_STD
SimpleDeviceMem
save_mean_device_buf
(
sizeof
(
SaveMeanInvStdDataType
)
*
M
);
SimpleDeviceMem
save_inv_std_device_buf
(
sizeof
(
SaveMeanInvStdDataType
)
*
M
);
#endif
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceNormalization
<
XDataType
,
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceNormalization
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
PassThrough
,
PassThrough
,
Rank
,
Rank
,
NumReduceDim
>
;
NumReduceDim
>
;
...
@@ -84,14 +90,21 @@ int main(int argc, char* argv[])
...
@@ -84,14 +90,21 @@ int main(int argc, char* argv[])
{
0
,
1
},
// gammaStrides
{
0
,
1
},
// gammaStrides
{
0
,
1
},
// betaStrides
{
0
,
1
},
// betaStrides
{
Stride
,
1
},
// yStrides
{
Stride
,
1
},
// yStrides
{
1
},
// save_mean Strides
{
1
},
// save_inv_std Strides
{
1
},
// reduceDims
{
1
},
// reduceDims
1e-4
,
1e-4
,
x_device_buf
.
GetDeviceBuffer
(),
x_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf
.
GetDeviceBuffer
(),
save_inv_std_device_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
nullptr
,
nullptr
,
#endif
PassThrough
{});
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
@@ -109,6 +122,10 @@ int main(int argc, char* argv[])
...
@@ -109,6 +122,10 @@ int main(int argc, char* argv[])
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
M
*
N
+
sizeof
(
GammaDataType
)
*
N
+
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
M
*
N
+
sizeof
(
GammaDataType
)
*
N
+
sizeof
(
BetaDataType
)
*
N
+
sizeof
(
YDataType
)
*
M
*
N
;
sizeof
(
BetaDataType
)
*
N
+
sizeof
(
YDataType
)
*
M
*
N
;
#ifdef SAVE_MEAN_INV_STD
num_byte
+=
sizeof
(
SaveMeanInvStdDataType
)
*
M
*
2
;
#endif
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
...
@@ -140,17 +157,24 @@ int main(int argc, char* argv[])
...
@@ -140,17 +157,24 @@ int main(int argc, char* argv[])
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
M
,
N
},
// lengths
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
M
,
N
},
// lengths
{
Stride
,
1
},
// xStrides
{
Stride
,
1
},
// xStrides
{
1
},
// gammaStrides
{
0
,
1
},
// gammaStrides
{
1
},
// betaStrides
{
0
,
1
},
// betaStrides
{
Stride
,
1
},
// yStrides
{
Stride
,
1
},
// yStrides
{
1
},
// save_mean Strides
{
1
},
// save_inv_std Strides
{
1
},
// reduceDims
{
1
},
// reduceDims
1e-4
,
1e-4
,
x_device_buf
.
GetDeviceBuffer
(),
x_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf
.
GetDeviceBuffer
(),
save_inv_std_device_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
nullptr
,
nullptr
,
#endif
PassThrough
{});
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
...
client_example/18_groupnorm/groupnorm_swish.cpp
View file @
de1afb7b
...
@@ -12,12 +12,14 @@
...
@@ -12,12 +12,14 @@
#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp"
#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp"
using
XDataType
=
ck
::
half_t
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
float
;
using
GammaDataType
=
float
;
using
BetaDataType
=
float
;
using
BetaDataType
=
float
;
using
YDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
using
SaveMeanInvStdDataType
=
float
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
#define SAVE_MEAN_INV_STD
constexpr
int
Rank
=
5
;
constexpr
int
Rank
=
5
;
constexpr
int
NumReduceDim
=
3
;
constexpr
int
NumReduceDim
=
3
;
...
@@ -49,19 +51,24 @@ int main(int argc, char* argv[])
...
@@ -49,19 +51,24 @@ int main(int argc, char* argv[])
std
::
size_t
xy_size
=
N
*
H
*
W
*
G
*
C
;
std
::
size_t
xy_size
=
N
*
H
*
W
*
G
*
C
;
std
::
size_t
gamma_beta_size
=
G
*
C
;
std
::
size_t
gamma_beta_size
=
G
*
C
;
std
::
vector
<
ck
::
index_t
>
xy_strides
=
{
H
*
W
*
G
*
C
,
W
*
G
*
C
,
G
*
C
,
C
,
1
};
std
::
vector
<
ck
::
index_t
>
xy_strides
=
{
H
*
W
*
G
*
C
,
W
*
G
*
C
,
G
*
C
,
C
,
1
};
std
::
vector
<
ck
::
index_t
>
gamma_beta_strides
=
{
0
,
0
,
0
,
C
,
1
};
std
::
vector
<
ck
::
index_t
>
gamma_beta_strides
=
{
0
,
0
,
0
,
C
,
1
};
std
::
vector
<
ck
::
index_t
>
save_mean_inv_std_strides
=
{
G
,
1
};
SimpleDeviceMem
x_device_buf
(
sizeof
(
XDataType
)
*
xy_size
);
SimpleDeviceMem
x_device_buf
(
sizeof
(
XDataType
)
*
xy_size
);
SimpleDeviceMem
gamma_device_buf
(
sizeof
(
GammaDataType
)
*
gamma_beta_size
);
SimpleDeviceMem
gamma_device_buf
(
sizeof
(
GammaDataType
)
*
gamma_beta_size
);
SimpleDeviceMem
beta_device_buf
(
sizeof
(
BetaDataType
)
*
gamma_beta_size
);
SimpleDeviceMem
beta_device_buf
(
sizeof
(
BetaDataType
)
*
gamma_beta_size
);
SimpleDeviceMem
y_device_buf
(
sizeof
(
YDataType
)
*
xy_size
);
SimpleDeviceMem
y_device_buf
(
sizeof
(
YDataType
)
*
xy_size
);
#ifdef SAVE_MEAN_INV_STD
SimpleDeviceMem
save_mean_device_buf
(
sizeof
(
SaveMeanInvStdDataType
)
*
N
*
G
);
SimpleDeviceMem
save_inv_std_device_buf
(
sizeof
(
SaveMeanInvStdDataType
)
*
N
*
G
);
#endif
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceNormalization
<
XDataType
,
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceNormalization
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
Swish
,
Swish
,
Rank
,
Rank
,
NumReduceDim
>
;
NumReduceDim
>
;
...
@@ -75,19 +82,26 @@ int main(int argc, char* argv[])
...
@@ -75,19 +82,26 @@ int main(int argc, char* argv[])
const
auto
&
generic_op_ptr
=
op_ptrs
[
0
];
const
auto
&
generic_op_ptr
=
op_ptrs
[
0
];
auto
generic_argument_ptr
=
auto
generic_argument_ptr
=
generic_op_ptr
->
MakeArgumentPointer
({
N
,
H
,
W
,
G
,
C
},
// lengths
generic_op_ptr
->
MakeArgumentPointer
({
N
,
H
,
W
,
G
,
C
},
// lengths
xy_strides
,
// xStrides
xy_strides
,
// xStrides
gamma_beta_strides
,
// gammaStrides
gamma_beta_strides
,
// gammaStrides
gamma_beta_strides
,
// betaStrides
gamma_beta_strides
,
// betaStrides
xy_strides
,
// yStrides
xy_strides
,
// yStrides
{
1
,
2
,
4
},
// reduceDims
save_mean_inv_std_strides
,
// save_mean Strides
save_mean_inv_std_strides
,
// save_inv_std Strides
{
1
,
2
,
4
},
// reduceDims
1e-6
,
1e-6
,
x_device_buf
.
GetDeviceBuffer
(),
x_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf
.
GetDeviceBuffer
(),
save_inv_std_device_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
nullptr
,
nullptr
,
#endif
Swish
{});
Swish
{});
if
(
!
generic_op_ptr
->
IsSupportedArgument
(
generic_argument_ptr
.
get
()))
if
(
!
generic_op_ptr
->
IsSupportedArgument
(
generic_argument_ptr
.
get
()))
...
@@ -107,21 +121,29 @@ int main(int argc, char* argv[])
...
@@ -107,21 +121,29 @@ int main(int argc, char* argv[])
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
N
,
H
,
W
,
G
,
C
},
// lengths
auto
argument_ptr
=
xy_strides
,
// xStrides
op_ptr
->
MakeArgumentPointer
({
N
,
H
,
W
,
G
,
C
},
// lengths
gamma_beta_strides
,
// gammaStrides
xy_strides
,
// xStrides
gamma_beta_strides
,
// betaStrides
gamma_beta_strides
,
// gammaStrides
xy_strides
,
// yStrides
gamma_beta_strides
,
// betaStrides
{
1
,
2
,
4
},
// reduceDims
xy_strides
,
// yStrides
1e-6
,
save_mean_inv_std_strides
,
// save_mean Strides
x_device_buf
.
GetDeviceBuffer
(),
save_mean_inv_std_strides
,
// save_inv_std Strides
gamma_device_buf
.
GetDeviceBuffer
(),
{
1
,
2
,
4
},
// reduceDims
beta_device_buf
.
GetDeviceBuffer
(),
1e-6
,
y_device_buf
.
GetDeviceBuffer
(),
x_device_buf
.
GetDeviceBuffer
(),
nullptr
,
gamma_device_buf
.
GetDeviceBuffer
(),
nullptr
,
beta_device_buf
.
GetDeviceBuffer
(),
Swish
{});
y_device_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf
.
GetDeviceBuffer
(),
save_inv_std_device_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
#endif
Swish
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
@@ -139,6 +161,10 @@ int main(int argc, char* argv[])
...
@@ -139,6 +161,10 @@ int main(int argc, char* argv[])
sizeof
(
XDataType
)
*
xy_size
+
sizeof
(
GammaDataType
)
*
gamma_beta_size
+
sizeof
(
XDataType
)
*
xy_size
+
sizeof
(
GammaDataType
)
*
gamma_beta_size
+
sizeof
(
BetaDataType
)
*
gamma_beta_size
+
sizeof
(
YDataType
)
*
xy_size
;
sizeof
(
BetaDataType
)
*
gamma_beta_size
+
sizeof
(
YDataType
)
*
xy_size
;
#ifdef SAVE_MEAN_INV_STD
num_byte
+=
sizeof
(
SaveMeanInvStdDataType
)
*
N
*
G
*
2
;
#endif
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
...
@@ -169,20 +195,28 @@ int main(int argc, char* argv[])
...
@@ -169,20 +195,28 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
N
,
H
,
W
,
G
,
C
},
// lengths
auto
argument_ptr
=
xy_strides
,
// xStrides
op_ptr
->
MakeArgumentPointer
({
N
,
H
,
W
,
G
,
C
},
// lengths
gamma_beta_strides
,
// gammaStrides
xy_strides
,
// xStrides
gamma_beta_strides
,
// betaStrides
gamma_beta_strides
,
// gammaStrides
xy_strides
,
// yStrides
gamma_beta_strides
,
// betaStrides
{
1
,
2
,
4
},
// reduceDims
xy_strides
,
// yStrides
1e-6
,
save_mean_inv_std_strides
,
// save_mean Strides
x_device_buf
.
GetDeviceBuffer
(),
save_mean_inv_std_strides
,
// save_inv_std Strides
gamma_device_buf
.
GetDeviceBuffer
(),
{
1
,
2
,
4
},
// reduceDims
beta_device_buf
.
GetDeviceBuffer
(),
1e-6
,
y_device_buf
.
GetDeviceBuffer
(),
x_device_buf
.
GetDeviceBuffer
(),
nullptr
,
gamma_device_buf
.
GetDeviceBuffer
(),
nullptr
,
beta_device_buf
.
GetDeviceBuffer
(),
Swish
{});
y_device_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf
.
GetDeviceBuffer
(),
save_inv_std_device_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
#endif
Swish
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
...
cmake/EnableCompilerWarnings.cmake
View file @
de1afb7b
...
@@ -70,6 +70,7 @@ else()
...
@@ -70,6 +70,7 @@ else()
-Wno-option-ignored
-Wno-option-ignored
-Wsign-compare
-Wsign-compare
-Wno-extra-semi-stmt
-Wno-extra-semi-stmt
-Wno-unused-template
)
)
if
(
CMAKE_
${
COMPILER
}
_COMPILER_ID MATCHES
"Clang"
)
if
(
CMAKE_
${
COMPILER
}
_COMPILER_ID MATCHES
"Clang"
)
list
(
APPEND CMAKE_COMPILER_WARNINGS
list
(
APPEND CMAKE_COMPILER_WARNINGS
...
...
example/01_gemm/CMakeLists.txt
View file @
de1afb7b
add_custom_target
(
example_gemm_dl
)
add_custom_target
(
example_gemm_dl
)
add_example_executable
(
example_gemm_dl_fp32 gemm_dl_fp32.cpp
)
add_example_executable
(
example_gemm_dl_fp32 gemm_dl_fp32.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_dl example_gemm_dl_fp32
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp32
)
endif
()
add_example_executable
(
example_gemm_dl_fp16 gemm_dl_fp16.cpp
)
add_example_executable
(
example_gemm_dl_fp16 gemm_dl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_dl example_gemm_dl_fp16
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp16
)
endif
()
add_example_executable
(
example_gemm_dpp_fp16 gemm_dpp_fp16.cpp
)
add_example_executable
(
example_gemm_dpp_fp16 gemm_dpp_fp16.cpp
)
add_example_executable
(
example_gemm_dl_int8 gemm_dl_int8.cpp
)
add_example_executable
(
example_gemm_dl_int8 gemm_dl_int8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_dl example_gemm_dl_int8
)
add_dependencies
(
example_gemm_dl example_gemm_dl_int8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_dl_int4 gemm_dl_int4.cpp
)
add_example_executable
(
example_gemm_dl_int4 gemm_dl_int4.cpp
)
add_dependencies
(
example_gemm_dl example_gemm_dl_int4
)
add_
example_
dependencies
(
example_gemm_dl example_gemm_dl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
add_custom_target
(
example_gemm_xdl
)
add_custom_target
(
example_gemm_xdl
)
add_example_executable
(
example_gemm_xdl_fp16 gemm_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_fp16 gemm_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16
)
endif
()
add_example_executable
(
example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_wavelet_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_wavelet_fp16
)
endif
()
add_example_executable
(
example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
endif
()
if
(
GPU_TARGETS MATCHES
"gfx1100"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102"
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102"
)
add_custom_target
(
example_gemm_wmma
)
add_custom_target
(
example_gemm_wmma
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
add_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
endif
()
endif
()
endif
()
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_example_executable
(
example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp
)
add_example_executable
(
example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_rtn
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_rtn
)
endif
()
add_example_executable
(
example_gemm_xdl_int8 gemm_xdl_int8.cpp
)
add_example_executable
(
example_gemm_xdl_int8 gemm_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_int8
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_xdl_int4 gemm_xdl_int4.cpp
)
add_example_executable
(
example_gemm_xdl_int4 gemm_xdl_int4.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int4
)
add_
example_
dependencies
(
example_gemm_xdl example_gemm_xdl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing
(
example_gemm_xdl_fp64 gemm_xdl_fp64.cpp
)
add_example_executable_no_testing
(
example_gemm_xdl_fp64 gemm_xdl_fp64.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp64
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp64
)
endif
()
add_example_executable
(
example_gemm_xdl_streamk gemm_xdl_streamk.cpp
)
add_example_executable
(
example_gemm_xdl_streamk gemm_xdl_streamk.cpp
)
add_example_executable
(
example_gemm_xdl_fp8 gemm_xdl_fp8.cpp
)
add_example_executable
(
example_gemm_xdl_fp8 gemm_xdl_fp8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8
)
endif
()
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
endif
()
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
endif
()
example/04_gemm_add_add_fastgelu/CMakeLists.txt
View file @
de1afb7b
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_gemm_add_add_fastgelu_xdl
)
add_custom_target
(
example_gemm_add_add_fastgelu_xdl
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8
)
set
(
target 1
)
endif
()
endif
()
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp
)
endforeach
()
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16
)
endif
()
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8
)
endif
()
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt
View file @
de1afb7b
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_convnd_fwd_reduce_xdl
)
add_custom_target
(
example_convnd_fwd_reduce_xdl
)
add_example_executable
(
example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_example_executable
(
example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8
)
add_example_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8
)
endif
()
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp
)
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16
)
endif
()
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp
)
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp
)
add_example_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16
)
if
(
result EQUAL 0
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16
)
add_example_executable
(
example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp
)
endif
()
add_example_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32
)
add_example_executable
(
example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp
)
if
(
result EQUAL 0
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32
)
add_example_executable
(
example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp
)
endif
()
add_example_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4
)
if
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp
)
set
(
target 1
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4
)
endif
()
endif
(
USE_BITINT_EXTENSION_INT4
)
endforeach
()
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
example/12_reduce/README.md
View file @
de1afb7b
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
## Run ```example_reduce_blockwise```
## Run ```example_reduce_blockwise```
```
bash
```
bash
# -D <xxx> : input 3
d
/4
d
/5
d
tensor lengths
# -D <xxx> : input 3
D
/4
D
/5
D
tensor lengths
# -R <xxx> : reduce dimension ids
# -R <xxx> : reduce dimension ids
# -v <x> : verification (0=no, 1=yes)
# -v <x> : verification (0=no, 1=yes)
#arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4)
#arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4)
...
@@ -22,7 +22,7 @@ Perf: 0.238063 ms, 264.285 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr
...
@@ -22,7 +22,7 @@ Perf: 0.238063 ms, 264.285 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr
## Run ```example_reduce_multiblock_atomic_add```
## Run ```example_reduce_multiblock_atomic_add```
```
bash
```
bash
# -D <xxx> : input 3
d
/4
d
/5
d
tensor lengths
# -D <xxx> : input 3
D
/4
D
/5
D
tensor lengths
# -R <xxx> : reduce dimension ids
# -R <xxx> : reduce dimension ids
# -v <x> : verification (0=no, 1=yes)
# -v <x> : verification (0=no, 1=yes)
#arg1: data type (0: fp32, 1: fp64)
#arg1: data type (0: fp32, 1: fp64)
...
...
example/15_grouped_gemm/CMakeLists.txt
View file @
de1afb7b
add_custom_target
(
example_grouped_gemm_xdl
)
add_custom_target
(
example_grouped_gemm_xdl
)
add_example_executable
(
example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32
)
endif
()
add_example_executable
(
example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16
)
endif
()
add_example_executable
(
example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16
)
endif
()
add_example_executable
(
example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16
)
endif
()
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16
)
endif
()
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16
)
endif
()
add_example_executable
(
example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16
)
endif
()
add_example_executable
(
example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int8
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int8
)
endif
()
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int4
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int4
)
endif
()
endif
()
endif
()
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
View file @
de1afb7b
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_gemm_reduce_xdl
)
add_custom_target
(
example_gemm_reduce_xdl
)
add_custom_target
(
example_gemm_reduce_xdl_max
)
add_custom_target
(
example_gemm_reduce_xdl_max
)
add_custom_target
(
example_gemm_reduce_xdl_mean_meansquare
)
add_custom_target
(
example_gemm_reduce_xdl_mean_meansquare
)
add_custom_target
(
example_gemm_add_add_mean_meansquare_xdl
)
add_custom_target
(
example_gemm_add_add_mean_meansquare_xdl
)
add_example_executable
(
example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_executable
(
example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16
)
add_example_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16
)
endif
()
add_example_executable
(
example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16
)
add_dependencies
(
example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16
)
endif
()
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp
)
add_example_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16
)
add_example_executable
(
example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp
)
endif
()
add_example_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_int8
)
add_example_executable
(
example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp
)
add_example_executable
(
example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_int8
)
endif
()
add_example_executable
(
example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp
)
add_example_executable
(
example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp
)
add_example_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp
)
endif
()
add_example_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32
)
add_example_executable
(
example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp
)
add_example_executable
(
example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32
)
endif
()
add_example_executable
(
example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp
)
add_example_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32
)
add_example_dependencies
(
example_gemm_reduce_xdl
endif
()
example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
add_example_executable
(
example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp
)
example_gemm_add_add_mean_meansquare_xdl
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16
)
if
(
USE_BITINT_EXTENSION_INT4
)
endif
()
add_example_executable
(
example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp
)
add_example_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_int4
)
if
(
result EQUAL 0
)
endif
()
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16
)
set
(
target 1
)
endif
()
endif
()
add_dependencies
(
example_gemm_reduce_xdl
example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
example_gemm_add_add_mean_meansquare_xdl
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_int4
)
endif
()
endif
()
set
(
target 1
)
endif
()
endforeach
()
endforeach
()
example/20_grouped_conv_bwd_weight/CMakeLists.txt
View file @
de1afb7b
...
@@ -2,34 +2,28 @@ list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
...
@@ -2,34 +2,28 @@ list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list_xdl AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list_xdl AND target EQUAL 0
)
add_custom_target
(
example_grouped_conv_bwd_weight
)
add_custom_target
(
example_grouped_conv_bwd_weight
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
)
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
)
endif
()
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp
)
add_example_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16
)
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
)
endif
()
add_example_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
)
set
(
target 1
)
if
(
result EQUAL 0
)
endif
()
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8
)
endif
()
if
(
gpu IN_LIST gpu_list_wmma AND target EQUAL 0
)
set
(
target 1
)
add_custom_target
(
example_grouped_conv_bwd_weight
)
endif
()
add_example_executable
(
example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp
)
if
(
gpu IN_LIST gpu_list_wmma AND target EQUAL 0
)
add_example_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16
)
add_custom_target
(
example_grouped_conv_bwd_weight
)
set
(
target 1
)
add_example_executable
(
example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp
)
endif
()
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16
)
endif
()
set
(
target 1
)
endif
()
endforeach
()
endforeach
()
add_custom_target
(
example_grouped_conv_bwd_weight_dl
)
add_custom_target
(
example_grouped_conv_bwd_weight_dl
)
add_example_executable
(
example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp
)
add_example_executable
(
example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16
)
add_dependencies
(
example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16
)
endif
()
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp
View file @
de1afb7b
...
@@ -114,12 +114,15 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
...
@@ -114,12 +114,15 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
BetaDataType
,
BetaDataType
,
HDataType
,
HDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
HElementOp
,
HElementOp
,
2
,
2
,
1
>
;
1
>
;
Tensor
<
EMeanVarDataType
>
e_m_n
(
HostTensorDescriptor
{
M
,
N
});
Tensor
<
EMeanVarDataType
>
e_m_n
(
HostTensorDescriptor
{
M
,
N
});
Tensor
<
AccDataType
>
c_m_n
(
HostTensorDescriptor
{
M
,
N
});
Tensor
<
AccDataType
>
c_m_n
(
HostTensorDescriptor
{
M
,
N
});
Tensor
<
AccDataType
>
save_mean
({
M
});
Tensor
<
AccDataType
>
save_inv_std
({
M
});
auto
ref_gemm
=
ReferenceGemm
{};
auto
ref_gemm
=
ReferenceGemm
{};
auto
ref_gemm_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_gemm_invoker
=
ref_gemm
.
MakeInvoker
();
...
@@ -145,7 +148,7 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
...
@@ -145,7 +148,7 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
auto
ref_layernorm_invoker
=
ref_layernorm
.
MakeInvoker
();
auto
ref_layernorm_invoker
=
ref_layernorm
.
MakeInvoker
();
auto
ref_layernorm_argument
=
ref_layernorm
.
MakeArgument
(
auto
ref_layernorm_argument
=
ref_layernorm
.
MakeArgument
(
e_m_n
,
gamma_n
,
beta_n
,
h_m_n
,
h_element_op
,
{
M
,
N
},
{
1
},
epsilon
);
e_m_n
,
gamma_n
,
beta_n
,
h_m_n
,
save_mean
,
save_inv_std
,
h_element_op
,
{
M
,
N
},
{
1
},
epsilon
);
ref_layernorm_invoker
.
Run
(
ref_layernorm_argument
);
ref_layernorm_invoker
.
Run
(
ref_layernorm_argument
);
}
}
...
...
example/22_cgemm/CMakeLists.txt
View file @
de1afb7b
add_custom_target
(
example_cgemm_xdl
)
add_custom_target
(
example_cgemm_xdl
)
add_example_executable
(
example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp
)
add_example_executable
(
example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_cgemm_xdl example_cgemm_xdl_bf16
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_bf16
)
endif
()
add_example_executable
(
example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp
)
add_example_executable
(
example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_cgemm_xdl example_cgemm_xdl_fp16
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_fp16
)
endif
()
add_example_executable
(
example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp
)
add_example_executable
(
example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_cgemm_xdl example_cgemm_xdl_fp32
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_fp32
)
endif
()
add_example_executable
(
example_cgemm_xdl_int8 cgemm_xdl_int8.cpp
)
add_example_executable
(
example_cgemm_xdl_int8 cgemm_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_cgemm_xdl example_cgemm_xdl_int8
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_int8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_cgemm_xdl_int4 cgemm_xdl_int4.cpp
)
add_example_executable
(
example_cgemm_xdl_int4 cgemm_xdl_int4.cpp
)
add_dependencies
(
example_cgemm_xdl example_cgemm_xdl_int4
)
add_
example_
dependencies
(
example_cgemm_xdl example_cgemm_xdl_int4
)
endif
()
endif
()
example/24_batched_gemm/CMakeLists.txt
View file @
de1afb7b
add_custom_target
(
example_batched_gemm_xdl
)
add_custom_target
(
example_batched_gemm_xdl
)
add_example_executable
(
example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp
)
add_example_executable
(
example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_fp32
)
add_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_fp32
)
endif
()
add_example_executable
(
example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_fp16
)
add_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_fp16
)
endif
()
add_example_executable
(
example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp
)
add_example_executable
(
example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_bf16
)
add_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_bf16
)
endif
()
add_example_executable
(
example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp
)
add_example_executable
(
example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_int8
)
add_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_int8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp
)
add_example_executable
(
example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_int4
)
add_dependencies
(
example_batched_gemm_xdl example_batched_gemm_xdl_int4
)
endif
()
endif
()
endif
()
example/27_layernorm/layernorm_fp16.cpp
View file @
de1afb7b
...
@@ -3,12 +3,15 @@
...
@@ -3,12 +3,15 @@
#include "common.hpp"
#include "common.hpp"
using
XDataType
=
ck
::
half_t
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
using
SaveMeanInvStdDataType
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ComputeDataType
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
#define SAVE_MEAN_INV_STD
constexpr
int
Rank
=
2
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
constexpr
int
NumReduceDim
=
1
;
...
@@ -19,6 +22,7 @@ using DeviceInstance =
...
@@ -19,6 +22,7 @@ using DeviceInstance =
BetaDataType
,
BetaDataType
,
ComputeDataType
,
ComputeDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
PassThrough
,
PassThrough
,
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
...
@@ -33,7 +37,8 @@ using DeviceInstance =
...
@@ -33,7 +37,8 @@ using DeviceInstance =
8
,
// GammaScalarPerVector
8
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
1
,
// BetaVecDim (0=M, 1=K)
8
,
// BetaScalarPerVector
8
,
// BetaScalarPerVector
8
>
;
// OutScalarPerVector
8
,
// YScalarPerVector
1
>
;
// SaveMeanInvStdScalarPerVector
#include "run_layernorm_example.inc"
#include "run_layernorm_example.inc"
int
main
()
{
return
run_groupnorm_example
<
DeviceInstance
>
();
}
int
main
()
{
return
run_groupnorm_example
<
DeviceInstance
>
();
}
example/27_layernorm/layernorm_splitk_fp16.cpp
View file @
de1afb7b
...
@@ -3,12 +3,15 @@
...
@@ -3,12 +3,15 @@
#include "common.hpp"
#include "common.hpp"
using
XDataType
=
ck
::
half_t
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
using
SaveMeanInvStdDataType
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ComputeDataType
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
#define SAVE_MEAN_INV_STD
constexpr
int
Rank
=
2
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
constexpr
int
NumReduceDim
=
1
;
...
@@ -19,6 +22,7 @@ using DeviceInstance =
...
@@ -19,6 +22,7 @@ using DeviceInstance =
BetaDataType
,
BetaDataType
,
ComputeDataType
,
ComputeDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
PassThrough
,
PassThrough
,
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
...
@@ -33,7 +37,8 @@ using DeviceInstance =
...
@@ -33,7 +37,8 @@ using DeviceInstance =
8
,
// GammaScalarPerVector
8
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
1
,
// BetaVecDim (0=M, 1=K)
8
,
// BetaScalarPerVector
8
,
// BetaScalarPerVector
8
>
;
// YScalarPerVector
8
,
// YScalarPerVector
1
>
;
// SaveMeanInvStdScalarPerVector
#include "run_layernorm_example.inc"
#include "run_layernorm_example.inc"
...
...
example/27_layernorm/run_layernorm_example.inc
View file @
de1afb7b
...
@@ -10,22 +10,13 @@ int run_groupnorm_example()
...
@@ -10,22 +10,13 @@ int run_groupnorm_example()
ck
::
index_t
M
=
1024
;
ck
::
index_t
M
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
Stride
=
N
;
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
Tensor
<
XDataType
>
x
({
M
,
N
});
return
HostTensorDescriptor
({
len
},
{
stride
});
Tensor
<
GammaDataType
>
gamma
({
N
});
};
Tensor
<
BetaDataType
>
beta
({
N
});
Tensor
<
YDataType
>
y
({
M
,
N
});
auto
f_host_tensor_descriptor2d
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
)
{
Tensor
<
SaveMeanInvStdDataType
>
save_mean
({
M
});
using
namespace
ck
::
literals
;
Tensor
<
SaveMeanInvStdDataType
>
save_inv_std
({
M
});
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1_
uz
});
};
Tensor
<
XDataType
>
x
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
Tensor
<
GammaDataType
>
gamma
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
BetaDataType
>
beta
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
YDataType
>
y
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
x
.
GenerateTensorValue
(
GeneratorTensor_3
<
XDataType
>
{
0.0
,
1.0
});
x
.
GenerateTensorValue
(
GeneratorTensor_3
<
XDataType
>
{
0.0
,
1.0
});
gamma
.
GenerateTensorValue
(
GeneratorTensor_3
<
GammaDataType
>
{
0.0
,
1.0
});
gamma
.
GenerateTensorValue
(
GeneratorTensor_3
<
GammaDataType
>
{
0.0
,
1.0
});
...
@@ -35,6 +26,11 @@ int run_groupnorm_example()
...
@@ -35,6 +26,11 @@ int run_groupnorm_example()
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
beta_dev
(
sizeof
(
BetaDataType
)
*
beta
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
beta_dev
(
sizeof
(
BetaDataType
)
*
beta
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_dev
(
sizeof
(
YDataType
)
*
y
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_dev
(
sizeof
(
YDataType
)
*
y
.
mDesc
.
GetElementSpaceSize
());
#ifdef SAVE_MEAN_INV_STD
DeviceMem
save_mean_dev
(
sizeof
(
SaveMeanInvStdDataType
)
*
save_mean
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
save_inv_std_dev
(
sizeof
(
SaveMeanInvStdDataType
)
*
save_inv_std
.
mDesc
.
GetElementSpaceSize
());
#endif
x_dev
.
ToDevice
(
x
.
mData
.
data
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
gamma_dev
.
ToDevice
(
gamma
.
mData
.
data
());
gamma_dev
.
ToDevice
(
gamma
.
mData
.
data
());
...
@@ -47,14 +43,23 @@ int run_groupnorm_example()
...
@@ -47,14 +43,23 @@ int run_groupnorm_example()
{
0
,
1
},
{
0
,
1
},
{
0
,
1
},
{
0
,
1
},
std
::
vector
<
ck
::
index_t
>
{
y
.
mDesc
.
GetStrides
()
.
begin
(),
y
.
mDesc
.
GetStrides
()
.
end
()},
std
::
vector
<
ck
::
index_t
>
{
y
.
mDesc
.
GetStrides
()
.
begin
(),
y
.
mDesc
.
GetStrides
()
.
end
()},
std
::
vector
<
ck
::
index_t
>
{
save_mean
.
mDesc
.
GetStrides
()
.
begin
(),
save_mean
.
mDesc
.
GetStrides
()
.
end
()},
std
::
vector
<
ck
::
index_t
>
{
save_mean
.
mDesc
.
GetStrides
()
.
begin
(),
save_mean
.
mDesc
.
GetStrides
()
.
end
()},
{
1
},
{
1
},
1
e
-
4
,
1
e
-
4
,
x_dev
.
GetDeviceBuffer
(),
x_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
beta_dev
.
GetDeviceBuffer
(),
beta_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
save_mean_dev
.
GetDeviceBuffer
(),
save_inv_std_dev
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
nullptr
,
nullptr
,
#endif
PassThrough
{});
PassThrough
{});
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
...
@@ -72,24 +77,45 @@ int run_groupnorm_example()
...
@@ -72,24 +77,45 @@ int run_groupnorm_example()
bool
pass
=
true
;
bool
pass
=
true
;
{
{
Tensor
<
YDataType
>
host_y
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
Tensor
<
YDataType
>
host_y
({
M
,
N
});
using
ReferenceInstance
=
ck
::
tensor_operation
::
host
::
ReferenceLayernorm
<
XDataType
,
Tensor
<
SaveMeanInvStdDataType
>
host_save_mean
({
M
});
GammaDataType
,
Tensor
<
SaveMeanInvStdDataType
>
host_save_inv_std
({
M
});
BetaDataType
,
YDataType
,
using
ReferenceInstance
=
ComputeDataType
,
ck
::
tensor_operation
::
host
::
ReferenceLayernorm
<
XDataType
,
PassThrough
,
GammaDataType
,
Rank
,
BetaDataType
,
NumReduceDim
>
;
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
PassThrough
,
Rank
,
NumReduceDim
>
;
ReferenceInstance
ref
;
ReferenceInstance
ref
;
auto
ref_argument
=
auto
ref_argument
=
ref
.
MakeArgument
(
x
,
ref
.
MakeArgument
(
x
,
gamma
,
beta
,
host_y
,
PassThrough
{},
{
M
,
N
},
{
1
},
1
e
-
4
);
gamma
,
auto
ref_invoker
=
ref
.
MakeInvoker
();
beta
,
host_y
,
host_save_mean
,
host_save_inv_std
,
PassThrough
{},
{
M
,
N
},
{
1
},
1
e
-
4
);
auto
ref_invoker
=
ref
.
MakeInvoker
();
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
y_dev
.
FromDevice
(
y
.
mData
.
data
());
y_dev
.
FromDevice
(
y
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
y
,
host_y
,
"Error: Incorrect results"
,
1
e
-
3
,
1
e
-
3
);
pass
&=
ck
::
utils
::
check_err
(
y
,
host_y
,
"Error: Incorrect results (y)"
,
1
e
-
3
,
1
e
-
3
);
#ifdef SAVE_MEAN_INV_STD
save_mean_dev
.
FromDevice
(
save_mean
.
mData
.
data
());
save_inv_std_dev
.
FromDevice
(
save_inv_std
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
save_mean
,
host_save_mean
,
"Error: Incorrect results (mean)"
,
1
e
-
3
,
1
e
-
3
);
pass
&=
ck
::
utils
::
check_err
(
save_inv_std
,
host_save_inv_std
,
"Error: Incorrect results (inv_std)"
,
1
e
-
3
,
1
e
-
3
);
#endif
}
}
return
(
pass
?
0
:
1
);
return
(
pass
?
0
:
1
);
...
...
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
View file @
de1afb7b
...
@@ -3,44 +3,38 @@ list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102)
...
@@ -3,44 +3,38 @@ list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list1 AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list1 AND target EQUAL 0
)
add_custom_target
(
example_grouped_conv_fwd_multiple_d
)
add_custom_target
(
example_grouped_conv_fwd_multiple_d
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16
)
add_example_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16
)
endif
()
add_example_executable
(
example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16
)
endif
()
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp
)
add_example_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32
)
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp
)
endif
()
add_example_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16
)
add_example_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8
)
endif
()
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp
)
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
result EQUAL 0
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
)
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8
)
add_example_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4
)
endif
()
endif
()
# USE_BITINT_EXTENSION_INT4
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
)
set
(
target 1
)
if
(
result EQUAL 0
)
endif
()
add_dependencies
(
example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4
)
endif
()
endif
()
# USE_BITINT_EXTENSION_INT4
set
(
target 1
)
endif
()
endforeach
()
endforeach
()
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list2 AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list2 AND target EQUAL 0
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp
)
add_example_executable
(
example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment