Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ccaea50e
Commit
ccaea50e
authored
Mar 08, 2024
by
Jing Zhang
Browse files
merge navi31_rel
parents
0b914465
10127959
Changes
126
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
793 additions
and
13 deletions
+793
-13
client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp
client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp
+2
-2
client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
.../21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
+2
-2
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_bf16.cpp
...nt_example/22_grouped_gemm/grouped_gemm_fixed_nk_bf16.cpp
+1
-1
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp
...nt_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp
+2
-2
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp
+2
-2
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp
+2
-2
client_example/22_im2col_col2im/image_to_column.cpp
client_example/22_im2col_col2im/image_to_column.cpp
+1
-1
client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp
...ple/23_elementwise_transpose/elementwise_transpose_3d.cpp
+1
-1
cmake/Embed.cmake
cmake/Embed.cmake
+238
-0
codegen/CMakeLists.txt
codegen/CMakeLists.txt
+49
-0
codegen/driver/main.cpp
codegen/driver/main.cpp
+71
-0
codegen/include/ck/host/device_gemm_multiple_d.hpp
codegen/include/ck/host/device_gemm_multiple_d.hpp
+42
-0
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
+42
-0
codegen/include/ck/host/device_gemm_multiple_d/problem.hpp
codegen/include/ck/host/device_gemm_multiple_d/problem.hpp
+39
-0
codegen/include/ck/host/headers.hpp
codegen/include/ck/host/headers.hpp
+18
-0
codegen/include/ck/host/operation/gemm.hpp
codegen/include/ck/host/operation/gemm.hpp
+49
-0
codegen/include/ck/host/stringutils.hpp
codegen/include/ck/host/stringutils.hpp
+104
-0
codegen/include/ck/host/types.hpp
codegen/include/ck/host/types.hpp
+78
-0
codegen/include/ck/host/utils.hpp
codegen/include/ck/host/utils.hpp
+17
-0
codegen/src/device_gemm_multiple_d.cpp
codegen/src/device_gemm_multiple_d.cpp
+33
-0
No files found.
client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
...
...
@@ -88,7 +88,7 @@ int main(int argc, char* argv[])
[](
std
::
size_t
nRow
,
std
::
size_t
nCol
,
std
::
size_t
stride
,
auto
layout
)
{
using
Layout
=
decltype
(
layout
);
if
constexpr
(
std
::
is_same
<
Layout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
if
constexpr
(
std
::
is_same
<
Layout
,
Row
>::
value
)
{
return
(
nRow
-
1
)
*
stride
+
nCol
;
}
...
...
client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
...
...
@@ -79,7 +79,7 @@ int main()
[](
std
::
size_t
nRow
,
std
::
size_t
nCol
,
std
::
size_t
stride
,
auto
layout
)
{
using
Layout
=
decltype
(
layout
);
if
constexpr
(
std
::
is_same
<
Layout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
if
constexpr
(
std
::
is_same
<
Layout
,
Row
>::
value
)
{
return
(
nRow
-
1
)
*
stride
+
nCol
;
}
...
...
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_bf16.cpp
View file @
ccaea50e
...
...
@@ -77,7 +77,7 @@ int main()
[](
std
::
size_t
nRow
,
std
::
size_t
nCol
,
std
::
size_t
stride
,
auto
layout
)
{
using
Layout
=
decltype
(
layout
);
if
constexpr
(
std
::
is_same
<
Layout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
if
constexpr
(
std
::
is_same
<
Layout
,
Row
>::
value
)
{
return
(
nRow
-
1
)
*
stride
+
nCol
;
}
...
...
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
...
...
@@ -76,7 +76,7 @@ int main()
[](
std
::
size_t
nRow
,
std
::
size_t
nCol
,
std
::
size_t
stride
,
auto
layout
)
{
using
Layout
=
decltype
(
layout
);
if
constexpr
(
std
::
is_same
<
Layout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
if
constexpr
(
std
::
is_same
<
Layout
,
Row
>::
value
)
{
return
(
nRow
-
1
)
*
stride
+
nCol
;
}
...
...
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
...
...
@@ -77,7 +77,7 @@ int main()
[](
std
::
size_t
nRow
,
std
::
size_t
nCol
,
std
::
size_t
stride
,
auto
layout
)
{
using
Layout
=
decltype
(
layout
);
if
constexpr
(
std
::
is_same
<
Layout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
if
constexpr
(
std
::
is_same
<
Layout
,
Row
>::
value
)
{
return
(
nRow
-
1
)
*
stride
+
nCol
;
}
...
...
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
...
...
@@ -77,7 +77,7 @@ int main()
[](
std
::
size_t
nRow
,
std
::
size_t
nCol
,
std
::
size_t
stride
,
auto
layout
)
{
using
Layout
=
decltype
(
layout
);
if
constexpr
(
std
::
is_same
<
Layout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
if
constexpr
(
std
::
is_same
<
Layout
,
Row
>::
value
)
{
return
(
nRow
-
1
)
*
stride
+
nCol
;
}
...
...
client_example/22_im2col_col2im/image_to_column.cpp
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
...
...
client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
...
...
cmake/Embed.cmake
0 → 100644
View file @
ccaea50e
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
if
(
WIN32
)
set
(
EMBED_USE RC CACHE STRING
"Use RC or CArrays to embed data files"
)
set_property
(
CACHE EMBED_USE PROPERTY STRINGS
"RC;CArrays"
)
else
()
if
(
BUILD_SHARED_LIBS
)
set
(
EMBED_USE LD CACHE STRING
"Use LD or CArrays to embed data files"
)
else
()
set
(
EMBED_USE CArrays CACHE STRING
"Use LD or CArrays to embed data files"
)
endif
()
set_property
(
CACHE EMBED_USE PROPERTY STRINGS
"LD;CArrays"
)
endif
()
if
(
EMBED_USE STREQUAL
"LD"
)
find_program
(
EMBED_LD ld REQUIRED
)
find_program
(
EMBED_OBJCOPY objcopy REQUIRED
)
endif
()
function
(
embed_wrap_string
)
set
(
options
)
set
(
oneValueArgs VARIABLE AT_COLUMN
)
set
(
multiValueArgs
)
cmake_parse_arguments
(
PARSE
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
string
(
LENGTH
${${
PARSE_VARIABLE
}}
string_length
)
math
(
EXPR offset
"0"
)
while
(
string_length GREATER 0
)
if
(
string_length GREATER
${
PARSE_AT_COLUMN
}
)
math
(
EXPR length
"
${
PARSE_AT_COLUMN
}
"
)
else
()
math
(
EXPR length
"
${
string_length
}
"
)
endif
()
string
(
SUBSTRING
${${
PARSE_VARIABLE
}}
${
offset
}
${
length
}
line
)
set
(
lines
"
${
lines
}
\n
${
line
}
"
)
math
(
EXPR string_length
"
${
string_length
}
-
${
length
}
"
)
math
(
EXPR offset
"
${
offset
}
+
${
length
}
"
)
endwhile
()
set
(
${
PARSE_VARIABLE
}
"
${
lines
}
"
PARENT_SCOPE
)
endfunction
()
function
(
generate_embed_source EMBED_NAME EMBED_DIR BASE_DIRECTORY
)
set
(
options
)
set
(
oneValueArgs
)
set
(
multiValueArgs SYMBOLS FILES
)
cmake_parse_arguments
(
PARSE
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
set
(
RESOURCE_ID 100
)
list
(
LENGTH PARSE_SYMBOLS SYMBOLS_LEN
)
list
(
LENGTH PARSE_FILES FILES_LEN
)
if
(
NOT
${
SYMBOLS_LEN
}
EQUAL
${
FILES_LEN
}
)
message
(
FATAL_ERROR
"Symbols and objects dont match:
${
SYMBOLS_LEN
}
!=
${
FILES_LEN
}
"
)
endif
()
math
(
EXPR LEN
"
${
SYMBOLS_LEN
}
- 1"
)
foreach
(
idx RANGE
${
LEN
}
)
list
(
GET PARSE_SYMBOLS
${
idx
}
SYMBOL
)
list
(
GET PARSE_FILES
${
idx
}
FILE
)
file
(
RELATIVE_PATH BASE_NAME
"
${
BASE_DIRECTORY
}
"
${
FILE
}
)
if
(
EMBED_USE STREQUAL
"RC"
)
string
(
TOUPPER
"
${
SYMBOL
}
"
SYMBOL
)
string
(
APPEND FILE_IDS
"#define IDR_
${
SYMBOL
}
${
RESOURCE_ID
}
\n
"
)
file
(
TO_NATIVE_PATH
"
${
FILE
}
"
NATIVE_FILE
)
string
(
REPLACE
"
\\
"
"
\\\\
"
NATIVE_FILE
"
${
NATIVE_FILE
}
"
)
string
(
APPEND RC_FILE_MAPPING
"IDR_
${
SYMBOL
}
TEXTFILE
\"
${
NATIVE_FILE
}
\"\n
"
)
string
(
APPEND INIT_KERNELS
"
\n
{
\"
${
BASE_NAME
}
\"
, resource::read(IDR_
${
SYMBOL
}
)},"
)
math
(
EXPR RESOURCE_ID
"
${
RESOURCE_ID
}
+ 1"
OUTPUT_FORMAT DECIMAL
)
else
()
set
(
START_SYMBOL
"_binary_
${
SYMBOL
}
_start"
)
set
(
LENGTH_SYMBOL
"_binary_
${
SYMBOL
}
_length"
)
if
(
EMBED_USE STREQUAL
"LD"
)
string
(
APPEND EXTERNS
"
extern const char
${
START_SYMBOL
}
[];
extern const size_t _binary_
${
SYMBOL
}
_size;
const auto
${
LENGTH_SYMBOL
}
= reinterpret_cast<size_t>(&_binary_
${
SYMBOL
}
_size);
"
)
else
()
string
(
APPEND EXTERNS
"
extern const char
${
START_SYMBOL
}
[];
extern const size_t
${
LENGTH_SYMBOL
}
;
"
)
endif
()
string
(
APPEND INIT_KERNELS
"
{
\"
${
BASE_NAME
}
\"
, {
${
START_SYMBOL
}
,
${
LENGTH_SYMBOL
}
} },"
)
endif
()
endforeach
()
if
(
EMBED_USE STREQUAL
"RC"
)
file
(
WRITE
"
${
EMBED_DIR
}
/include/resource.h"
"
#define TEXTFILE 256
${
FILE_IDS
}
"
)
file
(
WRITE
"
${
EMBED_DIR
}
/resource.rc"
"
#include
\"
resource.h
\"
${
RC_FILE_MAPPING
}
"
)
set
(
EXTERNS
"
#include <Windows.h>
#include
\"
resource.h
\"
namespace resource {
std::string_view read(int id)
{
HMODULE handle = GetModuleHandle(nullptr);
HRSRC rc = FindResource(handle, MAKEINTRESOURCE(id), MAKEINTRESOURCE(TEXTFILE));
HGLOBAL data = LoadResource(handle, rc);
return {static_cast<const char*>(LockResource(data)), SizeofResource(handle, rc)};
}
}
"
)
set
(
EMBED_FILES
${
EMBED_DIR
}
/include/resource.h
${
EMBED_DIR
}
/resource.rc
)
endif
()
file
(
WRITE
"
${
EMBED_DIR
}
/include/
${
EMBED_NAME
}
.hpp"
"
#include <string_view>
#include <unordered_map>
#include <utility>
std::unordered_map<std::string_view, std::string_view>
${
EMBED_NAME
}
();
"
)
file
(
WRITE
"
${
EMBED_DIR
}
/
${
EMBED_NAME
}
.cpp"
"
#include <
${
EMBED_NAME
}
.hpp>
${
EXTERNS
}
std::unordered_map<std::string_view, std::string_view>
${
EMBED_NAME
}
()
{
static std::unordered_map<std::string_view, std::string_view> result = {
${
INIT_KERNELS
}
};
return result;
}
"
)
list
(
APPEND EMBED_FILES
${
EMBED_DIR
}
/
${
EMBED_NAME
}
.cpp
${
EMBED_DIR
}
/include/
${
EMBED_NAME
}
.hpp
)
set
(
EMBED_FILES
${
EMBED_FILES
}
PARENT_SCOPE
)
endfunction
()
function
(
embed_file FILE BASE_DIRECTORY
)
message
(
STATUS
"
${
FILE
}
"
)
file
(
RELATIVE_PATH REL_FILE
"
${
BASE_DIRECTORY
}
"
${
FILE
}
)
string
(
MAKE_C_IDENTIFIER
"
${
REL_FILE
}
"
OUTPUT_SYMBOL
)
get_filename_component
(
OUTPUT_FILE_DIR
"
${
REL_FILE
}
"
DIRECTORY
)
file
(
MAKE_DIRECTORY
"
${
CMAKE_CURRENT_BINARY_DIR
}
/
${
OUTPUT_FILE_DIR
}
"
)
if
(
EMBED_USE STREQUAL
"LD"
)
set
(
OUTPUT_FILE
"
${
CMAKE_CURRENT_BINARY_DIR
}
/
${
REL_FILE
}
.o"
)
add_custom_command
(
OUTPUT
"
${
OUTPUT_FILE
}
"
COMMAND
${
EMBED_LD
}
-r -o
"
${
OUTPUT_FILE
}
"
-z noexecstack --format=binary
"
${
REL_FILE
}
"
COMMAND
${
EMBED_OBJCOPY
}
--rename-section .data=.rodata,alloc,load,readonly,data,contents
"
${
OUTPUT_FILE
}
"
WORKING_DIRECTORY
"
${
BASE_DIRECTORY
}
"
DEPENDS
"
${
FILE
}
"
VERBATIM
)
set
(
OUTPUT_FILE
${
OUTPUT_FILE
}
PARENT_SCOPE
)
elseif
(
EMBED_USE STREQUAL
"CArrays"
)
set_property
(
DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS
${
FILE
}
)
set
(
OUTPUT_FILE
"
${
CMAKE_CURRENT_BINARY_DIR
}
/
${
REL_FILE
}
.cpp"
)
# reads source file contents as hex string
file
(
READ
${
FILE
}
HEX_STRING HEX
)
# wraps the hex string into multiple lines
embed_wrap_string
(
VARIABLE HEX_STRING AT_COLUMN 80
)
# adds '0x' prefix and comma suffix before and after every byte respectively
string
(
REGEX REPLACE
"([0-9a-f][0-9a-f])"
"0x
\\
1, "
ARRAY_VALUES
${
HEX_STRING
}
)
# removes trailing comma
string
(
REGEX REPLACE
", $"
""
ARRAY_VALUES
${
ARRAY_VALUES
}
)
file
(
WRITE
"
${
OUTPUT_FILE
}
"
"
#include <cstddef>
extern const char _binary_
${
OUTPUT_SYMBOL
}
_start[] = {
${
ARRAY_VALUES
}
};
extern const size_t _binary_
${
OUTPUT_SYMBOL
}
_length = sizeof(_binary_
${
OUTPUT_SYMBOL
}
_start);
"
)
set
(
OUTPUT_FILE
${
OUTPUT_FILE
}
PARENT_SCOPE
)
endif
()
set
(
OUTPUT_SYMBOL
${
OUTPUT_SYMBOL
}
PARENT_SCOPE
)
endfunction
()
function
(
add_embed_library EMBED_NAME
)
set
(
options
)
set
(
oneValueArgs RELATIVE
)
set
(
multiValueArgs
)
cmake_parse_arguments
(
PARSE
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
set
(
EMBED_DIR
${
CMAKE_CURRENT_BINARY_DIR
}
/embed/
${
EMBED_NAME
}
)
file
(
MAKE_DIRECTORY
${
EMBED_DIR
}
)
message
(
STATUS
"Embedding kernel files:"
)
foreach
(
FILE
${
PARSE_UNPARSED_ARGUMENTS
}
)
embed_file
(
${
FILE
}
${
PARSE_RELATIVE
}
)
list
(
APPEND OUTPUT_FILES
${
OUTPUT_FILE
}
)
list
(
APPEND SYMBOLS
${
OUTPUT_SYMBOL
}
)
endforeach
()
message
(
STATUS
"Generating embedding library '
${
EMBED_NAME
}
'"
)
generate_embed_source
(
${
EMBED_NAME
}
${
EMBED_DIR
}
"
${
PARSE_RELATIVE
}
"
SYMBOLS
${
SYMBOLS
}
FILES
${
PARSE_UNPARSED_ARGUMENTS
}
)
set
(
INTERNAL_EMBED_LIB embed_lib_
${
EMBED_NAME
}
)
if
(
EMBED_USE STREQUAL
"LD"
)
add_library
(
${
INTERNAL_EMBED_LIB
}
STATIC
${
EMBED_FILES
}
${
OUTPUT_FILES
}
)
else
()
add_library
(
${
INTERNAL_EMBED_LIB
}
OBJECT
${
EMBED_FILES
}
)
endif
()
if
(
EMBED_USE STREQUAL
"CArrays"
)
target_sources
(
${
INTERNAL_EMBED_LIB
}
PRIVATE
${
OUTPUT_FILES
}
)
endif
()
target_include_directories
(
${
INTERNAL_EMBED_LIB
}
PRIVATE
"
${
EMBED_DIR
}
/include"
)
target_compile_options
(
${
INTERNAL_EMBED_LIB
}
PRIVATE -Wno-reserved-identifier -Wno-extern-initializer -Wno-missing-variable-declarations
)
set_target_properties
(
${
INTERNAL_EMBED_LIB
}
PROPERTIES POSITION_INDEPENDENT_CODE On
)
add_library
(
${
EMBED_NAME
}
INTERFACE
)
if
(
EMBED_USE STREQUAL
"RC"
)
target_link_libraries
(
${
EMBED_NAME
}
INTERFACE $<TARGET_OBJECTS:
${
INTERNAL_EMBED_LIB
}
>
)
elseif
(
EMBED_USE STREQUAL
"LD"
)
target_link_libraries
(
${
EMBED_NAME
}
INTERFACE
${
INTERNAL_EMBED_LIB
}
)
else
()
target_sources
(
${
EMBED_NAME
}
INTERFACE $<TARGET_OBJECTS:
${
INTERNAL_EMBED_LIB
}
>
)
endif
()
target_include_directories
(
${
EMBED_NAME
}
INTERFACE
"
${
EMBED_DIR
}
/include"
)
endfunction
()
codegen/CMakeLists.txt
0 → 100644
View file @
ccaea50e
cmake_minimum_required
(
VERSION 3.16
)
project
(
composable_kernel_host
)
set
(
CMAKE_EXPORT_COMPILE_COMMANDS ON
)
set
(
CMAKE_LIBRARY_OUTPUT_DIRECTORY
${
CMAKE_BINARY_DIR
}
/lib
)
set
(
CMAKE_ARCHIVE_OUTPUT_DIRECTORY
${
CMAKE_BINARY_DIR
}
/lib
)
set
(
CMAKE_RUNTIME_OUTPUT_DIRECTORY
${
CMAKE_BINARY_DIR
}
/bin
)
set
(
CK_ROOT
${
CMAKE_CURRENT_SOURCE_DIR
}
/..
)
find_package
(
ROCM
)
include
(
ROCMInstallTargets
)
include
(
ROCMTest
)
list
(
APPEND CMAKE_MODULE_PATH
${
CK_ROOT
}
/cmake
)
include
(
Embed
)
file
(
GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
${
CK_ROOT
}
/include/ck/*.hpp
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
message
(
STATUS
"RELATIVE:
${
CK_ROOT
}
/include"
)
add_embed_library
(
ck_headers
${
KERNEL_FILES
}
RELATIVE
${
CK_ROOT
}
/include
)
add_definitions
(
-std=c++17
)
file
(
GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp
)
# TODO: Use object library
add_library
(
ck_host STATIC
${
SOURCES
}
)
target_link_libraries
(
ck_host PRIVATE ck_headers
)
set_target_properties
(
ck_host PROPERTIES
LINKER_LANGUAGE CXX
POSITION_INDEPENDENT_CODE ON
)
target_include_directories
(
ck_host PUBLIC
$<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
)
add_executable
(
ck-template-driver driver/main.cpp
)
target_link_libraries
(
ck-template-driver ck_host
)
rocm_install
(
TARGETS ck_host ck_headers
EXPORT ck_hostTargets
)
rocm_install
(
DIRECTORY include/ck DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
)
if
(
BUILD_TESTING
)
add_subdirectory
(
test
)
endif
()
codegen/driver/main.cpp
0 → 100644
View file @
ccaea50e
#include <functional>
#include <iostream>
#include <string>
#include <unordered_map>
#include <vector>
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/stringutils.hpp"
using
ck
::
host
::
Transform
;
struct
Emitters
{
std
::
unordered_map
<
std
::
string
,
std
::
function
<
std
::
vector
<
std
::
string
>
()
>>
m
;
template
<
class
T
>
void
Register
(
const
std
::
string
&
name
)
{
m
[
name
]
=
[]
{
auto
configs
=
T
::
CreateOperations
();
return
Transform
(
configs
,
[](
const
auto
&
ops
)
{
return
ToTuple
(
ops
);
});
};
}
template
<
class
T
>
static
std
::
string
ToTuple
(
const
T
&
ops
)
{
auto
templates
=
Transform
(
ops
,
[](
const
auto
&
op
)
{
return
" "
+
op
.
ToSolution
().
ToTemplateString
();
});
return
"std::tuple<
\n
"
+
ck
::
host
::
JoinStrings
(
templates
,
",
\n
"
)
+
">"
;
}
std
::
string
Emit
(
const
std
::
string
&
name
)
{
return
ck
::
host
::
JoinStrings
(
m
.
at
(
name
)(),
"
\n
"
);
}
std
::
vector
<
std
::
string
>
List
()
const
{
return
Transform
(
m
,
[](
auto
&&
p
)
{
return
p
.
first
;
});
}
};
int
main
(
int
argc
,
const
char
*
argv
[])
{
std
::
string
prog
=
argv
[
0
];
std
::
vector
<
std
::
string
>
args
(
argv
+
1
,
argv
+
argc
);
Emitters
e
;
e
.
Register
<
ck
::
host
::
device_gemm_multiple_d
::
Operation_Xdl_CShuffle
>
(
"DeviceGemmMultipleD_Xdl_CShuffle"
);
if
(
args
.
empty
()
or
std
::
any_of
(
args
.
begin
(),
args
.
end
(),
[](
auto
arg
)
{
return
arg
==
"-h"
or
arg
==
"--help"
;
}))
{
std
::
cout
<<
"USAGE:"
<<
std
::
endl
;
std
::
cout
<<
" "
<<
prog
<<
" [TEMPLATE]"
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"FLAGS:"
<<
std
::
endl
;
std
::
cout
<<
" -h, --help Show help"
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"TEMPLATES:"
<<
std
::
endl
;
for
(
auto
x
:
e
.
List
())
std
::
cout
<<
" "
<<
x
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
return
0
;
}
for
(
auto
name
:
args
)
std
::
cout
<<
e
.
Emit
(
name
)
<<
std
::
endl
;
return
0
;
}
codegen/include/ck/host/device_gemm_multiple_d.hpp
0 → 100644
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/types.hpp"
namespace
ck
{
namespace
host
{
namespace
device_gemm_multiple_d
{
struct
Problem
{
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransE
=
false
;
std
::
vector
<
bool
>
DsTrans
=
{};
DataType
ADataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
EDataType
=
DataType
::
Half
;
std
::
vector
<
DataType
>
DsDataType
=
{};
std
::
string
AElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CDEElementOp
=
"ck::Tuple<>"
;
std
::
string
GetIncludeHeader
()
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
};
}
// namespace device_gemm_multiple_d
}
// namespace host
}
// namespace ck
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
0 → 100644
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp"
namespace
ck
{
namespace
host
{
namespace
device_gemm_multiple_d
{
struct
Operation_Xdl_CShuffle
{
static
std
::
vector
<
std
::
vector
<
Operation_Xdl_CShuffle
>>
CreateOperations
();
static
std
::
vector
<
Operation_Xdl_CShuffle
>
CreateOperations
(
const
Problem
&
prob
);
TensorDesc
A
{};
TensorDesc
B
{};
DataType
acc
=
DataType
::
Float
;
DataType
cs_type
=
DataType
::
Half
;
std
::
vector
<
TensorDesc
>
Ds
=
{};
TensorDesc
E
{};
std
::
string
a_elem_op
=
PassThrough
;
std
::
string
b_elem_op
=
PassThrough
;
std
::
string
cde_elem_op
=
Bilinear
;
std
::
string
gemm_specialization
=
"ck::tensor_operation::device::GemmSpecialization::Default"
;
operation
::
TileDesc
tile_desc
{};
operation
::
BlockTransferDesc
a_block_transfer
{};
operation
::
BlockTransferDesc
b_block_transfer
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
Solution
ToSolution
()
const
;
};
}
// namespace device_gemm_multiple_d
}
// namespace host
}
// namespace ck
codegen/include/ck/host/device_gemm_multiple_d/problem.hpp
0 → 100644
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
namespace
ck
{
namespace
host
{
namespace
device_gemm_multiple_d
{
struct
Problem
{
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransE
=
false
;
std
::
vector
<
bool
>
DsTrans
=
{};
DataType
ADataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
EDataType
=
DataType
::
Half
;
std
::
vector
<
DataType
>
DsDataType
=
{};
std
::
string
AElementOp
=
PassThrough
;
std
::
string
BElementOp
=
PassThrough
;
std
::
string
CDEElementOp
=
PassThrough
;
std
::
string
GetIncludeHeader
()
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
};
}
// namespace device_gemm_multiple_d
}
// namespace host
}
// namespace ck
codegen/include/ck/host/headers.hpp
0 → 100644
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <string_view>
#include <utility>
#include <unordered_map>
#include <vector>
namespace
ck
{
namespace
host
{
std
::
unordered_map
<
std
::
string_view
,
std
::
string_view
>
GetHeaders
();
}
// namespace host
}
// namespace ck
codegen/include/ck/host/operation/gemm.hpp
0 → 100644
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace
ck
{
namespace
host
{
namespace
operation
{
struct
TileDesc
{
int
block_size
=
0
;
int
m_per_block
=
0
;
int
n_per_block
=
0
;
int
k_per_block
=
0
;
int
ak1
=
0
;
int
bk1
=
0
;
int
m_per_XDL
=
0
;
int
n_per_XDL
=
0
;
int
m_Xdl_per_wave
=
0
;
int
n_Xdl_per_wave
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
};
struct
BlockTransferDesc
{
std
::
string
thread_cluster_length
=
""
;
std
::
string
thread_cluster_arrange_order
=
""
;
std
::
string
src_access_order
=
""
;
int
src_vec_dim
=
0
;
int
src_scalar_per_vector
=
0
;
int
dst_scalar_per_vector_k1
=
0
;
int
lds_add_extra_dim
=
0
;
};
struct
CShuffleDesc
{
int
m_Xdl_per_wave_per_shuffle
=
0
;
int
n_Xdl_per_wave_per_shuffle
=
0
;
};
struct
CBlockTransferDesc
{
std
::
string
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
=
""
;
int
scalar_per_vector_n_wave_n_per_Xdl
=
0
;
};
}
// namespace operation
}
// namespace host
}
// namespace ck
codegen/include/ck/host/stringutils.hpp
0 → 100644
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cassert>
#include <numeric>
#include <string>
#include <utility>
#include <unordered_map>
#include <vector>
namespace
ck
{
namespace
host
{
template
<
class
F
>
std
::
string
trim
(
const
std
::
string
&
s
,
F
f
)
{
auto
start
=
std
::
find_if_not
(
s
.
begin
(),
s
.
end
(),
f
);
auto
last
=
std
::
find_if_not
(
s
.
rbegin
(),
std
::
string
::
const_reverse_iterator
(
start
),
f
).
base
();
return
{
start
,
last
};
}
inline
std
::
string
trim
(
const
std
::
string
&
s
)
{
return
trim
(
s
,
[](
unsigned
char
c
)
{
return
std
::
isspace
(
c
);
});
}
template
<
class
Strings
>
inline
std
::
string
JoinStrings
(
Strings
strings
,
const
std
::
string
&
delim
)
{
auto
it
=
strings
.
begin
();
if
(
it
==
strings
.
end
())
return
""
;
auto
nit
=
std
::
next
(
it
);
return
std
::
accumulate
(
nit
,
strings
.
end
(),
*
it
,
[
&
](
std
::
string
x
,
std
::
string
y
)
{
return
std
::
move
(
x
)
+
delim
+
std
::
move
(
y
);
});
}
template
<
class
F
>
inline
std
::
string
InterpolateString
(
const
std
::
string
&
input
,
F
f
,
std
::
string
start
=
"${"
,
std
::
string
end
=
"}"
)
{
std
::
string
result
=
""
;
result
.
reserve
(
input
.
size
());
auto
it
=
input
.
begin
();
while
(
it
!=
input
.
end
())
{
auto
next_start
=
std
::
search
(
it
,
input
.
end
(),
start
.
begin
(),
start
.
end
());
auto
next_end
=
std
::
search
(
next_start
,
input
.
end
(),
end
.
begin
(),
end
.
end
());
result
.
append
(
it
,
next_start
);
if
(
next_start
==
input
.
end
())
break
;
if
(
next_end
==
input
.
end
())
{
throw
std
::
runtime_error
(
"Unbalanced brackets"
);
}
auto
r
=
f
(
next_start
+
start
.
size
(),
next_end
);
result
.
append
(
r
.
begin
(),
r
.
end
());
it
=
next_end
+
end
.
size
();
}
return
result
;
}
inline
std
::
string
InterpolateString
(
const
std
::
string
&
input
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
vars
,
std
::
string
start
=
"${"
,
std
::
string
end
=
"}"
)
{
return
InterpolateString
(
input
,
[
&
](
auto
start_it
,
auto
last_it
)
{
auto
key
=
trim
({
start_it
,
last_it
});
auto
it
=
vars
.
find
(
key
);
if
(
it
==
vars
.
end
())
throw
std
::
runtime_error
(
"Unknown key: "
+
key
);
return
it
->
second
;
},
std
::
move
(
start
),
std
::
move
(
end
));
}
template
<
class
Range
,
class
F
>
inline
auto
Transform
(
const
Range
&
r
,
F
f
)
->
std
::
vector
<
decltype
(
f
(
*
r
.
begin
()))
>
{
std
::
vector
<
decltype
(
f
(
*
r
.
begin
()))
>
result
;
std
::
transform
(
r
.
begin
(),
r
.
end
(),
std
::
back_inserter
(
result
),
f
);
return
result
;
}
template
<
class
Range1
,
class
Range2
,
class
F
>
inline
auto
Transform
(
const
Range1
&
r1
,
const
Range2
&
r2
,
F
f
)
->
std
::
vector
<
decltype
(
f
(
*
r1
.
begin
(),
*
r2
.
begin
()))
>
{
std
::
vector
<
decltype
(
f
(
*
r1
.
begin
(),
*
r2
.
begin
()))
>
result
;
assert
(
std
::
distance
(
r1
.
begin
(),
r1
.
end
())
==
std
::
distance
(
r2
.
begin
(),
r2
.
end
()));
std
::
transform
(
r1
.
begin
(),
r1
.
end
(),
r2
.
begin
(),
std
::
back_inserter
(
result
),
f
);
return
result
;
}
}
// namespace host
}
// namespace ck
codegen/include/ck/host/types.hpp
0 → 100644
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include <utility>
#include <unordered_map>
#include <vector>
namespace
ck
{
namespace
host
{
struct
Solution
{
Solution
()
=
default
;
Solution
(
std
::
string
str
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
values
);
std
::
string
ToTemplateString
()
const
;
std
::
string
GetTemplateParameter
(
const
std
::
string
&
name
)
const
;
template
<
class
T
>
T
GetTemplateParameter
(
const
std
::
string
&
name
)
const
{
T
result
;
std
::
stringstream
ss
(
GetTemplateParameter
(
name
));
ss
>>
result
;
return
result
;
}
private:
std
::
string
template_str
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
template_values
;
};
enum
class
DataType
{
Half
,
Float
,
Int8
,
Int32
};
std
::
string
ToString
(
DataType
dt
);
enum
class
Layout
{
Row
,
Column
};
std
::
string
ToString
(
Layout
dl
);
enum
class
GemmType
{
Default
};
std
::
string
ToString
(
GemmType
gt
);
struct
TensorDesc
{
DataType
element
;
Layout
layout
;
};
std
::
string
SequenceStr
(
const
std
::
vector
<
int
>&
v
);
std
::
string
MakeTuple
(
const
std
::
vector
<
std
::
string
>&
v
);
template
<
int
...
xs
>
const
std
::
string
S
=
SequenceStr
({
xs
...});
constexpr
const
char
*
PassThrough
=
"ck::tensor_operation::element_wise::PassThrough"
;
constexpr
const
char
*
Bilinear
=
"ck::tensor_operation::element_wise::Bilinear"
;
}
// namespace host
}
// namespace ck
codegen/include/ck/host/utils.hpp
0 → 100644
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdint>
#include <unordered_set>
namespace
ck
{
namespace
host
{
std
::
size_t
integer_divide_ceil
(
std
::
size_t
x
,
std
::
size_t
y
);
const
std
::
unordered_set
<
std
::
string
>&
get_xdlop_archs
();
}
// namespace host
}
// namespace ck
codegen/src/device_gemm_multiple_d.cpp
0 → 100644
View file @
ccaea50e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace
ck
{
namespace
host
{
namespace
device_gemm_multiple_d
{
std
::
string
Problem
::
GetIncludeHeader
()
const
{
return
"ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
;
}
std
::
vector
<
Solution
>
Problem
::
GetSolutions
(
const
std
::
string
&
arch
)
const
{
if
(
get_xdlop_archs
().
count
(
arch
)
==
0
)
return
{};
auto
ops
=
ck
::
host
::
device_gemm_multiple_d
::
Operation_Xdl_CShuffle
::
CreateOperations
(
*
this
);
std
::
vector
<
Solution
>
result
;
std
::
transform
(
ops
.
begin
(),
ops
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
op
)
{
return
op
.
ToSolution
();
});
return
result
;
}
}
// namespace device_gemm_multiple_d
}
// namespace host
}
// namespace ck
\ No newline at end of file
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment