Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
52426f84
Commit
52426f84
authored
Oct 02, 2024
by
Mirza Halilcevic
Browse files
Separate ck_host lib and gemm_softmax_gemm into different PR.
parent
f52c2a4d
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
54 additions
and
776 deletions
+54
-776
CMakeLists.txt
CMakeLists.txt
+18
-63
Config.cmake.in
Config.cmake.in
+1
-1
codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp
...de/ck/host/device_batched_gemm_softmax_gemm/operation.hpp
+0
-58
codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp
...lude/ck/host/device_batched_gemm_softmax_gemm/problem.hpp
+0
-47
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
+0
-2
codegen/include/ck/host/operation/gemm.hpp
codegen/include/ck/host/operation/gemm.hpp
+0
-20
codegen/include/ck/host/types.hpp
codegen/include/ck/host/types.hpp
+0
-15
codegen/src/device_batched_gemm_softmax_gemm.cpp
codegen/src/device_batched_gemm_softmax_gemm.cpp
+0
-38
codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp
...vice_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp
+0
-412
codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
...gen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
+35
-68
codegen/src/types.cpp
codegen/src/types.cpp
+0
-20
codegen/test/gemm_multiple_d.cpp
codegen/test/gemm_multiple_d.cpp
+0
-32
No files found.
CMakeLists.txt
View file @
52426f84
...
@@ -26,23 +26,7 @@ set(version 1.1.0)
...
@@ -26,23 +26,7 @@ set(version 1.1.0)
project
(
composable_kernel VERSION
${
version
}
LANGUAGES CXX HIP
)
project
(
composable_kernel VERSION
${
version
}
LANGUAGES CXX HIP
)
include
(
CTest
)
include
(
CTest
)
# Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8"
find_package
(
Python3 3.6 COMPONENTS Interpreter REQUIRED
)
# CK Codegen requires dataclass which is added in Python 3.7
# Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04
if
(
NOT CK_USE_ALTERNATIVE_PYTHON
)
find_package
(
Python3 3.8 COMPONENTS Interpreter REQUIRED
)
else
()
message
(
"Using alternative python version"
)
set
(
EXTRA_PYTHON_PATH
)
# this is overly restrictive, we may need to be more flexible on the following
string
(
REPLACE
"/bin/python3.8"
""
EXTRA_PYTHON_PATH
"
${
CK_USE_ALTERNATIVE_PYTHON
}
"
)
message
(
"alternative python path is:
${
EXTRA_PYTHON_PATH
}
"
)
find_package
(
Python3 3.6 COMPONENTS Interpreter REQUIRED
)
add_definitions
(
-DPython3_EXECUTABLE=
"
${
CK_USE_ALTERNATIVE_PYTHON
}
"
)
set
(
Python3_EXECUTABLE
"
${
CK_USE_ALTERNATIVE_PYTHON
}
"
)
set
(
PYTHON_EXECUTABLE
"
${
CK_USE_ALTERNATIVE_PYTHON
}
"
)
set
(
ENV{LD_LIBRARY_PATH}
"
${
EXTRA_PYTHON_PATH
}
/lib:$ENV{LD_LIBRARY_PATH}"
)
endif
()
list
(
APPEND CMAKE_MODULE_PATH
"
${
PROJECT_SOURCE_DIR
}
/cmake"
)
list
(
APPEND CMAKE_MODULE_PATH
"
${
PROJECT_SOURCE_DIR
}
/cmake"
)
...
@@ -78,14 +62,17 @@ if (DTYPES)
...
@@ -78,14 +62,17 @@ if (DTYPES)
endif
()
endif
()
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
else
()
else
()
add_definitions
(
-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16
-DCK_ENABLE_FP8 -DCK_ENABLE_BF8
)
add_definitions
(
-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16
)
set
(
CK_ENABLE_INT8
"ON"
)
set
(
CK_ENABLE_INT8
"ON"
)
set
(
CK_ENABLE_FP16
"ON"
)
set
(
CK_ENABLE_FP16
"ON"
)
set
(
CK_ENABLE_FP32
"ON"
)
set
(
CK_ENABLE_FP32
"ON"
)
set
(
CK_ENABLE_FP64
"ON"
)
set
(
CK_ENABLE_FP64
"ON"
)
set
(
CK_ENABLE_BF16
"ON"
)
set
(
CK_ENABLE_BF16
"ON"
)
set
(
CK_ENABLE_FP8
"ON"
)
if
(
GPU_TARGETS MATCHES
"gfx94"
)
set
(
CK_ENABLE_BF8
"ON"
)
add_definitions
(
-DCK_ENABLE_FP8 -DCK_ENABLE_BF8
)
set
(
CK_ENABLE_FP8
"ON"
)
set
(
CK_ENABLE_BF8
"ON"
)
endif
()
endif
()
endif
()
#for f8/bf8_t type
#for f8/bf8_t type
...
@@ -128,8 +115,6 @@ list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/ll
...
@@ -128,8 +115,6 @@ list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/ll
message
(
"GPU_TARGETS=
${
GPU_TARGETS
}
"
)
message
(
"GPU_TARGETS=
${
GPU_TARGETS
}
"
)
option
(
CK_BUILD_HOST_LIB,
"Only build the CK JIT Helper Library"
OFF
)
find_package
(
hip
)
find_package
(
hip
)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
# SWDEV-413293 and https://reviews.llvm.org/D155213
...
@@ -206,18 +191,12 @@ endif()
...
@@ -206,18 +191,12 @@ endif()
configure_file
(
include/ck/config.h.in
${
CMAKE_CURRENT_BINARY_DIR
}
/include/ck/config.h
)
configure_file
(
include/ck/config.h.in
${
CMAKE_CURRENT_BINARY_DIR
}
/include/ck/config.h
)
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
GREATER 500723302
)
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
GREATER 500723302
)
check_cxx_compiler_flag
(
"-fno-offload-uniform-block"
HAS_NO_OFFLOAD_UNIFORM_BLOCK
)
message
(
"Adding the fno-offload-uniform-block compiler flag"
)
if
(
HAS_NO_OFFLOAD_UNIFORM_BLOCK
)
add_compile_options
(
-fno-offload-uniform-block
)
message
(
"Adding the fno-offload-uniform-block compiler flag"
)
add_compile_options
(
-fno-offload-uniform-block
)
endif
()
endif
()
endif
()
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
GREATER 600140090
)
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
GREATER 600140090
)
check_cxx_compiler_flag
(
"-mllvm -enable-post-misched=0"
HAS_ENABLE_POST_MISCHED
)
message
(
"Adding the enable-post-misched=0 compiler flag"
)
if
(
HAS_ENABLE_POST_MISCHED
)
add_compile_options
(
"SHELL: -mllvm -enable-post-misched=0"
)
message
(
"Adding the enable-post-misched=0 compiler flag"
)
add_compile_options
(
"SHELL: -mllvm -enable-post-misched=0"
)
endif
()
endif
()
endif
()
set
(
check-coerce
)
set
(
check-coerce
)
check_cxx_compiler_flag
(
" -mllvm -amdgpu-coerce-illegal-types=1"
check-coerce
)
check_cxx_compiler_flag
(
" -mllvm -amdgpu-coerce-illegal-types=1"
check-coerce
)
...
@@ -256,7 +235,6 @@ elseif(CK_PARALLEL_COMPILE_JOBS)
...
@@ -256,7 +235,6 @@ elseif(CK_PARALLEL_COMPILE_JOBS)
message
(
WARNING
"Job pooling is only available with Ninja generators."
)
message
(
WARNING
"Job pooling is only available with Ninja generators."
)
endif
()
endif
()
if
(
NOT CK_BUILD_HOST_LIB
)
option
(
USE_BITINT_EXTENSION_INT4
"Whether to enable clang's BitInt extension to provide int4 data type."
OFF
)
option
(
USE_BITINT_EXTENSION_INT4
"Whether to enable clang's BitInt extension to provide int4 data type."
OFF
)
option
(
USE_OPT_GFX11
"Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons."
OFF
)
option
(
USE_OPT_GFX11
"Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons."
OFF
)
...
@@ -278,8 +256,6 @@ set(THREADS_PREFER_PTHREAD_FLAG ON)
...
@@ -278,8 +256,6 @@ set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package
(
Threads REQUIRED
)
find_package
(
Threads REQUIRED
)
link_libraries
(
Threads::Threads
)
link_libraries
(
Threads::Threads
)
endif
()
# NOT CK_BUILD_HOST_LIB
## C++
## C++
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CXX_STANDARD_REQUIRED ON
)
set
(
CMAKE_CXX_STANDARD_REQUIRED ON
)
...
@@ -296,8 +272,6 @@ if(USE_GLIBCXX_ASSERTIONS)
...
@@ -296,8 +272,6 @@ if(USE_GLIBCXX_ASSERTIONS)
add_compile_options
(
-Wp,-D_GLIBCXX_ASSERTIONS
)
add_compile_options
(
-Wp,-D_GLIBCXX_ASSERTIONS
)
endif
()
endif
()
if
(
NOT CK_BUILD_HOST_LIB
)
## HIP
## HIP
set
(
CMAKE_HIP_PLATFORM amd
)
set
(
CMAKE_HIP_PLATFORM amd
)
set
(
CMAKE_HIP_COMPILER
${
CMAKE_CXX_COMPILER
}
)
set
(
CMAKE_HIP_COMPILER
${
CMAKE_CXX_COMPILER
}
)
...
@@ -353,8 +327,6 @@ else()
...
@@ -353,8 +327,6 @@ else()
add_compile_definitions
(
__HIP_PLATFORM_HCC__=1
)
add_compile_definitions
(
__HIP_PLATFORM_HCC__=1
)
endif
()
endif
()
endif
()
# NOT CK_BUILD_HOST_LIB
## tidy
## tidy
include
(
EnableCompilerWarnings
)
include
(
EnableCompilerWarnings
)
set
(
CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name
)
set
(
CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name
)
...
@@ -508,8 +480,6 @@ include_directories(BEFORE
...
@@ -508,8 +480,6 @@ include_directories(BEFORE
${
HIP_INCLUDE_DIRS
}
${
HIP_INCLUDE_DIRS
}
)
)
if
(
NOT CK_BUILD_HOST_LIB
)
SET
(
BUILD_DEV ON CACHE BOOL
"BUILD_DEV"
)
SET
(
BUILD_DEV ON CACHE BOOL
"BUILD_DEV"
)
if
(
BUILD_DEV
)
if
(
BUILD_DEV
)
add_compile_options
(
-Werror
)
add_compile_options
(
-Werror
)
...
@@ -517,8 +487,6 @@ if(BUILD_DEV)
...
@@ -517,8 +487,6 @@ if(BUILD_DEV)
endif
()
endif
()
message
(
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
message
(
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
endif
()
# NOT CK_BUILD_HOST_LIB
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
MATCHES
"Clang"
)
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
MATCHES
"Clang"
)
add_compile_options
(
-fcolor-diagnostics
)
add_compile_options
(
-fcolor-diagnostics
)
endif
()
endif
()
...
@@ -528,8 +496,6 @@ endif()
...
@@ -528,8 +496,6 @@ endif()
add_custom_target
(
check COMMAND
${
CMAKE_CTEST_COMMAND
}
--output-on-failure -C
${
CMAKE_CFG_INTDIR
}
)
add_custom_target
(
check COMMAND
${
CMAKE_CTEST_COMMAND
}
--output-on-failure -C
${
CMAKE_CFG_INTDIR
}
)
if
(
NOT CK_BUILD_HOST_LIB
)
file
(
GLOB_RECURSE INSTANCE_FILES
"
${
PROJECT_SOURCE_DIR
}
/*/device_*_instance.cpp"
)
file
(
GLOB_RECURSE INSTANCE_FILES
"
${
PROJECT_SOURCE_DIR
}
/*/device_*_instance.cpp"
)
file
(
GLOB dir_list RELATIVE
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/*
)
file
(
GLOB dir_list RELATIVE
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/*
)
set
(
CK_DEVICE_INSTANCES
)
set
(
CK_DEVICE_INSTANCES
)
...
@@ -584,7 +550,12 @@ if(NOT DEFINED INSTANCES_ONLY)
...
@@ -584,7 +550,12 @@ if(NOT DEFINED INSTANCES_ONLY)
PACKAGE_NAME examples
PACKAGE_NAME examples
)
)
add_subdirectory
(
example
)
add_subdirectory
(
example
)
add_subdirectory
(
test
)
if
(
GPU_TARGETS MATCHES
"gfx9"
AND NOT INSTANCES_ONLY
)
add_subdirectory
(
codegen
)
endif
()
if
(
BUILD_TESTING
)
add_subdirectory
(
test
)
endif
()
rocm_package_setup_component
(
profiler
rocm_package_setup_component
(
profiler
LIBRARY_NAME composablekernel
LIBRARY_NAME composablekernel
...
@@ -601,22 +572,6 @@ if(NOT DEFINED INSTANCES_ONLY)
...
@@ -601,22 +572,6 @@ if(NOT DEFINED INSTANCES_ONLY)
endif
()
endif
()
endif
()
endif
()
if
(
NOT DEFINED PROFILER_ONLY
AND
(
GPU_TARGETS MATCHES
"gfx9"
OR DEFINED INSTANCES_ONLY
))
add_subdirectory
(
codegen
)
endif
()
else
()
# NOT CK_BUILD_HOST_LIB
if
(
GPU_TARGETS MATCHES
"gfx9"
)
rocm_package_setup_component
(
ck_host
LIBRARY_NAME composablekernel
PACKAGE_NAME ck_host
)
add_subdirectory
(
codegen
)
endif
()
endif
()
# NOT CK_BUILD_HOST_LIB
#Create an interface target for the include only files and call it "composablekernels"
#Create an interface target for the include only files and call it "composablekernels"
include
(
CMakePackageConfigHelpers
)
include
(
CMakePackageConfigHelpers
)
...
@@ -654,4 +609,4 @@ rocm_create_package(
...
@@ -654,4 +609,4 @@ rocm_create_package(
MAINTAINER
"MIOpen Kernels Dev Team <dl.MIOpen@amd.com>"
MAINTAINER
"MIOpen Kernels Dev Team <dl.MIOpen@amd.com>"
LDCONFIG
LDCONFIG
HEADER_ONLY
HEADER_ONLY
)
)
\ No newline at end of file
Config.cmake.in
View file @
52426f84
@PACKAGE_INIT@
@PACKAGE_INIT@
set(_composable_kernel_supported_components device_other_operations device_gemm_operations device_conv_operations device_mha_operations device_contraction_operations device_reduction_operations utility
ck_host
)
set(_composable_kernel_supported_components device_other_operations device_gemm_operations device_conv_operations device_mha_operations device_contraction_operations device_reduction_operations utility)
foreach(_comp ${composable_kernel_FIND_COMPONENTS})
foreach(_comp ${composable_kernel_FIND_COMPONENTS})
if(NOT _comp IN_LIST _composable_kernel_supported_components)
if(NOT _comp IN_LIST _composable_kernel_supported_components)
...
...
codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp
deleted
100644 → 0
View file @
f52c2a4d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
// defines all values need for an instance of fwd conv
struct
Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
static
std
::
vector
<
std
::
vector
<
Operation_Xdl_CShuffle
>>
CreateOperations
(
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
// returns a vector of instances, given a problem spec and fusion operators
static
std
::
vector
<
Operation_Xdl_CShuffle
>
CreateOperations
(
const
Problem
&
prob
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
TensorDesc
A
{};
TensorDesc
B
{};
TensorDesc
B1
{};
TensorDesc
C
{};
std
::
string
a_elem_op
=
PassThrough
;
std
::
string
b_elem_op
=
PassThrough
;
std
::
string
b1_elem_op
=
PassThrough
;
std
::
string
c_elem_op
=
PassThrough
;
std
::
string
acc_elem_op
=
Scale
;
std
::
string
prologue
=
""
;
std
::
string
epilogue
=
""
;
std
::
string
gemm_specialization
=
"ck::tensor_operation::device::GemmSpecialization::Default"
;
// tuning parameters
operation
::
TileDescGemmSoftmaxGemm
tile_desc
{};
operation
::
BlockTransferDesc
a_block_transfer
{};
operation
::
BlockTransferDesc
b0_block_transfer
{};
operation
::
BlockTransferDesc
b1_block_transfer
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
bool
mask_out_upper_triangle
=
false
;
// functions to update fusion operators if provided
void
update_prologue
(
const
std
::
string
&
prologue
);
void
update_epilogue
(
const
std
::
string
&
epilogue
);
/**constexpr**/
bool
IsSupported
(
std
::
size_t
MRaw_
,
std
::
size_t
NRaw_
,
std
::
size_t
KRaw_
);
// returns a templated instance
Solution
ToSolution
()
const
;
};
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp
deleted
100644 → 0
View file @
f52c2a4d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
// defines the problem specification for a GEMM operation
struct
Problem
{
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
O
=
0
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransB1
=
false
;
bool
TransC
=
false
;
DataType
ADataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
B1DataType
=
DataType
::
Half
;
DataType
CDataType
=
DataType
::
Half
;
std
::
string
AElementOp
=
PassThrough
;
std
::
string
BElementOp
=
PassThrough
;
std
::
string
B1ElementOp
=
PassThrough
;
std
::
string
CElementOp
=
PassThrough
;
std
::
string
AccElementOp
=
Scale
;
// returns the correct device op file for the operation
std
::
string
GetIncludeHeader
()
const
;
// returns a list of instances based on the problem spec and provided fusion operations
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
;
};
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
View file @
52426f84
...
@@ -41,8 +41,6 @@ struct Operation_Xdl_CShuffle
...
@@ -41,8 +41,6 @@ struct Operation_Xdl_CShuffle
operation
::
BlockTransferDesc
b_block_transfer
{};
operation
::
BlockTransferDesc
b_block_transfer
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
LoopScheduler
loop_scheduler
{};
PipelineVersion
pipeline_version
{};
// functions to update fusion operators if provided
// functions to update fusion operators if provided
void
update_prologue
(
const
std
::
string
&
prologue
);
void
update_prologue
(
const
std
::
string
&
prologue
);
...
...
codegen/include/ck/host/operation/gemm.hpp
View file @
52426f84
...
@@ -23,26 +23,6 @@ struct TileDesc
...
@@ -23,26 +23,6 @@ struct TileDesc
int
n_Xdl_per_wave
=
0
;
int
n_Xdl_per_wave
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
};
};
struct
TileDescGemmSoftmaxGemm
{
int
block_size
=
0
;
int
gemm01_m_per_block
=
0
;
int
gemm0_n_per_block
=
0
;
int
gemm0_k_per_block
=
0
;
int
gemm1_n_per_block
=
0
;
int
gemm1_k_per_block
=
0
;
int
ak1
=
0
;
int
bk1
=
0
;
int
b1k1
=
0
;
int
m_per_XDL
=
0
;
int
n_per_XDL
=
0
;
int
gemm0_m_Xdl_per_wave
=
0
;
int
gemm0_n_Xdl_per_wave
=
0
;
int
gemm1_n_Xdl_per_wave
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
};
struct
BlockTransferDesc
struct
BlockTransferDesc
{
{
std
::
string
thread_cluster_length
=
""
;
std
::
string
thread_cluster_length
=
""
;
...
...
codegen/include/ck/host/types.hpp
View file @
52426f84
...
@@ -66,20 +66,6 @@ enum class GemmType
...
@@ -66,20 +66,6 @@ enum class GemmType
};
};
std
::
string
ToString
(
GemmType
gt
);
std
::
string
ToString
(
GemmType
gt
);
enum
class
LoopScheduler
{
Default
,
Interwave
,
};
std
::
string
ToString
(
LoopScheduler
ls
);
enum
class
PipelineVersion
{
v1
,
v2
};
std
::
string
ToString
(
PipelineVersion
pv
);
struct
TensorDesc
struct
TensorDesc
{
{
DataType
element
;
DataType
element
;
...
@@ -98,7 +84,6 @@ const std::string S = SequenceStr({xs...});
...
@@ -98,7 +84,6 @@ const std::string S = SequenceStr({xs...});
constexpr
const
char
*
PassThrough
=
"ck::tensor_operation::element_wise::PassThrough"
;
constexpr
const
char
*
PassThrough
=
"ck::tensor_operation::element_wise::PassThrough"
;
constexpr
const
char
*
Bilinear
=
"ck::tensor_operation::element_wise::Bilinear"
;
constexpr
const
char
*
Bilinear
=
"ck::tensor_operation::element_wise::Bilinear"
;
constexpr
const
char
*
Scale
=
"ck::tensor_operation::element_wise::Scale"
;
}
// namespace host
}
// namespace host
}
// namespace ck
}
// namespace ck
codegen/src/device_batched_gemm_softmax_gemm.cpp
deleted
100644 → 0
View file @
f52c2a4d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
// return the relevant device op file based on the operation
std
::
string
Problem
::
GetIncludeHeader
()
const
{
return
"ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
;
}
// returns templated instances when provided with a problem specification
std
::
vector
<
Solution
>
Problem
::
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
{
if
(
get_xdlop_archs
().
count
(
arch
)
==
0
)
return
{};
auto
ops
=
ck
::
host
::
device_batched_gemm_softmax_gemm
::
Operation_Xdl_CShuffle
::
CreateOperations
(
*
this
,
prologue
,
epilogue
);
// obtains vector of instances
std
::
vector
<
Solution
>
result
;
std
::
transform
(
ops
.
begin
(),
ops
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
op
)
{
return
op
.
ToSolution
();
// template instance with correct values
});
return
result
;
}
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp
deleted
100644 → 0
View file @
f52c2a4d
This diff is collapsed.
Click to expand it.
codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
View file @
52426f84
...
@@ -62,13 +62,6 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
...
@@ -62,13 +62,6 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
// accounts for all possible combinations of Row/Col major
// accounts for all possible combinations of Row/Col major
static
Layout
ToLayout
(
bool
Trans
)
{
return
Trans
?
Layout
::
Column
:
Layout
::
Row
;
}
static
Layout
ToLayout
(
bool
Trans
)
{
return
Trans
?
Layout
::
Column
:
Layout
::
Row
;
}
// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1,
// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances
// instances
std
::
vector
<
Operation_Xdl_CShuffle
>
Operation_Xdl_CShuffle
::
CreateOperations
(
std
::
vector
<
Operation_Xdl_CShuffle
>
Operation_Xdl_CShuffle
::
CreateOperations
(
...
@@ -90,8 +83,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -90,8 +83,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
1
},
{
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
1
},
{
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
1
},
{
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
1
},
{
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
1
},
{
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
1
},
// Irregular tile
{
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -109,8 +100,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -109,8 +100,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -120,17 +109,15 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -120,17 +109,15 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
// | | | | | | |
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
// clang-format on
// clang-format on
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
};
};
std
::
vector
<
operation
::
BlockTransferDesc
>
b_block_descriptions_rowmajor
=
{
std
::
vector
<
operation
::
BlockTransferDesc
>
b_block_descriptions_rowmajor
=
{
...
@@ -147,8 +134,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -147,8 +134,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -166,8 +151,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -166,8 +151,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -184,7 +167,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -184,7 +167,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -203,8 +185,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -203,8 +185,6 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
1
,
16
,
1
,
8
>
,
8
},
{
S
<
1
,
16
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
// Irregular tile
{
S
<
1
,
16
,
1
,
4
>
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -219,44 +199,33 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -219,44 +199,33 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
assert
(
tile_descriptions
.
size
()
==
cshuffle_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
cshuffle_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
c_block_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
c_block_descriptions
.
size
());
const
std
::
vector
<
std
::
tuple
<
LoopScheduler
,
PipelineVersion
>>
scheduler_pipeline_descriptions
=
// Put all values together into a single operation > store into the result vector
{
for
(
std
::
size_t
i
=
0
;
i
<
tile_descriptions
.
size
();
i
++
)
{
LoopScheduler
::
Default
,
PipelineVersion
::
v1
},
{
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
},
{
LoopScheduler
::
Default
,
PipelineVersion
::
v2
},
};
for
(
auto
[
loop_scheduler
,
pipeline_version
]
:
scheduler_pipeline_descriptions
)
{
{
// Put all values together into a single operation > store into the result vector
Operation_Xdl_CShuffle
x
;
for
(
std
::
size_t
i
=
0
;
i
<
tile_descriptions
.
size
();
i
++
)
x
.
tile_desc
=
tile_descriptions
[
i
];
{
x
.
a_block_transfer
=
a_block_descriptions
[
i
];
Operation_Xdl_CShuffle
x
;
x
.
b_block_transfer
=
b_block_descriptions
[
i
];
x
.
tile_desc
=
tile_descriptions
[
i
];
x
.
cshuffle
=
cshuffle_descriptions
[
i
];
x
.
a_block_transfer
=
a_block_descriptions
[
i
];
x
.
c_block_transfer
=
c_block_descriptions
[
i
];
x
.
b_block_transfer
=
b_block_descriptions
[
i
];
x
.
A
=
TensorDesc
{
prob
.
ADataType
,
ToLayout
(
prob
.
TransA
)};
x
.
cshuffle
=
cshuffle_descriptions
[
i
];
x
.
B
=
TensorDesc
{
prob
.
BDataType
,
ToLayout
(
prob
.
TransB
)};
x
.
c_block_transfer
=
c_block_descriptions
[
i
];
x
.
E
=
TensorDesc
{
prob
.
EDataType
,
ToLayout
(
prob
.
TransE
)};
x
.
A
=
TensorDesc
{
prob
.
ADataType
,
ToLayout
(
prob
.
TransA
)};
x
.
Ds
=
Transform
(
prob
.
DsTrans
,
prob
.
DsDataType
,
[](
auto
trans
,
auto
dt
)
{
x
.
B
=
TensorDesc
{
prob
.
BDataType
,
ToLayout
(
prob
.
TransB
)};
return
TensorDesc
{
dt
,
ToLayout
(
trans
)};
x
.
E
=
TensorDesc
{
prob
.
EDataType
,
ToLayout
(
prob
.
TransE
)};
});
x
.
Ds
=
Transform
(
prob
.
DsTrans
,
prob
.
DsDataType
,
[](
auto
trans
,
auto
dt
)
{
x
.
a_elem_op
=
prob
.
AElementOp
;
return
TensorDesc
{
dt
,
ToLayout
(
trans
)};
x
.
b_elem_op
=
prob
.
BElementOp
;
});
x
.
cde_elem_op
=
prob
.
CDEElementOp
;
x
.
a_elem_op
=
prob
.
AElementOp
;
x
.
gemm_specialization
=
GetGemmSpec
(
prob
.
M
,
x
.
b_elem_op
=
prob
.
BElementOp
;
prob
.
N
,
x
.
cde_elem_op
=
prob
.
CDEElementOp
;
prob
.
K
,
x
.
gemm_specialization
=
GetGemmSpec
(
prob
.
M
,
x
.
tile_desc
.
m_per_block
,
prob
.
N
,
x
.
tile_desc
.
n_per_block
,
prob
.
K
,
x
.
tile_desc
.
k_per_block
);
x
.
tile_desc
.
m_per_block
,
x
.
update_prologue
(
prologue
);
x
.
tile_desc
.
n_per_block
,
x
.
update_epilogue
(
epilogue
);
x
.
tile_desc
.
k_per_block
);
result
.
push_back
(
x
);
x
.
loop_scheduler
=
loop_scheduler
;
x
.
pipeline_version
=
pipeline_version
;
x
.
update_prologue
(
prologue
);
x
.
update_epilogue
(
epilogue
);
result
.
push_back
(
x
);
}
}
}
return
result
;
return
result
;
}
}
...
@@ -294,7 +263,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
...
@@ -294,7 +263,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDEBlockTransferScalarPerVector_NPerBlock}
, ${LoopScheduler}, ${PipelineVersion}
>"
;
"${CDEBlockTransferScalarPerVector_NPerBlock}>"
;
// use hardcoded instances from vector of operations to substitute values into instance template
// use hardcoded instances from vector of operations to substitute values into instance template
Solution
Operation_Xdl_CShuffle
::
ToSolution
()
const
Solution
Operation_Xdl_CShuffle
::
ToSolution
()
const
...
@@ -367,8 +336,6 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
...
@@ -367,8 +336,6 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
this
->
c_block_transfer
.
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
},
this
->
c_block_transfer
.
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
},
{
"CDEBlockTransferScalarPerVector_NPerBlock"
,
{
"CDEBlockTransferScalarPerVector_NPerBlock"
,
std
::
to_string
(
this
->
c_block_transfer
.
scalar_per_vector_n_wave_n_per_Xdl
)},
std
::
to_string
(
this
->
c_block_transfer
.
scalar_per_vector_n_wave_n_per_Xdl
)},
{
"LoopScheduler"
,
ToString
(
this
->
loop_scheduler
)},
{
"PipelineVersion"
,
ToString
(
this
->
pipeline_version
)},
};
};
return
Solution
{
InterpolateString
(
DeviceGemmMultipleD_Xdl_CShuffleTemplate
,
values
),
return
Solution
{
InterpolateString
(
DeviceGemmMultipleD_Xdl_CShuffleTemplate
,
values
),
...
...
codegen/src/types.cpp
View file @
52426f84
...
@@ -56,26 +56,6 @@ std::string ToString(GemmType gt)
...
@@ -56,26 +56,6 @@ std::string ToString(GemmType gt)
throw
std
::
runtime_error
(
"Incorrect gemm type"
);
throw
std
::
runtime_error
(
"Incorrect gemm type"
);
}
}
std
::
string
ToString
(
LoopScheduler
ls
)
{
switch
(
ls
)
{
case
LoopScheduler
::
Default
:
return
"ck::LoopScheduler::Default"
;
case
LoopScheduler
::
Interwave
:
return
"ck::LoopScheduler::Interwave"
;
}
throw
std
::
runtime_error
(
"Incorrect LoopScheduler type"
);
}
std
::
string
ToString
(
PipelineVersion
pv
)
{
switch
(
pv
)
{
case
PipelineVersion
::
v1
:
return
"ck::PipelineVersion::v1"
;
case
PipelineVersion
::
v2
:
return
"ck::PipelineVersion::v2"
;
}
throw
std
::
runtime_error
(
"Incorrect PipelineVersion type"
);
}
std
::
string
SequenceStr
(
const
std
::
vector
<
int
>&
v
)
std
::
string
SequenceStr
(
const
std
::
vector
<
int
>&
v
)
{
{
return
"ck::Sequence<"
+
return
"ck::Sequence<"
+
...
...
codegen/test/gemm_multiple_d.cpp
View file @
52426f84
#include "common.hpp"
#include "common.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include "ck/host/utils.hpp"
...
@@ -87,34 +85,4 @@ TEST_CASE(test_problem_kernel)
...
@@ -87,34 +85,4 @@ TEST_CASE(test_problem_kernel)
}
}
}
}
TEST_CASE
(
test_gemm_softmax_gemm
)
{
ck
::
host
::
device_batched_gemm_softmax_gemm
::
Problem
prob
;
prob
.
TransA
=
false
;
prob
.
TransB
=
true
;
prob
.
TransB1
=
false
;
prob
.
TransC
=
false
;
prob
.
M
=
1024
;
prob
.
N
=
1024
;
prob
.
K
=
1024
;
prob
.
O
=
1024
;
check_all
<
half
>
check
;
auto
a
=
to_gpu
(
generate_buffer
<
half
>
(
1024
*
1024
,
0
));
auto
b
=
to_gpu
(
generate_buffer
<
half
>
(
1024
*
1024
,
1
));
auto
b1
=
to_gpu
(
generate_buffer
<
half
>
(
1024
*
1024
,
2
));
auto
c
=
to_gpu
(
generate_buffer
<
half
>
(
1024
*
1024
,
3
));
std
::
string
epilogue
=
""
;
std
::
string
prologue
=
""
;
auto
solutions
=
prob
.
GetSolutions
(
"gfx90a"
,
prologue
,
epilogue
);
std
::
cout
<<
"Num solutions: "
<<
solutions
.
size
()
<<
std
::
endl
;
for
(
auto
i
=
0
;
i
<
solutions
.
size
();
++
i
)
{
std
::
cout
<<
"Solution "
<<
i
<<
std
::
endl
;
std
::
cout
<<
solutions
[
i
].
ToTemplateString
()
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment