Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
14822d71
Commit
14822d71
authored
Aug 31, 2023
by
Jing Zhang
Browse files
merge
parents
5b02dfaf
80560ef2
Changes
190
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1314 additions
and
67 deletions
+1314
-67
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
CMakeLists.txt
CMakeLists.txt
+70
-43
Dockerfile
Dockerfile
+1
-1
Jenkinsfile
Jenkinsfile
+1
-1
client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp
client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp
+6
-0
client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp
client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp
+6
-0
client_example/20_splitk_gemm/CMakeLists.txt
client_example/20_splitk_gemm/CMakeLists.txt
+2
-0
client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp
client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp
+225
-0
client_example/21_grouped_gemm_bias/CMakeLists.txt
client_example/21_grouped_gemm_bias/CMakeLists.txt
+2
-0
client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
.../21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
+2
-2
client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_fp16.cpp
...ample/21_grouped_gemm_bias/grouped_gemm_fixed_nk_fp16.cpp
+0
-0
client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_fp8.cpp
...xample/21_grouped_gemm_bias/grouped_gemm_fixed_nk_fp8.cpp
+238
-0
client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_i8.cpp
...example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_i8.cpp
+238
-0
client_example/22_grouped_gemm/CMakeLists.txt
client_example/22_grouped_gemm/CMakeLists.txt
+0
-3
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_bias_fp16.cpp
...ample/22_grouped_gemm/grouped_gemm_fixed_nk_bias_fp16.cpp
+244
-0
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp
...nt_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp
+238
-0
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp
+0
-0
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp
+0
-0
client_example/CMakeLists.txt
client_example/CMakeLists.txt
+37
-16
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+3
-0
No files found.
.pre-commit-config.yaml
View file @
14822d71
...
@@ -3,7 +3,7 @@ repos:
...
@@ -3,7 +3,7 @@ repos:
hooks
:
hooks
:
-
id
:
clang-format
-
id
:
clang-format
name
:
clang-format
name
:
clang-format
entry
:
clang-format-1
0
-i --style=file
entry
:
clang-format-1
2
-i --style=file
language
:
system
language
:
system
types_or
:
[
c++
,
inc
]
types_or
:
[
c++
,
inc
]
-
id
:
copyright-year-checker
-
id
:
copyright-year-checker
...
...
CMakeLists.txt
View file @
14822d71
cmake_minimum_required
(
VERSION 3.14
)
cmake_minimum_required
(
VERSION 3.14
)
set
(
version 1.1.0
)
# Check support for CUDA/HIP in Cmake
# Check support for CUDA/HIP in Cmake
project
(
composable_kernel
)
project
(
composable_kernel
VERSION
${
version
}
)
list
(
APPEND CMAKE_MODULE_PATH
"
${
PROJECT_SOURCE_DIR
}
/cmake"
)
list
(
APPEND CMAKE_MODULE_PATH
"
${
PROJECT_SOURCE_DIR
}
/cmake"
)
if
(
DTYPES
)
if
(
DTYPES
)
add_definitions
(
-DDTYPES
)
add_definitions
(
-DDTYPES
)
if
(
DTYPES MATCHES
"int8"
)
if
(
DTYPES MATCHES
"int8"
)
add_definitions
(
-D__int8__
)
add_definitions
(
-DCK_ENABLE_INT8
)
endif
()
set
(
CK_ENABLE_INT8
"ON"
)
if
(
DTYPES MATCHES
"fp8"
)
endif
()
add_definitions
(
-D__fp8__
)
if
(
DTYPES MATCHES
"fp8"
)
endif
()
add_definitions
(
-DCK_ENABLE_FP8
)
if
(
DTYPES MATCHES
"fp16"
)
set
(
CK_ENABLE_FP8
"ON"
)
add_definitions
(
-D__fp16__
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp16"
)
if
(
DTYPES MATCHES
"fp32"
)
add_definitions
(
-DCK_ENABLE_FP16
)
add_definitions
(
-D__fp32__
)
set
(
CK_ENABLE_FP16
"ON"
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp64"
)
if
(
DTYPES MATCHES
"fp32"
)
add_definitions
(
-D__fp64__
)
add_definitions
(
-DCK_ENABLE_FP32
)
endif
()
set
(
CK_ENABLE_FP32
"ON"
)
if
(
DTYPES MATCHES
"bf16"
)
endif
()
add_definitions
(
-D__bf16__
)
if
(
DTYPES MATCHES
"fp64"
)
endif
()
add_definitions
(
-DCK_ENABLE_FP64
)
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
set
(
CK_ENABLE_FP64
"ON"
)
endif
()
if
(
DTYPES MATCHES
"bf16"
)
add_definitions
(
-DCK_ENABLE_BF16
)
set
(
CK_ENABLE_BF16
"ON"
)
endif
()
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
else
()
else
()
add_definitions
(
-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__
)
add_definitions
(
-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16
)
set
(
CK_ENABLE_ALL_DTYPES
"ON"
)
endif
()
endif
()
if
(
DL_KERNELS
)
if
(
DL_KERNELS
)
add_definitions
(
-DDL_KERNELS
)
add_definitions
(
-DDL_KERNELS
)
set
(
CK_ENABLE_DL_KERNELS
"ON"
)
endif
()
endif
()
if
(
INSTANCES_ONLY
)
if
(
INSTANCES_ONLY
)
add_definitions
(
-DINSTANCES_ONLY
)
add_definitions
(
-DINSTANCES_ONLY
)
set
(
CK_ENABLE_INSTANCES_ONLY
"ON"
)
endif
()
endif
()
# CK config file to record supported datatypes, etc.
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/include/ck/config.h.in"
"
${
PROJECT_BINARY_DIR
}
/include/ck/config.h"
)
# CK version file to record release version as well as git commit hash
find_package
(
Git REQUIRED
)
execute_process
(
COMMAND
"
${
GIT_EXECUTABLE
}
"
rev-parse HEAD OUTPUT_VARIABLE COMMIT_ID OUTPUT_STRIP_TRAILING_WHITESPACE
)
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/include/ck/version.h.in"
"
${
PROJECT_BINARY_DIR
}
/include/ck/version.h"
)
enable_testing
()
enable_testing
()
set
(
ROCM_SYMLINK_LIBS OFF
)
set
(
ROCM_SYMLINK_LIBS OFF
)
...
@@ -50,8 +68,10 @@ include(ROCMInstallSymlinks)
...
@@ -50,8 +68,10 @@ include(ROCMInstallSymlinks)
include
(
ROCMCreatePackage
)
include
(
ROCMCreatePackage
)
include
(
CheckCXXCompilerFlag
)
include
(
CheckCXXCompilerFlag
)
include
(
ROCMCheckTargetIds
)
include
(
ROCMCheckTargetIds
)
rocm_setup_version
(
VERSION 0.2.0
)
include
(
TargetFlags
)
include
(
TargetFlags
)
rocm_setup_version
(
VERSION
${
version
}
)
list
(
APPEND CMAKE_PREFIX_PATH
${
CMAKE_INSTALL_PREFIX
}
${
CMAKE_INSTALL_PREFIX
}
/llvm
${
CMAKE_INSTALL_PREFIX
}
/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip
)
list
(
APPEND CMAKE_PREFIX_PATH
${
CMAKE_INSTALL_PREFIX
}
${
CMAKE_INSTALL_PREFIX
}
/llvm
${
CMAKE_INSTALL_PREFIX
}
/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip
)
message
(
"GPU_TARGETS=
${
GPU_TARGETS
}
"
)
message
(
"GPU_TARGETS=
${
GPU_TARGETS
}
"
)
...
@@ -315,13 +335,14 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
...
@@ -315,13 +335,14 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set
(
CMAKE_ARCHIVE_OUTPUT_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/lib
)
set
(
CMAKE_ARCHIVE_OUTPUT_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/lib
)
set
(
CMAKE_RUNTIME_OUTPUT_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/bin
)
set
(
CMAKE_RUNTIME_OUTPUT_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/bin
)
# set CK project include directories
include_directories
(
BEFORE
include_directories
(
BEFORE
${
PROJECT_BINARY_DIR
}
/include
${
PROJECT_SOURCE_DIR
}
/include
${
PROJECT_SOURCE_DIR
}
/include
${
PROJECT_SOURCE_DIR
}
/library/include
${
PROJECT_SOURCE_DIR
}
/library/include
${
HIP_INCLUDE_DIRS
}
${
HIP_INCLUDE_DIRS
}
)
)
SET
(
BUILD_DEV ON CACHE BOOL
"BUILD_DEV"
)
SET
(
BUILD_DEV ON CACHE BOOL
"BUILD_DEV"
)
if
(
BUILD_DEV
)
if
(
BUILD_DEV
)
add_compile_options
(
-Werror
)
add_compile_options
(
-Werror
)
...
@@ -341,35 +362,35 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
...
@@ -341,35 +362,35 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
file
(
READ
"
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/
${
subdir_path
}
/CMakeLists.txt"
cmake_instance
)
file
(
READ
"
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/
${
subdir_path
}
/CMakeLists.txt"
cmake_instance
)
set
(
add_inst 0
)
set
(
add_inst 0
)
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp8
\"
"
AND DTYPES MATCHES
"fp8"
)
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp8
\"
"
AND DTYPES MATCHES
"fp8"
)
#message("fp8 instance found!")
#message("fp8 instance found!")
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp16
\"
"
AND DTYPES MATCHES
"fp16"
)
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp16
\"
"
AND DTYPES MATCHES
"fp16"
)
#message("fp16 instance found!")
#message("fp16 instance found!")
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp32
\"
"
AND DTYPES MATCHES
"fp32"
)
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp32
\"
"
AND DTYPES MATCHES
"fp32"
)
#message("fp32 instance found!")
#message("fp32 instance found!")
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp64
\"
"
AND DTYPES MATCHES
"fp64"
)
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp64
\"
"
AND DTYPES MATCHES
"fp64"
)
#message("fp64 instance found!")
#message("fp64 instance found!")
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
bf16
\"
"
AND DTYPES MATCHES
"bf16"
)
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
bf16
\"
"
AND DTYPES MATCHES
"bf16"
)
#message("bf16 instance found!")
#message("bf16 instance found!")
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
int8
\"
"
AND DTYPES MATCHES
"int8"
)
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
int8
\"
"
AND DTYPES MATCHES
"int8"
)
#message("int8 instance found!")
#message("int8 instance found!")
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
(
NOT
"
${
cmake_instance
}
"
MATCHES
"DTYPES"
)
if
(
NOT
"
${
cmake_instance
}
"
MATCHES
"DTYPES"
)
#message("instance should be built for all types!")
#message("instance should be built for all types!")
set
(
add_inst 1
)
set
(
add_inst 1
)
endif
()
endif
()
if
(
add_inst EQUAL 1 OR NOT DEFINED DTYPES
)
if
(
add_inst EQUAL 1 OR NOT DEFINED DTYPES
)
list
(
APPEND CK_DEVICE_INSTANCES device_
${
subdir_path
}
_instance
)
list
(
APPEND CK_DEVICE_INSTANCES device_
${
subdir_path
}
_instance
)
endif
()
endif
()
ENDIF
()
ENDIF
()
ENDFOREACH
()
ENDFOREACH
()
...
@@ -409,7 +430,6 @@ endif()
...
@@ -409,7 +430,6 @@ endif()
#Create an interface target for the include only files and call it "composablekernels"
#Create an interface target for the include only files and call it "composablekernels"
include
(
CMakePackageConfigHelpers
)
include
(
CMakePackageConfigHelpers
)
set
(
version 1.0.0
)
write_basic_package_version_file
(
write_basic_package_version_file
(
"
${
CMAKE_CURRENT_BINARY_DIR
}
/composable_kernelConfigVersion.cmake"
"
${
CMAKE_CURRENT_BINARY_DIR
}
/composable_kernelConfigVersion.cmake"
VERSION
"
${
version
}
"
VERSION
"
${
version
}
"
...
@@ -417,9 +437,9 @@ write_basic_package_version_file(
...
@@ -417,9 +437,9 @@ write_basic_package_version_file(
)
)
configure_package_config_file
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/Config.cmake.in
configure_package_config_file
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/Config.cmake.in
"
${
CMAKE_CURRENT_BINARY_DIR
}
/composable_kernelConfig.cmake"
"
${
CMAKE_CURRENT_BINARY_DIR
}
/composable_kernelConfig.cmake"
INSTALL_DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
INSTALL_DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
NO_CHECK_REQUIRED_COMPONENTS_MACRO
NO_CHECK_REQUIRED_COMPONENTS_MACRO
)
)
rocm_install
(
FILES
rocm_install
(
FILES
...
@@ -428,6 +448,13 @@ rocm_install(FILES
...
@@ -428,6 +448,13 @@ rocm_install(FILES
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
)
)
# Install CK version and configuration files
install
(
FILES
${
PROJECT_BINARY_DIR
}
/include/ck/version.h
${
PROJECT_BINARY_DIR
}
/include/ck/config.h
DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/
)
set
(
CPACK_RESOURCE_FILE_LICENSE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/LICENSE"
)
set
(
CPACK_RESOURCE_FILE_LICENSE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/LICENSE"
)
set
(
CPACK_RPM_PACKAGE_LICENSE
"MIT"
)
set
(
CPACK_RPM_PACKAGE_LICENSE
"MIT"
)
...
...
Dockerfile
View file @
14822d71
...
@@ -63,7 +63,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
...
@@ -63,7 +63,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
nano
\
nano
\
zlib1g-dev
\
zlib1g-dev
\
openssh-server
\
openssh-server
\
clang-format-1
0
\
clang-format-1
2
\
kmod
&&
\
kmod
&&
\
apt-get clean
&&
\
apt-get clean
&&
\
rm
-rf
/var/lib/apt/lists/
*
rm
-rf
/var/lib/apt/lists/
*
...
...
Jenkinsfile
View file @
14822d71
...
@@ -689,7 +689,7 @@ pipeline {
...
@@ -689,7 +689,7 @@ pipeline {
-o -iname \'*.cpp.in\' \
-o -iname \'*.cpp.in\' \
-o -iname \'*.cl\' \
-o -iname \'*.cl\' \
| grep -v 'build/' \
| grep -v 'build/' \
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-1
0
-style=file {} | diff - {}\'"
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-1
2
-style=file {} | diff - {}\'"
}
}
steps
{
steps
{
buildHipClangJobAndReboot
(
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
buildHipClangJobAndReboot
(
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
...
...
client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp
View file @
14822d71
...
@@ -191,6 +191,12 @@ int main(int argc, char* argv[])
...
@@ -191,6 +191,12 @@ int main(int argc, char* argv[])
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
size_t
workspace_sz
=
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
());
SimpleDeviceMem
workspace
(
workspace_sz
);
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
workspace
.
GetDeviceBuffer
());
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
}
...
...
client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp
View file @
14822d71
...
@@ -187,6 +187,12 @@ int main(int argc, char* argv[])
...
@@ -187,6 +187,12 @@ int main(int argc, char* argv[])
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
size_t
workspace_sz
=
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
());
SimpleDeviceMem
workspace
(
workspace_sz
);
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
workspace
.
GetDeviceBuffer
());
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
}
...
...
client_example/20_splitk_gemm/CMakeLists.txt
0 → 100644
View file @
14822d71
add_executable
(
client_splitK_gemm splitK_gemm_fp16_f8.cpp
)
target_link_libraries
(
client_splitK_gemm PRIVATE composable_kernel::device_operations
)
client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp
0 → 100644
View file @
14822d71
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
ADataType
=
F8
;
using
BDataType
=
F16
;
using
CDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
// GEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideC
=
4096
;
ck
::
index_t
KBatch
=
1
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
8
)
{
M
=
std
::
stoi
(
argv
[
1
]);
N
=
std
::
stoi
(
argv
[
2
]);
K
=
std
::
stoi
(
argv
[
3
]);
StrideA
=
std
::
stoi
(
argv
[
4
]);
StrideB
=
std
::
stoi
(
argv
[
5
]);
StrideC
=
std
::
stoi
(
argv
[
6
]);
KBatch
=
std
::
stoi
(
argv
[
7
]);
}
else
{
printf
(
"arg1 to 7: M, N, K, StrideA, StrideB, StrideC, KBatch
\n
"
);
exit
(
0
);
}
auto
f_matrix_space_size
=
[](
std
::
size_t
nRow
,
std
::
size_t
nCol
,
std
::
size_t
stride
,
auto
layout
)
{
using
Layout
=
decltype
(
layout
);
if
constexpr
(
std
::
is_same
<
Layout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
(
nRow
-
1
)
*
stride
+
nCol
;
}
else
{
return
(
nCol
-
1
)
*
stride
+
nRow
;
}
};
SimpleDeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
f_matrix_space_size
(
M
,
K
,
StrideA
,
ALayout
{}));
SimpleDeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
f_matrix_space_size
(
K
,
N
,
StrideB
,
BLayout
{}));
SimpleDeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
f_matrix_space_size
(
M
,
N
,
StrideC
,
CLayout
{}));
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemmSplitK
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
std
::
string
best_op_name
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
c_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
KBatch
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
found
=
true
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
{
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
c_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
KBatch
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
0
;
}
client_example/21_grouped_gemm_bias/CMakeLists.txt
0 → 100644
View file @
14822d71
add_executable
(
client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp
)
target_link_libraries
(
client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations
)
client_example/2
0
_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
→
client_example/2
1
_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
View file @
14822d71
...
@@ -20,7 +20,7 @@ using Row = ck::tensor_layout::gemm::RowMajor;
...
@@ -20,7 +20,7 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
Bias
=
ck
::
tensor_operation
::
element_wise
::
Add
Bias
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
ADataType
=
F16
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
BDataType
=
F16
;
...
@@ -36,7 +36,7 @@ using ELayout = Row;
...
@@ -36,7 +36,7 @@ using ELayout = Row;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
Add
Bias
;
using
CDEElementOp
=
Add
;
struct
SimpleDeviceMem
struct
SimpleDeviceMem
{
{
...
...
client_example/2
0
_grouped_gemm_bias/grouped_gemm_fixed_nk_fp16.cpp
→
client_example/2
1
_grouped_gemm_bias/grouped_gemm_fixed_nk_fp16.cpp
View file @
14822d71
File moved
client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_fp8.cpp
0 → 100644
View file @
14822d71
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <vector>
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp"
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
BDataType
=
F8
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
PassThrough
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
()
{
std
::
vector
<
int
>
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideEs
;
int
sum_of_m
=
0
;
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
int
group_count
=
Ms
.
size
();
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
Ns
.
push_back
(
768
);
Ks
.
push_back
(
4608
);
StrideAs
.
push_back
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
Ks
[
i
]
:
Ms
[
i
]);
StrideBs
.
push_back
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
Ns
[
i
]
:
Ks
[
i
]);
StrideEs
.
push_back
(
std
::
is_same
<
Row
,
ELayout
>::
value
?
Ns
[
i
]
:
Ms
[
i
]);
sum_of_m
+=
Ms
[
i
];
}
auto
f_matrix_space_size
=
[](
std
::
size_t
nRow
,
std
::
size_t
nCol
,
std
::
size_t
stride
,
auto
layout
)
{
using
Layout
=
decltype
(
layout
);
if
constexpr
(
std
::
is_same
<
Layout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
(
nRow
-
1
)
*
stride
+
nCol
;
}
else
{
return
(
nCol
-
1
)
*
stride
+
nRow
;
}
};
std
::
vector
<
SimpleDeviceMem
>
a_dev_bufs
,
b_dev_bufs
,
e_dev_bufs
;
a_dev_bufs
.
reserve
(
group_count
);
b_dev_bufs
.
reserve
(
group_count
);
e_dev_bufs
.
reserve
(
group_count
);
std
::
vector
<
void
*>
p_e
;
p_e
.
reserve
(
group_count
);
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmDesc
>
gemm_descs
;
gemm_descs
.
reserve
(
group_count
);
std
::
vector
<
ck
::
tensor_operation
::
device
::
GroupedGemmKernelArgument
<
1
>>
grouped_gemm_kernel_args_
;
grouped_gemm_kernel_args_
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
a_dev_bufs
.
emplace_back
(
sizeof
(
ADataType
)
*
f_matrix_space_size
(
Ms
[
i
],
Ks
[
i
],
StrideAs
[
i
],
ALayout
{}));
b_dev_bufs
.
emplace_back
(
sizeof
(
BDataType
)
*
f_matrix_space_size
(
Ks
[
i
],
Ns
[
i
],
StrideBs
[
i
],
BLayout
{}));
e_dev_bufs
.
emplace_back
(
sizeof
(
EDataType
)
*
f_matrix_space_size
(
Ms
[
i
],
Ns
[
i
],
StrideEs
[
i
],
ELayout
{}));
gemm_descs
.
push_back
({
sum_of_m
,
Ns
[
i
],
Ks
[
i
],
1
,
StrideBs
[
i
],
1
,
{
0
}});
p_e
.
push_back
(
e_dev_bufs
[
i
].
GetDeviceBuffer
());
grouped_gemm_kernel_args_
.
push_back
({
a_dev_bufs
[
i
].
GetDeviceBuffer
(),
b_dev_bufs
[
i
].
GetDeviceBuffer
(),
{},
e_dev_bufs
[
i
].
GetDeviceBuffer
(),
Ms
[
i
],
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
{},
StrideEs
[
i
]});
}
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmFixedNK
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
cde_element_op
=
CDEElementOp
{};
std
::
string
best_op_name
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
std
::
vector
<
const
void
*>
p_a
=
{},
p_b
=
{};
std
::
vector
<
std
::
array
<
const
void
*
,
0
>>
p_ds
=
{};
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
p_a
,
p_b
,
p_ds
,
p_e
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
SimpleDeviceMem
grouped_gemm_kernel_args_dev
(
op_ptr
->
GetDeviceKernelArgSize
(
argument_ptr
.
get
()));
SimpleDeviceMem
grouped_gemm_workspace_dev
(
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
hipGetErrorString
(
hipMemcpy
(
grouped_gemm_kernel_args_dev
.
GetDeviceBuffer
(),
grouped_gemm_kernel_args_
.
data
(),
op_ptr
->
GetDeviceKernelArgSize
(
argument_ptr
.
get
()),
hipMemcpyHostToDevice
));
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
grouped_gemm_workspace_dev
.
GetDeviceBuffer
());
op_ptr
->
SetDeviceKernelArgs
(
argument_ptr
.
get
(),
grouped_gemm_kernel_args_dev
.
GetDeviceBuffer
());
op_ptr
->
SetKBatch
(
argument_ptr
.
get
(),
16
);
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
j
=
0
;
j
<
gemm_descs
.
size
();
++
j
)
{
flop
+=
std
::
size_t
(
2
)
*
Ms
[
j
]
*
Ns
[
j
]
*
Ks
[
j
];
num_btype
+=
sizeof
(
ADataType
)
*
Ms
[
j
]
*
Ks
[
j
]
+
sizeof
(
BDataType
)
*
Ks
[
j
]
*
Ns
[
j
]
+
sizeof
(
EDataType
)
*
Ms
[
j
]
*
Ns
[
j
];
}
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
found
=
true
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
0
;
}
client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_i8.cpp
0 → 100644
View file @
14822d71
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <vector>
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp"
using
I8
=
int8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
BDataType
=
I8
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Row
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
PassThrough
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
()
{
std
::
vector
<
int
>
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideEs
;
int
sum_of_m
=
0
;
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
int
group_count
=
Ms
.
size
();
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
Ns
.
push_back
(
768
);
Ks
.
push_back
(
4608
);
StrideAs
.
push_back
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
Ks
[
i
]
:
Ms
[
i
]);
StrideBs
.
push_back
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
Ns
[
i
]
:
Ks
[
i
]);
StrideEs
.
push_back
(
std
::
is_same
<
Row
,
ELayout
>::
value
?
Ns
[
i
]
:
Ms
[
i
]);
sum_of_m
+=
Ms
[
i
];
}
auto
f_matrix_space_size
=
[](
std
::
size_t
nRow
,
std
::
size_t
nCol
,
std
::
size_t
stride
,
auto
layout
)
{
using
Layout
=
decltype
(
layout
);
if
constexpr
(
std
::
is_same
<
Layout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
(
nRow
-
1
)
*
stride
+
nCol
;
}
else
{
return
(
nCol
-
1
)
*
stride
+
nRow
;
}
};
std
::
vector
<
SimpleDeviceMem
>
a_dev_bufs
,
b_dev_bufs
,
e_dev_bufs
;
a_dev_bufs
.
reserve
(
group_count
);
b_dev_bufs
.
reserve
(
group_count
);
e_dev_bufs
.
reserve
(
group_count
);
std
::
vector
<
void
*>
p_e
;
p_e
.
reserve
(
group_count
);
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmDesc
>
gemm_descs
;
gemm_descs
.
reserve
(
group_count
);
std
::
vector
<
ck
::
tensor_operation
::
device
::
GroupedGemmKernelArgument
<
1
>>
grouped_gemm_kernel_args_
;
grouped_gemm_kernel_args_
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
a_dev_bufs
.
emplace_back
(
sizeof
(
ADataType
)
*
f_matrix_space_size
(
Ms
[
i
],
Ks
[
i
],
StrideAs
[
i
],
ALayout
{}));
b_dev_bufs
.
emplace_back
(
sizeof
(
BDataType
)
*
f_matrix_space_size
(
Ks
[
i
],
Ns
[
i
],
StrideBs
[
i
],
BLayout
{}));
e_dev_bufs
.
emplace_back
(
sizeof
(
EDataType
)
*
f_matrix_space_size
(
Ms
[
i
],
Ns
[
i
],
StrideEs
[
i
],
ELayout
{}));
gemm_descs
.
push_back
({
sum_of_m
,
Ns
[
i
],
Ks
[
i
],
1
,
StrideBs
[
i
],
1
,
{
0
}});
p_e
.
push_back
(
e_dev_bufs
[
i
].
GetDeviceBuffer
());
grouped_gemm_kernel_args_
.
push_back
({
a_dev_bufs
[
i
].
GetDeviceBuffer
(),
b_dev_bufs
[
i
].
GetDeviceBuffer
(),
{},
e_dev_bufs
[
i
].
GetDeviceBuffer
(),
Ms
[
i
],
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
{},
StrideEs
[
i
]});
}
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmFixedNK
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
cde_element_op
=
CDEElementOp
{};
std
::
string
best_op_name
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
std
::
vector
<
const
void
*>
p_a
=
{},
p_b
=
{};
std
::
vector
<
std
::
array
<
const
void
*
,
0
>>
p_ds
=
{};
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
p_a
,
p_b
,
p_ds
,
p_e
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
SimpleDeviceMem
grouped_gemm_kernel_args_dev
(
op_ptr
->
GetDeviceKernelArgSize
(
argument_ptr
.
get
()));
SimpleDeviceMem
grouped_gemm_workspace_dev
(
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
hipGetErrorString
(
hipMemcpy
(
grouped_gemm_kernel_args_dev
.
GetDeviceBuffer
(),
grouped_gemm_kernel_args_
.
data
(),
op_ptr
->
GetDeviceKernelArgSize
(
argument_ptr
.
get
()),
hipMemcpyHostToDevice
));
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
grouped_gemm_workspace_dev
.
GetDeviceBuffer
());
op_ptr
->
SetDeviceKernelArgs
(
argument_ptr
.
get
(),
grouped_gemm_kernel_args_dev
.
GetDeviceBuffer
());
op_ptr
->
SetKBatch
(
argument_ptr
.
get
(),
32
);
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
j
=
0
;
j
<
gemm_descs
.
size
();
++
j
)
{
flop
+=
std
::
size_t
(
2
)
*
Ms
[
j
]
*
Ns
[
j
]
*
Ks
[
j
];
num_btype
+=
sizeof
(
ADataType
)
*
Ms
[
j
]
*
Ks
[
j
]
+
sizeof
(
BDataType
)
*
Ks
[
j
]
*
Ns
[
j
]
+
sizeof
(
EDataType
)
*
Ms
[
j
]
*
Ns
[
j
];
}
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
found
=
true
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
0
;
}
client_example/2
0
_grouped_gemm
_bias
/CMakeLists.txt
→
client_example/2
2
_grouped_gemm/CMakeLists.txt
View file @
14822d71
add_executable
(
client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp
)
target_link_libraries
(
client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations
)
add_executable
(
client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp
)
add_executable
(
client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp
)
target_link_libraries
(
client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_operations
)
...
...
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_bias_fp16.cpp
0 → 100644
View file @
14822d71
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <vector>
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
D0DataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
>
;
using
EDataType
=
F32
;
using
ALayout
=
Row
;
using
BLayout
=
Row
;
using
D0Layout
=
Row
;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
>
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
Add
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
()
{
std
::
vector
<
int
>
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideEs
;
int
sum_of_m
=
0
;
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
int
group_count
=
Ms
.
size
();
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
Ns
.
push_back
(
768
);
Ks
.
push_back
(
4608
);
StrideAs
.
push_back
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
Ks
[
i
]
:
Ms
[
i
]);
StrideBs
.
push_back
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
Ns
[
i
]
:
Ks
[
i
]);
StrideEs
.
push_back
(
std
::
is_same
<
Row
,
ELayout
>::
value
?
Ns
[
i
]
:
Ms
[
i
]);
sum_of_m
+=
Ms
[
i
];
}
auto
f_matrix_space_size
=
[](
std
::
size_t
nRow
,
std
::
size_t
nCol
,
std
::
size_t
stride
,
auto
layout
)
{
using
Layout
=
decltype
(
layout
);
if
constexpr
(
std
::
is_same
<
Layout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
(
nRow
-
1
)
*
stride
+
nCol
;
}
else
{
return
(
nCol
-
1
)
*
stride
+
nRow
;
}
};
std
::
vector
<
SimpleDeviceMem
>
a_dev_bufs
,
b_dev_bufs
,
d0_dev_bufs
,
e_dev_bufs
;
a_dev_bufs
.
reserve
(
group_count
);
b_dev_bufs
.
reserve
(
group_count
);
d0_dev_bufs
.
reserve
(
group_count
);
e_dev_bufs
.
reserve
(
group_count
);
std
::
vector
<
void
*>
p_e
;
p_e
.
reserve
(
group_count
);
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmDesc
>
gemm_descs
;
gemm_descs
.
reserve
(
group_count
);
std
::
vector
<
ck
::
tensor_operation
::
device
::
GroupedGemmKernelArgument
<
1
>>
grouped_gemm_kernel_args_
;
grouped_gemm_kernel_args_
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
a_dev_bufs
.
emplace_back
(
sizeof
(
ADataType
)
*
f_matrix_space_size
(
Ms
[
i
],
Ks
[
i
],
StrideAs
[
i
],
ALayout
{}));
b_dev_bufs
.
emplace_back
(
sizeof
(
BDataType
)
*
f_matrix_space_size
(
Ks
[
i
],
Ns
[
i
],
StrideBs
[
i
],
BLayout
{}));
d0_dev_bufs
.
emplace_back
(
sizeof
(
D0DataType
)
*
f_matrix_space_size
(
Ms
[
i
],
Ns
[
i
],
0
,
D0Layout
{}));
e_dev_bufs
.
emplace_back
(
sizeof
(
EDataType
)
*
f_matrix_space_size
(
Ms
[
i
],
Ns
[
i
],
StrideEs
[
i
],
ELayout
{}));
gemm_descs
.
push_back
({
sum_of_m
,
Ns
[
i
],
Ks
[
i
],
1
,
StrideBs
[
i
],
1
,
{
0
}});
p_e
.
push_back
(
e_dev_bufs
[
i
].
GetDeviceBuffer
());
grouped_gemm_kernel_args_
.
push_back
(
{
a_dev_bufs
[
i
].
GetDeviceBuffer
(),
b_dev_bufs
[
i
].
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d0_dev_bufs
[
i
].
GetDeviceBuffer
()},
e_dev_bufs
[
i
].
GetDeviceBuffer
(),
Ms
[
i
],
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
std
::
array
<
ck
::
index_t
,
1
>
{
0
},
StrideEs
[
i
]});
}
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmFixedNK
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
cde_element_op
=
CDEElementOp
{};
std
::
string
best_op_name
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
std
::
vector
<
const
void
*>
p_a
=
{},
p_b
=
{};
std
::
vector
<
std
::
array
<
const
void
*
,
1
>>
p_ds
=
{};
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
p_a
,
p_b
,
p_ds
,
p_e
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
SimpleDeviceMem
grouped_gemm_kernel_args_dev
(
op_ptr
->
GetDeviceKernelArgSize
(
argument_ptr
.
get
()));
SimpleDeviceMem
grouped_gemm_workspace_dev
(
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
hipGetErrorString
(
hipMemcpy
(
grouped_gemm_kernel_args_dev
.
GetDeviceBuffer
(),
grouped_gemm_kernel_args_
.
data
(),
op_ptr
->
GetDeviceKernelArgSize
(
argument_ptr
.
get
()),
hipMemcpyHostToDevice
));
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
grouped_gemm_workspace_dev
.
GetDeviceBuffer
());
op_ptr
->
SetDeviceKernelArgs
(
argument_ptr
.
get
(),
grouped_gemm_kernel_args_dev
.
GetDeviceBuffer
());
op_ptr
->
SetKBatch
(
argument_ptr
.
get
(),
2
);
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
j
=
0
;
j
<
gemm_descs
.
size
();
++
j
)
{
flop
+=
std
::
size_t
(
2
)
*
Ms
[
j
]
*
Ns
[
j
]
*
Ks
[
j
];
num_btype
+=
sizeof
(
ADataType
)
*
Ms
[
j
]
*
Ks
[
j
]
+
sizeof
(
BDataType
)
*
Ks
[
j
]
*
Ns
[
j
]
+
sizeof
(
EDataType
)
*
Ms
[
j
]
*
Ns
[
j
];
}
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
found
=
true
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
0
;
}
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp
0 → 100644
View file @
14822d71
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <vector>
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Row
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
PassThrough
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
()
{
std
::
vector
<
int
>
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideEs
;
int
sum_of_m
=
0
;
// Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
Ms
=
{
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
};
int
group_count
=
Ms
.
size
();
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
Ns
.
push_back
(
768
);
Ks
.
push_back
(
4608
);
StrideAs
.
push_back
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
Ks
[
i
]
:
Ms
[
i
]);
StrideBs
.
push_back
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
Ns
[
i
]
:
Ks
[
i
]);
StrideEs
.
push_back
(
std
::
is_same
<
Row
,
ELayout
>::
value
?
Ns
[
i
]
:
Ms
[
i
]);
sum_of_m
+=
Ms
[
i
];
}
auto
f_matrix_space_size
=
[](
std
::
size_t
nRow
,
std
::
size_t
nCol
,
std
::
size_t
stride
,
auto
layout
)
{
using
Layout
=
decltype
(
layout
);
if
constexpr
(
std
::
is_same
<
Layout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
(
nRow
-
1
)
*
stride
+
nCol
;
}
else
{
return
(
nCol
-
1
)
*
stride
+
nRow
;
}
};
std
::
vector
<
SimpleDeviceMem
>
a_dev_bufs
,
b_dev_bufs
,
e_dev_bufs
;
a_dev_bufs
.
reserve
(
group_count
);
b_dev_bufs
.
reserve
(
group_count
);
e_dev_bufs
.
reserve
(
group_count
);
std
::
vector
<
void
*>
p_e
;
p_e
.
reserve
(
group_count
);
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmDesc
>
gemm_descs
;
gemm_descs
.
reserve
(
group_count
);
std
::
vector
<
ck
::
tensor_operation
::
device
::
GroupedGemmKernelArgument
<
1
>>
grouped_gemm_kernel_args_
;
grouped_gemm_kernel_args_
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
a_dev_bufs
.
emplace_back
(
sizeof
(
ADataType
)
*
f_matrix_space_size
(
Ms
[
i
],
Ks
[
i
],
StrideAs
[
i
],
ALayout
{}));
b_dev_bufs
.
emplace_back
(
sizeof
(
BDataType
)
*
f_matrix_space_size
(
Ks
[
i
],
Ns
[
i
],
StrideBs
[
i
],
BLayout
{}));
e_dev_bufs
.
emplace_back
(
sizeof
(
EDataType
)
*
f_matrix_space_size
(
Ms
[
i
],
Ns
[
i
],
StrideEs
[
i
],
ELayout
{}));
gemm_descs
.
push_back
({
sum_of_m
,
Ns
[
i
],
Ks
[
i
],
1
,
StrideBs
[
i
],
1
,
{
0
}});
p_e
.
push_back
(
e_dev_bufs
[
i
].
GetDeviceBuffer
());
grouped_gemm_kernel_args_
.
push_back
({
a_dev_bufs
[
i
].
GetDeviceBuffer
(),
b_dev_bufs
[
i
].
GetDeviceBuffer
(),
{},
e_dev_bufs
[
i
].
GetDeviceBuffer
(),
Ms
[
i
],
Ns
[
i
],
Ks
[
i
],
StrideAs
[
i
],
StrideBs
[
i
],
{},
StrideEs
[
i
]});
}
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmFixedNK
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
cde_element_op
=
CDEElementOp
{};
std
::
string
best_op_name
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
std
::
vector
<
const
void
*>
p_a
=
{},
p_b
=
{};
std
::
vector
<
std
::
array
<
const
void
*
,
0
>>
p_ds
=
{};
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
p_a
,
p_b
,
p_ds
,
p_e
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
SimpleDeviceMem
grouped_gemm_kernel_args_dev
(
op_ptr
->
GetDeviceKernelArgSize
(
argument_ptr
.
get
()));
SimpleDeviceMem
grouped_gemm_workspace_dev
(
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
hipGetErrorString
(
hipMemcpy
(
grouped_gemm_kernel_args_dev
.
GetDeviceBuffer
(),
grouped_gemm_kernel_args_
.
data
(),
op_ptr
->
GetDeviceKernelArgSize
(
argument_ptr
.
get
()),
hipMemcpyHostToDevice
));
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
grouped_gemm_workspace_dev
.
GetDeviceBuffer
());
op_ptr
->
SetDeviceKernelArgs
(
argument_ptr
.
get
(),
grouped_gemm_kernel_args_dev
.
GetDeviceBuffer
());
op_ptr
->
SetKBatch
(
argument_ptr
.
get
(),
32
);
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
j
=
0
;
j
<
gemm_descs
.
size
();
++
j
)
{
flop
+=
std
::
size_t
(
2
)
*
Ms
[
j
]
*
Ns
[
j
]
*
Ks
[
j
];
num_btype
+=
sizeof
(
ADataType
)
*
Ms
[
j
]
*
Ks
[
j
]
+
sizeof
(
BDataType
)
*
Ks
[
j
]
*
Ns
[
j
]
+
sizeof
(
EDataType
)
*
Ms
[
j
]
*
Ns
[
j
];
}
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
found
=
true
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
0
;
}
client_example/2
0
_grouped_gemm
_bias
/grouped_gemm_fixed_nk_fp8.cpp
→
client_example/2
2
_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp
View file @
14822d71
File moved
client_example/2
0
_grouped_gemm
_bias
/grouped_gemm_fixed_nk_i8.cpp
→
client_example/2
2
_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp
View file @
14822d71
File moved
client_example/CMakeLists.txt
View file @
14822d71
...
@@ -3,31 +3,52 @@ project(ck_app)
...
@@ -3,31 +3,52 @@ project(ck_app)
add_compile_options
(
-std=c++17
)
add_compile_options
(
-std=c++17
)
if
(
DTYPES
)
if
(
DTYPES
)
add_definitions
(
-DDTYPES
)
add_definitions
(
-DDTYPES
)
if
(
DTYPES MATCHES
"int8"
)
if
(
DTYPES MATCHES
"int8"
)
add_definitions
(
-D__int8__
)
add_definitions
(
-DCK_ENABLE_INT8
)
if
(
NOT DEFINED
${
CK_ENABLE_INT8
}
)
set
(
CK_ENABLE_INT8
"ON"
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp8"
)
endif
()
add_definitions
(
-D__fp8__
)
if
(
DTYPES MATCHES
"fp8"
)
add_definitions
(
-DCK_ENABLE_FP8
)
if
(
NOT DEFINED
${
CK_ENABLE_FP8
}
)
set
(
CK_ENABLE_FP8
"ON"
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp16"
)
endif
()
add_definitions
(
-D__fp16__
)
if
(
DTYPES MATCHES
"fp16"
)
add_definitions
(
-DCK_ENABLE_FP16
)
if
(
NOT DEFINED
${
CK_ENABLE_FP16
}
)
set
(
CK_ENABLE_FP16
"ON"
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp32"
)
endif
()
add_definitions
(
-D__fp32__
)
if
(
DTYPES MATCHES
"fp32"
)
add_definitions
(
-DCK_ENABLE_FP32
)
if
(
NOT DEFINED
${
CK_ENABLE_FP32
}
)
set
(
CK_ENABLE_FP32
"ON"
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp64"
)
endif
()
add_definitions
(
-D__fp64__
)
if
(
DTYPES MATCHES
"fp64"
)
add_definitions
(
-DCK_ENABLE_FP64
)
if
(
NOT DEFINED
${
CK_ENABLE_FP64
}
)
set
(
CK_ENABLE_FP64
"ON"
)
endif
()
endif
()
if
(
DTYPES MATCHES
"bf16"
)
endif
()
add_definitions
(
-D__bf16__
)
if
(
DTYPES MATCHES
"bf16"
)
add_definitions
(
-DCK_ENABLE_BF16
)
if
(
NOT DEFINED
${
CK_ENABLE_BF16
}
)
set
(
CK_ENABLE_BF16
"ON"
)
endif
()
endif
()
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
endif
()
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
else
()
else
()
add_definitions
(
-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__
)
add_definitions
(
-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16
)
if
(
NOT DEFINED
${
CK_ENABLE_ALL_DTYPES
}
)
set
(
CK_ENABLE_ALL_DTYPES
"ON"
)
endif
()
endif
()
endif
()
find_package
(
composable_kernel
1.0.0
COMPONENTS device_operations
)
find_package
(
composable_kernel COMPONENTS device_operations
)
find_package
(
hip REQUIRED PATHS /opt/rocm
)
find_package
(
hip REQUIRED PATHS /opt/rocm
)
message
(
STATUS
"Build with HIP
${
hip_VERSION
}
"
)
message
(
STATUS
"Build with HIP
${
hip_VERSION
}
"
)
...
...
example/01_gemm/CMakeLists.txt
View file @
14822d71
...
@@ -40,6 +40,9 @@ endif()
...
@@ -40,6 +40,9 @@ endif()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_example_executable
(
example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_rtn
)
endif
()
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
...
...
Prev
1
2
3
4
5
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment