Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d4ad52d6
Commit
d4ad52d6
authored
Aug 23, 2023
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
39002e9e
c8a8385f
Changes
151
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
522 additions
and
118 deletions
+522
-118
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
CMakeLists.txt
CMakeLists.txt
+70
-43
Dockerfile
Dockerfile
+1
-1
Jenkinsfile
Jenkinsfile
+1
-1
client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp
client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp
+6
-0
client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp
client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp
+6
-0
client_example/20_splitk_gemm/CMakeLists.txt
client_example/20_splitk_gemm/CMakeLists.txt
+2
-0
client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp
client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp
+225
-0
client_example/CMakeLists.txt
client_example/CMakeLists.txt
+37
-16
example/35_splitK_gemm/CMakeLists.txt
example/35_splitK_gemm/CMakeLists.txt
+3
-3
example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp
example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp
+6
-5
example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp
example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp
+6
-5
include/ck/ck.hpp
include/ck/ck.hpp
+2
-3
include/ck/config.h.in
include/ck/config.h.in
+102
-0
include/ck/tensor_description/multi_index_transform.hpp
include/ck/tensor_description/multi_index_transform.hpp
+9
-9
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+12
-8
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
...ion/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
+10
-7
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
...ation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
+12
-8
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
..._batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
+9
-6
No files found.
.pre-commit-config.yaml
View file @
d4ad52d6
...
...
@@ -3,7 +3,7 @@ repos:
hooks
:
-
id
:
clang-format
name
:
clang-format
entry
:
clang-format-1
0
-i --style=file
entry
:
clang-format-1
2
-i --style=file
language
:
system
types_or
:
[
c++
,
inc
]
-
id
:
copyright-year-checker
...
...
CMakeLists.txt
View file @
d4ad52d6
cmake_minimum_required
(
VERSION 3.14
)
set
(
version 1.1.0
)
# Check support for CUDA/HIP in Cmake
project
(
composable_kernel
)
project
(
composable_kernel
VERSION
${
version
}
)
list
(
APPEND CMAKE_MODULE_PATH
"
${
PROJECT_SOURCE_DIR
}
/cmake"
)
if
(
DTYPES
)
add_definitions
(
-DDTYPES
)
if
(
DTYPES MATCHES
"int8"
)
add_definitions
(
-D__int8__
)
endif
()
if
(
DTYPES MATCHES
"fp8"
)
add_definitions
(
-D__fp8__
)
endif
()
if
(
DTYPES MATCHES
"fp16"
)
add_definitions
(
-D__fp16__
)
endif
()
if
(
DTYPES MATCHES
"fp32"
)
add_definitions
(
-D__fp32__
)
endif
()
if
(
DTYPES MATCHES
"fp64"
)
add_definitions
(
-D__fp64__
)
endif
()
if
(
DTYPES MATCHES
"bf16"
)
add_definitions
(
-D__bf16__
)
endif
()
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
add_definitions
(
-DDTYPES
)
if
(
DTYPES MATCHES
"int8"
)
add_definitions
(
-DCK_ENABLE_INT8
)
set
(
CK_ENABLE_INT8
"ON"
)
endif
()
if
(
DTYPES MATCHES
"fp8"
)
add_definitions
(
-DCK_ENABLE_FP8
)
set
(
CK_ENABLE_FP8
"ON"
)
endif
()
if
(
DTYPES MATCHES
"fp16"
)
add_definitions
(
-DCK_ENABLE_FP16
)
set
(
CK_ENABLE_FP16
"ON"
)
endif
()
if
(
DTYPES MATCHES
"fp32"
)
add_definitions
(
-DCK_ENABLE_FP32
)
set
(
CK_ENABLE_FP32
"ON"
)
endif
()
if
(
DTYPES MATCHES
"fp64"
)
add_definitions
(
-DCK_ENABLE_FP64
)
set
(
CK_ENABLE_FP64
"ON"
)
endif
()
if
(
DTYPES MATCHES
"bf16"
)
add_definitions
(
-DCK_ENABLE_BF16
)
set
(
CK_ENABLE_BF16
"ON"
)
endif
()
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
else
()
add_definitions
(
-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__
)
add_definitions
(
-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16
)
set
(
CK_ENABLE_ALL_DTYPES
"ON"
)
endif
()
if
(
DL_KERNELS
)
add_definitions
(
-DDL_KERNELS
)
set
(
CK_ENABLE_DL_KERNELS
"ON"
)
endif
()
if
(
INSTANCES_ONLY
)
add_definitions
(
-DINSTANCES_ONLY
)
set
(
CK_ENABLE_INSTANCES_ONLY
"ON"
)
endif
()
# CK config file to record supported datatypes, etc.
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/include/ck/config.h.in"
"
${
PROJECT_BINARY_DIR
}
/include/ck/config.h"
)
# CK version file to record release version as well as git commit hash
find_package
(
Git REQUIRED
)
execute_process
(
COMMAND
"
${
GIT_EXECUTABLE
}
"
rev-parse HEAD OUTPUT_VARIABLE COMMIT_ID OUTPUT_STRIP_TRAILING_WHITESPACE
)
configure_file
(
"
${
PROJECT_SOURCE_DIR
}
/include/ck/version.h.in"
"
${
PROJECT_BINARY_DIR
}
/include/ck/version.h"
)
enable_testing
()
set
(
ROCM_SYMLINK_LIBS OFF
)
...
...
@@ -50,8 +68,10 @@ include(ROCMInstallSymlinks)
include
(
ROCMCreatePackage
)
include
(
CheckCXXCompilerFlag
)
include
(
ROCMCheckTargetIds
)
rocm_setup_version
(
VERSION 0.2.1
)
include
(
TargetFlags
)
rocm_setup_version
(
VERSION
${
version
}
)
list
(
APPEND CMAKE_PREFIX_PATH
${
CMAKE_INSTALL_PREFIX
}
${
CMAKE_INSTALL_PREFIX
}
/llvm
${
CMAKE_INSTALL_PREFIX
}
/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip
)
message
(
"GPU_TARGETS=
${
GPU_TARGETS
}
"
)
...
...
@@ -315,13 +335,14 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set
(
CMAKE_ARCHIVE_OUTPUT_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/lib
)
set
(
CMAKE_RUNTIME_OUTPUT_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
/bin
)
# set CK project include directories
include_directories
(
BEFORE
${
PROJECT_BINARY_DIR
}
/include
${
PROJECT_SOURCE_DIR
}
/include
${
PROJECT_SOURCE_DIR
}
/library/include
${
HIP_INCLUDE_DIRS
}
)
SET
(
BUILD_DEV ON CACHE BOOL
"BUILD_DEV"
)
if
(
BUILD_DEV
)
add_compile_options
(
-Werror
)
...
...
@@ -341,35 +362,35 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
file
(
READ
"
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu/
${
subdir_path
}
/CMakeLists.txt"
cmake_instance
)
set
(
add_inst 0
)
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp8
\"
"
AND DTYPES MATCHES
"fp8"
)
#message("fp8 instance found!")
set
(
add_inst 1
)
#message("fp8 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp16
\"
"
AND DTYPES MATCHES
"fp16"
)
#message("fp16 instance found!")
set
(
add_inst 1
)
#message("fp16 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp32
\"
"
AND DTYPES MATCHES
"fp32"
)
#message("fp32 instance found!")
set
(
add_inst 1
)
#message("fp32 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp64
\"
"
AND DTYPES MATCHES
"fp64"
)
#message("fp64 instance found!")
set
(
add_inst 1
)
#message("fp64 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
bf16
\"
"
AND DTYPES MATCHES
"bf16"
)
#message("bf16 instance found!")
set
(
add_inst 1
)
#message("bf16 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
int8
\"
"
AND DTYPES MATCHES
"int8"
)
#message("int8 instance found!")
set
(
add_inst 1
)
#message("int8 instance found!")
set
(
add_inst 1
)
endif
()
if
(
NOT
"
${
cmake_instance
}
"
MATCHES
"DTYPES"
)
#message("instance should be built for all types!")
set
(
add_inst 1
)
#message("instance should be built for all types!")
set
(
add_inst 1
)
endif
()
if
(
add_inst EQUAL 1 OR NOT DEFINED DTYPES
)
list
(
APPEND CK_DEVICE_INSTANCES device_
${
subdir_path
}
_instance
)
list
(
APPEND CK_DEVICE_INSTANCES device_
${
subdir_path
}
_instance
)
endif
()
ENDIF
()
ENDFOREACH
()
...
...
@@ -409,7 +430,6 @@ endif()
#Create an interface target for the include only files and call it "composablekernels"
include
(
CMakePackageConfigHelpers
)
set
(
version 1.0.0
)
write_basic_package_version_file
(
"
${
CMAKE_CURRENT_BINARY_DIR
}
/composable_kernelConfigVersion.cmake"
VERSION
"
${
version
}
"
...
...
@@ -417,9 +437,9 @@ write_basic_package_version_file(
)
configure_package_config_file
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/Config.cmake.in
"
${
CMAKE_CURRENT_BINARY_DIR
}
/composable_kernelConfig.cmake"
INSTALL_DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
NO_CHECK_REQUIRED_COMPONENTS_MACRO
"
${
CMAKE_CURRENT_BINARY_DIR
}
/composable_kernelConfig.cmake"
INSTALL_DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
NO_CHECK_REQUIRED_COMPONENTS_MACRO
)
rocm_install
(
FILES
...
...
@@ -428,6 +448,13 @@ rocm_install(FILES
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
)
# Install CK version and configuration files
install
(
FILES
${
PROJECT_BINARY_DIR
}
/include/ck/version.h
${
PROJECT_BINARY_DIR
}
/include/ck/config.h
DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/
)
set
(
CPACK_RESOURCE_FILE_LICENSE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/LICENSE"
)
set
(
CPACK_RPM_PACKAGE_LICENSE
"MIT"
)
...
...
Dockerfile
View file @
d4ad52d6
...
...
@@ -63,7 +63,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
nano
\
zlib1g-dev
\
openssh-server
\
clang-format-1
0
\
clang-format-1
2
\
kmod
&&
\
apt-get clean
&&
\
rm
-rf
/var/lib/apt/lists/
*
...
...
Jenkinsfile
View file @
d4ad52d6
...
...
@@ -689,7 +689,7 @@ pipeline {
-o -iname \'*.cpp.in\' \
-o -iname \'*.cl\' \
| grep -v 'build/' \
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-1
0
-style=file {} | diff - {}\'"
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-1
2
-style=file {} | diff - {}\'"
}
steps
{
buildHipClangJobAndReboot
(
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
...
...
client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp
View file @
d4ad52d6
...
...
@@ -191,6 +191,12 @@ int main(int argc, char* argv[])
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
size_t
workspace_sz
=
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
());
SimpleDeviceMem
workspace
(
workspace_sz
);
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
workspace
.
GetDeviceBuffer
());
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
...
...
client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp
View file @
d4ad52d6
...
...
@@ -187,6 +187,12 @@ int main(int argc, char* argv[])
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
size_t
workspace_sz
=
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
());
SimpleDeviceMem
workspace
(
workspace_sz
);
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
workspace
.
GetDeviceBuffer
());
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
...
...
client_example/20_splitk_gemm/CMakeLists.txt
0 → 100644
View file @
d4ad52d6
add_executable
(
client_splitK_gemm splitK_gemm_fp16_f8.cpp
)
target_link_libraries
(
client_splitK_gemm PRIVATE composable_kernel::device_operations
)
client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp
0 → 100644
View file @
d4ad52d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
ADataType
=
F8
;
using
BDataType
=
F16
;
using
CDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
// GEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideC
=
4096
;
ck
::
index_t
KBatch
=
1
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
8
)
{
M
=
std
::
stoi
(
argv
[
1
]);
N
=
std
::
stoi
(
argv
[
2
]);
K
=
std
::
stoi
(
argv
[
3
]);
StrideA
=
std
::
stoi
(
argv
[
4
]);
StrideB
=
std
::
stoi
(
argv
[
5
]);
StrideC
=
std
::
stoi
(
argv
[
6
]);
KBatch
=
std
::
stoi
(
argv
[
7
]);
}
else
{
printf
(
"arg1 to 7: M, N, K, StrideA, StrideB, StrideC, KBatch
\n
"
);
exit
(
0
);
}
auto
f_matrix_space_size
=
[](
std
::
size_t
nRow
,
std
::
size_t
nCol
,
std
::
size_t
stride
,
auto
layout
)
{
using
Layout
=
decltype
(
layout
);
if
constexpr
(
std
::
is_same
<
Layout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
(
nRow
-
1
)
*
stride
+
nCol
;
}
else
{
return
(
nCol
-
1
)
*
stride
+
nRow
;
}
};
SimpleDeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
f_matrix_space_size
(
M
,
K
,
StrideA
,
ALayout
{}));
SimpleDeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
f_matrix_space_size
(
K
,
N
,
StrideB
,
BLayout
{}));
SimpleDeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
f_matrix_space_size
(
M
,
N
,
StrideC
,
CLayout
{}));
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemmSplitK
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
std
::
string
best_op_name
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
c_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
KBatch
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
found
=
true
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
{
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
c_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
KBatch
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
0
;
}
client_example/CMakeLists.txt
View file @
d4ad52d6
...
...
@@ -3,31 +3,52 @@ project(ck_app)
add_compile_options
(
-std=c++17
)
if
(
DTYPES
)
add_definitions
(
-DDTYPES
)
if
(
DTYPES MATCHES
"int8"
)
add_definitions
(
-D__int8__
)
add_definitions
(
-DDTYPES
)
if
(
DTYPES MATCHES
"int8"
)
add_definitions
(
-DCK_ENABLE_INT8
)
if
(
NOT DEFINED
${
CK_ENABLE_INT8
}
)
set
(
CK_ENABLE_INT8
"ON"
)
endif
()
if
(
DTYPES MATCHES
"fp8"
)
add_definitions
(
-D__fp8__
)
endif
()
if
(
DTYPES MATCHES
"fp8"
)
add_definitions
(
-DCK_ENABLE_FP8
)
if
(
NOT DEFINED
${
CK_ENABLE_FP8
}
)
set
(
CK_ENABLE_FP8
"ON"
)
endif
()
if
(
DTYPES MATCHES
"fp16"
)
add_definitions
(
-D__fp16__
)
endif
()
if
(
DTYPES MATCHES
"fp16"
)
add_definitions
(
-DCK_ENABLE_FP16
)
if
(
NOT DEFINED
${
CK_ENABLE_FP16
}
)
set
(
CK_ENABLE_FP16
"ON"
)
endif
()
if
(
DTYPES MATCHES
"fp32"
)
add_definitions
(
-D__fp32__
)
endif
()
if
(
DTYPES MATCHES
"fp32"
)
add_definitions
(
-DCK_ENABLE_FP32
)
if
(
NOT DEFINED
${
CK_ENABLE_FP32
}
)
set
(
CK_ENABLE_FP32
"ON"
)
endif
()
if
(
DTYPES MATCHES
"fp64"
)
add_definitions
(
-D__fp64__
)
endif
()
if
(
DTYPES MATCHES
"fp64"
)
add_definitions
(
-DCK_ENABLE_FP64
)
if
(
NOT DEFINED
${
CK_ENABLE_FP64
}
)
set
(
CK_ENABLE_FP64
"ON"
)
endif
()
if
(
DTYPES MATCHES
"bf16"
)
add_definitions
(
-D__bf16__
)
endif
()
if
(
DTYPES MATCHES
"bf16"
)
add_definitions
(
-DCK_ENABLE_BF16
)
if
(
NOT DEFINED
${
CK_ENABLE_BF16
}
)
set
(
CK_ENABLE_BF16
"ON"
)
endif
()
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
endif
()
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
else
()
add_definitions
(
-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__
)
add_definitions
(
-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16
)
if
(
NOT DEFINED
${
CK_ENABLE_ALL_DTYPES
}
)
set
(
CK_ENABLE_ALL_DTYPES
"ON"
)
endif
()
endif
()
find_package
(
composable_kernel
1.0.0
COMPONENTS device_operations
)
find_package
(
composable_kernel COMPONENTS device_operations
)
find_package
(
hip REQUIRED PATHS /opt/rocm
)
message
(
STATUS
"Build with HIP
${
hip_VERSION
}
"
)
...
...
example/35_splitK_gemm/CMakeLists.txt
View file @
d4ad52d6
...
...
@@ -3,15 +3,15 @@ set(target 0)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_splitK_gemm_xdl
)
if
(
DTYPES MATCHES
"
int8
"
OR NOT DEFINED DTYPES
)
if
(
DTYPES MATCHES
"
fp32
"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp
)
add_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32
)
endif
()
if
(
DTYPES MATCHES
"
int8
"
OR NOT DEFINED DTYPES
)
if
(
DTYPES MATCHES
"
fp16
"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp
)
add_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16
)
endif
()
if
(
DTYPES MATCHES
"
int8
"
OR NOT DEFINED DTYPES
)
if
(
DTYPES MATCHES
"
bf16
"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp
)
add_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_bfp16
)
endif
()
...
...
example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp
View file @
d4ad52d6
...
...
@@ -33,6 +33,7 @@ using ADataType = BF16;
using
BDataType
=
BF16
;
using
AccDataType
=
F32
;
using
CDataType
=
F32
;
using
ComputeType
=
BF16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
...
...
@@ -46,11 +47,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
// clang-format off
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
;
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
Compute|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
Type|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
ComputeType
>
;
// clang-format on
#include "run_splitK_gemm_example.inc"
...
...
example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp
View file @
d4ad52d6
...
...
@@ -30,6 +30,7 @@ using ADataType = int8_t;
using
BDataType
=
int8_t
;
using
AccDataType
=
int32_t
;
using
CDataType
=
int32_t
;
using
ComputeType
=
int8_t
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
...
...
@@ -43,11 +44,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
// clang-format off
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
256
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
16
,
16
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
16
,
16
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
;
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
Compute|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
Type|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
256
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
16
,
16
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
16
,
16
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
ComputeType
>
;
// clang-format on
#include "run_splitK_gemm_example.inc"
...
...
include/ck/ck.hpp
View file @
d4ad52d6
...
...
@@ -3,6 +3,8 @@
#pragma once
#include "ck/config.h"
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
...
...
@@ -200,9 +202,6 @@
// workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
#define CK_WORKAROUND_GITHUB_ISSUE_824 1
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
...
...
include/ck/config.h.in
0 → 100644
View file @
d4ad52d6
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2023 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_CONFIG_H_IN
#define CK_CONFIG_H_IN
// clang-format off
//
// DataType supports in the current CK build
//
#ifndef DTYPES
#cmakedefine DTYPES "@DTYPES@"
#endif
// if DTYPES is not defined, enable all datatypes in headerfiles
#ifndef CK_ENABLE_ALL_DTYPES
#cmakedefine CK_ENABLE_ALL_DTYPES @CK_ENABLE_ALL_DTYPES@
#if defined(CK_ENABLE_ALL_DTYPES)
#ifndef CK_ENABLE_INT8
#define CK_ENABLE_INT8 "ON"
#endif
#ifndef CK_ENABLE_FP8
#define CK_ENABLE_FP8 "ON"
#endif
#ifndef CK_ENABLE_FP16
#define CK_ENABLE_FP16 "ON"
#endif
#ifndef CK_ENABLE_BF16
#define CK_ENABLE_BF16 "ON"
#endif
#ifndef CK_ENABLE_FP32
#define CK_ENABLE_FP32 "ON"
#endif
#ifndef CK_ENABLE_FP64
#define CK_ENABLE_FP64 "ON"
#endif
#endif
#endif
// if DTYPES are selectively enabled
#ifndef CK_ENABLE_INT8
#cmakedefine CK_ENABLE_INT8 @CK_ENABLE_INT8@
#endif
#ifndef CK_ENABLE_FP8
#cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@
#endif
#ifndef CK_ENABLE_FP16
#cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@
#endif
#ifndef CK_ENABLE_BF16
#cmakedefine CK_ENABLE_BF16 @CK_ENABLE_BF16@
#endif
#ifndef CK_ENABLE_FP32
#cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@
#endif
#ifndef CK_ENABLE_FP64
#cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@
#endif
//
// Legacy DL kernel supports in the current CK build
// by default DL kernels are turned OFF
//
#ifndef CK_ENABLE_DL_KERNELS
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
// clang-format on
#endif // CK_CONFIG_H_IN
include/ck/tensor_description/multi_index_transform.hpp
View file @
d4ad52d6
...
...
@@ -1042,13 +1042,13 @@ struct Merge_v2_magic_division
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})));
using
LowLengthsMagicDivisorMultipiler
=
decltype
(
generate_tuple
(
lambda_merge_generate_MagicDivision_calculate_magic_multiplier
<
LowLengths
>
{},
Number
<
NDimLow
>
{}));
using
LowLengthsMagicDivisorMultipiler
=
decltype
(
generate_tuple
(
lambda_merge_generate_MagicDivision_calculate_magic_multiplier
<
LowLengths
>
{},
Number
<
NDimLow
>
{}));
using
LowLengthsMagicDivisorShift
=
decltype
(
generate_tuple
(
lambda_merge_generate_MagicDivision_calculate_magic_shift
<
LowLengths
>
{},
Number
<
NDimLow
>
{}));
using
LowLengthsMagicDivisorShift
=
decltype
(
generate_tuple
(
lambda_merge_generate_MagicDivision_calculate_magic_shift
<
LowLengths
>
{},
Number
<
NDimLow
>
{}));
LowLengths
low_lengths_
;
LowLengthsMagicDivisorMultipiler
low_lengths_magic_divisor_multiplier_
;
...
...
@@ -1201,9 +1201,9 @@ struct Merge_v2r2_magic_division
lambda_merge_generate_MagicDivision_calculate_magic_multiplier
<
LowLengthsScan
>
{},
Number
<
NDimLow
>
{}));
using
LowLengthsScanMagicDivisorShift
=
decltype
(
generate_tuple
(
lambda_merge_generate_MagicDivision_calculate_magic_shift
<
LowLengthsScan
>
{},
Number
<
NDimLow
>
{}));
using
LowLengthsScanMagicDivisorShift
=
decltype
(
generate_tuple
(
lambda_merge_generate_MagicDivision_calculate_magic_shift
<
LowLengthsScan
>
{},
Number
<
NDimLow
>
{}));
LowLengths
low_lengths_
;
LowLengthsScan
low_lengths_scan_
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
View file @
d4ad52d6
...
...
@@ -35,8 +35,8 @@ struct BlockwiseSoftmax
static
constexpr
index_t
MRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
);
static
constexpr
index_t
KRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I1
);
using
ThreadSliceDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
))));
using
ThreadSliceDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
))));
using
ThreadwiseMaxReduce
=
typename
conditional
<
IgnoreNaN
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
d4ad52d6
...
...
@@ -588,14 +588,18 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
LoopSched
>
;
// desc for blockwise copy
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// block-to-e-tile map
using
Block2ETileMap
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
View file @
d4ad52d6
...
...
@@ -378,13 +378,16 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}));
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}));
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
// Argument
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
View file @
d4ad52d6
...
...
@@ -368,14 +368,18 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
LoopSched
>
;
// desc for blockwise copy
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// block-to-e-tile map
using
Block2ETileMap
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
View file @
d4ad52d6
...
...
@@ -510,12 +510,15 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
using
A0GridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultA0GridDescriptor_AK0_M_AK1
(
A0GridDesc_M_K
{}))
>
;
using
B0GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultB0GridDescriptor_BK0_N_BK1
(
B0GridDesc_N_K
{}))
>
;
using
B1GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultB1GridDescriptor_BK0_N_BK1
(
B1GridDesc_N_K
{}))
>
;
using
A0GridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultA0GridDescriptor_AK0_M_AK1
(
A0GridDesc_M_K
{}))
>
;
using
B0GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultB0GridDescriptor_BK0_N_BK1
(
B0GridDesc_N_K
{}))
>
;
using
B1GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultB1GridDescriptor_BK0_N_BK1
(
B1GridDesc_N_K
{}))
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
Prev
1
2
3
4
5
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment