Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
909f519c
Unverified
Commit
909f519c
authored
Jun 27, 2024
by
Harisankar Sadasivan
Committed by
GitHub
Jun 27, 2024
Browse files
Merge branch 'develop' into universal_streamk
parents
406fa265
3bb0fe6c
Changes
82
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2087 additions
and
73 deletions
+2087
-73
.azuredevops/rocm-ci.yml
.azuredevops/rocm-ci.yml
+1
-14
CMakeLists.txt
CMakeLists.txt
+4
-2
Jenkinsfile
Jenkinsfile
+1
-3
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-1
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+27
-27
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+1
-1
example/04_gemm_add_add_fastgelu/CMakeLists.txt
example/04_gemm_add_add_fastgelu/CMakeLists.txt
+1
-1
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
..._bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
+2
-2
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
..._scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
...m_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
+3
-3
example/CMakeLists.txt
example/CMakeLists.txt
+2
-2
example/ck_tile/01_fmha/CMakeLists.txt
example/ck_tile/01_fmha/CMakeLists.txt
+4
-4
example/ck_tile/01_fmha/codegen/__init__.py
example/ck_tile/01_fmha/codegen/__init__.py
+0
-0
example/ck_tile/01_fmha/codegen/cmake_config.py
example/ck_tile/01_fmha/codegen/cmake_config.py
+5
-0
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+92
-0
example/ck_tile/01_fmha/codegen/ops/__init__.py
example/ck_tile/01_fmha/codegen/ops/__init__.py
+0
-0
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+611
-0
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+498
-0
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+671
-0
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+160
-10
No files found.
.azuredevops/rocm-ci.yml
View file @
909f519c
...
@@ -23,20 +23,7 @@ trigger:
...
@@ -23,20 +23,7 @@ trigger:
-
Jenkinsfile
-
Jenkinsfile
-
LICENSE
-
LICENSE
pr
:
pr
:
none
autoCancel
:
true
branches
:
include
:
-
develop
paths
:
exclude
:
-
.github
-
docs
-
'
.*.y*ml'
-
'
*.md'
-
Jenkinsfile
-
LICENSE
drafts
:
false
jobs
:
jobs
:
-
template
:
${{ variables.CI_COMPONENT_PATH }}/composable_kernel.yml@pipelines_repo
-
template
:
${{ variables.CI_COMPONENT_PATH }}/composable_kernel.yml@pipelines_repo
CMakeLists.txt
View file @
909f519c
...
@@ -117,7 +117,7 @@ else()
...
@@ -117,7 +117,7 @@ else()
add_definitions
(
-DPROFILER_ONLY
)
add_definitions
(
-DPROFILER_ONLY
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
if
(
GPU_TARGETS
)
if
(
GPU_TARGETS
)
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx1
1
"
)
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10,
gfx11
or gfx1
2
"
)
endif
()
endif
()
if
(
GPU_ARCH MATCHES
"gfx90"
)
if
(
GPU_ARCH MATCHES
"gfx90"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx908;gfx90a"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx908;gfx90a"
)
...
@@ -127,8 +127,10 @@ else()
...
@@ -127,8 +127,10 @@ else()
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1030"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1030"
)
elseif
(
GPU_ARCH MATCHES
"gfx11"
)
elseif
(
GPU_ARCH MATCHES
"gfx11"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1100;gfx1101;gfx1102"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1100;gfx1101;gfx1102"
)
elseif
(
GPU_ARCH MATCHES
"gfx12"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1200;gfx1201"
)
else
()
else
()
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx1
1
"
)
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10,
gfx11
or gfx1
2
"
)
endif
()
endif
()
set
(
GPU_TARGETS
"
${
DEFAULT_GPU_TARGETS
}
"
CACHE STRING
" "
FORCE
)
set
(
GPU_TARGETS
"
${
DEFAULT_GPU_TARGETS
}
"
CACHE STRING
" "
FORCE
)
endif
()
endif
()
...
...
Jenkinsfile
View file @
909f519c
...
@@ -493,6 +493,7 @@ def Build_CK(Map conf=[:]){
...
@@ -493,6 +493,7 @@ def Build_CK(Map conf=[:]){
def
variant
=
env
.
STAGE_NAME
def
variant
=
env
.
STAGE_NAME
def
retimage
def
retimage
gitStatusWrapper
(
credentialsId:
"${env.status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel'
)
{
gitStatusWrapper
(
credentialsId:
"${env.status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel'
)
{
try
{
try
{
(
retimage
,
image
)
=
getDockerImage
(
conf
)
(
retimage
,
image
)
=
getDockerImage
(
conf
)
...
@@ -660,9 +661,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM
...
@@ -660,9 +661,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM
pipeline
{
pipeline
{
agent
none
agent
none
triggers
{
parameterizedCron
(
CRON_SETTINGS
)
}
options
{
options
{
parallelsAlwaysFailFast
()
parallelsAlwaysFailFast
()
}
}
...
...
cmake/EnableCompilerWarnings.cmake
View file @
909f519c
example/01_gemm/gemm_wmma_fp16.cpp
View file @
909f519c
...
@@ -40,7 +40,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -40,7 +40,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
64
,
// MPerBlock
64
,
// MPerBlock
128
,
// NPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
64
,
// KPerBlock
8
,
// K1
2
,
// K1
16
,
// MPerWmma
16
,
// MPerWmma
16
,
// NPerWmma
16
,
// NPerWmma
2
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
2
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
...
@@ -49,15 +49,15 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -49,15 +49,15 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
2
,
8
,
2
,
true
,
true
,
S
<
4
,
32
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
2
,
8
,
2
,
true
,
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
...
...
example/01_gemm/run_gemm_example.inc
View file @
909f519c
...
@@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
break
;
case
4
:
case
4
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1
.
f
,
1
.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5
.
f
,
5
.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
break
;
break
;
case
5
:
case
5
:
...
...
example/04_gemm_add_add_fastgelu/CMakeLists.txt
View file @
909f519c
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
View file @
909f519c
...
@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN =
...
@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN =
2
,
2
,
4
,
4
,
4
,
4
,
tru
e
,
fals
e
,
S
<
4
,
32
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
4
,
4
,
4
,
4
,
tru
e
,
fals
e
,
1
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
S
<
1
,
64
,
1
,
2
>
,
...
...
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
View file @
909f519c
...
@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
...
@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
//
#define CK_MHA_USE_WAVE_8
using
DeviceMHAFactory
=
using
DeviceMHAFactory
=
std
::
tuple
<
std
::
tuple
<
#ifdef CK_MHA_USE_WAVE_1
#ifdef CK_MHA_USE_WAVE_1
...
@@ -277,10 +277,10 @@ using DeviceMHAFactory =
...
@@ -277,10 +277,10 @@ using DeviceMHAFactory =
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
MaskingSpec
>
#endif
#endif
#ifdef CK_MHA_USE_WAVE_8
#ifdef CK_MHA_USE_WAVE_8
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
...
...
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
View file @
909f519c
...
@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
...
@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
//
#define CK_MHA_USE_WAVE_8
using
DeviceMHAFactory
=
using
DeviceMHAFactory
=
std
::
tuple
<
std
::
tuple
<
#ifdef CK_MHA_USE_WAVE_1
#ifdef CK_MHA_USE_WAVE_1
...
@@ -277,10 +277,10 @@ using DeviceMHAFactory =
...
@@ -277,10 +277,10 @@ using DeviceMHAFactory =
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
MaskingSpec
>
#endif
#endif
#ifdef CK_MHA_USE_WAVE_8
#ifdef CK_MHA_USE_WAVE_8
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
...
...
example/CMakeLists.txt
View file @
909f519c
...
@@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
...
@@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endforeach
()
endforeach
()
#Do not build any WMMA examples if gfx11 targets are not on the list
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT
EX
_TARGETS MATCHES
"gfx1
1
"
AND source MATCHES
"_wmma"
)
if
(
NOT GPU_TARGETS MATCHES
"gfx11"
AND
NOT
GPU
_TARGETS MATCHES
"gfx1
2
"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma example
${
source
}
"
)
message
(
"removing wmma example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
...
@@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
...
@@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endforeach
()
endforeach
()
#Do not build any WMMA examples if gfx11 targets are not on the list
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT
EX
_TARGETS MATCHES
"gfx1
1
"
AND source MATCHES
"_wmma"
)
if
(
NOT GPU_TARGETS MATCHES
"gfx11"
AND
NOT
GPU
_TARGETS MATCHES
"gfx1
2
"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma example
${
source
}
"
)
message
(
"removing wmma example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
...
...
example/ck_tile/01_fmha/CMakeLists.txt
View file @
909f519c
# generate a list of kernels, but not actually emit files at config stage
# generate a list of kernels, but not actually emit files at config stage
execute_process
(
execute_process
(
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--
direction fwd
--list_blobs
${
CMAKE_CURRENT_BINARY_DIR
}
/fwd_blob_list.txt
--
api fwd,fwd_splitkv
--list_blobs
${
CMAKE_CURRENT_BINARY_DIR
}
/fwd_blob_list.txt
)
)
execute_process
(
execute_process
(
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--
direction
bwd --list_blobs
${
CMAKE_CURRENT_BINARY_DIR
}
/bwd_blob_list.txt
--
api
bwd --list_blobs
${
CMAKE_CURRENT_BINARY_DIR
}
/bwd_blob_list.txt
)
)
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
...
@@ -17,13 +17,13 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
...
@@ -17,13 +17,13 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command
(
add_custom_command
(
OUTPUT
${
FMHA_FWD_GEN_BLOBS
}
OUTPUT
${
FMHA_FWD_GEN_BLOBS
}
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--
direction fwd
--output_dir
${
CMAKE_CURRENT_BINARY_DIR
}
--
api fwd,fwd_splitkv
--output_dir
${
CMAKE_CURRENT_BINARY_DIR
}
)
)
add_custom_command
(
add_custom_command
(
OUTPUT
${
FMHA_BWD_GEN_BLOBS
}
OUTPUT
${
FMHA_BWD_GEN_BLOBS
}
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--
direction
bwd --output_dir
${
CMAKE_CURRENT_BINARY_DIR
}
--
api
bwd --output_dir
${
CMAKE_CURRENT_BINARY_DIR
}
)
)
set
(
EXAMPLE_FMHA_FWD
"tile_example_fmha_fwd"
)
set
(
EXAMPLE_FMHA_FWD
"tile_example_fmha_fwd"
)
...
...
example/ck_tile/01_fmha/codegen/__init__.py
0 → 100644
View file @
909f519c
example/ck_tile/01_fmha/codegen/cmake_config.py
0 → 100644
View file @
909f519c
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
GEN_DIR
=
""
# in Cmake, have to generate files in same folder
\ No newline at end of file
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
0 → 100644
View file @
909f519c
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
DTYPE_MAP
=
{
"fp16"
:
"ck_tile::fp16_t"
,
"bf16"
:
"ck_tile::bf16_t"
,
"fp8"
:
"ck_tile::fp8_t"
}
MASK_IMPL
=
{
"generic"
:
"ck_tile::GenericAttentionMask"
,
"simplified"
:
"ck_tile::SimplifiedGenericAttentionMask"
}
_MASK_SIMPLIFIED_MAP
=
{
"s_no"
:
"ck_tile::SimplifiedGenericAttentionMask<false>"
,
"s_mask"
:
"ck_tile::SimplifiedGenericAttentionMask<true>"
,
}
_MASK_MAP
=
{
"no"
:
"FmhaMasks::NoMask"
,
"causal"
:
"FmhaMasks::CausalMask"
,
"generic"
:
"FmhaMasks::GenericMask"
}
def
get_mask_map
(
mask
:
str
):
if
mask
==
"generic"
:
return
_MASK_MAP
elif
mask
==
"simplified"
:
return
_MASK_SIMPLIFIED_MAP
else
:
assert
False
return
None
_MASK_CHECK_MAP
=
{
"no"
:
"t.mask_type == mask_enum::no_mask"
,
"causal"
:
"t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right"
,
"generic"
:
"t.mask_type == mask_enum::window_generic"
,
}
_MASK_SIMPLIFIED_CHECK_MAP
=
{
"s_no"
:
"t.mask_type == mask_enum::no_mask"
,
"s_mask"
:
"t.mask_type != mask_enum::no_mask"
,
}
def
get_mask_check_map
(
mask
:
str
):
if
mask
==
"generic"
:
return
_MASK_CHECK_MAP
elif
mask
==
"simplified"
:
return
_MASK_SIMPLIFIED_CHECK_MAP
else
:
assert
False
return
None
BIAS_MAP
=
{
"no"
:
"ck_tile::BlockAttentionBiasEnum::NO_BIAS"
,
"bias"
:
"ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS"
,
"alibi"
:
"ck_tile::BlockAttentionBiasEnum::ALIBI"
}
# TODO: this is ugly
BIAS_CHECK_MAP
=
{
"no"
:
"bias_enum::no_bias"
,
"bias"
:
"bias_enum::elementwise_bias"
,
"alibi"
:
"bias_enum::alibi"
}
MODE_MAP
=
{
"batch"
:
"false"
,
"group"
:
"true"
}
LAYOUT_MAP
=
{
"row"
:
"true"
,
"col"
:
"false"
}
PIPELINE_MAP
=
{
"qr"
:
"ck_tile::BlockFmhaPipelineQRKSVS"
,
"qr_async"
:
"ck_tile::BlockFmhaPipelineQRKSVSAsync"
,
}
PIPELINE_ENUM_MAP
=
{
"qr"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS"
,
"qr_async"
:
"ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC"
,
}
BOOL_MAP
=
{
"t"
:
"true"
,
"f"
:
"false"
}
\ No newline at end of file
example/ck_tile/01_fmha/codegen/ops/__init__.py
0 → 100644
View file @
909f519c
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
0 → 100644
View file @
909f519c
This diff is collapsed.
Click to expand it.
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
0 → 100644
View file @
909f519c
This diff is collapsed.
Click to expand it.
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
0 → 100644
View file @
909f519c
This diff is collapsed.
Click to expand it.
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
909f519c
...
@@ -114,6 +114,9 @@ auto create_args(int argc, char* argv[])
...
@@ -114,6 +114,9 @@ auto create_args(int argc, char* argv[])
.
insert
(
"drop_seed"
,
"1"
,
"seed for random number generator"
)
.
insert
(
"drop_seed"
,
"1"
,
"seed for random number generator"
)
.
insert
(
"drop_offset"
,
"0"
,
"offset for random number generator"
)
.
insert
(
"drop_offset"
,
"0"
,
"offset for random number generator"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"num_splits"
,
"1"
,
"# of splits for key/value. 0 to determine actual number by heuristic"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"20"
,
"number of iterations to benchmark the kernel"
);
.
insert
(
"repeat"
,
"20"
,
"number of iterations to benchmark the kernel"
);
...
@@ -155,6 +158,106 @@ auto get_elimit<ck_tile::fp8_t>(std::string init_method)
...
@@ -155,6 +158,106 @@ auto get_elimit<ck_tile::fp8_t>(std::string init_method)
}
}
}
}
int
num_splits_heuristic
(
int
batch_nhead_mblocks
,
int
num_SMs
,
int
num_n_blocks
,
int
max_splits
)
{
// If we have enough to almost fill the SMs, then just use 1 split
if
(
batch_nhead_mblocks
>=
0.8
f
*
num_SMs
)
{
return
1
;
}
max_splits
=
std
::
min
({
max_splits
,
num_SMs
,
num_n_blocks
});
float
max_efficiency
=
0.
f
;
std
::
vector
<
float
>
efficiency
;
efficiency
.
reserve
(
max_splits
);
auto
ceildiv
=
[](
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
};
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto
is_split_eligible
=
[
&
ceildiv
,
&
num_n_blocks
](
int
num_splits
)
{
return
num_splits
==
1
||
ceildiv
(
num_n_blocks
,
num_splits
)
!=
ceildiv
(
num_n_blocks
,
num_splits
-
1
);
};
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
!
is_split_eligible
(
num_splits
))
{
efficiency
.
push_back
(
0.
f
);
}
else
{
float
n_waves
=
float
(
batch_nhead_mblocks
*
num_splits
)
/
num_SMs
;
float
eff
=
n_waves
/
ceil
(
n_waves
);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if
(
eff
>
max_efficiency
)
{
max_efficiency
=
eff
;
}
efficiency
.
push_back
(
eff
);
}
}
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
!
is_split_eligible
(
num_splits
))
{
continue
;
}
if
(
efficiency
[
num_splits
-
1
]
>=
0.85
*
max_efficiency
)
{
// printf("num_splits chosen = %d\n", num_splits);
return
num_splits
;
}
}
return
1
;
}
int
override_num_splits_if_necessary
(
int
batch
,
int
nhead
,
int
max_seqlen_q
,
int
hdim_v
,
float
p_drop
,
int
num_splits
)
{
int
device
;
auto
status
=
hipGetDevice
(
&
device
);
if
(
status
!=
hipSuccess
)
{
return
num_splits
;
}
hipDeviceProp_t
props
{};
status
=
hipGetDeviceProperties
(
&
props
,
device
);
if
(
status
!=
hipSuccess
)
{
return
num_splits
;
}
// tile size should match the generate.py
const
int
kM0
=
64
;
const
int
kN1
=
hdim_v
;
const
int
num_m_blocks
=
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
kM0
);
const
int
num_n_blocks
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
if
(
num_splits
<
1
&&
p_drop
==
0.0
f
)
{
return
num_splits_heuristic
(
batch
*
nhead
*
num_m_blocks
,
props
.
multiProcessorCount
*
2
,
num_n_blocks
,
128
);
}
return
num_splits
;
}
float
fmha_fwd_dispatch
(
fmha_fwd_traits
traits
,
fmha_fwd_args
args
,
const
ck_tile
::
stream_config
&
config
)
{
if
(
1
<
args
.
num_splits
)
{
return
fmha_fwd_splitkv
(
traits
,
args
,
config
);
}
else
{
return
fmha_fwd
(
traits
,
args
,
config
);
}
}
template
<
typename
DataType
>
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
...
@@ -260,6 +363,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -260,6 +363,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
seed
.
reset
();
seed
.
reset
();
}
}
int
num_splits
=
arg_parser
.
get_int
(
"num_splits"
);
int
stream_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
stream_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
stream_repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
stream_repeat
=
arg_parser
.
get_int
(
"repeat"
);
bool
kname
=
arg_parser
.
get_bool
(
"kname"
);
bool
kname
=
arg_parser
.
get_bool
(
"kname"
);
...
@@ -320,6 +425,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -320,6 +425,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
}
}
// legalize num_splits according to other options
if
(
num_splits
<
1
)
{
num_splits
=
override_num_splits_if_necessary
(
batch
,
nhead
,
max_seqlen_q
,
hdim_v
,
p_drop
,
num_splits
);
}
if
(
128
<
num_splits
)
{
std
::
cerr
<<
"num_splits greater than 128 is not supported"
<<
std
::
endl
;
return
false
;
}
auto
get_lengths
=
[
&
](
bool
permute
,
auto
get_lengths
=
[
&
](
bool
permute
,
ck_tile
::
index_t
b
/*batch*/
,
ck_tile
::
index_t
b
/*batch*/
,
ck_tile
::
index_t
h
/*nhead*/
,
ck_tile
::
index_t
h
/*nhead*/
,
...
@@ -361,7 +478,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -361,7 +478,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
batch
,
nhead
})
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
batch
,
nhead
})
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
1
,
1
});
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
1
,
1
});
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
ck_tile
::
HostTensor
<
LSEDataType
>
lse_acc_host
(
1
<
num_splits
?
std
::
array
<
ck_tile
::
index_t
,
4
>
{
num_splits
,
batch
,
nhead
,
max_seqlen_q
}
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
1
<
num_splits
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
num_splits
,
batch
,
nhead
,
max_seqlen_q
,
hdim_v
}
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
1
,
1
,
1
,
1
,
1
});
// self define lse data layout as [batch, nhead, max_seqlen_q]
ck_tile
::
HostTensor
<
LSEDataType
>
lse_host
(
ck_tile
::
HostTensor
<
LSEDataType
>
lse_host
(
lse
?
std
::
array
<
ck_tile
::
index_t
,
3
>
{
batch
,
nhead
,
max_seqlen_q
}
lse
?
std
::
array
<
ck_tile
::
index_t
,
3
>
{
batch
,
nhead
,
max_seqlen_q
}
:
std
::
array
<
ck_tile
::
index_t
,
3
>
{
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
:
std
::
array
<
ck_tile
::
index_t
,
3
>
{
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
...
@@ -443,6 +568,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -443,6 +568,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
k_buf
(
k_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
k_buf
(
k_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
v_buf
(
v_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
v_buf
(
v_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
bias_buf
(
bias_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
bias_buf
(
bias_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
lse_acc_buf
(
lse_acc_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
o_acc_buf
(
o_acc_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
lse_buf
(
lse_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
lse_buf
(
lse_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
seqstart_q
(
seqstart_q_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_q
(
seqstart_q_host
.
size
()
*
sizeof
(
int32_t
));
...
@@ -479,7 +606,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -479,7 +606,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
:
(
std
::
string
(
"("
)
+
std
::
to_string
(
seqlen_kpads
[
0
])
+
")"
))
:
(
std
::
string
(
"("
)
+
std
::
to_string
(
seqlen_kpads
[
0
])
+
")"
))
<<
", d:"
<<
hdim_q
<<
"/"
<<
hdim_v
<<
", scale_s:"
<<
scale_s
<<
", bias:"
<<
bias
<<
", d:"
<<
hdim_q
<<
"/"
<<
hdim_v
<<
", scale_s:"
<<
scale_s
<<
", bias:"
<<
bias
<<
", p_drop:"
<<
p_drop
<<
", lse:"
<<
lse
<<
", squant:"
<<
squant
<<
", p_drop:"
<<
p_drop
<<
", lse:"
<<
lse
<<
", squant:"
<<
squant
<<
", mask:"
<<
mask
<<
", v:"
<<
vlayout
<<
std
::
flush
;
<<
", mask:"
<<
mask
<<
", v:"
<<
vlayout
;
if
(
1
<
num_splits
)
{
std
::
cout
<<
", num_splits:"
<<
num_splits
;
}
std
::
cout
<<
std
::
flush
;
auto
fmha_traits
=
fmha_fwd_traits
{
hdim_q
,
auto
fmha_traits
=
fmha_fwd_traits
{
hdim_q
,
hdim_v
,
hdim_v
,
...
@@ -523,6 +655,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -523,6 +655,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}();
}();
const
ck_tile
::
index_t
stride_bias
=
(
i_perm
?
shape_seqlen_k
:
1
*
shape_seqlen_k
);
const
ck_tile
::
index_t
stride_bias
=
(
i_perm
?
shape_seqlen_k
:
1
*
shape_seqlen_k
);
const
ck_tile
::
index_t
stride_randval
=
(
max_seqlen_k
);
const
ck_tile
::
index_t
stride_randval
=
(
max_seqlen_k
);
const
ck_tile
::
index_t
stride_o_acc
=
hdim_v
;
const
ck_tile
::
index_t
stride_o
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
);
const
ck_tile
::
index_t
stride_o
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
);
// setup nhead_stride_* arguments
// setup nhead_stride_* arguments
const
ck_tile
::
index_t
nhead_stride_q
=
(
i_perm
?
shape_seqlen_q
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_q
=
(
i_perm
?
shape_seqlen_q
*
hdim_q
:
hdim_q
);
...
@@ -537,6 +670,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -537,6 +670,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
(
i_perm
?
0
*
shape_seqlen_q
*
shape_seqlen_k
:
0
*
shape_seqlen_k
);
(
i_perm
?
0
*
shape_seqlen_q
*
shape_seqlen_k
:
0
*
shape_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_randval
=
(
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_randval
=
(
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_lse
=
max_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lse
=
max_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lse_acc
=
max_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_o_acc
=
(
max_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
// setup batch_stride_* arguments
// setup batch_stride_* arguments
const
ck_tile
::
index_t
batch_stride_q
=
(
nhead
*
shape_seqlen_q
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_q
=
(
nhead
*
shape_seqlen_q
*
hdim_q
);
...
@@ -545,7 +680,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -545,7 +680,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
batch_stride_bias
=
(
0
*
nhead
*
shape_seqlen_q
*
shape_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_bias
=
(
0
*
nhead
*
shape_seqlen_q
*
shape_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_randval
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_randval
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_lse
=
(
nhead
*
max_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse
=
(
nhead
*
max_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse_acc
=
(
nhead
*
max_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_o_acc
=
(
nhead
*
max_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
// setup split_stride_* arguments (only used in split-kv kernel)
const
ck_tile
::
index_t
split_stride_lse_acc
=
(
batch
*
nhead
*
max_seqlen_q
);
const
ck_tile
::
index_t
split_stride_o_acc
=
(
batch
*
nhead
*
max_seqlen_q
*
hdim_v
);
return
fmha_fwd_args
{
q_buf
.
GetDeviceBuffer
(),
return
fmha_fwd_args
{
q_buf
.
GetDeviceBuffer
(),
k_buf
.
GetDeviceBuffer
(),
k_buf
.
GetDeviceBuffer
(),
...
@@ -553,6 +693,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -553,6 +693,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias
.
type
==
bias_enum
::
alibi
?
alibi_slope_buf
.
GetDeviceBuffer
()
bias
.
type
==
bias_enum
::
alibi
?
alibi_slope_buf
.
GetDeviceBuffer
()
:
bias_buf
.
GetDeviceBuffer
(),
:
bias_buf
.
GetDeviceBuffer
(),
randval_buf
.
GetDeviceBuffer
(),
randval_buf
.
GetDeviceBuffer
(),
lse_acc_buf
.
GetDeviceBuffer
(),
o_acc_buf
.
GetDeviceBuffer
(),
lse_buf
.
GetDeviceBuffer
(),
lse_buf
.
GetDeviceBuffer
(),
o_buf
.
GetDeviceBuffer
(),
o_buf
.
GetDeviceBuffer
(),
seqstart_q
.
GetDeviceBuffer
(),
seqstart_q
.
GetDeviceBuffer
(),
...
@@ -566,6 +708,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -566,6 +708,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
hdim_v
,
hdim_v
,
nhead
,
nhead
,
nhead_k
,
nhead_k
,
num_splits
,
scale_s
,
scale_s
,
scale_p
,
scale_p
,
scale_o
,
scale_o
,
...
@@ -575,6 +718,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -575,6 +718,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias
.
type
==
bias_enum
::
alibi
?
(
bias
.
rank_info
==
0
?
0
:
nhead
)
bias
.
type
==
bias_enum
::
alibi
?
(
bias
.
rank_info
==
0
?
0
:
nhead
)
:
stride_bias
,
:
stride_bias
,
stride_randval
,
stride_randval
,
stride_o_acc
,
stride_o
,
stride_o
,
nhead_stride_q
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_k
,
...
@@ -582,6 +726,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -582,6 +726,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_bias
,
nhead_stride_bias
,
nhead_stride_randval
,
nhead_stride_randval
,
nhead_stride_lse
,
nhead_stride_lse
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
nhead_stride_o
,
batch_stride_q
,
batch_stride_q
,
batch_stride_k
,
batch_stride_k
,
...
@@ -589,7 +735,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -589,7 +735,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_bias
,
batch_stride_bias
,
batch_stride_randval
,
batch_stride_randval
,
batch_stride_lse
,
batch_stride_lse
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
batch_stride_o
,
batch_stride_o
,
split_stride_lse_acc
,
split_stride_o_acc
,
mask
.
left
,
mask
.
left
,
mask
.
right
,
mask
.
right
,
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
...
@@ -598,7 +748,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -598,7 +748,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
drop_seed
,
drop_offset
}};
{
drop_seed
,
drop_offset
}};
}();
}();
float
ave_time
=
fmha_fwd
(
fmha_traits
,
fmha_args
,
stream_config
);
float
ave_time
=
fmha_fwd
_dispatch
(
fmha_traits
,
fmha_args
,
stream_config
);
if
(
ave_time
<
0
)
if
(
ave_time
<
0
)
{
{
...
@@ -849,14 +999,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -849,14 +999,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
lse_host_result
.
ForEach
(
lse_host_result
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_host
(
wb
,
idx
[
0
],
idx
[
1
]);
});
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_host
(
wb
,
idx
[
0
],
idx
[
1
]);
});
bool
lse
_pass
=
ck_tile
::
check_err
(
lse_host_result
,
cur
_pass
=
ck_tile
::
check_err
(
lse_host_result
,
lse_host_ref
,
lse_host_ref
,
"LSE Error: Incorrect results!"
,
"LSE Error: Incorrect results!"
,
rtol
,
rtol
,
atol
,
atol
,
/* allow_infinity_ref = */
true
);
/* allow_infinity_ref = */
true
);
pass
&=
lse
_pass
;
pass
&=
cur
_pass
;
if
(
!
cur_pass
)
if
(
!
cur_pass
)
{
{
std
::
cerr
<<
"LSE mismatch found at batch: "
<<
wb
<<
std
::
endl
std
::
cerr
<<
"LSE mismatch found at batch: "
<<
wb
<<
std
::
endl
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment