Commit 52426f84 authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Separate ck_host lib and gemm_softmax_gemm into different PR.

parent f52c2a4d
...@@ -26,23 +26,7 @@ set(version 1.1.0) ...@@ -26,23 +26,7 @@ set(version 1.1.0)
project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) project(composable_kernel VERSION ${version} LANGUAGES CXX HIP)
include(CTest) include(CTest)
# Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8" find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED)
# CK Codegen requires dataclass which is added in Python 3.7
# Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04
if(NOT CK_USE_ALTERNATIVE_PYTHON)
find_package(Python3 3.8 COMPONENTS Interpreter REQUIRED)
else()
message("Using alternative python version")
set(EXTRA_PYTHON_PATH)
# this is overly restrictive, we may need to be more flexible on the following
string(REPLACE "/bin/python3.8" "" EXTRA_PYTHON_PATH "${CK_USE_ALTERNATIVE_PYTHON}")
message("alternative python path is: ${EXTRA_PYTHON_PATH}")
find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED)
add_definitions(-DPython3_EXECUTABLE="${CK_USE_ALTERNATIVE_PYTHON}")
set(Python3_EXECUTABLE "${CK_USE_ALTERNATIVE_PYTHON}")
set(PYTHON_EXECUTABLE "${CK_USE_ALTERNATIVE_PYTHON}")
set(ENV{LD_LIBRARY_PATH} "${EXTRA_PYTHON_PATH}/lib:$ENV{LD_LIBRARY_PATH}")
endif()
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
...@@ -78,14 +62,17 @@ if (DTYPES) ...@@ -78,14 +62,17 @@ if (DTYPES)
endif() endif()
message("DTYPES macro set to ${DTYPES}") message("DTYPES macro set to ${DTYPES}")
else() else()
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8) add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
set(CK_ENABLE_INT8 "ON") set(CK_ENABLE_INT8 "ON")
set(CK_ENABLE_FP16 "ON") set(CK_ENABLE_FP16 "ON")
set(CK_ENABLE_FP32 "ON") set(CK_ENABLE_FP32 "ON")
set(CK_ENABLE_FP64 "ON") set(CK_ENABLE_FP64 "ON")
set(CK_ENABLE_BF16 "ON") set(CK_ENABLE_BF16 "ON")
set(CK_ENABLE_FP8 "ON") if (GPU_TARGETS MATCHES "gfx94")
set(CK_ENABLE_BF8 "ON") add_definitions(-DCK_ENABLE_FP8 -DCK_ENABLE_BF8)
set(CK_ENABLE_FP8 "ON")
set(CK_ENABLE_BF8 "ON")
endif()
endif() endif()
#for f8/bf8_t type #for f8/bf8_t type
...@@ -128,8 +115,6 @@ list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/ll ...@@ -128,8 +115,6 @@ list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/ll
message("GPU_TARGETS= ${GPU_TARGETS}") message("GPU_TARGETS= ${GPU_TARGETS}")
option(CK_BUILD_HOST_LIB, "Only build the CK JIT Helper Library" OFF)
find_package(hip) find_package(hip)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility # No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213 # SWDEV-413293 and https://reviews.llvm.org/D155213
...@@ -206,18 +191,12 @@ endif() ...@@ -206,18 +191,12 @@ endif()
configure_file(include/ck/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/config.h) configure_file(include/ck/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/config.h)
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500723302) if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500723302)
check_cxx_compiler_flag("-fno-offload-uniform-block" HAS_NO_OFFLOAD_UNIFORM_BLOCK) message("Adding the fno-offload-uniform-block compiler flag")
if(HAS_NO_OFFLOAD_UNIFORM_BLOCK) add_compile_options(-fno-offload-uniform-block)
message("Adding the fno-offload-uniform-block compiler flag")
add_compile_options(-fno-offload-uniform-block)
endif()
endif() endif()
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090) if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090)
check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED) message("Adding the enable-post-misched=0 compiler flag")
if(HAS_ENABLE_POST_MISCHED) add_compile_options("SHELL: -mllvm -enable-post-misched=0")
message("Adding the enable-post-misched=0 compiler flag")
add_compile_options("SHELL: -mllvm -enable-post-misched=0")
endif()
endif() endif()
set(check-coerce) set(check-coerce)
check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce) check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce)
...@@ -256,7 +235,6 @@ elseif(CK_PARALLEL_COMPILE_JOBS) ...@@ -256,7 +235,6 @@ elseif(CK_PARALLEL_COMPILE_JOBS)
message(WARNING "Job pooling is only available with Ninja generators.") message(WARNING "Job pooling is only available with Ninja generators.")
endif() endif()
if (NOT CK_BUILD_HOST_LIB)
option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF) option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF)
...@@ -278,8 +256,6 @@ set(THREADS_PREFER_PTHREAD_FLAG ON) ...@@ -278,8 +256,6 @@ set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
link_libraries(Threads::Threads) link_libraries(Threads::Threads)
endif() # NOT CK_BUILD_HOST_LIB
## C++ ## C++
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
...@@ -296,8 +272,6 @@ if(USE_GLIBCXX_ASSERTIONS) ...@@ -296,8 +272,6 @@ if(USE_GLIBCXX_ASSERTIONS)
add_compile_options(-Wp,-D_GLIBCXX_ASSERTIONS) add_compile_options(-Wp,-D_GLIBCXX_ASSERTIONS)
endif() endif()
if (NOT CK_BUILD_HOST_LIB)
## HIP ## HIP
set(CMAKE_HIP_PLATFORM amd) set(CMAKE_HIP_PLATFORM amd)
set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER}) set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER})
...@@ -353,8 +327,6 @@ else() ...@@ -353,8 +327,6 @@ else()
add_compile_definitions(__HIP_PLATFORM_HCC__=1) add_compile_definitions(__HIP_PLATFORM_HCC__=1)
endif() endif()
endif() # NOT CK_BUILD_HOST_LIB
## tidy ## tidy
include(EnableCompilerWarnings) include(EnableCompilerWarnings)
set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name) set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
...@@ -508,8 +480,6 @@ include_directories(BEFORE ...@@ -508,8 +480,6 @@ include_directories(BEFORE
${HIP_INCLUDE_DIRS} ${HIP_INCLUDE_DIRS}
) )
if (NOT CK_BUILD_HOST_LIB)
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV) if(BUILD_DEV)
add_compile_options(-Werror) add_compile_options(-Werror)
...@@ -517,8 +487,6 @@ if(BUILD_DEV) ...@@ -517,8 +487,6 @@ if(BUILD_DEV)
endif() endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
endif() # NOT CK_BUILD_HOST_LIB
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
add_compile_options(-fcolor-diagnostics) add_compile_options(-fcolor-diagnostics)
endif() endif()
...@@ -528,8 +496,6 @@ endif() ...@@ -528,8 +496,6 @@ endif()
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
if (NOT CK_BUILD_HOST_LIB)
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp")
file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*) file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*)
set(CK_DEVICE_INSTANCES) set(CK_DEVICE_INSTANCES)
...@@ -584,7 +550,12 @@ if(NOT DEFINED INSTANCES_ONLY) ...@@ -584,7 +550,12 @@ if(NOT DEFINED INSTANCES_ONLY)
PACKAGE_NAME examples PACKAGE_NAME examples
) )
add_subdirectory(example) add_subdirectory(example)
add_subdirectory(test) if(GPU_TARGETS MATCHES "gfx9" AND NOT INSTANCES_ONLY)
add_subdirectory(codegen)
endif()
if(BUILD_TESTING)
add_subdirectory(test)
endif()
rocm_package_setup_component(profiler rocm_package_setup_component(profiler
LIBRARY_NAME composablekernel LIBRARY_NAME composablekernel
...@@ -601,22 +572,6 @@ if(NOT DEFINED INSTANCES_ONLY) ...@@ -601,22 +572,6 @@ if(NOT DEFINED INSTANCES_ONLY)
endif() endif()
endif() endif()
if(NOT DEFINED PROFILER_ONLY AND (GPU_TARGETS MATCHES "gfx9" OR DEFINED INSTANCES_ONLY))
add_subdirectory(codegen)
endif()
else() # NOT CK_BUILD_HOST_LIB
if(GPU_TARGETS MATCHES "gfx9")
rocm_package_setup_component(ck_host
LIBRARY_NAME composablekernel
PACKAGE_NAME ck_host
)
add_subdirectory(codegen)
endif()
endif() # NOT CK_BUILD_HOST_LIB
#Create an interface target for the include only files and call it "composablekernels" #Create an interface target for the include only files and call it "composablekernels"
include(CMakePackageConfigHelpers) include(CMakePackageConfigHelpers)
...@@ -654,4 +609,4 @@ rocm_create_package( ...@@ -654,4 +609,4 @@ rocm_create_package(
MAINTAINER "MIOpen Kernels Dev Team <dl.MIOpen@amd.com>" MAINTAINER "MIOpen Kernels Dev Team <dl.MIOpen@amd.com>"
LDCONFIG LDCONFIG
HEADER_ONLY HEADER_ONLY
) )
\ No newline at end of file
@PACKAGE_INIT@ @PACKAGE_INIT@
set(_composable_kernel_supported_components device_other_operations device_gemm_operations device_conv_operations device_mha_operations device_contraction_operations device_reduction_operations utility ck_host) set(_composable_kernel_supported_components device_other_operations device_gemm_operations device_conv_operations device_mha_operations device_contraction_operations device_reduction_operations utility)
foreach(_comp ${composable_kernel_FIND_COMPONENTS}) foreach(_comp ${composable_kernel_FIND_COMPONENTS})
if(NOT _comp IN_LIST _composable_kernel_supported_components) if(NOT _comp IN_LIST _composable_kernel_supported_components)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// defines all values need for an instance of fwd conv
struct Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
static std::vector<std::vector<Operation_Xdl_CShuffle>>
CreateOperations(const std::string& prologue, const std::string& epilogue);
// returns a vector of instances, given a problem spec and fusion operators
static std::vector<Operation_Xdl_CShuffle>
CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
TensorDesc A{};
TensorDesc B{};
TensorDesc B1{};
TensorDesc C{};
std::string a_elem_op = PassThrough;
std::string b_elem_op = PassThrough;
std::string b1_elem_op = PassThrough;
std::string c_elem_op = PassThrough;
std::string acc_elem_op = Scale;
std::string prologue = "";
std::string epilogue = "";
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
// tuning parameters
operation::TileDescGemmSoftmaxGemm tile_desc{};
operation::BlockTransferDesc a_block_transfer{};
operation::BlockTransferDesc b0_block_transfer{};
operation::BlockTransferDesc b1_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
bool mask_out_upper_triangle = false;
// functions to update fusion operators if provided
void update_prologue(const std::string& prologue);
void update_epilogue(const std::string& epilogue);
/**constexpr**/ bool IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_);
// returns a templated instance
Solution ToSolution() const;
};
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// defines the problem specification for a GEMM operation
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
std::size_t O = 0;
bool TransA = false;
bool TransB = false;
bool TransB1 = false;
bool TransC = false;
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType B1DataType = DataType::Half;
DataType CDataType = DataType::Half;
std::string AElementOp = PassThrough;
std::string BElementOp = PassThrough;
std::string B1ElementOp = PassThrough;
std::string CElementOp = PassThrough;
std::string AccElementOp = Scale;
// returns the correct device op file for the operation
std::string GetIncludeHeader() const;
// returns a list of instances based on the problem spec and provided fusion operations
std::vector<Solution> GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const;
};
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
...@@ -41,8 +41,6 @@ struct Operation_Xdl_CShuffle ...@@ -41,8 +41,6 @@ struct Operation_Xdl_CShuffle
operation::BlockTransferDesc b_block_transfer{}; operation::BlockTransferDesc b_block_transfer{};
operation::CShuffleDesc cshuffle{}; operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{}; operation::CBlockTransferDesc c_block_transfer{};
LoopScheduler loop_scheduler{};
PipelineVersion pipeline_version{};
// functions to update fusion operators if provided // functions to update fusion operators if provided
void update_prologue(const std::string& prologue); void update_prologue(const std::string& prologue);
......
...@@ -23,26 +23,6 @@ struct TileDesc ...@@ -23,26 +23,6 @@ struct TileDesc
int n_Xdl_per_wave = 0; int n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0; int num_gemmk_prefetch_stage = 0;
}; };
struct TileDescGemmSoftmaxGemm
{
int block_size = 0;
int gemm01_m_per_block = 0;
int gemm0_n_per_block = 0;
int gemm0_k_per_block = 0;
int gemm1_n_per_block = 0;
int gemm1_k_per_block = 0;
int ak1 = 0;
int bk1 = 0;
int b1k1 = 0;
int m_per_XDL = 0;
int n_per_XDL = 0;
int gemm0_m_Xdl_per_wave = 0;
int gemm0_n_Xdl_per_wave = 0;
int gemm1_n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
};
struct BlockTransferDesc struct BlockTransferDesc
{ {
std::string thread_cluster_length = ""; std::string thread_cluster_length = "";
......
...@@ -66,20 +66,6 @@ enum class GemmType ...@@ -66,20 +66,6 @@ enum class GemmType
}; };
std::string ToString(GemmType gt); std::string ToString(GemmType gt);
enum class LoopScheduler
{
Default,
Interwave,
};
std::string ToString(LoopScheduler ls);
enum class PipelineVersion
{
v1,
v2
};
std::string ToString(PipelineVersion pv);
struct TensorDesc struct TensorDesc
{ {
DataType element; DataType element;
...@@ -98,7 +84,6 @@ const std::string S = SequenceStr({xs...}); ...@@ -98,7 +84,6 @@ const std::string S = SequenceStr({xs...});
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough"; constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear"; constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
constexpr const char* Scale = "ck::tensor_operation::element_wise::Scale";
} // namespace host } // namespace host
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// return the relevant device op file based on the operation
std::string Problem::GetIncludeHeader() const
{
return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp";
}
// returns templated instances when provided with a problem specification
std::vector<Solution> Problem::GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const
{
if(get_xdlop_archs().count(arch) == 0)
return {};
auto ops = ck::host::device_batched_gemm_softmax_gemm::Operation_Xdl_CShuffle::CreateOperations(
*this, prologue, epilogue); // obtains vector of instances
std::vector<Solution> result;
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
return op.ToSolution(); // template instance with correct values
});
return result;
}
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
...@@ -62,13 +62,6 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi) ...@@ -62,13 +62,6 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
// accounts for all possible combinations of Row/Col major // accounts for all possible combinations of Row/Col major
static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1,
// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
// Hard-code tuning parameters in modularized fashion, string them together into a vector of // Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances // instances
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...@@ -90,8 +83,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -90,8 +83,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1}, { 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1},
{ 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1}, { 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1},
{ 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1}, { 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1},
// Irregular tile
{ 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1},
// clang-format on // clang-format on
}; };
...@@ -109,8 +100,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -109,8 +100,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
// clang-format on // clang-format on
}; };
...@@ -120,17 +109,15 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -120,17 +109,15 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| // ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | // Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | | // | | | | | | |
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// clang-format on // clang-format on
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
}; };
std::vector<operation::BlockTransferDesc> b_block_descriptions_rowmajor = { std::vector<operation::BlockTransferDesc> b_block_descriptions_rowmajor = {
...@@ -147,8 +134,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -147,8 +134,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// clang-format on // clang-format on
}; };
...@@ -166,8 +151,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -166,8 +151,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
// clang-format on // clang-format on
}; };
...@@ -184,7 +167,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -184,7 +167,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ 1, 1}, { 1, 1},
{ 1, 1}, { 1, 1},
{ 1, 1}, { 1, 1},
{ 1, 1},
{ 1, 1}, { 1, 1},
// clang-format on // clang-format on
}; };
...@@ -203,8 +185,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -203,8 +185,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<1, 16, 1, 8>, 8}, { S<1, 16, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8},
// Irregular tile
{ S<1, 16, 1, 4>, 1},
// clang-format on // clang-format on
}; };
...@@ -219,44 +199,33 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -219,44 +199,33 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
assert(tile_descriptions.size() == cshuffle_descriptions.size()); assert(tile_descriptions.size() == cshuffle_descriptions.size());
assert(tile_descriptions.size() == c_block_descriptions.size()); assert(tile_descriptions.size() == c_block_descriptions.size());
const std::vector<std::tuple<LoopScheduler, PipelineVersion>> scheduler_pipeline_descriptions = // Put all values together into a single operation > store into the result vector
{ for(std::size_t i = 0; i < tile_descriptions.size(); i++)
{LoopScheduler::Default, PipelineVersion::v1},
{LoopScheduler::Interwave, PipelineVersion::v1},
{LoopScheduler::Default, PipelineVersion::v2},
};
for(auto [loop_scheduler, pipeline_version] : scheduler_pipeline_descriptions)
{ {
// Put all values together into a single operation > store into the result vector Operation_Xdl_CShuffle x;
for(std::size_t i = 0; i < tile_descriptions.size(); i++) x.tile_desc = tile_descriptions[i];
{ x.a_block_transfer = a_block_descriptions[i];
Operation_Xdl_CShuffle x; x.b_block_transfer = b_block_descriptions[i];
x.tile_desc = tile_descriptions[i]; x.cshuffle = cshuffle_descriptions[i];
x.a_block_transfer = a_block_descriptions[i]; x.c_block_transfer = c_block_descriptions[i];
x.b_block_transfer = b_block_descriptions[i]; x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
x.cshuffle = cshuffle_descriptions[i]; x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
x.c_block_transfer = c_block_descriptions[i]; x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)};
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) {
x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; return TensorDesc{dt, ToLayout(trans)};
x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)}; });
x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) { x.a_elem_op = prob.AElementOp;
return TensorDesc{dt, ToLayout(trans)}; x.b_elem_op = prob.BElementOp;
}); x.cde_elem_op = prob.CDEElementOp;
x.a_elem_op = prob.AElementOp; x.gemm_specialization = GetGemmSpec(prob.M,
x.b_elem_op = prob.BElementOp; prob.N,
x.cde_elem_op = prob.CDEElementOp; prob.K,
x.gemm_specialization = GetGemmSpec(prob.M, x.tile_desc.m_per_block,
prob.N, x.tile_desc.n_per_block,
prob.K, x.tile_desc.k_per_block);
x.tile_desc.m_per_block, x.update_prologue(prologue);
x.tile_desc.n_per_block, x.update_epilogue(epilogue);
x.tile_desc.k_per_block); result.push_back(x);
x.loop_scheduler = loop_scheduler;
x.pipeline_version = pipeline_version;
x.update_prologue(prologue);
x.update_epilogue(epilogue);
result.push_back(x);
}
} }
return result; return result;
} }
...@@ -294,7 +263,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = ...@@ -294,7 +263,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, " "${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " "${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, " "${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDEBlockTransferScalarPerVector_NPerBlock}, ${LoopScheduler}, ${PipelineVersion}>"; "${CDEBlockTransferScalarPerVector_NPerBlock}>";
// use hardcoded instances from vector of operations to substitute values into instance template // use hardcoded instances from vector of operations to substitute values into instance template
Solution Operation_Xdl_CShuffle::ToSolution() const Solution Operation_Xdl_CShuffle::ToSolution() const
...@@ -367,8 +336,6 @@ Solution Operation_Xdl_CShuffle::ToSolution() const ...@@ -367,8 +336,6 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
{"CDEBlockTransferScalarPerVector_NPerBlock", {"CDEBlockTransferScalarPerVector_NPerBlock",
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
{"LoopScheduler", ToString(this->loop_scheduler)},
{"PipelineVersion", ToString(this->pipeline_version)},
}; };
return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values), return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values),
......
...@@ -56,26 +56,6 @@ std::string ToString(GemmType gt) ...@@ -56,26 +56,6 @@ std::string ToString(GemmType gt)
throw std::runtime_error("Incorrect gemm type"); throw std::runtime_error("Incorrect gemm type");
} }
std::string ToString(LoopScheduler ls)
{
switch(ls)
{
case LoopScheduler::Default: return "ck::LoopScheduler::Default";
case LoopScheduler::Interwave: return "ck::LoopScheduler::Interwave";
}
throw std::runtime_error("Incorrect LoopScheduler type");
}
std::string ToString(PipelineVersion pv)
{
switch(pv)
{
case PipelineVersion::v1: return "ck::PipelineVersion::v1";
case PipelineVersion::v2: return "ck::PipelineVersion::v2";
}
throw std::runtime_error("Incorrect PipelineVersion type");
}
std::string SequenceStr(const std::vector<int>& v) std::string SequenceStr(const std::vector<int>& v)
{ {
return "ck::Sequence<" + return "ck::Sequence<" +
......
#include "common.hpp" #include "common.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp" #include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp" #include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/headers.hpp" #include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp" #include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp" #include "ck/host/utils.hpp"
...@@ -87,34 +85,4 @@ TEST_CASE(test_problem_kernel) ...@@ -87,34 +85,4 @@ TEST_CASE(test_problem_kernel)
} }
} }
TEST_CASE(test_gemm_softmax_gemm)
{
ck::host::device_batched_gemm_softmax_gemm::Problem prob;
prob.TransA = false;
prob.TransB = true;
prob.TransB1 = false;
prob.TransC = false;
prob.M = 1024;
prob.N = 1024;
prob.K = 1024;
prob.O = 1024;
check_all<half> check;
auto a = to_gpu(generate_buffer<half>(1024 * 1024, 0));
auto b = to_gpu(generate_buffer<half>(1024 * 1024, 1));
auto b1 = to_gpu(generate_buffer<half>(1024 * 1024, 2));
auto c = to_gpu(generate_buffer<half>(1024 * 1024, 3));
std::string epilogue = "";
std::string prologue = "";
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
std::cout << "Num solutions: " << solutions.size() << std::endl;
for(auto i = 0; i < solutions.size(); ++i) {
std::cout << "Solution " << i << std::endl;
std::cout << solutions[i].ToTemplateString() << std::endl;
std::cout << std::endl;
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment