Commit ccaea50e authored by Jing Zhang's avatar Jing Zhang
Browse files

merge navi31_rel

parents 0b914465 10127959
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip> #include <iomanip>
#include <vector> #include <vector>
...@@ -88,7 +88,7 @@ int main(int argc, char* argv[]) ...@@ -88,7 +88,7 @@ int main(int argc, char* argv[])
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout); using Layout = decltype(layout);
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value) if constexpr(std::is_same<Layout, Row>::value)
{ {
return (nRow - 1) * stride + nCol; return (nRow - 1) * stride + nCol;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
...@@ -79,7 +79,7 @@ int main() ...@@ -79,7 +79,7 @@ int main()
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout); using Layout = decltype(layout);
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value) if constexpr(std::is_same<Layout, Row>::value)
{ {
return (nRow - 1) * stride + nCol; return (nRow - 1) * stride + nCol;
} }
......
...@@ -77,7 +77,7 @@ int main() ...@@ -77,7 +77,7 @@ int main()
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout); using Layout = decltype(layout);
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value) if constexpr(std::is_same<Layout, Row>::value)
{ {
return (nRow - 1) * stride + nCol; return (nRow - 1) * stride + nCol;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
...@@ -76,7 +76,7 @@ int main() ...@@ -76,7 +76,7 @@ int main()
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout); using Layout = decltype(layout);
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value) if constexpr(std::is_same<Layout, Row>::value)
{ {
return (nRow - 1) * stride + nCol; return (nRow - 1) * stride + nCol;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
...@@ -77,7 +77,7 @@ int main() ...@@ -77,7 +77,7 @@ int main()
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout); using Layout = decltype(layout);
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value) if constexpr(std::is_same<Layout, Row>::value)
{ {
return (nRow - 1) * stride + nCol; return (nRow - 1) * stride + nCol;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
...@@ -77,7 +77,7 @@ int main() ...@@ -77,7 +77,7 @@ int main()
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout); using Layout = decltype(layout);
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value) if constexpr(std::is_same<Layout, Row>::value)
{ {
return (nRow - 1) * stride + nCol; return (nRow - 1) * stride + nCol;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include <iomanip> #include <iomanip>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip> #include <iomanip>
#include <vector> #include <vector>
......
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
if(WIN32)
set(EMBED_USE RC CACHE STRING "Use RC or CArrays to embed data files")
set_property(CACHE EMBED_USE PROPERTY STRINGS "RC;CArrays")
else()
if(BUILD_SHARED_LIBS)
set(EMBED_USE LD CACHE STRING "Use LD or CArrays to embed data files")
else()
set(EMBED_USE CArrays CACHE STRING "Use LD or CArrays to embed data files")
endif()
set_property(CACHE EMBED_USE PROPERTY STRINGS "LD;CArrays")
endif()
if(EMBED_USE STREQUAL "LD")
find_program(EMBED_LD ld REQUIRED)
find_program(EMBED_OBJCOPY objcopy REQUIRED)
endif()
function(embed_wrap_string)
set(options)
set(oneValueArgs VARIABLE AT_COLUMN)
set(multiValueArgs)
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
string(LENGTH ${${PARSE_VARIABLE}} string_length)
math(EXPR offset "0")
while(string_length GREATER 0)
if(string_length GREATER ${PARSE_AT_COLUMN})
math(EXPR length "${PARSE_AT_COLUMN}")
else()
math(EXPR length "${string_length}")
endif()
string(SUBSTRING ${${PARSE_VARIABLE}} ${offset} ${length} line)
set(lines "${lines}\n${line}")
math(EXPR string_length "${string_length} - ${length}")
math(EXPR offset "${offset} + ${length}")
endwhile()
set(${PARSE_VARIABLE} "${lines}" PARENT_SCOPE)
endfunction()
function(generate_embed_source EMBED_NAME EMBED_DIR BASE_DIRECTORY)
set(options)
set(oneValueArgs)
set(multiValueArgs SYMBOLS FILES)
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(RESOURCE_ID 100)
list(LENGTH PARSE_SYMBOLS SYMBOLS_LEN)
list(LENGTH PARSE_FILES FILES_LEN)
if(NOT ${SYMBOLS_LEN} EQUAL ${FILES_LEN})
message(FATAL_ERROR "Symbols and objects dont match: ${SYMBOLS_LEN} != ${FILES_LEN}")
endif()
math(EXPR LEN "${SYMBOLS_LEN} - 1")
foreach(idx RANGE ${LEN})
list(GET PARSE_SYMBOLS ${idx} SYMBOL)
list(GET PARSE_FILES ${idx} FILE)
file(RELATIVE_PATH BASE_NAME "${BASE_DIRECTORY}" ${FILE})
if(EMBED_USE STREQUAL "RC")
string(TOUPPER "${SYMBOL}" SYMBOL)
string(APPEND FILE_IDS "#define IDR_${SYMBOL} ${RESOURCE_ID}\n")
file(TO_NATIVE_PATH "${FILE}" NATIVE_FILE)
string(REPLACE "\\" "\\\\" NATIVE_FILE "${NATIVE_FILE}")
string(APPEND RC_FILE_MAPPING "IDR_${SYMBOL} TEXTFILE \"${NATIVE_FILE}\"\n")
string(APPEND INIT_KERNELS "\n {\"${BASE_NAME}\", resource::read(IDR_${SYMBOL})},")
math(EXPR RESOURCE_ID "${RESOURCE_ID} + 1" OUTPUT_FORMAT DECIMAL)
else()
set(START_SYMBOL "_binary_${SYMBOL}_start")
set(LENGTH_SYMBOL "_binary_${SYMBOL}_length")
if(EMBED_USE STREQUAL "LD")
string(APPEND EXTERNS "
extern const char ${START_SYMBOL}[];
extern const size_t _binary_${SYMBOL}_size;
const auto ${LENGTH_SYMBOL} = reinterpret_cast<size_t>(&_binary_${SYMBOL}_size);
")
else()
string(APPEND EXTERNS "
extern const char ${START_SYMBOL}[];
extern const size_t ${LENGTH_SYMBOL};
")
endif()
string(APPEND INIT_KERNELS "
{ \"${BASE_NAME}\", { ${START_SYMBOL}, ${LENGTH_SYMBOL}} },")
endif()
endforeach()
if(EMBED_USE STREQUAL "RC")
file(WRITE "${EMBED_DIR}/include/resource.h" "
#define TEXTFILE 256
${FILE_IDS}
")
file(WRITE "${EMBED_DIR}/resource.rc" "
#include \"resource.h\"
${RC_FILE_MAPPING}
")
set(EXTERNS "
#include <Windows.h>
#include \"resource.h\"
namespace resource {
std::string_view read(int id)
{
HMODULE handle = GetModuleHandle(nullptr);
HRSRC rc = FindResource(handle, MAKEINTRESOURCE(id), MAKEINTRESOURCE(TEXTFILE));
HGLOBAL data = LoadResource(handle, rc);
return {static_cast<const char*>(LockResource(data)), SizeofResource(handle, rc)};
}
}
")
set(EMBED_FILES ${EMBED_DIR}/include/resource.h ${EMBED_DIR}/resource.rc)
endif()
file(WRITE "${EMBED_DIR}/include/${EMBED_NAME}.hpp" "
#include <string_view>
#include <unordered_map>
#include <utility>
std::unordered_map<std::string_view, std::string_view> ${EMBED_NAME}();
")
file(WRITE "${EMBED_DIR}/${EMBED_NAME}.cpp" "
#include <${EMBED_NAME}.hpp>
${EXTERNS}
std::unordered_map<std::string_view, std::string_view> ${EMBED_NAME}()
{
static std::unordered_map<std::string_view, std::string_view> result = {${INIT_KERNELS}
};
return result;
}
")
list(APPEND EMBED_FILES ${EMBED_DIR}/${EMBED_NAME}.cpp ${EMBED_DIR}/include/${EMBED_NAME}.hpp)
set(EMBED_FILES ${EMBED_FILES} PARENT_SCOPE)
endfunction()
function(embed_file FILE BASE_DIRECTORY)
message(STATUS " ${FILE}")
file(RELATIVE_PATH REL_FILE "${BASE_DIRECTORY}" ${FILE})
string(MAKE_C_IDENTIFIER "${REL_FILE}" OUTPUT_SYMBOL)
get_filename_component(OUTPUT_FILE_DIR "${REL_FILE}" DIRECTORY)
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE_DIR}")
if(EMBED_USE STREQUAL "LD")
set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.o")
add_custom_command(
OUTPUT "${OUTPUT_FILE}"
COMMAND ${EMBED_LD} -r -o "${OUTPUT_FILE}" -z noexecstack --format=binary "${REL_FILE}"
COMMAND ${EMBED_OBJCOPY} --rename-section .data=.rodata,alloc,load,readonly,data,contents "${OUTPUT_FILE}"
WORKING_DIRECTORY "${BASE_DIRECTORY}"
DEPENDS "${FILE}"
VERBATIM)
set(OUTPUT_FILE ${OUTPUT_FILE} PARENT_SCOPE)
elseif(EMBED_USE STREQUAL "CArrays")
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${FILE})
set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.cpp")
# reads source file contents as hex string
file(READ ${FILE} HEX_STRING HEX)
# wraps the hex string into multiple lines
embed_wrap_string(VARIABLE HEX_STRING AT_COLUMN 80)
# adds '0x' prefix and comma suffix before and after every byte respectively
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1, " ARRAY_VALUES ${HEX_STRING})
# removes trailing comma
string(REGEX REPLACE ", $" "" ARRAY_VALUES ${ARRAY_VALUES})
file(WRITE "${OUTPUT_FILE}" "
#include <cstddef>
extern const char _binary_${OUTPUT_SYMBOL}_start[] = { ${ARRAY_VALUES} };
extern const size_t _binary_${OUTPUT_SYMBOL}_length = sizeof(_binary_${OUTPUT_SYMBOL}_start);
")
set(OUTPUT_FILE ${OUTPUT_FILE} PARENT_SCOPE)
endif()
set(OUTPUT_SYMBOL ${OUTPUT_SYMBOL} PARENT_SCOPE)
endfunction()
function(add_embed_library EMBED_NAME)
set(options)
set(oneValueArgs RELATIVE)
set(multiValueArgs)
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(EMBED_DIR ${CMAKE_CURRENT_BINARY_DIR}/embed/${EMBED_NAME})
file(MAKE_DIRECTORY ${EMBED_DIR})
message(STATUS "Embedding kernel files:")
foreach(FILE ${PARSE_UNPARSED_ARGUMENTS})
embed_file(${FILE} ${PARSE_RELATIVE})
list(APPEND OUTPUT_FILES ${OUTPUT_FILE})
list(APPEND SYMBOLS ${OUTPUT_SYMBOL})
endforeach()
message(STATUS "Generating embedding library '${EMBED_NAME}'")
generate_embed_source(${EMBED_NAME} ${EMBED_DIR} "${PARSE_RELATIVE}" SYMBOLS ${SYMBOLS} FILES ${PARSE_UNPARSED_ARGUMENTS})
set(INTERNAL_EMBED_LIB embed_lib_${EMBED_NAME})
if(EMBED_USE STREQUAL "LD")
add_library(${INTERNAL_EMBED_LIB} STATIC ${EMBED_FILES} ${OUTPUT_FILES})
else()
add_library(${INTERNAL_EMBED_LIB} OBJECT ${EMBED_FILES})
endif()
if(EMBED_USE STREQUAL "CArrays")
target_sources(${INTERNAL_EMBED_LIB} PRIVATE ${OUTPUT_FILES})
endif()
target_include_directories(${INTERNAL_EMBED_LIB} PRIVATE "${EMBED_DIR}/include")
target_compile_options(${INTERNAL_EMBED_LIB} PRIVATE -Wno-reserved-identifier -Wno-extern-initializer -Wno-missing-variable-declarations)
set_target_properties(${INTERNAL_EMBED_LIB} PROPERTIES POSITION_INDEPENDENT_CODE On)
add_library(${EMBED_NAME} INTERFACE)
if(EMBED_USE STREQUAL "RC")
target_link_libraries(${EMBED_NAME} INTERFACE $<TARGET_OBJECTS:${INTERNAL_EMBED_LIB}>)
elseif(EMBED_USE STREQUAL "LD")
target_link_libraries(${EMBED_NAME} INTERFACE ${INTERNAL_EMBED_LIB})
else()
target_sources(${EMBED_NAME} INTERFACE $<TARGET_OBJECTS:${INTERNAL_EMBED_LIB}>)
endif()
target_include_directories(${EMBED_NAME} INTERFACE "${EMBED_DIR}/include")
endfunction()
cmake_minimum_required(VERSION 3.16)
project(composable_kernel_host)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..)
find_package(ROCM)
include(ROCMInstallTargets)
include(ROCMTest)
list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake)
include(Embed)
file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
${CK_ROOT}/include/ck/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
message(STATUS "RELATIVE: ${CK_ROOT}/include")
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include)
add_definitions(-std=c++17)
file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
# TODO: Use object library
add_library(ck_host STATIC ${SOURCES})
target_link_libraries(ck_host PRIVATE ck_headers)
set_target_properties(ck_host PROPERTIES
LINKER_LANGUAGE CXX
POSITION_INDEPENDENT_CODE ON)
target_include_directories(ck_host PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
)
add_executable(ck-template-driver driver/main.cpp)
target_link_libraries(ck-template-driver ck_host)
rocm_install(
TARGETS ck_host ck_headers
EXPORT ck_hostTargets
)
rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
if(BUILD_TESTING)
add_subdirectory(test)
endif()
#include <functional>
#include <iostream>
#include <string>
#include <unordered_map>
#include <vector>
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/stringutils.hpp"
using ck::host::Transform;
struct Emitters
{
std::unordered_map<std::string, std::function<std::vector<std::string>()>> m;
template <class T>
void Register(const std::string& name)
{
m[name] = [] {
auto configs = T::CreateOperations();
return Transform(configs, [](const auto& ops) { return ToTuple(ops); });
};
}
template <class T>
static std::string ToTuple(const T& ops)
{
auto templates = Transform(
ops, [](const auto& op) { return " " + op.ToSolution().ToTemplateString(); });
return "std::tuple<\n" + ck::host::JoinStrings(templates, ",\n") + ">";
}
std::string Emit(const std::string& name) { return ck::host::JoinStrings(m.at(name)(), "\n"); }
std::vector<std::string> List() const
{
return Transform(m, [](auto&& p) { return p.first; });
}
};
int main(int argc, const char* argv[])
{
std::string prog = argv[0];
std::vector<std::string> args(argv + 1, argv + argc);
Emitters e;
e.Register<ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle>(
"DeviceGemmMultipleD_Xdl_CShuffle");
if(args.empty() or std::any_of(args.begin(), args.end(), [](auto arg) {
return arg == "-h" or arg == "--help";
}))
{
std::cout << "USAGE:" << std::endl;
std::cout << " " << prog << " [TEMPLATE]" << std::endl;
std::cout << std::endl;
std::cout << "FLAGS:" << std::endl;
std::cout << " -h, --help Show help" << std::endl;
std::cout << std::endl;
std::cout << "TEMPLATES:" << std::endl;
for(auto x : e.List())
std::cout << " " << x << std::endl;
std::cout << std::endl;
return 0;
}
for(auto name : args)
std::cout << e.Emit(name) << std::endl;
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/types.hpp"
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
bool TransA = false;
bool TransB = false;
bool TransE = false;
std::vector<bool> DsTrans = {};
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType EDataType = DataType::Half;
std::vector<DataType> DsDataType = {};
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CDEElementOp = "ck::Tuple<>";
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
};
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp"
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
struct Operation_Xdl_CShuffle
{
static std::vector<std::vector<Operation_Xdl_CShuffle>> CreateOperations();
static std::vector<Operation_Xdl_CShuffle> CreateOperations(const Problem& prob);
TensorDesc A{};
TensorDesc B{};
DataType acc = DataType::Float;
DataType cs_type = DataType::Half;
std::vector<TensorDesc> Ds = {};
TensorDesc E{};
std::string a_elem_op = PassThrough;
std::string b_elem_op = PassThrough;
std::string cde_elem_op = Bilinear;
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
operation::TileDesc tile_desc{};
operation::BlockTransferDesc a_block_transfer{};
operation::BlockTransferDesc b_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
Solution ToSolution() const;
};
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
bool TransA = false;
bool TransB = false;
bool TransE = false;
std::vector<bool> DsTrans = {};
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType EDataType = DataType::Half;
std::vector<DataType> DsDataType = {};
std::string AElementOp = PassThrough;
std::string BElementOp = PassThrough;
std::string CDEElementOp = PassThrough;
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
};
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <string_view>
#include <utility>
#include <unordered_map>
#include <vector>
namespace ck {
namespace host {
std::unordered_map<std::string_view, std::string_view> GetHeaders();
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace ck {
namespace host {
namespace operation {
struct TileDesc
{
int block_size = 0;
int m_per_block = 0;
int n_per_block = 0;
int k_per_block = 0;
int ak1 = 0;
int bk1 = 0;
int m_per_XDL = 0;
int n_per_XDL = 0;
int m_Xdl_per_wave = 0;
int n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
};
struct BlockTransferDesc
{
std::string thread_cluster_length = "";
std::string thread_cluster_arrange_order = "";
std::string src_access_order = "";
int src_vec_dim = 0;
int src_scalar_per_vector = 0;
int dst_scalar_per_vector_k1 = 0;
int lds_add_extra_dim = 0;
};
struct CShuffleDesc
{
int m_Xdl_per_wave_per_shuffle = 0;
int n_Xdl_per_wave_per_shuffle = 0;
};
struct CBlockTransferDesc
{
std::string cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl = "";
int scalar_per_vector_n_wave_n_per_Xdl = 0;
};
} // namespace operation
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cassert>
#include <numeric>
#include <string>
#include <utility>
#include <unordered_map>
#include <vector>
namespace ck {
namespace host {
template <class F>
std::string trim(const std::string& s, F f)
{
auto start = std::find_if_not(s.begin(), s.end(), f);
auto last = std::find_if_not(s.rbegin(), std::string::const_reverse_iterator(start), f).base();
return {start, last};
}
inline std::string trim(const std::string& s)
{
return trim(s, [](unsigned char c) { return std::isspace(c); });
}
template <class Strings>
inline std::string JoinStrings(Strings strings, const std::string& delim)
{
auto it = strings.begin();
if(it == strings.end())
return "";
auto nit = std::next(it);
return std::accumulate(nit, strings.end(), *it, [&](std::string x, std::string y) {
return std::move(x) + delim + std::move(y);
});
}
template <class F>
inline std::string
InterpolateString(const std::string& input, F f, std::string start = "${", std::string end = "}")
{
std::string result = "";
result.reserve(input.size());
auto it = input.begin();
while(it != input.end())
{
auto next_start = std::search(it, input.end(), start.begin(), start.end());
auto next_end = std::search(next_start, input.end(), end.begin(), end.end());
result.append(it, next_start);
if(next_start == input.end())
break;
if(next_end == input.end())
{
throw std::runtime_error("Unbalanced brackets");
}
auto r = f(next_start + start.size(), next_end);
result.append(r.begin(), r.end());
it = next_end + end.size();
}
return result;
}
inline std::string InterpolateString(const std::string& input,
const std::unordered_map<std::string, std::string>& vars,
std::string start = "${",
std::string end = "}")
{
return InterpolateString(
input,
[&](auto start_it, auto last_it) {
auto key = trim({start_it, last_it});
auto it = vars.find(key);
if(it == vars.end())
throw std::runtime_error("Unknown key: " + key);
return it->second;
},
std::move(start),
std::move(end));
}
template <class Range, class F>
inline auto Transform(const Range& r, F f) -> std::vector<decltype(f(*r.begin()))>
{
std::vector<decltype(f(*r.begin()))> result;
std::transform(r.begin(), r.end(), std::back_inserter(result), f);
return result;
}
template <class Range1, class Range2, class F>
inline auto Transform(const Range1& r1, const Range2& r2, F f)
-> std::vector<decltype(f(*r1.begin(), *r2.begin()))>
{
std::vector<decltype(f(*r1.begin(), *r2.begin()))> result;
assert(std::distance(r1.begin(), r1.end()) == std::distance(r2.begin(), r2.end()));
std::transform(r1.begin(), r1.end(), r2.begin(), std::back_inserter(result), f);
return result;
}
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include <utility>
#include <unordered_map>
#include <vector>
namespace ck {
namespace host {
struct Solution
{
Solution() = default;
Solution(std::string str, std::unordered_map<std::string, std::string> values);
std::string ToTemplateString() const;
std::string GetTemplateParameter(const std::string& name) const;
template <class T>
T GetTemplateParameter(const std::string& name) const
{
T result;
std::stringstream ss(GetTemplateParameter(name));
ss >> result;
return result;
}
private:
std::string template_str;
std::unordered_map<std::string, std::string> template_values;
};
enum class DataType
{
Half,
Float,
Int8,
Int32
};
std::string ToString(DataType dt);
enum class Layout
{
Row,
Column
};
std::string ToString(Layout dl);
enum class GemmType
{
Default
};
std::string ToString(GemmType gt);
struct TensorDesc
{
DataType element;
Layout layout;
};
std::string SequenceStr(const std::vector<int>& v);
std::string MakeTuple(const std::vector<std::string>& v);
template <int... xs>
const std::string S = SequenceStr({xs...});
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdint>
#include <unordered_set>
namespace ck {
namespace host {
std::size_t integer_divide_ceil(std::size_t x, std::size_t y);
const std::unordered_set<std::string>& get_xdlop_archs();
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
std::string Problem::GetIncludeHeader() const
{
return "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp";
}
std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
{
if(get_xdlop_archs().count(arch) == 0)
return {};
auto ops = ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle::CreateOperations(*this);
std::vector<Solution> result;
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
return op.ToSolution();
});
return result;
}
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment