Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
49c41176
Unverified
Commit
49c41176
authored
Oct 16, 2024
by
Harisankar Sadasivan
Committed by
GitHub
Oct 16, 2024
Browse files
Merge branch 'develop' into lwpck-2374_2
parents
ae0c3649
14c3cfb1
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
212 additions
and
194 deletions
+212
-194
Jenkinsfile
Jenkinsfile
+38
-11
codegen/CMakeLists.txt
codegen/CMakeLists.txt
+29
-31
codegen/test/CMakeLists.txt
codegen/test/CMakeLists.txt
+20
-18
codegen/test/include/common.hpp
codegen/test/include/common.hpp
+0
-0
codegen/test/rtc/CMakeLists.txt
codegen/test/rtc/CMakeLists.txt
+2
-0
codegen/test/rtc/include/rtc/compile_kernel.hpp
codegen/test/rtc/include/rtc/compile_kernel.hpp
+2
-2
codegen/test/rtc/include/rtc/filesystem.hpp
codegen/test/rtc/include/rtc/filesystem.hpp
+60
-0
codegen/test/rtc/include/rtc/tmp_dir.hpp
codegen/test/rtc/include/rtc/tmp_dir.hpp
+2
-2
codegen/test/rtc/src/compile_kernel.cpp
codegen/test/rtc/src/compile_kernel.cpp
+5
-5
codegen/test/rtc/src/tmp_dir.cpp
codegen/test/rtc/src/tmp_dir.cpp
+3
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+12
-15
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
...eline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+12
-15
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+27
-92
No files found.
Jenkinsfile
View file @
49c41176
...
@@ -735,11 +735,11 @@ def process_results(Map conf=[:]){
...
@@ -735,11 +735,11 @@ def process_results(Map conf=[:]){
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS
=
BRANCH_NAME
==
"develop"
?
'''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.2;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true
CRON_SETTINGS
=
BRANCH_NAME
==
"develop"
?
'''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.2;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true
0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true
0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true
;RUN_CODEGEN_TESTS=true
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_
CODEGEN_TESTS=false;RUN_
PERFORMANCE_TESTS=false;USE_SCCACHE=false
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false
0 13 * * * % BUILD_LEGACY_OS=true
'''
:
""
0 13 * * * % BUILD_LEGACY_OS=true'''
:
""
pipeline
{
pipeline
{
agent
none
agent
none
...
@@ -806,6 +806,10 @@ pipeline {
...
@@ -806,6 +806,10 @@ pipeline {
name:
"RUN_GROUPED_CONV_LARGE_CASES_TESTS"
,
name:
"RUN_GROUPED_CONV_LARGE_CASES_TESTS"
,
defaultValue:
false
,
defaultValue:
false
,
description:
"Run the grouped conv large cases tests (default: OFF)"
)
description:
"Run the grouped conv large cases tests (default: OFF)"
)
booleanParam
(
name:
"RUN_CODEGEN_TESTS"
,
defaultValue:
false
,
description:
"Run codegen tests (default: OFF)"
)
booleanParam
(
booleanParam
(
name:
"RUN_CK_TILE_FMHA_TESTS"
,
name:
"RUN_CK_TILE_FMHA_TESTS"
,
defaultValue:
false
,
defaultValue:
false
,
...
@@ -934,6 +938,29 @@ pipeline {
...
@@ -934,6 +938,29 @@ pipeline {
}
}
}
}
}
}
stage
(
"Run Codegen Tests"
)
{
parallel
{
stage
(
"Run Codegen Tests on gfx90a"
)
{
when
{
beforeAgent
true
expression
{
params
.
RUN_CODEGEN_TESTS
.
toBoolean
()
}
}
agent
{
label
rocmnode
(
"gfx90a"
)}
environment
{
setup_args
=
"NO_CK_BUILD"
execute_args
=
""" CXX=/opt/rocm/llvm/bin/clang++ cmake ../codegen && \
make -j64 check"""
}
steps
{
buildHipClangJobAndReboot
(
setup_args:
setup_args
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
)
cleanWs
()
}
}
}
}
stage
(
"Run CK_TILE_FMHA Tests"
)
stage
(
"Run CK_TILE_FMHA Tests"
)
{
{
parallel
parallel
...
...
codegen/CMakeLists.txt
View file @
49c41176
cmake_minimum_required
(
VERSION 3.16
)
project
(
composable_kernel_host
)
set
(
CMAKE_EXPORT_COMPILE_COMMANDS ON
)
set
(
CMAKE_EXPORT_COMPILE_COMMANDS ON
)
set
(
CMAKE_LIBRARY_OUTPUT_DIRECTORY
${
CMAKE_BINARY_DIR
}
/lib
)
set
(
CMAKE_LIBRARY_OUTPUT_DIRECTORY
${
CMAKE_BINARY_DIR
}
/lib
)
...
@@ -5,30 +8,24 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
...
@@ -5,30 +8,24 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set
(
CMAKE_RUNTIME_OUTPUT_DIRECTORY
${
CMAKE_BINARY_DIR
}
/bin
)
set
(
CMAKE_RUNTIME_OUTPUT_DIRECTORY
${
CMAKE_BINARY_DIR
}
/bin
)
set
(
CK_ROOT
${
CMAKE_CURRENT_SOURCE_DIR
}
/..
)
set
(
CK_ROOT
${
CMAKE_CURRENT_SOURCE_DIR
}
/..
)
add_compile_options
(
-std=c++17
)
find_package
(
ROCM
)
f
in
d_package
(
hip
)
in
clude
(
ROCMInstallTargets
)
add_custom_target
(
codegen
)
include
(
ROCMTest
)
# add include directories
rocm_setup_version
(
VERSION 1.0
)
include_directories
(
BEFORE
${
PROJECT_BINARY_DIR
}
/include
${
PROJECT_SOURCE_DIR
}
/include
${
PROJECT_SOURCE_DIR
}
/library/include
${
HIP_INCLUDE_DIRS
}
)
list
(
APPEND CMAKE_MODULE_PATH
${
CK_ROOT
}
/cmake
)
list
(
APPEND CMAKE_MODULE_PATH
${
CK_ROOT
}
/cmake
)
include
(
Embed
)
include
(
Embed
)
file
(
GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
file
(
GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
${
CK_ROOT
}
/include/ck/*.hpp
)
${
CK_ROOT
}
/include/ck/*.hpp
)
#printouts fot debug purposes
#
printouts fot debug purposes
#message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
#
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
#message(STATUS "RELATIVE: ${CK_ROOT}/include")
#
message(STATUS "RELATIVE: ${CK_ROOT}/include")
add_embed_library
(
ck_headers
${
KERNEL_FILES
}
RELATIVE
${
CK_ROOT
}
/include
)
add_embed_library
(
ck_headers
${
KERNEL_FILES
}
RELATIVE
${
CK_ROOT
}
/include
)
file
(
GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp
)
add_compile_options
(
-std=c++17
)
##message(STATUS "SOURCE_FILES: ${SOURCES}"
)
file
(
GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp
)
# TODO: Use object library
# TODO: Use object library
add_library
(
ck_host STATIC
${
SOURCES
}
)
add_library
(
ck_host STATIC
${
SOURCES
}
)
target_link_libraries
(
ck_host PRIVATE ck_headers
)
target_link_libraries
(
ck_host PRIVATE ck_headers
)
...
@@ -37,24 +34,25 @@ set_target_properties(ck_host PROPERTIES
...
@@ -37,24 +34,25 @@ set_target_properties(ck_host PROPERTIES
LINKER_LANGUAGE CXX
LINKER_LANGUAGE CXX
POSITION_INDEPENDENT_CODE ON
)
POSITION_INDEPENDENT_CODE ON
)
target_include_directories
(
ck_host PUBLIC
# target_include_directories(ck_host PUBLIC
$<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
# $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>
# )
)
add_executable
(
ck-template-driver driver/main.cpp
)
add_executable
(
ck-template-driver driver/main.cpp
)
target_link_libraries
(
ck-template-driver ck_host
)
target_link_libraries
(
ck-template-driver ck_host
)
rocm_install
(
rocm_install
_targets
(
TARGETS ck_host ck_headers
TARGETS ck_host ck_headers
EXPORT ck_hostTargets
EXPORT ck_host_targets
INCLUDE include
PRIVATE
)
)
rocm_
install
(
EXPORT ck_hostT
argets
rocm_
export_t
argets
(
FILE composable_kernel
ck_host
T
argets
.cmake
EXPORT
ck_host
_t
argets
NAMESPACE composable_kernel::
NAMESPACE composable_kernel::
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
)
)
rocm_install
(
DIRECTORY include/ck DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
)
if
(
BUILD_TESTING
)
if
(
BUILD_TESTING
)
add_subdirectory
(
test
)
add_subdirectory
(
test
)
endif
()
endif
()
codegen/test/CMakeLists.txt
View file @
49c41176
list
(
APPEND CMAKE_PREFIX_PATH /opt/rocm
)
list
(
APPEND CMAKE_PREFIX_PATH /opt/rocm
)
add_subdirectory
(
rtc
)
add_subdirectory
(
rtc
)
file
(
GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp
)
file
(
GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp
)
# do not build the tests when we build the library for various targets
if
(
NOT GPU_ARCHS
)
# TODO: These tests need to be refactored to remove dependency on main ck
foreach
(
TEST_SRC
${
TEST_SRCS
}
)
# headers and device compilation.
set_source_files_properties
(
${
TEST_SRC
}
PROPERTIES LANGUAGE HIP
)
set
(
TESTS_REQUIRE_DEVICE_COMPILE
grouped_conv_fwd_multiple_d_v1
grouped_conv_fwd_multiple_d_v2
grouped_conv_fwd_multiple_d_v3
grouped_conv_fwd_multiple_d_v4
)
find_package
(
hip
)
foreach
(
TEST_SRC
${
TEST_SRCS
}
)
get_filename_component
(
BASE_NAME
${
TEST_SRC
}
NAME_WE
)
get_filename_component
(
BASE_NAME
${
TEST_SRC
}
NAME_WE
)
add_executable
(
codegen_test_
${
BASE_NAME
}
${
TEST_SRC
}
)
rocm_add_test_executable
(
codegen_test_
${
BASE_NAME
}
${
TEST_SRC
}
)
if
(
CK_USE_ALTERNATIVE_PYTHON
)
target_link_options
(
codegen_test_
${
BASE_NAME
}
PRIVATE -lstdc++fs
)
endif
()
add_dependencies
(
codegen codegen_test_
${
BASE_NAME
}
)
add_dependencies
(
tests codegen_test_
${
BASE_NAME
}
)
add_dependencies
(
check codegen_test_
${
BASE_NAME
}
)
add_test
(
NAME codegen_test_
${
BASE_NAME
}
COMMAND codegen_test_
${
BASE_NAME
}
)
message
(
"adding test codegen_test_
${
BASE_NAME
}
"
)
target_link_libraries
(
codegen_test_
${
BASE_NAME
}
ck_rtc ck_host
)
target_link_libraries
(
codegen_test_
${
BASE_NAME
}
ck_rtc ck_host
)
target_include_directories
(
codegen_test_
${
BASE_NAME
}
PUBLIC
${
CK_ROOT
}
/codegen/test/include
)
target_include_directories
(
codegen_test_
${
BASE_NAME
}
PUBLIC include
)
if
(
BASE_NAME IN_LIST TESTS_REQUIRE_DEVICE_COMPILE
)
target_link_libraries
(
codegen_test_
${
BASE_NAME
}
hip::device
)
target_include_directories
(
codegen_test_
${
BASE_NAME
}
PUBLIC
${
CK_ROOT
}
/include
)
target_include_directories
(
codegen_test_
${
BASE_NAME
}
PUBLIC
${
CK_ROOT
}
/include
)
target_include_directories
(
codegen_test_
${
BASE_NAME
}
PUBLIC
${
CK_ROOT
}
/library/include
)
target_include_directories
(
codegen_test_
${
BASE_NAME
}
PUBLIC
${
CK_ROOT
}
/library/include
)
endf
oreach
()
end
i
f
()
end
i
f
()
endf
oreach
()
codegen/test/common.hpp
→
codegen/test/
include/
common.hpp
View file @
49c41176
File moved
codegen/test/rtc/CMakeLists.txt
View file @
49c41176
find_package
(
hip
)
file
(
GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp
)
file
(
GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp
)
add_library
(
ck_rtc
${
RTC_SOURCES
}
)
add_library
(
ck_rtc
${
RTC_SOURCES
}
)
target_include_directories
(
ck_rtc PUBLIC include
)
target_include_directories
(
ck_rtc PUBLIC include
)
target_link_libraries
(
ck_rtc PUBLIC hip::host
)
target_link_libraries
(
ck_rtc PUBLIC hip::host
)
target_link_libraries
(
ck_rtc PUBLIC -lstdc++fs
)
codegen/test/rtc/include/rtc/compile_kernel.hpp
View file @
49c41176
...
@@ -2,14 +2,14 @@
...
@@ -2,14 +2,14 @@
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL
#include <rtc/kernel.hpp>
#include <rtc/kernel.hpp>
#include <c
k
/filesystem.hpp>
#include <
rt
c/filesystem.hpp>
#include <string>
#include <string>
namespace
rtc
{
namespace
rtc
{
struct
src_file
struct
src_file
{
{
CK
::
fs
::
path
path
;
fs
::
path
path
;
std
::
string_view
content
;
std
::
string_view
content
;
};
};
...
...
codegen/test/rtc/include/rtc/filesystem.hpp
0 → 100644
View file @
49c41176
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#ifndef GUARD_TEST_HOST_RTC_FILESYSTEM_HPP
#define GUARD_TEST_HOST_RTC_FILESYSTEM_HPP
#include <string>
#include <string_view>
// clang-format off
#if defined(CPPCHECK)
#define RTC_HAS_FILESYSTEM 1
#define RTC_HAS_FILESYSTEM_TS 1
#elif defined(_WIN32)
#if _MSC_VER >= 1920
#define RTC_HAS_FILESYSTEM 1
#define RTC_HAS_FILESYSTEM_TS 0
#elif _MSC_VER >= 1900
#define RTC_HAS_FILESYSTEM 0
#define RTC_HAS_FILESYSTEM_TS 1
#else
#define RTC_HAS_FILESYSTEM 0
#define RTC_HAS_FILESYSTEM_TS 0
#endif
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#define RTC_HAS_FILESYSTEM 1
#else
#define RTC_HAS_FILESYSTEM 0
#endif
#if __has_include(<experimental/filesystem>) && __cplusplus >= 201103L
#define RTC_HAS_FILESYSTEM_TS 1
#else
#define RTC_HAS_FILESYSTEM_TS 0
#endif
#else
#define RTC_HAS_FILESYSTEM 0
#define RTC_HAS_FILESYSTEM_TS 0
#endif
// clang-format on
#if RTC_HAS_FILESYSTEM
#include <filesystem>
#elif RTC_HAS_FILESYSTEM_TS
#include <experimental/filesystem>
#else
#error "No filesystem include available"
#endif
namespace
rtc
{
#if RTC_HAS_FILESYSTEM
namespace
fs
=
::
std
::
filesystem
;
#elif RTC_HAS_FILESYSTEM_TS
namespace
fs
=
::
std
::
experimental
::
filesystem
;
#endif
}
// namespace rtc
#endif // GUARD_RTC_FILESYSTEM_HPP_
codegen/test/rtc/include/rtc/tmp_dir.hpp
View file @
49c41176
...
@@ -2,13 +2,13 @@
...
@@ -2,13 +2,13 @@
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR
#include <string>
#include <string>
#include <c
k
/filesystem.hpp>
#include <
rt
c/filesystem.hpp>
namespace
rtc
{
namespace
rtc
{
struct
tmp_dir
struct
tmp_dir
{
{
CK
::
fs
::
path
path
;
fs
::
path
path
;
tmp_dir
(
const
std
::
string
&
prefix
=
""
);
tmp_dir
(
const
std
::
string
&
prefix
=
""
);
void
execute
(
const
std
::
string
&
cmd
)
const
;
void
execute
(
const
std
::
string
&
cmd
)
const
;
...
...
codegen/test/rtc/src/compile_kernel.cpp
View file @
49c41176
#include
"
rtc/hip.hpp
"
#include
<
rtc/hip.hpp
>
#include <rtc/compile_kernel.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/tmp_dir.hpp>
#include <rtc/tmp_dir.hpp>
#include <stdexcept>
#include <stdexcept>
...
@@ -70,9 +70,9 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
...
@@ -70,9 +70,9 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
for
(
const
auto
&
src
:
srcs
)
for
(
const
auto
&
src
:
srcs
)
{
{
CK
::
fs
::
path
full_path
=
td
.
path
/
src
.
path
;
fs
::
path
full_path
=
td
.
path
/
src
.
path
;
CK
::
fs
::
path
parent_path
=
full_path
.
parent_path
();
fs
::
path
parent_path
=
full_path
.
parent_path
();
CK
::
fs
::
create_directories
(
parent_path
);
fs
::
create_directories
(
parent_path
);
write_string
(
full_path
.
string
(),
src
.
content
);
write_string
(
full_path
.
string
(),
src
.
content
);
if
(
src
.
path
.
extension
().
string
()
==
".cpp"
)
if
(
src
.
path
.
extension
().
string
()
==
".cpp"
)
{
{
...
@@ -86,7 +86,7 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
...
@@ -86,7 +86,7 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
td
.
execute
(
compiler
()
+
options
.
flags
);
td
.
execute
(
compiler
()
+
options
.
flags
);
auto
out_path
=
td
.
path
/
out
;
auto
out_path
=
td
.
path
/
out
;
if
(
not
CK
::
fs
::
exists
(
out_path
))
if
(
not
fs
::
exists
(
out_path
))
throw
std
::
runtime_error
(
"Output file missing: "
+
out
);
throw
std
::
runtime_error
(
"Output file missing: "
+
out
);
auto
obj
=
read_buffer
(
out_path
.
string
());
auto
obj
=
read_buffer
(
out_path
.
string
());
...
...
codegen/test/rtc/src/tmp_dir.cpp
View file @
49c41176
...
@@ -31,10 +31,10 @@ std::string unique_string(const std::string& prefix)
...
@@ -31,10 +31,10 @@ std::string unique_string(const std::string& prefix)
}
}
tmp_dir
::
tmp_dir
(
const
std
::
string
&
prefix
)
tmp_dir
::
tmp_dir
(
const
std
::
string
&
prefix
)
:
path
(
CK
::
fs
::
temp_directory_path
()
/
:
path
(
fs
::
temp_directory_path
()
/
unique_string
(
prefix
.
empty
()
?
"ck-rtc"
:
"ck-rtc-"
+
prefix
))
unique_string
(
prefix
.
empty
()
?
"ck-rtc"
:
"ck-rtc-"
+
prefix
))
{
{
CK
::
fs
::
create_directories
(
this
->
path
);
fs
::
create_directories
(
this
->
path
);
}
}
void
tmp_dir
::
execute
(
const
std
::
string
&
cmd
)
const
void
tmp_dir
::
execute
(
const
std
::
string
&
cmd
)
const
...
@@ -43,6 +43,6 @@ void tmp_dir::execute(const std::string& cmd) const
...
@@ -43,6 +43,6 @@ void tmp_dir::execute(const std::string& cmd) const
std
::
system
(
s
.
c_str
());
std
::
system
(
s
.
c_str
());
}
}
tmp_dir
::~
tmp_dir
()
{
CK
::
fs
::
remove_all
(
this
->
path
);
}
tmp_dir
::~
tmp_dir
()
{
fs
::
remove_all
(
this
->
path
);
}
}
// namespace rtc
}
// namespace rtc
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
View file @
49c41176
...
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
auto
k_lds_write_window
=
auto
k_lds_write_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
k_lds_read_window
=
auto
k_lds_read_window
=
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
k_lds_write_window
.
get_window_origin
(),
k_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeKReg
Slice
BlockDescriptor
<
Problem
>());
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
...
@@ -204,15 +204,12 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -204,15 +204,12 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
auto
v_lds_write_window
=
auto
v_lds_write_window
=
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
v_lds_read_window
=
auto
v_lds_read_window
=
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
v_lds_write_window
.
get_window_origin
(),
v_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeVRegSliceBlockDescriptor
<
Problem
>());
auto
v_reg_tensor
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
//------------------------------------------------------------------
...
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_k_lds_write_window
=
make_tile_window
(
auto
shuffled_k_lds_write_window
=
make_tile_window
(
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
...
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
block_sync_lds
();
block_sync_lds
();
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
auto
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
block_sync_lds
();
block_sync_lds
();
//---------------------------- Loop Load in ----------------------------//
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
// Q: HBM ->Reg ->LDS
...
@@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
q_lds_read_window
=
auto
q_lds_read_window
=
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
...
@@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_q_lds_write_window
=
make_tile_window
(
auto
shuffled_q_lds_write_window
=
make_tile_window
(
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
...
@@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
do_lds_read_window
=
auto
do_lds_read_window
=
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
...
@@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_do_lds_write_window
=
make_tile_window
(
auto
shuffled_do_lds_write_window
=
make_tile_window
(
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
...
@@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
index_t
i_total_loops
=
0
;
index_t
i_total_loops
=
0
;
index_t
seqlen_q_step
=
seqlen_q_start
;
index_t
seqlen_q_step
=
seqlen_q_start
;
static_assert
(
kQKHeaddim
=
=
kK0
,
"kQKHeaddim should equal
t
o kK0"
);
static_assert
(
kQKHeaddim
>
=
kK0
,
"kQKHeaddim should
be
equal o
r greater than
kK0"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kVHeaddim
=
=
kK2
,
"kVHeaddim should equal
t
o kK2"
);
static_assert
(
kVHeaddim
>
=
kK2
,
"kVHeaddim should
be
equal o
r greater than
kK2"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
View file @
49c41176
...
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
auto
k_lds_write_window
=
auto
k_lds_write_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
k_lds_read_window
=
auto
k_lds_read_window
=
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
k_lds_write_window
.
get_window_origin
(),
k_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeKReg
Slice
BlockDescriptor
<
Problem
>());
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
...
@@ -204,15 +204,12 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -204,15 +204,12 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
auto
v_lds_write_window
=
auto
v_lds_write_window
=
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
v_lds_read_window
=
auto
v_lds_read_window
=
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
v_lds_write_window
.
get_window_origin
(),
v_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeVRegSliceBlockDescriptor
<
Problem
>());
auto
v_reg_tensor
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
//------------------------------------------------------------------
...
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_k_lds_write_window
=
make_tile_window
(
auto
shuffled_k_lds_write_window
=
make_tile_window
(
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
...
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
block_sync_lds
();
block_sync_lds
();
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
auto
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
//---------------------------- Loop Load in ----------------------------//
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
// Q: HBM ->Reg ->LDS
auto
q_dram_window
=
auto
q_dram_window
=
...
@@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
q_lds_read_window
=
auto
q_lds_read_window
=
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
...
@@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_q_lds_write_window
=
make_tile_window
(
auto
shuffled_q_lds_write_window
=
make_tile_window
(
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
...
@@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
do_lds_read_window
=
auto
do_lds_read_window
=
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
...
@@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_do_lds_write_window
=
make_tile_window
(
auto
shuffled_do_lds_write_window
=
make_tile_window
(
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
...
@@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
index_t
i_total_loops
=
0
;
index_t
i_total_loops
=
0
;
index_t
seqlen_q_step
=
seqlen_q_start
;
index_t
seqlen_q_step
=
seqlen_q_start
;
static_assert
(
kQKHeaddim
=
=
kK0
,
"kQKHeaddim should equal
t
o kK0"
);
static_assert
(
kQKHeaddim
>
=
kK0
,
"kQKHeaddim should
be
equal o
r greater than
kK0"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kVHeaddim
=
=
kK2
,
"kVHeaddim should equal
t
o kK2"
);
static_assert
(
kVHeaddim
>
=
kK2
,
"kVHeaddim should
be
equal o
r greater than
kK2"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
49c41176
...
@@ -196,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -196,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
QDataType
);
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
QDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
QDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
QDataType
);
...
@@ -215,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -215,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
KDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
KDataType
);
...
@@ -234,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -234,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
...
@@ -254,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -254,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
OGradDataType
);
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
OGradDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
OGradDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
OGradDataType
);
...
@@ -315,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -315,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
...
@@ -327,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -327,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
return
total_pixels
/
GetAlignmentK
<
Problem
>
();
return
total_pixels
/
GetAlignmentK
<
Problem
>
();
...
@@ -338,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -338,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
...
@@ -376,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -376,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
@@ -399,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -399,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
K1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
K1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
@@ -422,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -422,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
@@ -445,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -445,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
@@ -816,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -816,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsWriteBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsWriteBlockDescriptor
()
{
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
k_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
k_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
k_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
k_block_dstr
=
make_static_tile_distribution
(
k_block_dstr_encode
);
return
k_block_dstr
;
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKRegBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKRegBlockDescriptor
()
{
{
...
@@ -865,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -865,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
...
@@ -890,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -890,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVLdsWriteBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVLdsWriteBlockDescriptor
()
{
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kVPack
=
GetSmemKPackV
<
Problem
>
();
constexpr
index_t
kVPack
=
GetSmemKPackV
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kVPack
>
();
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kVPack
>
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
v_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
v_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
v_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
v_block_dstr
=
make_static_tile_distribution
(
v_block_dstr_encode
);
return
v_block_dstr
;
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVRegBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVRegBlockDescriptor
()
{
{
...
@@ -940,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -940,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
...
@@ -966,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -966,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
@@ -1048,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1048,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptor
()
{
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
...
@@ -1092,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1092,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
@@ -1255,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1255,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
{
// Hold full block data
// Hold full block data
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
...
@@ -1299,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1299,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
@@ -1859,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1859,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static
constexpr
index_t
kN0
=
Problem
::
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kN0
=
Problem
::
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
index_t
kVHeaddim
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
index_t
kK0
=
Problem
::
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kK2
=
Problem
::
BlockFmhaShape
::
kK2
;
static
constexpr
index_t
kK4
=
Problem
::
BlockFmhaShape
::
kK4
;
static
constexpr
index_t
kK4
=
Problem
::
BlockFmhaShape
::
kK4
;
static
constexpr
index_t
WarpGemmM
=
static
constexpr
index_t
WarpGemmM
=
...
@@ -1873,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1873,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
// Compute
// Compute
static
constexpr
index_t
Gemm0MFMA
=
static
constexpr
index_t
Gemm0MFMA
=
kM0
*
kN0
*
kQKHeaddim
/
kM0
*
kN0
*
kK0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm1MFMA
=
static
constexpr
index_t
Gemm1MFMA
=
kM0
*
kN0
*
kVHeaddim
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm2MFMA
=
kN0
*
kVHeaddim
*
kM0
/
kN0
*
kVHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm2MFMA
=
kM0
*
kN0
*
kK2
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm3MFMA
=
static
constexpr
index_t
Gemm3MFMA
=
kN0
*
kQKHeaddim
*
kM0
/
kN0
*
kQKHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
...
@@ -1903,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1903,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
kM0
*
kQKHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentQ
<
Problem
>
();
kM0
*
kQKHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentQ
<
Problem
>
();
static
constexpr
index_t
SGradT_LDS_READ_P1
=
static
constexpr
index_t
SGradT_LDS_READ_P1
=
kM0
*
kK4
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
kM0
*
kK4
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
Q_LDS_READ
=
static
constexpr
index_t
Q_LDS_READ
=
kM0
*
kK0
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
kM0
*
kQKHeaddim
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
LSE_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
static
constexpr
index_t
LSE_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
static
constexpr
index_t
SGradT_LDS_READ_P2
=
static
constexpr
index_t
SGradT_LDS_READ_P2
=
kM0
*
(
kN0
-
kK4
)
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
kM0
*
(
kN0
-
kK4
)
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
OGrad_LDS_READ
=
static
constexpr
index_t
OGrad_LDS_READ
=
kM0
*
k
VHeaddim
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
kM0
*
k
K2
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
D_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
static
constexpr
index_t
D_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
// LDS Write
// LDS Write
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment