Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
aa61ccf0
Unverified
Commit
aa61ccf0
authored
Oct 24, 2023
by
arai713
Committed by
GitHub
Oct 24, 2023
Browse files
Merge branch 'develop' into hip_tensor_permute
parents
4498e2a1
bec84efb
Changes
184
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
458 additions
and
313 deletions
+458
-313
CHANGELOG.md
CHANGELOG.md
+2
-2
CMakeLists.txt
CMakeLists.txt
+3
-3
Jenkinsfile
Jenkinsfile
+2
-2
client_example/05_layernorm/layernorm2d.cpp
client_example/05_layernorm/layernorm2d.cpp
+33
-9
client_example/18_groupnorm/groupnorm_swish.cpp
client_example/18_groupnorm/groupnorm_swish.cpp
+78
-44
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-0
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+23
-45
example/04_gemm_add_add_fastgelu/CMakeLists.txt
example/04_gemm_add_add_fastgelu/CMakeLists.txt
+20
-24
example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt
...e/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt
+22
-25
example/12_reduce/README.md
example/12_reduce/README.md
+2
-2
example/15_grouped_gemm/CMakeLists.txt
example/15_grouped_gemm/CMakeLists.txt
+19
-31
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
+44
-58
example/20_grouped_conv_bwd_weight/CMakeLists.txt
example/20_grouped_conv_bwd_weight/CMakeLists.txt
+22
-21
example/20_grouped_conv_bwd_weight/common.hpp
example/20_grouped_conv_bwd_weight/common.hpp
+18
-22
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
...ouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
+20
-1
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp
...ped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp
+88
-0
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
+20
-1
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
+20
-1
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
..._weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
+20
-1
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
...d_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
+1
-21
No files found.
CHANGELOG.md
View file @
aa61ccf0
...
@@ -14,10 +14,10 @@ None
...
@@ -14,10 +14,10 @@ None
### Additions
### Additions
-
Added an image to a column kernel (#867)
-
Added an image to a column kernel (#867)
-
Added a column to an image kernel (#930)
-
Added a column to an image kernel (#930)
-
Support for 3D grouped convolution
forward
on RDNA 3 GPUs (#935)
-
Support for 3D grouped convolution on RDNA 3 GPUs (#935
, #950, #985
)
-
Grouped convolution support for small K and C (#822 #879 #897)
-
Grouped convolution support for small K and C (#822 #879 #897)
-
Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
-
Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
-
Support for bf16/f32/f16 and NHWGC (2D and 3
d
) grouped convolution backward data (#757 #799)
-
Support for bf16/f32/f16 and NHWGC (2D and 3
D
) grouped convolution backward data (#757 #799)
-
Support for Batched Gemm DL (#732)
-
Support for Batched Gemm DL (#732)
### Changes
### Changes
...
...
CMakeLists.txt
View file @
aa61ccf0
...
@@ -32,12 +32,10 @@ if (DTYPES)
...
@@ -32,12 +32,10 @@ if (DTYPES)
if
(
DTYPES MATCHES
"fp8"
)
if
(
DTYPES MATCHES
"fp8"
)
add_definitions
(
-DCK_ENABLE_FP8
)
add_definitions
(
-DCK_ENABLE_FP8
)
set
(
CK_ENABLE_FP8
"ON"
)
set
(
CK_ENABLE_FP8
"ON"
)
add_compile_options
(
-Wno-bit-int-extension
)
endif
()
endif
()
if
(
DTYPES MATCHES
"bf8"
)
if
(
DTYPES MATCHES
"bf8"
)
add_definitions
(
-DCK_ENABLE_BF8
)
add_definitions
(
-DCK_ENABLE_BF8
)
set
(
CK_ENABLE_BF8
"ON"
)
set
(
CK_ENABLE_BF8
"ON"
)
add_compile_options
(
-Wno-bit-int-extension
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp16"
)
if
(
DTYPES MATCHES
"fp16"
)
add_definitions
(
-DCK_ENABLE_FP16
)
add_definitions
(
-DCK_ENABLE_FP16
)
...
@@ -59,9 +57,11 @@ if (DTYPES)
...
@@ -59,9 +57,11 @@ if (DTYPES)
else
()
else
()
add_definitions
(
-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16
)
add_definitions
(
-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16
)
set
(
CK_ENABLE_ALL_DTYPES
"ON"
)
set
(
CK_ENABLE_ALL_DTYPES
"ON"
)
add_compile_options
(
-Wno-bit-int-extension
)
# enable fp8 and bf8
endif
()
endif
()
#for f8/bf8_t type
add_compile_options
(
-Wno-bit-int-extension
)
if
(
DL_KERNELS
)
if
(
DL_KERNELS
)
add_definitions
(
-DDL_KERNELS
)
add_definitions
(
-DDL_KERNELS
)
set
(
CK_ENABLE_DL_KERNELS
"ON"
)
set
(
CK_ENABLE_DL_KERNELS
"ON"
)
...
...
Jenkinsfile
View file @
aa61ccf0
...
@@ -790,8 +790,8 @@ pipeline {
...
@@ -790,8 +790,8 @@ pipeline {
}
}
agent
{
label
rocmnode
(
"navi32"
)
}
agent
{
label
rocmnode
(
"navi32"
)
}
environment
{
environment
{
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" """
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101"
-DDL_KERNELS=ON
"""
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101"
-DDL_KERNELS=ON
-D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
}
}
steps
{
steps
{
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
"install"
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
,
prefixpath:
'/usr/local'
)
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
"install"
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
,
prefixpath:
'/usr/local'
)
...
...
client_example/05_layernorm/layernorm2d.cpp
View file @
aa61ccf0
...
@@ -12,12 +12,14 @@
...
@@ -12,12 +12,14 @@
#include "ck/library/tensor_operation_instance/gpu/normalization.hpp"
#include "ck/library/tensor_operation_instance/gpu/normalization.hpp"
using
XDataType
=
ck
::
half_t
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
using
SaveMeanInvStdDataType
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
#define SAVE_MEAN_INV_STD
constexpr
int
Rank
=
2
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
constexpr
int
NumReduceDim
=
1
;
...
@@ -50,12 +52,16 @@ int main(int argc, char* argv[])
...
@@ -50,12 +52,16 @@ int main(int argc, char* argv[])
SimpleDeviceMem
gamma_device_buf
(
sizeof
(
GammaDataType
)
*
N
);
SimpleDeviceMem
gamma_device_buf
(
sizeof
(
GammaDataType
)
*
N
);
SimpleDeviceMem
beta_device_buf
(
sizeof
(
BetaDataType
)
*
N
);
SimpleDeviceMem
beta_device_buf
(
sizeof
(
BetaDataType
)
*
N
);
SimpleDeviceMem
y_device_buf
(
sizeof
(
YDataType
)
*
xy_size
);
SimpleDeviceMem
y_device_buf
(
sizeof
(
YDataType
)
*
xy_size
);
#ifdef SAVE_MEAN_INV_STD
SimpleDeviceMem
save_mean_device_buf
(
sizeof
(
SaveMeanInvStdDataType
)
*
M
);
SimpleDeviceMem
save_inv_std_device_buf
(
sizeof
(
SaveMeanInvStdDataType
)
*
M
);
#endif
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceNormalization
<
XDataType
,
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceNormalization
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
PassThrough
,
PassThrough
,
Rank
,
Rank
,
NumReduceDim
>
;
NumReduceDim
>
;
...
@@ -84,14 +90,21 @@ int main(int argc, char* argv[])
...
@@ -84,14 +90,21 @@ int main(int argc, char* argv[])
{
0
,
1
},
// gammaStrides
{
0
,
1
},
// gammaStrides
{
0
,
1
},
// betaStrides
{
0
,
1
},
// betaStrides
{
Stride
,
1
},
// yStrides
{
Stride
,
1
},
// yStrides
{
1
},
// save_mean Strides
{
1
},
// save_inv_std Strides
{
1
},
// reduceDims
{
1
},
// reduceDims
1e-4
,
1e-4
,
x_device_buf
.
GetDeviceBuffer
(),
x_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf
.
GetDeviceBuffer
(),
save_inv_std_device_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
nullptr
,
nullptr
,
#endif
PassThrough
{});
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
@@ -109,6 +122,10 @@ int main(int argc, char* argv[])
...
@@ -109,6 +122,10 @@ int main(int argc, char* argv[])
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
M
*
N
+
sizeof
(
GammaDataType
)
*
N
+
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
M
*
N
+
sizeof
(
GammaDataType
)
*
N
+
sizeof
(
BetaDataType
)
*
N
+
sizeof
(
YDataType
)
*
M
*
N
;
sizeof
(
BetaDataType
)
*
N
+
sizeof
(
YDataType
)
*
M
*
N
;
#ifdef SAVE_MEAN_INV_STD
num_byte
+=
sizeof
(
SaveMeanInvStdDataType
)
*
M
*
2
;
#endif
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
...
@@ -140,17 +157,24 @@ int main(int argc, char* argv[])
...
@@ -140,17 +157,24 @@ int main(int argc, char* argv[])
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
M
,
N
},
// lengths
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
M
,
N
},
// lengths
{
Stride
,
1
},
// xStrides
{
Stride
,
1
},
// xStrides
{
1
},
// gammaStrides
{
0
,
1
},
// gammaStrides
{
1
},
// betaStrides
{
0
,
1
},
// betaStrides
{
Stride
,
1
},
// yStrides
{
Stride
,
1
},
// yStrides
{
1
},
// save_mean Strides
{
1
},
// save_inv_std Strides
{
1
},
// reduceDims
{
1
},
// reduceDims
1e-4
,
1e-4
,
x_device_buf
.
GetDeviceBuffer
(),
x_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf
.
GetDeviceBuffer
(),
save_inv_std_device_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
nullptr
,
nullptr
,
#endif
PassThrough
{});
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
...
client_example/18_groupnorm/groupnorm_swish.cpp
View file @
aa61ccf0
...
@@ -12,12 +12,14 @@
...
@@ -12,12 +12,14 @@
#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp"
#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp"
using
XDataType
=
ck
::
half_t
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
float
;
using
GammaDataType
=
float
;
using
BetaDataType
=
float
;
using
BetaDataType
=
float
;
using
YDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
using
SaveMeanInvStdDataType
=
float
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
#define SAVE_MEAN_INV_STD
constexpr
int
Rank
=
5
;
constexpr
int
Rank
=
5
;
constexpr
int
NumReduceDim
=
3
;
constexpr
int
NumReduceDim
=
3
;
...
@@ -49,19 +51,24 @@ int main(int argc, char* argv[])
...
@@ -49,19 +51,24 @@ int main(int argc, char* argv[])
std
::
size_t
xy_size
=
N
*
H
*
W
*
G
*
C
;
std
::
size_t
xy_size
=
N
*
H
*
W
*
G
*
C
;
std
::
size_t
gamma_beta_size
=
G
*
C
;
std
::
size_t
gamma_beta_size
=
G
*
C
;
std
::
vector
<
ck
::
index_t
>
xy_strides
=
{
H
*
W
*
G
*
C
,
W
*
G
*
C
,
G
*
C
,
C
,
1
};
std
::
vector
<
ck
::
index_t
>
xy_strides
=
{
H
*
W
*
G
*
C
,
W
*
G
*
C
,
G
*
C
,
C
,
1
};
std
::
vector
<
ck
::
index_t
>
gamma_beta_strides
=
{
0
,
0
,
0
,
C
,
1
};
std
::
vector
<
ck
::
index_t
>
gamma_beta_strides
=
{
0
,
0
,
0
,
C
,
1
};
std
::
vector
<
ck
::
index_t
>
save_mean_inv_std_strides
=
{
G
,
1
};
SimpleDeviceMem
x_device_buf
(
sizeof
(
XDataType
)
*
xy_size
);
SimpleDeviceMem
x_device_buf
(
sizeof
(
XDataType
)
*
xy_size
);
SimpleDeviceMem
gamma_device_buf
(
sizeof
(
GammaDataType
)
*
gamma_beta_size
);
SimpleDeviceMem
gamma_device_buf
(
sizeof
(
GammaDataType
)
*
gamma_beta_size
);
SimpleDeviceMem
beta_device_buf
(
sizeof
(
BetaDataType
)
*
gamma_beta_size
);
SimpleDeviceMem
beta_device_buf
(
sizeof
(
BetaDataType
)
*
gamma_beta_size
);
SimpleDeviceMem
y_device_buf
(
sizeof
(
YDataType
)
*
xy_size
);
SimpleDeviceMem
y_device_buf
(
sizeof
(
YDataType
)
*
xy_size
);
#ifdef SAVE_MEAN_INV_STD
SimpleDeviceMem
save_mean_device_buf
(
sizeof
(
SaveMeanInvStdDataType
)
*
N
*
G
);
SimpleDeviceMem
save_inv_std_device_buf
(
sizeof
(
SaveMeanInvStdDataType
)
*
N
*
G
);
#endif
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceNormalization
<
XDataType
,
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceNormalization
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
Swish
,
Swish
,
Rank
,
Rank
,
NumReduceDim
>
;
NumReduceDim
>
;
...
@@ -75,19 +82,26 @@ int main(int argc, char* argv[])
...
@@ -75,19 +82,26 @@ int main(int argc, char* argv[])
const
auto
&
generic_op_ptr
=
op_ptrs
[
0
];
const
auto
&
generic_op_ptr
=
op_ptrs
[
0
];
auto
generic_argument_ptr
=
auto
generic_argument_ptr
=
generic_op_ptr
->
MakeArgumentPointer
({
N
,
H
,
W
,
G
,
C
},
// lengths
generic_op_ptr
->
MakeArgumentPointer
({
N
,
H
,
W
,
G
,
C
},
// lengths
xy_strides
,
// xStrides
xy_strides
,
// xStrides
gamma_beta_strides
,
// gammaStrides
gamma_beta_strides
,
// gammaStrides
gamma_beta_strides
,
// betaStrides
gamma_beta_strides
,
// betaStrides
xy_strides
,
// yStrides
xy_strides
,
// yStrides
{
1
,
2
,
4
},
// reduceDims
save_mean_inv_std_strides
,
// save_mean Strides
save_mean_inv_std_strides
,
// save_inv_std Strides
{
1
,
2
,
4
},
// reduceDims
1e-6
,
1e-6
,
x_device_buf
.
GetDeviceBuffer
(),
x_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf
.
GetDeviceBuffer
(),
save_inv_std_device_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
nullptr
,
nullptr
,
#endif
Swish
{});
Swish
{});
if
(
!
generic_op_ptr
->
IsSupportedArgument
(
generic_argument_ptr
.
get
()))
if
(
!
generic_op_ptr
->
IsSupportedArgument
(
generic_argument_ptr
.
get
()))
...
@@ -107,21 +121,29 @@ int main(int argc, char* argv[])
...
@@ -107,21 +121,29 @@ int main(int argc, char* argv[])
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
N
,
H
,
W
,
G
,
C
},
// lengths
auto
argument_ptr
=
xy_strides
,
// xStrides
op_ptr
->
MakeArgumentPointer
({
N
,
H
,
W
,
G
,
C
},
// lengths
gamma_beta_strides
,
// gammaStrides
xy_strides
,
// xStrides
gamma_beta_strides
,
// betaStrides
gamma_beta_strides
,
// gammaStrides
xy_strides
,
// yStrides
gamma_beta_strides
,
// betaStrides
{
1
,
2
,
4
},
// reduceDims
xy_strides
,
// yStrides
1e-6
,
save_mean_inv_std_strides
,
// save_mean Strides
x_device_buf
.
GetDeviceBuffer
(),
save_mean_inv_std_strides
,
// save_inv_std Strides
gamma_device_buf
.
GetDeviceBuffer
(),
{
1
,
2
,
4
},
// reduceDims
beta_device_buf
.
GetDeviceBuffer
(),
1e-6
,
y_device_buf
.
GetDeviceBuffer
(),
x_device_buf
.
GetDeviceBuffer
(),
nullptr
,
gamma_device_buf
.
GetDeviceBuffer
(),
nullptr
,
beta_device_buf
.
GetDeviceBuffer
(),
Swish
{});
y_device_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf
.
GetDeviceBuffer
(),
save_inv_std_device_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
#endif
Swish
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
@@ -139,6 +161,10 @@ int main(int argc, char* argv[])
...
@@ -139,6 +161,10 @@ int main(int argc, char* argv[])
sizeof
(
XDataType
)
*
xy_size
+
sizeof
(
GammaDataType
)
*
gamma_beta_size
+
sizeof
(
XDataType
)
*
xy_size
+
sizeof
(
GammaDataType
)
*
gamma_beta_size
+
sizeof
(
BetaDataType
)
*
gamma_beta_size
+
sizeof
(
YDataType
)
*
xy_size
;
sizeof
(
BetaDataType
)
*
gamma_beta_size
+
sizeof
(
YDataType
)
*
xy_size
;
#ifdef SAVE_MEAN_INV_STD
num_byte
+=
sizeof
(
SaveMeanInvStdDataType
)
*
N
*
G
*
2
;
#endif
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
...
@@ -169,20 +195,28 @@ int main(int argc, char* argv[])
...
@@ -169,20 +195,28 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
N
,
H
,
W
,
G
,
C
},
// lengths
auto
argument_ptr
=
xy_strides
,
// xStrides
op_ptr
->
MakeArgumentPointer
({
N
,
H
,
W
,
G
,
C
},
// lengths
gamma_beta_strides
,
// gammaStrides
xy_strides
,
// xStrides
gamma_beta_strides
,
// betaStrides
gamma_beta_strides
,
// gammaStrides
xy_strides
,
// yStrides
gamma_beta_strides
,
// betaStrides
{
1
,
2
,
4
},
// reduceDims
xy_strides
,
// yStrides
1e-6
,
save_mean_inv_std_strides
,
// save_mean Strides
x_device_buf
.
GetDeviceBuffer
(),
save_mean_inv_std_strides
,
// save_inv_std Strides
gamma_device_buf
.
GetDeviceBuffer
(),
{
1
,
2
,
4
},
// reduceDims
beta_device_buf
.
GetDeviceBuffer
(),
1e-6
,
y_device_buf
.
GetDeviceBuffer
(),
x_device_buf
.
GetDeviceBuffer
(),
nullptr
,
gamma_device_buf
.
GetDeviceBuffer
(),
nullptr
,
beta_device_buf
.
GetDeviceBuffer
(),
Swish
{});
y_device_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf
.
GetDeviceBuffer
(),
save_inv_std_device_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
#endif
Swish
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
...
cmake/EnableCompilerWarnings.cmake
View file @
aa61ccf0
...
@@ -70,6 +70,7 @@ else()
...
@@ -70,6 +70,7 @@ else()
-Wno-option-ignored
-Wno-option-ignored
-Wsign-compare
-Wsign-compare
-Wno-extra-semi-stmt
-Wno-extra-semi-stmt
-Wno-unused-template
)
)
if
(
CMAKE_
${
COMPILER
}
_COMPILER_ID MATCHES
"Clang"
)
if
(
CMAKE_
${
COMPILER
}
_COMPILER_ID MATCHES
"Clang"
)
list
(
APPEND CMAKE_COMPILER_WARNINGS
list
(
APPEND CMAKE_COMPILER_WARNINGS
...
...
example/01_gemm/CMakeLists.txt
View file @
aa61ccf0
add_custom_target
(
example_gemm_dl
)
add_custom_target
(
example_gemm_dl
)
add_example_executable
(
example_gemm_dl_fp32 gemm_dl_fp32.cpp
)
add_example_executable
(
example_gemm_dl_fp32 gemm_dl_fp32.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_dl example_gemm_dl_fp32
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp32
)
endif
()
add_example_executable
(
example_gemm_dl_fp16 gemm_dl_fp16.cpp
)
add_example_executable
(
example_gemm_dl_fp16 gemm_dl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_dl example_gemm_dl_fp16
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp16
)
endif
()
add_example_executable
(
example_gemm_dpp_fp16 gemm_dpp_fp16.cpp
)
add_example_executable
(
example_gemm_dpp_fp16 gemm_dpp_fp16.cpp
)
add_example_executable
(
example_gemm_dl_int8 gemm_dl_int8.cpp
)
add_example_executable
(
example_gemm_dl_int8 gemm_dl_int8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_dl example_gemm_dl_int8
)
add_dependencies
(
example_gemm_dl example_gemm_dl_int8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_dl_int4 gemm_dl_int4.cpp
)
add_example_executable
(
example_gemm_dl_int4 gemm_dl_int4.cpp
)
add_dependencies
(
example_gemm_dl example_gemm_dl_int4
)
add_
example_
dependencies
(
example_gemm_dl example_gemm_dl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
add_custom_target
(
example_gemm_xdl
)
add_custom_target
(
example_gemm_xdl
)
add_example_executable
(
example_gemm_xdl_fp16 gemm_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_fp16 gemm_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16
)
endif
()
add_example_executable
(
example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_wavelet_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_wavelet_fp16
)
endif
()
add_example_executable
(
example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
endif
()
if
(
GPU_TARGETS MATCHES
"gfx1100"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102"
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102"
)
add_custom_target
(
example_gemm_wmma
)
add_custom_target
(
example_gemm_wmma
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
add_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
endif
()
endif
()
endif
()
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_example_executable
(
example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp
)
add_example_executable
(
example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_rtn
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_rtn
)
endif
()
add_example_executable
(
example_gemm_xdl_int8 gemm_xdl_int8.cpp
)
add_example_executable
(
example_gemm_xdl_int8 gemm_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_int8
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_xdl_int4 gemm_xdl_int4.cpp
)
add_example_executable
(
example_gemm_xdl_int4 gemm_xdl_int4.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int4
)
add_
example_
dependencies
(
example_gemm_xdl example_gemm_xdl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing
(
example_gemm_xdl_fp64 gemm_xdl_fp64.cpp
)
add_example_executable_no_testing
(
example_gemm_xdl_fp64 gemm_xdl_fp64.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp64
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp64
)
endif
()
add_example_executable
(
example_gemm_xdl_streamk gemm_xdl_streamk.cpp
)
add_example_executable
(
example_gemm_xdl_streamk gemm_xdl_streamk.cpp
)
add_example_executable
(
example_gemm_xdl_fp8 gemm_xdl_fp8.cpp
)
add_example_executable
(
example_gemm_xdl_fp8 gemm_xdl_fp8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8
)
endif
()
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
endif
()
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
endif
()
example/04_gemm_add_add_fastgelu/CMakeLists.txt
View file @
aa61ccf0
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_gemm_add_add_fastgelu_xdl
)
add_custom_target
(
example_gemm_add_add_fastgelu_xdl
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8
)
set
(
target 1
)
endif
()
endif
()
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp
)
endforeach
()
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16
)
endif
()
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8
)
endif
()
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt
View file @
aa61ccf0
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_convnd_fwd_reduce_xdl
)
add_custom_target
(
example_convnd_fwd_reduce_xdl
)
add_example_executable
(
example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_example_executable
(
example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8
)
add_example_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8
)
endif
()
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp
)
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16
)
endif
()
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp
)
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp
)
add_example_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16
)
if
(
result EQUAL 0
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16
)
add_example_executable
(
example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp
)
endif
()
add_example_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32
)
add_example_executable
(
example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp
)
if
(
result EQUAL 0
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32
)
add_example_executable
(
example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp
)
endif
()
add_example_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4
)
if
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp
)
set
(
target 1
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4
)
endif
()
endif
(
USE_BITINT_EXTENSION_INT4
)
endforeach
()
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
example/12_reduce/README.md
View file @
aa61ccf0
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
## Run ```example_reduce_blockwise```
## Run ```example_reduce_blockwise```
```
bash
```
bash
# -D <xxx> : input 3
d
/4
d
/5
d
tensor lengths
# -D <xxx> : input 3
D
/4
D
/5
D
tensor lengths
# -R <xxx> : reduce dimension ids
# -R <xxx> : reduce dimension ids
# -v <x> : verification (0=no, 1=yes)
# -v <x> : verification (0=no, 1=yes)
#arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4)
#arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4)
...
@@ -22,7 +22,7 @@ Perf: 0.238063 ms, 264.285 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr
...
@@ -22,7 +22,7 @@ Perf: 0.238063 ms, 264.285 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr
## Run ```example_reduce_multiblock_atomic_add```
## Run ```example_reduce_multiblock_atomic_add```
```
bash
```
bash
# -D <xxx> : input 3
d
/4
d
/5
d
tensor lengths
# -D <xxx> : input 3
D
/4
D
/5
D
tensor lengths
# -R <xxx> : reduce dimension ids
# -R <xxx> : reduce dimension ids
# -v <x> : verification (0=no, 1=yes)
# -v <x> : verification (0=no, 1=yes)
#arg1: data type (0: fp32, 1: fp64)
#arg1: data type (0: fp32, 1: fp64)
...
...
example/15_grouped_gemm/CMakeLists.txt
View file @
aa61ccf0
add_custom_target
(
example_grouped_gemm_xdl
)
add_custom_target
(
example_grouped_gemm_xdl
)
add_example_executable
(
example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32
)
endif
()
add_example_executable
(
example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16
)
endif
()
add_example_executable
(
example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16
)
endif
()
add_example_executable
(
example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16
)
endif
()
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16
)
endif
()
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16
)
endif
()
add_example_executable
(
example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16
)
endif
()
add_example_executable
(
example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int8
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int8
)
endif
()
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int4
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int4
)
endif
()
endif
()
endif
()
example/16_gemm_multi_d_multi_reduces/CMakeLists.txt
View file @
aa61ccf0
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_gemm_reduce_xdl
)
add_custom_target
(
example_gemm_reduce_xdl
)
add_custom_target
(
example_gemm_reduce_xdl_max
)
add_custom_target
(
example_gemm_reduce_xdl_max
)
add_custom_target
(
example_gemm_reduce_xdl_mean_meansquare
)
add_custom_target
(
example_gemm_reduce_xdl_mean_meansquare
)
add_custom_target
(
example_gemm_add_add_mean_meansquare_xdl
)
add_custom_target
(
example_gemm_add_add_mean_meansquare_xdl
)
add_example_executable
(
example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_executable
(
example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16
)
add_example_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16
)
endif
()
add_example_executable
(
example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16
)
add_dependencies
(
example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16
)
endif
()
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp
)
add_example_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16
)
add_example_executable
(
example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp
)
endif
()
add_example_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_int8
)
add_example_executable
(
example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp
)
add_example_executable
(
example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_int8
)
endif
()
add_example_executable
(
example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp
)
add_example_executable
(
example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp
)
add_example_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp
)
endif
()
add_example_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32
)
add_example_executable
(
example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp
)
add_example_executable
(
example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32
)
endif
()
add_example_executable
(
example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp
)
add_example_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32
)
add_example_dependencies
(
example_gemm_reduce_xdl
endif
()
example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
add_example_executable
(
example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp
)
example_gemm_add_add_mean_meansquare_xdl
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16
)
if
(
USE_BITINT_EXTENSION_INT4
)
endif
()
add_example_executable
(
example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp
)
add_example_executable
(
example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp
)
add_example_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_int4
)
if
(
result EQUAL 0
)
endif
()
add_dependencies
(
example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16
)
set
(
target 1
)
endif
()
endif
()
add_dependencies
(
example_gemm_reduce_xdl
example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max
example_gemm_add_add_mean_meansquare_xdl
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp
)
if
(
result EQUAL 0
)
add_dependencies
(
example_gemm_reduce_xdl_max example_gemm_max_xdl_int4
)
endif
()
endif
()
set
(
target 1
)
endif
()
endforeach
()
endforeach
()
example/20_grouped_conv_bwd_weight/CMakeLists.txt
View file @
aa61ccf0
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list_xdl AND target EQUAL 0
)
add_custom_target
(
example_grouped_conv_bwd_weight
)
add_custom_target
(
example_grouped_conv_bwd_weight
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
)
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
)
endif
()
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp
)
add_example_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16
)
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
)
endif
()
add_example_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8
)
if
(
GPU_TARGETS MATCHES
"gfx940"
OR GPU_TARGETS MATCHES
"gfx941"
OR GPU_TARGETS MATCHES
"gfx942"
)
set
(
target 1
)
add_example_executable
(
example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
)
endif
()
if
(
result EQUAL 0
)
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8
)
if
(
gpu IN_LIST gpu_list_wmma AND target EQUAL 0
)
add_custom_target
(
example_grouped_conv_bwd_weight
)
add_example_executable
(
example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp
)
add_example_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16
)
set
(
target 1
)
endif
()
endif
()
endif
()
set
(
target 1
)
endif
()
endforeach
()
endforeach
()
add_custom_target
(
example_grouped_conv_bwd_weight_dl
)
add_custom_target
(
example_grouped_conv_bwd_weight_dl
)
add_example_executable
(
example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp
)
add_example_executable
(
example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16
)
add_dependencies
(
example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16
)
endif
()
example/20_grouped_conv_bwd_weight/common.hpp
View file @
aa61ccf0
...
@@ -46,25 +46,21 @@ struct CommonLayoutSetting
...
@@ -46,25 +46,21 @@ struct CommonLayoutSetting
using
OutputLayout
=
OutputLay
;
using
OutputLayout
=
OutputLay
;
};
};
template
<
ck
::
index_t
NDimSpatial
>
struct
CommonLayoutSettingSelector
;
namespace
ctl
=
ck
::
tensor_layout
::
convolution
;
namespace
ctl
=
ck
::
tensor_layout
::
convolution
;
template
<
ck
::
index_t
NDimSpatial
>
template
<
>
struct
CommonLayoutSettingSelector
struct
CommonLayoutSettingSelector
<
1
>
final
:
CommonLayoutSetting
<
ctl
::
GNWC
,
ctl
::
GKXC
,
ctl
::
GNWK
>
:
CommonLayoutSetting
<
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
{
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
};
ck
::
tensor_layout
::
convolution
::
GNHWC
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
>>
,
template
<
>
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
struct
CommonLayoutSettingSelector
<
2
>
final
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GKXC
,
:
CommonLayoutSetting
<
ctl
::
GNHWC
,
ctl
::
GKYXC
,
ctl
::
GNHWK
>
ck
::
tensor_layout
::
convolution
::
GKYXC
,
{
ck
::
tensor_layout
::
convolution
::
GKZYXC
>>
,
};
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
template
<
>
ck
::
tensor_layout
::
convolution
::
GNHWK
,
struct
CommonLayoutSettingSelector
<
3
>
final
ck
::
tensor_layout
::
convolution
::
GNDHWK
>>>
:
CommonLayoutSetting
<
ctl
::
GNDHWC
,
ctl
::
GKZYXC
,
ctl
::
GNDHWK
>
{
{
};
};
...
@@ -84,10 +80,10 @@ struct ExecutionConfig final
...
@@ -84,10 +80,10 @@ struct ExecutionConfig final
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
};
};
#define DefaultConvParam \
#define DefaultConvParam
\
ck::utils::conv::ConvParam \
ck::utils::conv::ConvParam
\
{ \
{
\
2
, 4, 1, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, { 1, 1 } \
3
, 4, 1, 128, 256, {3,
3,
3}, {14,
14,
14}, {1,
1,
1}, {1,
1,
1}, {1,
1,
1}, { 1,
1,
1 } \
}
}
inline
void
print_help_msg
()
inline
void
print_help_msg
()
...
...
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
View file @
aa61ccf0
...
@@ -76,4 +76,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
...
@@ -76,4 +76,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
1
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
!
run_grouped_conv_bwd_weight
<
1
>
(
config
,
conv_param
);
case
2
:
return
!
run_grouped_conv_bwd_weight
<
2
>
(
config
,
conv_param
);
case
3
:
return
!
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
default:
break
;
}
return
1
;
}
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp
0 → 100644
View file @
aa61ccf0
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
using
InDataType
=
F16
;
using
WeiDataType
=
F16
;
using
OutDataType
=
F16
;
using
AccDataType
=
F32
;
using
InElementOp
=
PassThrough
;
using
WeiElementOp
=
PassThrough
;
using
OutElementOp
=
PassThrough
;
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight_Wmma_CShuffle
<
NDimSpatial
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
,
ck
::
tensor_layout
::
convolution
::
GKZYXC
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
,
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
AccDataType
,
// AccDataType
InElementOp
,
// InElementwiseOperation
WeiElementOp
,
// WeiElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
ConvBwdWeightDefault
,
// ConvolutionBackwardWeightSpecialization
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
4
,
// K0PerBlock
8
,
// K1
16
,
// MPerWMMA
16
,
// NPerWMMA
4
,
// MRepeat
2
,
// NRepeat
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_AK0_M_AK1
S
<
0
,
2
,
1
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
>
,
// ABlockTransferSrcAccessOrder
1
,
// ABlockTransferSrcVectorDim
1
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_AK1
true
,
// ABlockLdsExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_BK0_N_BK1
S
<
0
,
2
,
1
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
>
,
// BBlockTransferSrcAccessOrder
1
,
// BBlockTransferSrcVectorDim
1
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_BK1
true
,
// BBlockLdsExtraN
4
,
2
,
S
<
1
,
32
,
1
,
8
>
,
1
>
;
template
<
ck
::
index_t
NDimSpatial
>
using
HostConvBwdWeightInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWeight
<
NDimSpatial
,
InDataType
,
WeiDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
1
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
3
:
return
!
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
default:
break
;
}
return
1
;
}
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
View file @
aa61ccf0
...
@@ -78,4 +78,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
...
@@ -78,4 +78,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
1
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
!
run_grouped_conv_bwd_weight
<
1
>
(
config
,
conv_param
);
case
2
:
return
!
run_grouped_conv_bwd_weight
<
2
>
(
config
,
conv_param
);
case
3
:
return
!
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
default:
break
;
}
return
1
;
}
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
View file @
aa61ccf0
...
@@ -77,4 +77,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
...
@@ -77,4 +77,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
1
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
!
run_grouped_conv_bwd_weight
<
1
>
(
config
,
conv_param
);
case
2
:
return
!
run_grouped_conv_bwd_weight
<
2
>
(
config
,
conv_param
);
case
3
:
return
!
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
default:
break
;
}
return
1
;
}
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
View file @
aa61ccf0
...
@@ -83,4 +83,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
...
@@ -83,4 +83,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
#include "run_grouped_conv_bwd_weight_example.inc"
#include "run_grouped_conv_bwd_weight_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_bwd_weight_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
1
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
!
run_grouped_conv_bwd_weight
<
1
>
(
config
,
conv_param
);
case
2
:
return
!
run_grouped_conv_bwd_weight
<
2
>
(
config
,
conv_param
);
case
3
:
return
!
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
default:
break
;
}
return
1
;
}
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
View file @
aa61ccf0
...
@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial>
...
@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial>
bool
run_grouped_conv_bwd_weight
(
const
ExecutionConfig
&
config
,
bool
run_grouped_conv_bwd_weight
(
const
ExecutionConfig
&
config
,
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
{
{
// Dl op do
es
n't support split_k > 1
// Dl
and WMMA
op
s
don't support split_k > 1
constexpr
ck
::
index_t
split_k
=
1
;
constexpr
ck
::
index_t
split_k
=
1
;
const
auto
in_g_n_c_wis_desc
=
const
auto
in_g_n_c_wis_desc
=
...
@@ -143,23 +143,3 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
...
@@ -143,23 +143,3 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
return
true
;
return
true
;
}
}
bool
run_grouped_conv_bwd_weight_example
(
int
argc
,
char
*
argv
[])
{
ExecutionConfig
config
;
ck
::
utils
::
conv
::
ConvParam
conv_param
=
DefaultConvParam
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_param
))
{
return
false
;
}
switch
(
conv_param
.
num_dim_spatial_
)
{
case
1
:
return
run_grouped_conv_bwd_weight
<
1
>
(
config
,
conv_param
);
case
2
:
return
run_grouped_conv_bwd_weight
<
2
>
(
config
,
conv_param
);
case
3
:
return
run_grouped_conv_bwd_weight
<
3
>
(
config
,
conv_param
);
}
return
false
;
}
Prev
1
2
3
4
5
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment