Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9f8ab221
Unverified
Commit
9f8ab221
authored
Oct 19, 2023
by
zjing14
Committed by
GitHub
Oct 19, 2023
Browse files
Merge branch 'develop' into add_int8_wmma_example_instance
parents
755ace59
b4fc4d0b
Changes
490
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
201 additions
and
208 deletions
+201
-208
client_example/22_im2col_col2im/image_to_column.cpp
client_example/22_im2col_col2im/image_to_column.cpp
+16
-10
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-0
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+2
-2
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+43
-58
example/01_gemm/gemm_xdl_fp16_fp8.cpp
example/01_gemm/gemm_xdl_fp16_fp8.cpp
+0
-0
example/01_gemm/gemm_xdl_fp8.cpp
example/01_gemm/gemm_xdl_fp8.cpp
+3
-3
example/01_gemm/gemm_xdl_fp8_bf8.cpp
example/01_gemm/gemm_xdl_fp8_bf8.cpp
+49
-0
example/02_gemm_bilinear/CMakeLists.txt
example/02_gemm_bilinear/CMakeLists.txt
+0
-2
example/03_gemm_bias_relu/CMakeLists.txt
example/03_gemm_bias_relu/CMakeLists.txt
+0
-2
example/04_gemm_add_add_fastgelu/CMakeLists.txt
example/04_gemm_add_add_fastgelu/CMakeLists.txt
+20
-24
example/09_convnd_fwd/CMakeLists.txt
example/09_convnd_fwd/CMakeLists.txt
+4
-22
example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt
...e/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt
+22
-25
example/12_reduce/README.md
example/12_reduce/README.md
+2
-2
example/13_pool2d_fwd/CMakeLists.txt
example/13_pool2d_fwd/CMakeLists.txt
+2
-6
example/14_gemm_quantization/CMakeLists.txt
example/14_gemm_quantization/CMakeLists.txt
+1
-6
example/15_grouped_gemm/CMakeLists.txt
example/15_grouped_gemm/CMakeLists.txt
+27
-31
example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp
+0
-0
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
+3
-5
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp
+3
-5
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
+3
-5
No files found.
client_example/2
0
_im
age_to
_col
umn
/image_to_column.cpp
→
client_example/2
2
_im
2col
_col
2im
/image_to_column.cpp
View file @
9f8ab221
...
@@ -9,13 +9,14 @@
...
@@ -9,13 +9,14 @@
#include <vector>
#include <vector>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/image_to_column.hpp"
#include "ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using
InDataType
=
ck
::
half_t
;
using
InDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
I
n
Layout
=
ck
::
tensor_layout
::
convolution
::
GNHWC
;
using
I
mage
Layout
=
ck
::
tensor_layout
::
convolution
::
GNHWC
;
static
constexpr
ck
::
index_t
NumDimSpatial
=
2
;
static
constexpr
ck
::
index_t
NumDimSpatial
=
2
;
static
constexpr
ck
::
index_t
G
=
1
;
static
constexpr
ck
::
index_t
G
=
1
;
...
@@ -54,8 +55,8 @@ int main()
...
@@ -54,8 +55,8 @@ int main()
// We have NHWGC in memory space (G is dummy)
// We have NHWGC in memory space (G is dummy)
// However, CK's API only accept length and stride with order of GNCHW
// However, CK's API only accept length and stride with order of GNCHW
// Hence, we need to adjust the order of stride
// Hence, we need to adjust the order of stride
std
::
array
<
ck
::
index_t
,
5
>
i
n
_strides
{
C
,
Hi
*
Wi
*
G
*
C
,
1
,
Wi
*
G
*
C
,
G
*
C
};
std
::
array
<
ck
::
index_t
,
5
>
i
mage
_strides
{
C
,
Hi
*
Wi
*
G
*
C
,
1
,
Wi
*
G
*
C
,
G
*
C
};
std
::
array
<
ck
::
index_t
,
2
>
out
_strides
{
Y
*
X
*
C
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
gemm
_strides
{
Y
*
X
*
C
,
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_strides
{
1
,
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_strides
{
1
,
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_dilations
{
1
,
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_dilations
{
1
,
1
};
...
@@ -65,8 +66,13 @@ int main()
...
@@ -65,8 +66,13 @@ int main()
SimpleDeviceMem
in
(
sizeof
(
InDataType
)
*
N
*
Hi
*
Wi
*
G
*
C
);
SimpleDeviceMem
in
(
sizeof
(
InDataType
)
*
N
*
Hi
*
Wi
*
G
*
C
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
Y
*
X
*
C
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
Y
*
X
*
C
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
using
namespace
ck
::
conv_tensor_rearrange_op
;
DeviceImageToColumn
<
NumDimSpatial
,
InLayout
,
InDataType
,
OutDataType
>
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceConvTensorRearrange
<
NumDimSpatial
,
ImageLayout
,
InDataType
,
OutDataType
,
ImageToColumn
>
;
// get device op instances
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
...
@@ -92,8 +98,8 @@ int main()
...
@@ -92,8 +98,8 @@ int main()
in_spatial_lengths
,
in_spatial_lengths
,
out_spatial_lengths
,
out_spatial_lengths
,
wei_spatial_lengths
,
wei_spatial_lengths
,
i
n
_strides
,
i
mage
_strides
,
out
_strides
,
gemm
_strides
,
filter_strides
,
filter_strides
,
filter_dilations
,
filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -148,8 +154,8 @@ int main()
...
@@ -148,8 +154,8 @@ int main()
in_spatial_lengths
,
in_spatial_lengths
,
out_spatial_lengths
,
out_spatial_lengths
,
wei_spatial_lengths
,
wei_spatial_lengths
,
i
n
_strides
,
i
mage
_strides
,
out
_strides
,
gemm
_strides
,
filter_strides
,
filter_strides
,
filter_dilations
,
filter_dilations
,
input_left_pads
,
input_left_pads
,
...
...
cmake/EnableCompilerWarnings.cmake
View file @
9f8ab221
...
@@ -70,6 +70,7 @@ else()
...
@@ -70,6 +70,7 @@ else()
-Wno-option-ignored
-Wno-option-ignored
-Wsign-compare
-Wsign-compare
-Wno-extra-semi-stmt
-Wno-extra-semi-stmt
-Wno-unused-template
)
)
if
(
CMAKE_
${
COMPILER
}
_COMPILER_ID MATCHES
"Clang"
)
if
(
CMAKE_
${
COMPILER
}
_COMPILER_ID MATCHES
"Clang"
)
list
(
APPEND CMAKE_COMPILER_WARNINGS
list
(
APPEND CMAKE_COMPILER_WARNINGS
...
...
docs/sphinx/requirements.txt
View file @
9f8ab221
...
@@ -42,7 +42,7 @@ fastjsonschema==2.18.0
...
@@ -42,7 +42,7 @@ fastjsonschema==2.18.0
# via rocm-docs-core
# via rocm-docs-core
gitdb==4.0.10
gitdb==4.0.10
# via gitpython
# via gitpython
gitpython==3.1.3
1
gitpython==3.1.3
5
# via rocm-docs-core
# via rocm-docs-core
idna==3.4
idna==3.4
# via requests
# via requests
...
@@ -103,7 +103,7 @@ requests==2.28.2
...
@@ -103,7 +103,7 @@ requests==2.28.2
# via
# via
# pygithub
# pygithub
# sphinx
# sphinx
rocm-docs-core
>
=0.2
0
.0
rocm-docs-core
=
=0.2
4
.0
# via -r requirements.in
# via -r requirements.in
six==1.16.0
six==1.16.0
# via
# via
...
...
example/01_gemm/CMakeLists.txt
View file @
9f8ab221
if
(
DL_KERNELS
)
add_custom_target
(
example_gemm_dl
)
add_custom_target
(
example_gemm_dl
)
add_example_executable
(
example_gemm_dl_fp32 gemm_dl_fp32.cpp
)
add_example_executable
(
example_gemm_dl_fp32 gemm_dl_fp32.cpp
)
add_example_dependencies
(
example_gemm_dl example_gemm_dl_fp32
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp32
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_dl_fp16 gemm_dl_fp16.cpp
)
add_example_executable
(
example_gemm_dl_fp16 gemm_dl_fp16.cpp
)
add_example_dependencies
(
example_gemm_dl example_gemm_dl_fp16
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp16
)
add_example_executable
(
example_gemm_dpp_fp16 gemm_dpp_fp16.cpp
)
add_example_executable
(
example_gemm_dpp_fp16 gemm_dpp_fp16.cpp
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_dl_int8 gemm_dl_int8.cpp
)
add_example_executable
(
example_gemm_dl_int8 gemm_dl_int8.cpp
)
add_example_dependencies
(
example_gemm_dl example_gemm_dl_int8
)
add_dependencies
(
example_gemm_dl example_gemm_dl_int8
)
if
(
USE_BITINT_EXTENSION_INT4
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_dl_int4 gemm_dl_int4.cpp
)
add_example_executable
(
example_gemm_dl_int4 gemm_dl_int4.cpp
)
add_dependencies
(
example_gemm_dl example_gemm_dl_int4
)
add_example_dependencies
(
example_gemm_dl example_gemm_dl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
endif
()
add_custom_target
(
example_gemm_xdl
)
add_custom_target
(
example_gemm_xdl
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_xdl_fp16 gemm_xdl_fp16.cpp
)
add_example_
executable
(
example_gemm_xdl
_fp16
gemm_xdl_fp16
.cpp
)
add_example_
dependencies
(
example_gemm_xdl
example_
gemm_xdl_fp16
)
add_example_executable
(
example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp
)
add_
dependencies
(
example_gemm_xdl example_gemm_xdl_fp16
)
add_
example_executable
(
example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_wavelet_fp16
)
add_
example_
dependencies
(
example_gemm_xdl example_gemm_xdl_wavelet_fp16
)
add_example_executable
(
example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp
)
add_
dependencies
(
example_gemm_xdl example_
gemm_xdl_skip_b_lds_fp16
)
add_
example_executable
(
example_gemm_xdl_skip_b_lds_fp16
gemm_xdl_skip_b_lds_fp16
.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102"
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102"
)
add_custom_target
(
example_gemm_wmma
)
add_custom_target
(
example_gemm_wmma
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
add_example_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
endif
()
endif
()
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_example_executable
(
example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp
)
add_example_executable
(
example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_rtn
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_rtn
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_xdl_int8 gemm_xdl_int8.cpp
)
add_example_executable
(
example_gemm_xdl_int8 gemm_xdl_int8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_int8
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_xdl_int4 gemm_xdl_int4.cpp
)
add_example_executable
(
example_gemm_xdl_int4 gemm_xdl_int4.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int4
)
add_
example_
dependencies
(
example_gemm_xdl example_gemm_xdl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
if
(
DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing
(
example_gemm_xdl_fp64 gemm_xdl_fp64.cpp
)
add_example_executable_no_testing
(
example_gemm_xdl_fp64 gemm_xdl_fp64.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp64
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp64
)
endif
()
add_example_executable
(
example_gemm_xdl_streamk gemm_xdl_streamk.cpp
)
add_example_executable
(
example_gemm_xdl_streamk gemm_xdl_streamk.cpp
)
if
(
DTYPES MATCHES
"fp8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_xdl_fp8 gemm_xdl_fp8.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx940"
OR GPU_TARGETS MATCHES
"gfx941"
OR GPU_TARGETS MATCHES
"gfx942"
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8
)
add_example_executable
(
example_gemm_xdl_f8 gemm_xdl_f8.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_f8
)
endif
()
endif
()
if
((
DTYPES MATCHES
"fp8"
AND DTYPES MATCHES
"fp16"
)
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_f8
)
endif
()
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
example/01_gemm/gemm_xdl_fp16_f8.cpp
→
example/01_gemm/gemm_xdl_fp16_f
p
8.cpp
View file @
9f8ab221
File moved
example/01_gemm/gemm_xdl_f8.cpp
→
example/01_gemm/gemm_xdl_f
p
8.cpp
View file @
9f8ab221
...
@@ -7,9 +7,9 @@
...
@@ -7,9 +7,9 @@
using
ADataType
=
ck
::
f8_t
;
using
ADataType
=
ck
::
f8_t
;
using
BDataType
=
ck
::
f8_t
;
using
BDataType
=
ck
::
f8_t
;
using
CDataType
=
ck
::
f
8
_t
;
using
CDataType
=
ck
::
hal
f_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
ck
::
f8_
t
;
using
CShuffleDataType
=
floa
t
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
BLayout
=
Col
;
...
@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
...
@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
16
>
;
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
example/01_gemm/gemm_xdl_fp8_bf8.cpp
0 → 100644
View file @
9f8ab221
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
using
ADataType
=
ck
::
f8_t
;
using
BDataType
=
ck
::
bf8_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
LoopSched
=
ck
::
make_default_loop_scheduler
();
static
constexpr
auto
PipelineVer
=
ck
::
PipelineVersion
::
v1
;
using
ComputeTypeA
=
ck
::
f8_t
;
using
ComputeTypeB
=
ck
::
bf8_t
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
LoopSched
,
PipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
ComputeTypeA
,
ComputeTypeB
>
;
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/02_gemm_bilinear/CMakeLists.txt
View file @
9f8ab221
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list1 gfx1100 gfx1101 gfx1102
)
list
(
APPEND gpu_list1 gfx1100 gfx1101 gfx1102
)
list
(
APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
...
@@ -19,4 +18,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -19,4 +18,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
endif
()
example/03_gemm_bias_relu/CMakeLists.txt
View file @
9f8ab221
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
@@ -7,4 +6,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -7,4 +6,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
endif
()
example/04_gemm_add_add_fastgelu/CMakeLists.txt
View file @
9f8ab221
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_gemm_add_add_fastgelu_xdl
)
add_custom_target
(
example_gemm_add_add_fastgelu_xdl
)
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8
)
set
(
target 1
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
endforeach
()
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp
)
add_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8
)
endif
()
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
example/09_convnd_fwd/CMakeLists.txt
View file @
9f8ab221
...
@@ -2,34 +2,16 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
...
@@ -2,34 +2,16 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp
)
endif
()
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
if
(
DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
add_example_executable_no_testing
(
example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp
)
add_example_executable_no_testing
(
example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp
)
endif
()
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
if
(
DL_KERNELS
)
add_example_executable
(
example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp
)
add_example_executable
(
example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp
)
endif
()
endif
()
example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt
View file @
9f8ab221
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_convnd_fwd_reduce_xdl
)
add_custom_target
(
example_convnd_fwd_reduce_xdl
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp
)
add_example_executable
(
example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8
)
add_example_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp
)
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp
)
add_example_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16
)
endif
()
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16
)
add_example_executable_no_testing
(
example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16
)
add_example_executable
(
example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp
)
endif
()
add_example_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32
)
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32
)
add_example_executable
(
example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp
)
endif
()
add_example_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4
)
if
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp
)
set
(
target 1
)
add_dependencies
(
example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4
)
endif
()
endif
(
USE_BITINT_EXTENSION_INT4
)
endforeach
()
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
example/12_reduce/README.md
View file @
9f8ab221
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
## Run ```example_reduce_blockwise```
## Run ```example_reduce_blockwise```
```
bash
```
bash
# -D <xxx> : input 3
d
/4
d
/5
d
tensor lengths
# -D <xxx> : input 3
D
/4
D
/5
D
tensor lengths
# -R <xxx> : reduce dimension ids
# -R <xxx> : reduce dimension ids
# -v <x> : verification (0=no, 1=yes)
# -v <x> : verification (0=no, 1=yes)
#arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4)
#arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4)
...
@@ -22,7 +22,7 @@ Perf: 0.238063 ms, 264.285 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr
...
@@ -22,7 +22,7 @@ Perf: 0.238063 ms, 264.285 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr
## Run ```example_reduce_multiblock_atomic_add```
## Run ```example_reduce_multiblock_atomic_add```
```
bash
```
bash
# -D <xxx> : input 3
d
/4
d
/5
d
tensor lengths
# -D <xxx> : input 3
D
/4
D
/5
D
tensor lengths
# -R <xxx> : reduce dimension ids
# -R <xxx> : reduce dimension ids
# -v <x> : verification (0=no, 1=yes)
# -v <x> : verification (0=no, 1=yes)
#arg1: data type (0: fp32, 1: fp64)
#arg1: data type (0: fp32, 1: fp64)
...
...
example/13_pool2d_fwd/CMakeLists.txt
View file @
9f8ab221
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp
)
add_example_executable
(
example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp
)
add_example_executable
(
example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp
)
endif
()
example/14_gemm_quantization/CMakeLists.txt
View file @
9f8ab221
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
# dlops
# dlops
if
(
DL_KERNELS
)
add_example_executable
(
example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp
)
add_example_executable
(
example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp
)
endif
()
# xdlops
# xdlops
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
...
@@ -14,4 +10,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -14,4 +10,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
endif
()
\ No newline at end of file
example/15_grouped_gemm/CMakeLists.txt
View file @
9f8ab221
add_custom_target
(
example_grouped_gemm_xdl
)
add_custom_target
(
example_grouped_gemm_xdl
)
add_example_executable
(
example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32
)
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32
)
endif
()
add_example_executable
(
example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16
)
add_example_executable
(
example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp
)
add_dependencies
(
example_grouped_gemm_xdl
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16
)
example_grouped_gemm_xdl_fp16
example_grouped_gemm_multiple_d_dl_fp16
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
)
example_grouped_gemm_xdl_splitk_fp16
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16
)
example_grouped_gemm_xdl_fixed_nk_fp16
example_grouped_gemm_xdl_fixed_nk_bias_fp16
)
add_example_executable
(
example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp
)
endif
()
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16
)
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_bfp16
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int8
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int8
)
endif
()
if
(
DTYPES MATCHES
"f8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int4
)
add_
example_
dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int4
)
endif
()
endif
()
example/15_grouped_gemm/grouped_gemm_xdl_bf
p
16.cpp
→
example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp
View file @
9f8ab221
File moved
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
View file @
9f8ab221
...
@@ -296,13 +296,11 @@ int main(int argc, char* argv[])
...
@@ -296,13 +296,11 @@ int main(int argc, char* argv[])
problem_size
.
group_count
=
16
;
problem_size
.
group_count
=
16
;
problem_size
.
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
{
problem_size
.
Ns
.
push_back
(
768
);
problem_size
.
Ms
.
push_back
(
256
+
256
*
i
);
problem_size
.
Ks
.
push_back
(
4608
);
problem_size
.
Ns
.
push_back
(
128
+
128
*
i
);
problem_size
.
Ks
.
push_back
(
128
+
64
*
i
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp
View file @
9f8ab221
...
@@ -297,13 +297,11 @@ int main(int argc, char* argv[])
...
@@ -297,13 +297,11 @@ int main(int argc, char* argv[])
problem_size
.
group_count
=
16
;
problem_size
.
group_count
=
16
;
problem_size
.
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
{
problem_size
.
Ns
.
push_back
(
768
);
problem_size
.
Ms
.
push_back
(
256
+
256
*
i
);
problem_size
.
Ks
.
push_back
(
4608
);
problem_size
.
Ns
.
push_back
(
128
+
128
*
i
);
problem_size
.
Ks
.
push_back
(
128
+
64
*
i
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
...
...
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
View file @
9f8ab221
...
@@ -66,13 +66,11 @@ int main(int argc, char* argv[])
...
@@ -66,13 +66,11 @@ int main(int argc, char* argv[])
problem_size
.
group_count
=
16
;
problem_size
.
group_count
=
16
;
problem_size
.
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
{
problem_size
.
Ns
.
push_back
(
768
);
problem_size
.
Ms
.
push_back
(
256
+
256
*
i
);
problem_size
.
Ks
.
push_back
(
4608
);
problem_size
.
Ns
.
push_back
(
128
+
128
*
i
);
problem_size
.
Ks
.
push_back
(
128
+
64
*
i
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
...
...
Prev
1
2
3
4
5
6
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment