Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c8c016dd
Commit
c8c016dd
authored
Dec 13, 2024
by
aska-0096
Browse files
Merge branch 'develop' of
https://github.com/ROCm/composable_kernel
into update_cka8w8
parents
e8ca3daf
4e731776
Changes
399
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
373 additions
and
261 deletions
+373
-261
example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc
...e_softmax_gemm/run_multi_query_attention_forward_wmma.inc
+2
-2
example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc
...tched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc
+2
-2
example/35_splitK_gemm/run_splitK_gemm_example.inc
example/35_splitK_gemm/run_splitK_gemm_example.inc
+5
-2
example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
..._gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
+1
-1
example/38_grouped_conv_bwd_data_multiple_d/common.hpp
example/38_grouped_conv_bwd_data_multiple_d/common.hpp
+2
-2
example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp
...ftmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp
+2
-2
example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp
..._ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp
+4
-4
example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
...lti_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
+3
-3
example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp
...62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp
+5
-4
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
...iply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
+0
-3
example/CMakeLists.txt
example/CMakeLists.txt
+7
-0
example/README.md
example/README.md
+2
-0
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+11
-4
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+7
-7
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+7
-4
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
+6
-3
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+42
-33
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+7
-7
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+114
-106
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+144
-72
No files found.
example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
int
run
(
int
argc
,
char
*
argv
[])
int
run
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -156,7 +156,7 @@ int run(int argc, char* argv[])
...
@@ -156,7 +156,7 @@ int run(int argc, char* argv[])
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
break
;
default
:
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
2
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
}
...
...
example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
int
run
(
int
argc
,
char
*
argv
[])
int
run
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -173,7 +173,7 @@ int run(int argc, char* argv[])
...
@@ -173,7 +173,7 @@ int run(int argc, char* argv[])
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
break
;
default
:
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
2
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
}
...
...
example/35_splitK_gemm/run_splitK_gemm_example.inc
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
struct
ProblemSize
final
struct
ProblemSize
final
...
@@ -66,8 +69,8 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
...
@@ -66,8 +69,8 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
break
;
default
:
default
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
0
>
{});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
BDataType
,
1
>
{});
}
}
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
...
...
example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
View file @
c8c016dd
...
@@ -377,7 +377,7 @@ int main(int argc, char* argv[])
...
@@ -377,7 +377,7 @@ int main(int argc, char* argv[])
break
;
break
;
default:
default:
a0_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A0DataType
>
{
1
});
a0_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A0DataType
>
{
1
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
d00_g_m_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D00DataType
>
{
1
});
d00_g_m_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D00DataType
>
{
1
});
d01_g_m_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D01DataType
>
{
1
});
d01_g_m_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D01DataType
>
{
1
});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
...
...
example/38_grouped_conv_bwd_data_multiple_d/common.hpp
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -41,7 +41,7 @@ struct ExecutionConfig final
...
@@ -41,7 +41,7 @@ struct ExecutionConfig final
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
bool
time_kernel
=
tru
e
;
bool
time_kernel
=
fals
e
;
};
};
#define DefaultConvParams \
#define DefaultConvParams \
...
...
example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
...
@@ -248,7 +248,7 @@ int main(int argc, char* argv[])
...
@@ -248,7 +248,7 @@ int main(int argc, char* argv[])
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{
1
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{
1
});
break
;
break
;
default:
default:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
2
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{
1
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{
1
});
...
...
example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp
View file @
c8c016dd
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -194,9 +194,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -194,9 +194,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b1_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
-
0.5
,
0.5
});
b1_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
-
0.5
,
0.5
});
break
;
break
;
default:
default:
a0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
a0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
A0DataType
,
0
>
{});
b0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
b1_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b1_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B1DataType
,
1
>
{});
}
}
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
-
0.5
,
0.5
});
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
-
0.5
,
0.5
});
...
...
example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
View file @
c8c016dd
...
@@ -184,9 +184,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -184,9 +184,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
break
;
break
;
default:
default:
a0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
a0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
A0DataType
,
0
>
{});
a1_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
a1_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
A1DataType
,
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
}
}
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
-
0.5
,
0.5
});
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
-
0.5
,
0.5
});
...
...
example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp
View file @
c8c016dd
...
@@ -172,12 +172,13 @@ bool run_grouped_conv_fwd(bool do_verification,
...
@@ -172,12 +172,13 @@ bool run_grouped_conv_fwd(bool do_verification,
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
// values generated: -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5
wei
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
6
});
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
1.0
,
1.0
});
break
;
break
;
default:
default:
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
0
.0
,
1
.0
});
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
-
5
.0
,
5
.0
});
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
1.0
,
1.0
});
}
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpaceSize
());
...
...
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
View file @
c8c016dd
...
@@ -205,7 +205,6 @@ int main(int argc, char* argv[])
...
@@ -205,7 +205,6 @@ int main(int argc, char* argv[])
a1_device_buf
.
ToDevice
(
a1_m_k
.
mData
.
data
());
a1_device_buf
.
ToDevice
(
a1_m_k
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_k_n
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_k_n
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_k_n
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_k_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
...
@@ -253,8 +252,6 @@ int main(int argc, char* argv[])
...
@@ -253,8 +252,6 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
if
(
do_verification
)
if
(
do_verification
)
{
{
Tensor
<
AccDataType
>
c_m_n
({
M
,
N
});
Tensor
<
AccDataType
>
c_m_n
({
M
,
N
});
...
...
example/CMakeLists.txt
View file @
c8c016dd
...
@@ -54,6 +54,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
...
@@ -54,6 +54,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
#Do not build any DPP examples if DL_KERNELS not set
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dpp"
)
message
(
"removing dpp example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endforeach
()
#Do not build any XDL examples if gfx9 targets are not on the list
#Do not build any XDL examples if gfx9 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT EX_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
if
(
NOT EX_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
...
...
example/README.md
0 → 100644
View file @
c8c016dd
[
Back to the main page
](
../README.md
)
# Composable Kernel examples
\ No newline at end of file
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
View file @
c8c016dd
...
@@ -2,10 +2,17 @@
...
@@ -2,10 +2,17 @@
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
# generate kernel instances to speed up compilation
DTYPE_MAP
=
{
FWD_DTYPE_MAP
=
{
"fp16"
:
"ck_tile::fp16_t"
,
"fp16"
:
"FmhaFwdFp16"
,
"bf16"
:
"ck_tile::bf16_t"
,
"bf16"
:
"FmhaFwdBf16"
,
"fp8"
:
"ck_tile::fp8_t"
"fp8"
:
"FmhaFwdFp8"
,
"fp8fp16"
:
"FmhaFwdFp8Fp16"
,
"fp8bf16"
:
"FmhaFwdFp8Bf16"
}
BWD_DTYPE_MAP
=
{
"fp16"
:
"FmhaBwdFp16"
,
"bf16"
:
"FmhaBwdBf16"
}
}
MASK_IMPL
=
{
MASK_IMPL
=
{
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
c8c016dd
...
@@ -283,7 +283,7 @@ class FmhaBwdApiPool:
...
@@ -283,7 +283,7 @@ class FmhaBwdApiPool:
inners
=
inners
+
FMHA_BWD_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_pipeline_enum
=
BWD_DQDKDV_PIPELINE_ENUM_MAP
[
trait
.
pipeline
],
inners
=
inners
+
FMHA_BWD_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_pipeline_enum
=
BWD_DQDKDV_PIPELINE_ENUM_MAP
[
trait
.
pipeline
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_dbias
=
BOOL_MAP
[
trait
.
dbias
],
F_dropout_check
=
DROPOUT_CHECK_MAP
[
trait
.
dropout
],
F_dropout
=
DROPOUT_MAP
[
trait
.
dropout
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_dbias
=
BOOL_MAP
[
trait
.
dbias
],
F_dropout_check
=
DROPOUT_CHECK_MAP
[
trait
.
dropout
],
F_dropout
=
DROPOUT_MAP
[
trait
.
dropout
],
F_scheck
=
trait
.
scheck
(
spad1
=
spad1
),
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
],
F_scheck
=
trait
.
scheck
(
spad1
=
spad1
),
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_hdim
=
hdim
,
F_dtype
=
BWD_
DTYPE_MAP
[
dtype
],
F_spad0
=
BOOL_MAP
[
trait
.
spad
],
F_spad1
=
BOOL_MAP
[
spad1
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_spad0
=
BOOL_MAP
[
trait
.
spad
],
F_spad1
=
BOOL_MAP
[
spad1
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_deterministic
=
BOOL_MAP
[
trait
.
deterministic
])
F_deterministic
=
BOOL_MAP
[
trait
.
deterministic
])
...
@@ -360,7 +360,7 @@ class FmhaBwdDQDKDVKernel:
...
@@ -360,7 +360,7 @@ class FmhaBwdDQDKDVKernel:
FMHA_BWD_DQ_DK_DV_KERNEL_BODY
.
format
(
FMHA_BWD_DQ_DK_DV_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
BWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
...
@@ -469,7 +469,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -469,7 +469,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen
=
list
()
gen
=
list
()
api_pool
=
FmhaBwdApiPool
(
mask_impl
)
api_pool
=
FmhaBwdApiPool
(
mask_impl
)
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
BWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
if
d
==
None
:
if
d
==
None
:
continue
continue
...
@@ -585,7 +585,7 @@ class FmhaBwdOGradDotOKernel:
...
@@ -585,7 +585,7 @@ class FmhaBwdOGradDotOKernel:
FMHA_BWD_DOT_DO_O_KERNEL_BODY
.
format
(
FMHA_BWD_DOT_DO_O_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
BWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_dvpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_dvpad
],
F_mode
=
MODE_MAP
[
self
.
F_mode
],
F_mode
=
MODE_MAP
[
self
.
F_mode
],
...
@@ -616,7 +616,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
...
@@ -616,7 +616,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
gen
=
list
()
gen
=
list
()
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
BWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
if
d
==
None
:
if
d
==
None
:
continue
continue
...
@@ -716,7 +716,7 @@ class FmhaBwdConvertQGradKernel:
...
@@ -716,7 +716,7 @@ class FmhaBwdConvertQGradKernel:
FMHA_BWD_CONVERT_DQ_KERNEL_BODY
.
format
(
FMHA_BWD_CONVERT_DQ_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
BWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_bm0
,
F_bm0
=
self
.
F_bm0
,
F_bn0
=
self
.
F_bn0
,
F_bn0
=
self
.
F_bn0
,
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
...
@@ -751,7 +751,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
...
@@ -751,7 +751,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
gen
=
list
()
gen
=
list
()
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
BWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
d
=
get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype
(
dtype
)
if
d
==
None
:
if
d
==
None
:
continue
continue
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
c8c016dd
...
@@ -282,7 +282,7 @@ class FmhaFwdApiPool:
...
@@ -282,7 +282,7 @@ class FmhaFwdApiPool:
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0max
=
trait
.
bk0max
,
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0max
=
trait
.
bk0max
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
F_hdim
=
hdim
,
F_dtype
=
FWD_
DTYPE_MAP
[
dtype
])
if_j
=
'if'
if
j
==
0
else
'else if'
if_j
=
'if'
if
j
==
0
else
'else if'
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
if_i
=
'if'
if
i
==
0
else
'else if'
...
@@ -301,7 +301,7 @@ class FmhaFwdTileSize:
...
@@ -301,7 +301,7 @@ class FmhaFwdTileSize:
F_bk1
:
int
# tile size along kv gemm unroll
F_bk1
:
int
# tile size along kv gemm unroll
F_bk0max
:
int
# total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_bk0max
:
int
# total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0
:
int
# number of warps for gemm0 along q seqlen
F_rm0
:
int
# number of warps for gemm0 along q seqlen
F_rn0
:
int
# number of warps for gemm0 along k seqlen
F_rn0
:
int
# number of warps for gemm0 along k seqlen
F_rk0
:
int
# number of warps for gemm0 along head dim q (not used)
F_rk0
:
int
# number of warps for gemm0 along head dim q (not used)
F_rm1
:
int
# number of warps for gemm1 along q seqlen
F_rm1
:
int
# number of warps for gemm1 along q seqlen
F_rn1
:
int
# number of warps for gemm1 along head dim v
F_rn1
:
int
# number of warps for gemm1 along head dim v
...
@@ -339,7 +339,7 @@ class FmhaFwdKernel:
...
@@ -339,7 +339,7 @@ class FmhaFwdKernel:
FMHA_FWD_KERNEL_BODY
.
format
(
FMHA_FWD_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
FWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
...
@@ -462,6 +462,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
...
@@ -462,6 +462,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
# no need lse/dropout kernels
# no need lse/dropout kernels
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'f'
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'f'
,
squant
,
mask
))
elif
dtype
in
[
'fp8fp16'
,
'fp8bf16'
]:
# TODO
None
else
:
else
:
assert
False
assert
False
return
pipelines
return
pipelines
...
@@ -469,7 +472,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
...
@@ -469,7 +472,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
gen
=
list
()
gen
=
list
()
api_pool
=
FmhaFwdApiPool
(
mask_impl
)
api_pool
=
FmhaFwdApiPool
(
mask_impl
)
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
FWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
d
=
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
if
d
==
None
:
continue
continue
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
View file @
c8c016dd
...
@@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool:
...
@@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool:
inners
=
inners
+
FMHA_FWD_APPENDKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
inners
=
inners
+
FMHA_FWD_APPENDKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_rope_check
=
ROPE_CHECK_MAP
[
trait
.
rope
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_rope_check
=
ROPE_CHECK_MAP
[
trait
.
rope
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_rope
=
ROPE_MAP
[
trait
.
rope
],
F_bs
=
trait
.
bs
,
F_bsk
=
trait
.
bsk
,
F_bd
=
trait
.
bd
,
F_bdv
=
trait
.
bdv
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
F_rope
=
ROPE_MAP
[
trait
.
rope
],
F_bs
=
trait
.
bs
,
F_bsk
=
trait
.
bsk
,
F_bd
=
trait
.
bd
,
F_bdv
=
trait
.
bdv
,
F_hdim
=
hdim
,
F_dtype
=
FWD_
DTYPE_MAP
[
dtype
])
if_j
=
'if'
if
j
==
0
else
'else if'
if_j
=
'if'
if
j
==
0
else
'else if'
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
if_i
=
'if'
if
i
==
0
else
'else if'
...
@@ -216,7 +216,7 @@ class FmhaFwdAppendKVKernel:
...
@@ -216,7 +216,7 @@ class FmhaFwdAppendKVKernel:
FMHA_FWD_APPENDKV_KERNEL_BODY
.
format
(
FMHA_FWD_APPENDKV_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
FWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_bs
=
self
.
F_tile
.
F_bs
,
F_bs
=
self
.
F_tile
.
F_bs
,
F_bsk
=
self
.
F_tile
.
F_bsk
,
F_bsk
=
self
.
F_tile
.
F_bsk
,
F_bd
=
self
.
F_tile
.
F_bd
,
F_bd
=
self
.
F_tile
.
F_bd
,
...
@@ -301,6 +301,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -301,6 +301,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
elif
dtype
in
[
'fp8'
,
'bf8'
]:
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# rope/paged-kv is not supported
# rope/paged-kv is not supported
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
'col'
,
't'
,
't'
,
't'
,
't'
,
'no'
,
'f'
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
'col'
,
't'
,
't'
,
't'
,
't'
,
'no'
,
'f'
))
elif
dtype
in
[
'fp8fp16'
,
'fp8bf16'
]:
# TODO
None
else
:
else
:
assert
False
assert
False
return
pipelines
return
pipelines
...
@@ -308,7 +311,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -308,7 +311,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen
=
list
()
gen
=
list
()
api_pool
=
FmhaFwdAppendKVApiPool
(
mask_impl
)
api_pool
=
FmhaFwdAppendKVApiPool
(
mask_impl
)
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
FWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_appendkv_tile_dict_from_dtype
(
dtype
)
d
=
get_fmha_fwd_appendkv_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
if
d
==
None
:
continue
continue
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
c8c016dd
...
@@ -112,7 +112,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
...
@@ -112,7 +112,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
}}
}}
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;
{F_dvpad}>;
#include <iostream>
#include <iostream>
...
@@ -161,7 +161,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
...
@@ -161,7 +161,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
{F_hdim},
{F_hdim},
{F_bm0},
{F_bm0},
{F_bn1},
{F_bn1},
{F_mode},
{F_mode},
fmha_trait>;
fmha_trait>;
...
@@ -231,11 +231,11 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a
...
@@ -231,11 +231,11 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a
if(s.log_level_ > 0)
if(s.log_level_ > 0)
std::cout
std::cout
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
<< std::flush;
<< std::flush;
return ck_tile::launch_kernel(s,
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
);
);
}}
}}
...
@@ -247,12 +247,22 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
...
@@ -247,12 +247,22 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
}}
}}
"""
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) &&
(t.has_lse == {F_lse}) &&
(t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, ck_tile::fp8_t>) {{
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
return -1;
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, true, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, false, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}}
}}
"""
"""
...
@@ -421,11 +431,11 @@ class FmhaFwdSplitKVApiPool:
...
@@ -421,11 +431,11 @@ class FmhaFwdSplitKVApiPool:
inners
=
inners
+
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
inners
=
inners
+
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0max
=
trait
.
bk0max
,
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0max
=
trait
.
bk0max
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
F_hdim
=
hdim
,
F_dtype
=
FWD_
DTYPE_MAP
[
dtype
])
if_j
=
'if'
if
j
==
0
else
'else if'
if_j
=
'if'
if
j
==
0
else
'else if'
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
if_i
=
'if'
if
i
==
0
else
'else if'
...
@@ -462,7 +472,7 @@ class FmhaFwdSplitKVKernel:
...
@@ -462,7 +472,7 @@ class FmhaFwdSplitKVKernel:
FMHA_FWD_SPLITKV_KERNEL_BODY
.
format
(
FMHA_FWD_SPLITKV_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
FWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bn0
=
self
.
F_tile
.
F_bn0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
F_bk0
=
self
.
F_tile
.
F_bk0
,
...
@@ -482,7 +492,7 @@ class FmhaFwdSplitKVKernel:
...
@@ -482,7 +492,7 @@ class FmhaFwdSplitKVKernel:
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_skpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_skpad
],
F_skpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_skpad
],
F_dpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dpad
],
F_dpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
...
@@ -542,7 +552,7 @@ class FmhaFwdSplitKVCombineKernel:
...
@@ -542,7 +552,7 @@ class FmhaFwdSplitKVCombineKernel:
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY
.
format
(
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_dtype
=
FWD_
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bm0
=
self
.
F_tile
.
F_bm0
,
F_bn1
=
self
.
F_tile
.
F_bn1
,
F_bn1
=
self
.
F_tile
.
F_bn1
,
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
...
@@ -614,27 +624,29 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -614,27 +624,29 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant
=
't'
if
dtype
==
'fp8'
else
'f'
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
if
dtype
in
[
'fp16'
,
'bf16'
]:
for
mask
,
bias
,
lse
,
pagedkv
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
for
mask
,
bias
,
pagedkv
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
]):
# TODO: use async pipeline when compiler is more stable
# TODO: use async pipeline when compiler is more stable
if
hdim
==
256
or
hdim
in
[
32
,
64
,
128
]:
### [32, 64, 96, 128]:
if
hdim
==
256
or
hdim
in
[
32
,
64
,
128
]:
### [32, 64, 96, 128]:
# if True:
# if True:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
else
:
else
:
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
if
receipt
==
1
:
if
receipt
==
1
:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
elif
dtype
in
[
'fp8'
,
'bf8'
]:
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# no need lse/paged-kv kernels
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
squant
,
'f'
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
'f'
,
mask
))
elif
dtype
in
[
'fp8fp16'
,
'fp8bf16'
]:
# TODO
None
else
:
else
:
assert
False
assert
False
return
pipelines
return
pipelines
...
@@ -642,7 +654,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -642,7 +654,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen
=
list
()
gen
=
list
()
api_pool
=
FmhaFwdSplitKVApiPool
(
mask_impl
)
api_pool
=
FmhaFwdSplitKVApiPool
(
mask_impl
)
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
FWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
d
=
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
if
d
==
None
:
continue
continue
...
@@ -655,9 +667,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -655,9 +667,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
continue
if
pipeline
.
F_pagedkv
==
't'
:
# we only use batch mode kernels to handle (paged-) kvcache problems
continue
k
=
Kernel
(
F_idx
=
0
,
k
=
Kernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_dtype
=
dtype
,
...
@@ -705,7 +714,7 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
...
@@ -705,7 +714,7 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
gen
=
list
()
gen
=
list
()
for
dtype
in
DTYPE_MAP
.
keys
():
for
dtype
in
FWD_
DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_splitkv_combine_tile_dict_from_dtype
(
dtype
)
d
=
get_fmha_fwd_splitkv_combine_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
if
d
==
None
:
continue
continue
...
...
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
c8c016dd
...
@@ -101,7 +101,7 @@ auto create_args(int argc, char* argv[])
...
@@ -101,7 +101,7 @@ auto create_args(int argc, char* argv[])
}
}
// different threshold for different dtype
// different threshold for different dtype
template
<
typename
DataType
>
template
<
typename
DataType
Config
>
auto
get_elimit
(
ck_tile
::
index_t
/*hdim_q*/
,
ck_tile
::
index_t
/*hdim_v*/
)
auto
get_elimit
(
ck_tile
::
index_t
/*hdim_q*/
,
ck_tile
::
index_t
/*hdim_v*/
)
{
{
double
rtol
=
1e-2
;
double
rtol
=
1e-2
;
...
@@ -110,7 +110,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
...
@@ -110,7 +110,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
}
}
template
<
>
template
<
>
auto
get_elimit
<
ck_tile
::
b
f16
_t
>
(
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
)
auto
get_elimit
<
FmhaBwdB
f16
>
(
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
)
{
{
double
rtol
=
1e-2
;
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
double
atol
=
1e-2
;
...
@@ -122,7 +122,7 @@ auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_
...
@@ -122,7 +122,7 @@ auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
}
template
<
typename
DataType
>
template
<
typename
DataType
Config
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
...
@@ -209,7 +209,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -209,7 +209,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
auto
seqstart_q_host
=
generate_seqstarts
(
mode
,
batch
,
seqlen_q
);
const
auto
seqstart_q_host
=
generate_seqstarts
(
mode
,
batch
,
seqlen_q
);
const
auto
seqstart_k_host
=
generate_seqstarts
(
mode
,
batch
,
seqlen_k
);
const
auto
seqstart_k_host
=
generate_seqstarts
(
mode
,
batch
,
seqlen_k
);
using
TypeConfig
=
FmhaBwdTypeConfig
<
DataType
>
;
using
TypeConfig
=
FmhaBwdTypeConfig
<
DataType
Config
>
;
using
QDataType
=
typename
TypeConfig
::
QDataType
;
using
QDataType
=
typename
TypeConfig
::
QDataType
;
using
KDataType
=
typename
TypeConfig
::
KDataType
;
using
KDataType
=
typename
TypeConfig
::
KDataType
;
...
@@ -933,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -933,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
// clang-format on
// clang-format on
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
hdim_q
,
hdim_v
);
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
Config
>
(
hdim_q
,
hdim_v
);
bool
dq_cur_pass
=
ck_tile
::
check_err
(
dq_host_result
,
bool
dq_cur_pass
=
ck_tile
::
check_err
(
dq_host_result
,
dq_host_ref
,
dq_host_ref
,
std
::
string
(
"Error: QGrad Incorrect results!"
),
std
::
string
(
"Error: QGrad Incorrect results!"
),
...
@@ -986,11 +986,11 @@ int main(int argc, char* argv[])
...
@@ -986,11 +986,11 @@ int main(int argc, char* argv[])
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
if
(
data_type
==
"fp16"
)
if
(
data_type
==
"fp16"
)
{
{
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
FmhaBwdFp16
>
(
arg_parser
)
?
0
:
-
2
;
}
}
else
if
(
data_type
==
"bf16"
)
else
if
(
data_type
==
"bf16"
)
{
{
return
run
<
ck_tile
::
b
f16
_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
FmhaBwdB
f16
>
(
arg_parser
)
?
0
:
-
2
;
}
}
return
-
3
;
return
-
3
;
...
...
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
c8c016dd
...
@@ -14,11 +14,19 @@
...
@@ -14,11 +14,19 @@
#include <utility>
#include <utility>
#include <variant>
#include <variant>
struct
FmhaBwdFp16
{
};
struct
FmhaBwdBf16
{
};
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
FmhaBwdTypeConfig
;
struct
FmhaBwdTypeConfig
;
template
<
>
template
<
>
struct
FmhaBwdTypeConfig
<
ck_tile
::
half_t
>
struct
FmhaBwdTypeConfig
<
FmhaBwdFp16
>
{
{
using
QDataType
=
ck_tile
::
half_t
;
using
QDataType
=
ck_tile
::
half_t
;
using
KDataType
=
ck_tile
::
half_t
;
using
KDataType
=
ck_tile
::
half_t
;
...
@@ -38,7 +46,7 @@ struct FmhaBwdTypeConfig<ck_tile::half_t>
...
@@ -38,7 +46,7 @@ struct FmhaBwdTypeConfig<ck_tile::half_t>
};
};
template
<
>
template
<
>
struct
FmhaBwdTypeConfig
<
ck_tile
::
b
f16
_t
>
struct
FmhaBwdTypeConfig
<
FmhaBwdB
f16
>
{
{
using
QDataType
=
ck_tile
::
bf16_t
;
using
QDataType
=
ck_tile
::
bf16_t
;
using
KDataType
=
ck_tile
::
bf16_t
;
using
KDataType
=
ck_tile
::
bf16_t
;
...
@@ -150,113 +158,113 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
...
@@ -150,113 +158,113 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
// create group mode kernel arguments
// create group mode kernel arguments
if
constexpr
(
FmhaBwdDQDKDVKernel
::
kIsGroupMode
)
if
constexpr
(
FmhaBwdDQDKDVKernel
::
kIsGroupMode
)
{
{
return
FmhaBwdDQDKDVKernel
::
MakeKargs
(
args
.
q_ptr
,
return
FmhaBwdDQDKDVKernel
::
MakeKargs
Impl
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
bias_ptr
,
args
.
lse_ptr
,
args
.
lse_ptr
,
args
.
do_ptr
,
args
.
do_ptr
,
args
.
d_ptr
,
args
.
d_ptr
,
args
.
rand_val_ptr
,
args
.
rand_val_ptr
,
args
.
dk_ptr
,
args
.
dk_ptr
,
args
.
dv_ptr
,
args
.
dv_ptr
,
args
.
dbias_ptr
,
args
.
dbias_ptr
,
args
.
dq_acc_ptr
,
args
.
dq_acc_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
args
.
seqlen_k_ptr
,
args
.
hdim_q
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
scale
,
args
.
scale
,
args
.
stride_q
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_randval
,
args
.
stride_do
,
args
.
stride_do
,
args
.
stride_dq_acc
,
args
.
stride_dq_acc
,
args
.
stride_dk
,
args
.
stride_dk
,
args
.
stride_dv
,
args
.
stride_dv
,
args
.
stride_dbias
,
args
.
stride_dbias
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_do
,
args
.
nhead_stride_do
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dbias
,
args
.
nhead_stride_dbias
,
args
.
split_stride_dq_acc
,
args
.
split_stride_dq_acc
,
args
.
window_size_left
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
,
args
.
p_drop
,
args
.
p_drop
,
args
.
drop_seed_offset
);
args
.
drop_seed_offset
);
}
}
else
else
{
// create batch mode kernel arguments
{
// create batch mode kernel arguments
return
FmhaBwdDQDKDVKernel
::
MakeKargs
(
args
.
q_ptr
,
return
FmhaBwdDQDKDVKernel
::
MakeKargs
Impl
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
bias_ptr
,
args
.
lse_ptr
,
args
.
lse_ptr
,
args
.
do_ptr
,
args
.
do_ptr
,
args
.
d_ptr
,
args
.
d_ptr
,
args
.
rand_val_ptr
,
args
.
rand_val_ptr
,
args
.
dk_ptr
,
args
.
dk_ptr
,
args
.
dv_ptr
,
args
.
dv_ptr
,
args
.
dbias_ptr
,
args
.
dbias_ptr
,
args
.
dq_acc_ptr
,
args
.
dq_acc_ptr
,
args
.
seqlen_q
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
seqlen_k
,
args
.
hdim_q
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
scale
,
args
.
scale
,
args
.
stride_q
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_randval
,
args
.
stride_do
,
args
.
stride_do
,
args
.
stride_dq_acc
,
args
.
stride_dq_acc
,
args
.
stride_dk
,
args
.
stride_dk
,
args
.
stride_dv
,
args
.
stride_dv
,
args
.
stride_dbias
,
args
.
stride_dbias
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_do
,
args
.
nhead_stride_do
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dbias
,
args
.
nhead_stride_dbias
,
args
.
batch_stride_q
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_v
,
args
.
batch_stride_bias
,
args
.
batch_stride_bias
,
args
.
batch_stride_randval
,
args
.
batch_stride_randval
,
args
.
batch_stride_do
,
args
.
batch_stride_do
,
args
.
batch_stride_lsed
,
args
.
batch_stride_lsed
,
args
.
batch_stride_dq_acc
,
args
.
batch_stride_dq_acc
,
args
.
batch_stride_dk
,
args
.
batch_stride_dk
,
args
.
batch_stride_dv
,
args
.
batch_stride_dv
,
args
.
batch_stride_dbias
,
args
.
batch_stride_dbias
,
args
.
split_stride_dq_acc
,
args
.
split_stride_dq_acc
,
args
.
window_size_left
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
,
args
.
p_drop
,
args
.
p_drop
,
args
.
drop_seed_offset
);
args
.
drop_seed_offset
);
}
}
}();
}();
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
c8c016dd
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "fmha_fwd.hpp"
#include "fmha_fwd.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ref/naive_attention.hpp"
#include "mask.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include "rotary.hpp"
#include "utils.hpp"
#include "utils.hpp"
...
@@ -41,7 +42,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
...
@@ -41,7 +42,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"v"
,
"1"
,
"
weather do CPU validation or not
"
)
arg_parser
.
insert
(
"v"
,
"1"
,
"
0:no validation, 2:cpu validation, 2:gpu validation(experimental)
"
)
.
insert
(
"mode"
,
"0"
,
"kernel mode. 0:batch, 1:group"
)
.
insert
(
"mode"
,
"0"
,
"kernel mode. 0:batch, 1:group"
)
.
insert
(
"b"
,
"2"
,
"batch size"
)
.
insert
(
"b"
,
"2"
,
"batch size"
)
.
insert
(
"h"
,
"8"
,
"num of head, for q"
)
.
insert
(
"h"
,
"8"
,
"num of head, for q"
)
...
@@ -62,7 +63,7 @@ auto create_args(int argc, char* argv[])
...
@@ -62,7 +63,7 @@ auto create_args(int argc, char* argv[])
"-1 to choose s_knew in [1, s] randomly."
)
"-1 to choose s_knew in [1, s] randomly."
)
.
insert
(
"s_kpad"
,
.
insert
(
"s_kpad"
,
"-1"
,
"-1"
,
"seqlen_k stride between 2
token
s, currently used in group-mode only
\n
"
"seqlen_k stride between 2
batche
s, currently used in group-mode only
\n
"
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride
\n
"
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride
\n
"
"along seqlen, instead of packed. same as xformer kv_padding"
)
"along seqlen, instead of packed. same as xformer kv_padding"
)
.
insert
(
"d"
,
"128"
,
"head dim for q, k"
)
.
insert
(
"d"
,
"128"
,
"head dim for q, k"
)
...
@@ -142,7 +143,7 @@ auto create_args(int argc, char* argv[])
...
@@ -142,7 +143,7 @@ auto create_args(int argc, char* argv[])
}
}
// different threshold for different dtype
// different threshold for different dtype
template
<
typename
DataType
>
template
<
typename
DataType
Config
>
auto
get_elimit
(
std
::
string
/*init_method*/
)
auto
get_elimit
(
std
::
string
/*init_method*/
)
{
{
double
rtol
=
1e-3
;
double
rtol
=
1e-3
;
...
@@ -151,7 +152,7 @@ auto get_elimit(std::string /*init_method*/)
...
@@ -151,7 +152,7 @@ auto get_elimit(std::string /*init_method*/)
}
}
template
<
>
template
<
>
auto
get_elimit
<
ck_tile
::
b
f16
_t
>
(
std
::
string
/*init_method*/
)
auto
get_elimit
<
FmhaFwdB
f16
>
(
std
::
string
/*init_method*/
)
{
{
double
rtol
=
1e-2
;
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
double
atol
=
1e-2
;
...
@@ -159,7 +160,7 @@ auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
...
@@ -159,7 +160,7 @@ auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
}
}
template
<
>
template
<
>
auto
get_elimit
<
ck_tile
::
fp8_t
>
(
std
::
string
init_method
)
auto
get_elimit
<
FmhaFwdFp8
>
(
std
::
string
init_method
)
{
{
if
(
init_method
==
"ui"
||
init_method
==
"ni"
)
if
(
init_method
==
"ui"
||
init_method
==
"ni"
)
{
{
...
@@ -261,7 +262,7 @@ int override_num_splits_if_necessary(
...
@@ -261,7 +262,7 @@ int override_num_splits_if_necessary(
return
num_splits
;
return
num_splits
;
}
}
template
<
typename
DataType
>
template
<
typename
DataType
Config
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
...
@@ -294,7 +295,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -294,7 +295,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
#if !CK_TILE_FMHA_FWD_APPENDKV_API
#if !CK_TILE_FMHA_FWD_APPENDKV_API
if
(
seqlen_knew
!=
0
)
if
(
seqlen_knew
!=
0
)
{
{
std
::
cerr
<<
"kvcache is not supported. ignoring the 's_knew' option"
<<
std
::
endl
;
std
::
cerr
<<
"fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option"
<<
std
::
endl
;
seqlen_knew
=
0
;
seqlen_knew
=
0
;
}
}
#endif
#endif
...
@@ -304,8 +306,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -304,8 +306,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
ck_tile
::
index_t
rotary_dim
=
arg_parser
.
get_int
(
"rotary_dim"
);
ck_tile
::
index_t
rotary_dim
=
arg_parser
.
get_int
(
"rotary_dim"
);
if
constexpr
(
!
(
std
::
is_same_v
<
DataType
,
ck_tile
::
f
p16
_t
>
||
if
constexpr
(
!
(
std
::
is_same_v
<
DataType
Config
,
FmhaFwdF
p16
>
||
std
::
is_same_v
<
DataType
,
ck_tile
::
b
f16
_t
>
))
std
::
is_same_v
<
DataType
Config
,
FmhaFwdB
f16
>
))
{
{
if
(
0
<
rotary_dim
)
if
(
0
<
rotary_dim
)
{
{
...
@@ -321,6 +323,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -321,6 +323,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
rotary_dim
=
0
;
rotary_dim
=
0
;
}
}
#endif
#endif
// to use fmha_fwd_appendkv(), make sure it's in batch mode
const
bool
need_append_kvcache
=
(
0
<
seqlen_knew
||
0
<
rotary_dim
);
if
(
need_append_kvcache
&&
mode
==
mode_enum
::
group
)
{
std
::
cerr
<<
"fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option"
<<
std
::
endl
;
mode
=
mode_enum
::
batch
;
}
if
(
!
(
rotary_dim
<=
hdim_q
))
if
(
!
(
rotary_dim
<=
hdim_q
))
{
{
std
::
cerr
<<
"rotary_dim should be less than or equal to head dim for q"
<<
std
::
endl
;
std
::
cerr
<<
"rotary_dim should be less than or equal to head dim for q"
<<
std
::
endl
;
...
@@ -356,22 +365,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -356,22 +365,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
<<
std
::
endl
;
<<
std
::
endl
;
use_cache_batch_idx
=
false
;
use_cache_batch_idx
=
false
;
}
}
#e
ndif
#e
lse
if
(
0
<
page_block_size
&&
use_cache_batch_idx
)
if
(
use_cache_batch_idx
)
{
{
std
::
cerr
<<
"paged-kvcache does not support cache_batch_idx. ignoring the "
if
(
0
<
page_block_size
)
"'cache_batch_idx' option"
{
<<
std
::
endl
;
std
::
cerr
<<
"paged-kvcache does not support cache_batch_idx. ignoring the "
use_cache_batch_idx
=
false
;
"'cache_batch_idx' option"
<<
std
::
endl
;
use_cache_batch_idx
=
false
;
}
else
if
(
mode
==
mode_enum
::
group
)
{
std
::
cerr
<<
"group mode will not use cache_batch_idx. ignoring the "
"'cache_batch_idx' option"
<<
std
::
endl
;
use_cache_batch_idx
=
false
;
}
}
}
// the input tensor layout for kvcache is same as batch mode
#endif
const
bool
need_append_kvcache
=
(
0
<
seqlen_knew
||
0
<
rotary_dim
);
const
bool
use_kvcache
=
(
need_append_kvcache
||
use_cache_batch_idx
||
0
<
page_block_size
);
const
bool
use_kvcache
=
(
need_append_kvcache
||
use_cache_batch_idx
||
0
<
page_block_size
);
if
(
use_kvcache
&&
mode
!=
mode_enum
::
batch
)
{
std
::
cerr
<<
"kvcache enabled. ignoring the 'mode' option"
<<
std
::
endl
;
mode
=
mode_enum
::
batch
;
}
auto
[
seqlen_qs
,
seqlen_ks
,
seqlen_kpads
]
=
auto
[
seqlen_qs
,
seqlen_ks
,
seqlen_kpads
]
=
decode_seqlen
(
mode
,
decode_seqlen
(
mode
,
...
@@ -380,7 +393,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -380,7 +393,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
arg_parser
.
get_str
(
"s_k"
),
arg_parser
.
get_str
(
"s_k"
),
arg_parser
.
get_str
(
"s_kpad"
),
arg_parser
.
get_str
(
"s_kpad"
),
/*seqlen_k_min=*/
0
<
seqlen_knew
?
seqlen_knew
:
0
,
/*seqlen_k_min=*/
0
<
seqlen_knew
?
seqlen_knew
:
0
,
use
_kvcache
);
need_append
_kvcache
);
// compute kvcache seqlen_k (before appending knew/vnew)
// compute kvcache seqlen_k (before appending knew/vnew)
auto
cache_seqlen_ks
=
seqlen_ks
;
auto
cache_seqlen_ks
=
seqlen_ks
;
std
::
transform
(
cache_seqlen_ks
.
begin
(),
std
::
transform
(
cache_seqlen_ks
.
begin
(),
...
@@ -416,25 +429,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -416,25 +429,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
return
atoi
(
squant_str
.
c_str
())
!=
0
?
true
:
false
;
return
atoi
(
squant_str
.
c_str
())
!=
0
?
true
:
false
;
}();
}();
float
range_q
=
arg_parser
.
get_float
(
"range_q"
);
float
range_k
=
arg_parser
.
get_float
(
"range_k"
);
float
range_v
=
arg_parser
.
get_float
(
"range_v"
);
float
range_p
=
arg_parser
.
get_float
(
"range_p"
);
float
range_o
=
arg_parser
.
get_float
(
"range_o"
);
float
dtype_max
=
ck_tile
::
type_convert
<
float
>
(
ck_tile
::
numeric
<
DataType
>::
max
());
float
scale_p
=
1.
f
;
float
scale_o
=
1.
f
;
if
(
squant
)
{
scale_s
=
scale_s
*
(
range_q
/
dtype_max
)
*
(
range_k
/
dtype_max
);
scale_p
=
dtype_max
/
range_p
;
// scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)]
scale_o
=
range_p
*
range_v
/
range_o
/
dtype_max
;
}
std
::
string
vlayout
=
arg_parser
.
get_str
(
"vlayout"
);
std
::
string
vlayout
=
arg_parser
.
get_str
(
"vlayout"
);
bool
lse
=
arg_parser
.
get_bool
(
"lse"
);
bool
lse
=
arg_parser
.
get_bool
(
"lse"
);
...
@@ -454,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -454,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
bool
s_randval
=
false
;
bool
s_randval
=
false
;
if
(
p_drop
>
0.0
f
&&
do_validation
)
if
(
p_drop
>
0.0
f
&&
do_validation
!=
0
)
{
{
s_randval
=
true
;
s_randval
=
true
;
}
}
...
@@ -487,7 +481,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -487,7 +481,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
auto
seqstart_k_host
=
to_seqstarts
(
seqlen_ks
);
const
auto
seqstart_k_host
=
to_seqstarts
(
seqlen_ks
);
const
auto
seqstart_k_with_padding_host
=
to_seqstarts
(
seqlen_kpads
);
const
auto
seqstart_k_with_padding_host
=
to_seqstarts
(
seqlen_kpads
);
using
TypeConfig
=
FmhaFwdTypeConfig
<
DataType
>
;
using
TypeConfig
=
FmhaFwdTypeConfig
<
DataType
Config
>
;
using
QDataType
=
typename
TypeConfig
::
QDataType
;
using
QDataType
=
typename
TypeConfig
::
QDataType
;
using
KDataType
=
typename
TypeConfig
::
KDataType
;
using
KDataType
=
typename
TypeConfig
::
KDataType
;
...
@@ -501,6 +495,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -501,6 +495,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
OaccDataType
=
typename
TypeConfig
::
OaccDataType
;
using
OaccDataType
=
typename
TypeConfig
::
OaccDataType
;
using
ODataType
=
typename
TypeConfig
::
ODataType
;
using
ODataType
=
typename
TypeConfig
::
ODataType
;
float
range_q
=
arg_parser
.
get_float
(
"range_q"
);
float
range_k
=
arg_parser
.
get_float
(
"range_k"
);
float
range_v
=
arg_parser
.
get_float
(
"range_v"
);
float
range_p
=
arg_parser
.
get_float
(
"range_p"
);
float
range_o
=
arg_parser
.
get_float
(
"range_o"
);
float
q_dtype_max
=
ck_tile
::
type_convert
<
float
>
(
ck_tile
::
numeric
<
QDataType
>::
max
());
float
k_dtype_max
=
ck_tile
::
type_convert
<
float
>
(
ck_tile
::
numeric
<
KDataType
>::
max
());
float
v_dtype_max
=
ck_tile
::
type_convert
<
float
>
(
ck_tile
::
numeric
<
VDataType
>::
max
());
float
p_dtype_max
=
v_dtype_max
;
// assume p and v is the same type
float
o_dtype_max
=
ck_tile
::
type_convert
<
float
>
(
ck_tile
::
numeric
<
ODataType
>::
max
());
float
scale_p
=
1.
f
;
float
scale_o
=
1.
f
;
if
(
squant
)
{
scale_s
=
scale_s
*
(
range_q
/
q_dtype_max
)
*
(
range_k
/
k_dtype_max
);
scale_p
=
p_dtype_max
/
range_p
;
scale_o
=
(
o_dtype_max
/
range_o
)
*
(
range_p
/
p_dtype_max
)
*
(
range_v
/
v_dtype_max
);
}
// accumulation numbers for performance evaluation
// accumulation numbers for performance evaluation
std
::
size_t
flop
=
0
,
num_byte
=
0
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
auto
max_seqlen_q
=
auto
max_seqlen_q
=
...
@@ -697,14 +713,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -697,14 +713,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
else
if
(
init_method
==
"ufq"
||
init_method
==
"uf:q"
||
else
if
(
init_method
==
"ufq"
||
init_method
==
"uf:q"
||
init_method
==
"3"
)
// suitable for fp8 quantization
init_method
==
"3"
)
// suitable for fp8 quantization
{
{
ck_tile
::
FillUniformDistribution
<
QDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistribution
<
QDataType
>
{
-
q_
dtype_max
,
q_
dtype_max
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
-
k_
dtype_max
,
k_
dtype_max
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
knew_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
-
k_
dtype_max
,
k_
dtype_max
,
seed
}(
knew_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
-
v_
dtype_max
,
v_
dtype_max
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
vnew_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
-
v_
dtype_max
,
v_
dtype_max
,
seed
}(
vnew_host
);
// bias_fp8 = qscale_bias * bias_fp32
// bias_fp8 = qscale_bias * bias_fp32
float
qscale_bias
=
(
dtype_max
/
range_q
)
*
(
dtype_max
/
range_k
);
float
qscale_bias
=
(
q_
dtype_max
/
range_q
)
*
(
k_
dtype_max
/
range_k
);
// Assume bias is in [-1.f, 1.f] in original fp32
// Assume bias is in [-1.f, 1.f] in original fp32
ck_tile
::
FillUniformDistribution
<
BiasDataType
>
{
-
qscale_bias
,
qscale_bias
,
seed
}(
bias_host
);
ck_tile
::
FillUniformDistribution
<
BiasDataType
>
{
-
qscale_bias
,
qscale_bias
,
seed
}(
bias_host
);
}
}
...
@@ -741,8 +757,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -741,8 +757,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
seqstart_q
(
seqstart_q_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_q
(
seqstart_q_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_k
(
seqstart_k_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_k
(
seqstart_k_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqlen_k_buf
(
ck_tile
::
DeviceMem
seqlen_k_buf
((
mode
==
mode_enum
::
batch
&&
use_kvcache
)
||
use_kvcache
||
0
<=
seqlen_kpads
[
0
]
?
seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
0
<=
seqlen_kpads
[
0
]
?
seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
ck_tile
::
DeviceMem
cache_seqlen_k_buf
(
ck_tile
::
DeviceMem
cache_seqlen_k_buf
(
need_append_kvcache
?
cache_seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
need_append_kvcache
?
cache_seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
ck_tile
::
DeviceMem
rotary_cos_buf
(
rotary_cos_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
rotary_cos_buf
(
rotary_cos_host
.
get_element_space_size_in_bytes
());
...
@@ -763,7 +781,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -763,7 +781,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
seqstart_q
.
ToDevice
(
seqstart_q_host
.
data
());
seqstart_q
.
ToDevice
(
seqstart_q_host
.
data
());
seqstart_k
.
ToDevice
(
seqlen_kpads
[
0
]
<
0
?
seqstart_k_host
.
data
()
seqstart_k
.
ToDevice
(
seqlen_kpads
[
0
]
<
0
?
seqstart_k_host
.
data
()
:
seqstart_k_with_padding_host
.
data
());
:
seqstart_k_with_padding_host
.
data
());
seqlen_k_buf
.
ToDevice
(
use_kvcache
||
0
<=
seqlen_kpads
[
0
]
?
seqlen_ks
.
data
()
:
nullptr
);
seqlen_k_buf
.
ToDevice
((
mode
==
mode_enum
::
batch
&&
use_kvcache
)
||
0
<=
seqlen_kpads
[
0
]
?
seqlen_ks
.
data
()
:
nullptr
);
cache_seqlen_k_buf
.
ToDevice
(
need_append_kvcache
?
cache_seqlen_ks
.
data
()
:
nullptr
);
cache_seqlen_k_buf
.
ToDevice
(
need_append_kvcache
?
cache_seqlen_ks
.
data
()
:
nullptr
);
rotary_cos_buf
.
ToDevice
(
rotary_cos_host
.
data
());
rotary_cos_buf
.
ToDevice
(
rotary_cos_host
.
data
());
rotary_sin_buf
.
ToDevice
(
rotary_sin_host
.
data
());
rotary_sin_buf
.
ToDevice
(
rotary_sin_host
.
data
());
...
@@ -976,8 +996,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -976,8 +996,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(
mode
==
mode_enum
::
group
?
seqstart_q
.
GetDeviceBuffer
()
:
nullptr
);
(
mode
==
mode_enum
::
group
?
seqstart_q
.
GetDeviceBuffer
()
:
nullptr
);
args
.
seqstart_k_ptr
=
args
.
seqstart_k_ptr
=
(
mode
==
mode_enum
::
group
?
seqstart_k
.
GetDeviceBuffer
()
:
nullptr
);
(
mode
==
mode_enum
::
group
?
seqstart_k
.
GetDeviceBuffer
()
:
nullptr
);
args
.
seqlen_k_ptr
=
args
.
seqlen_k_ptr
=
((
mode
==
mode_enum
::
batch
&&
use_kvcache
)
||
0
<=
k_paddings_
[
0
]
(
use_kvcache
||
0
<=
k_paddings_
[
0
]
?
seqlen_k_buf
.
GetDeviceBuffer
()
:
nullptr
);
?
seqlen_k_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
seqlen_k
=
shape_seqlen_k
;
// unused in group mode (or kvcache enabled)
args
.
seqlen_k
=
shape_seqlen_k
;
// unused in group mode (or kvcache enabled)
args
.
max_seqlen_q
=
max_seqlen_q
;
args
.
max_seqlen_q
=
max_seqlen_q
;
...
@@ -1029,6 +1050,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -1029,6 +1050,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
(
0
<
page_block_size
?
block_table_buf
.
GetDeviceBuffer
()
:
nullptr
);
(
0
<
page_block_size
?
block_table_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
batch_stride_block_table
=
batch_stride_block_table
;
args
.
batch_stride_block_table
=
batch_stride_block_table
;
args
.
page_block_size
=
page_block_size
;
args
.
page_block_size
=
page_block_size
;
args
.
is_gappy
=
false
;
// use 'false' for flash-attention integration
args
.
cache_batch_idx
=
args
.
cache_batch_idx
=
(
use_cache_batch_idx
?
cache_batch_idx_buf
.
GetDeviceBuffer
()
:
nullptr
);
(
use_cache_batch_idx
?
cache_batch_idx_buf
.
GetDeviceBuffer
()
:
nullptr
);
...
@@ -1100,25 +1122,75 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -1100,25 +1122,75 @@ bool run(const ck_tile::ArgParser& arg_parser)
<<
std
::
setprecision
(
2
)
<<
tflops
<<
" TFlops, "
<<
std
::
setprecision
(
2
)
<<
gb_per_sec
<<
std
::
setprecision
(
2
)
<<
tflops
<<
" TFlops, "
<<
std
::
setprecision
(
2
)
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
<<
" GB/s"
<<
std
::
flush
;
if
(
!
do_validation
)
if
(
do_validation
==
0
)
{
{
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
return
true
;
return
true
;
}
}
if
(
do_validation
==
2
)
{
// NOTE: use gpu to do validation
ck_tile
::
naive_attention_fwd_traits
naive_t
;
naive_t
.
q_type
=
data_type
;
naive_t
.
k_type
=
data_type
;
naive_t
.
v_type
=
data_type
;
naive_t
.
o_type
=
data_type
;
naive_t
.
q_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
k_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
v_layout
=
i_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
o_layout
=
o_perm
==
1
?
"bhsd"
:
"bshd"
;
naive_t
.
variation
=
0
;
// TODO?
ck_tile
::
DeviceMem
o_naive_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
naive_attention_fwd_args
naive_a
;
naive_a
.
q_ptr
=
q_buf
.
GetDeviceBuffer
();
naive_a
.
k_ptr
=
k_buf
.
GetDeviceBuffer
();
naive_a
.
v_ptr
=
v_buf
.
GetDeviceBuffer
();
naive_a
.
o_ptr
=
o_naive_buf
.
GetDeviceBuffer
();
naive_a
.
scale_s
=
scale_s
;
naive_a
.
context_len_ptr
=
nullptr
;
// used when seqlen kv come from a pointer
naive_a
.
page_table_ptr
=
nullptr
;
// [batch, num_blocks] seqlen_kv is in different block(paged attn)
naive_a
.
hdim
=
hdim_q
;
naive_a
.
hdim_v
=
hdim_v
;
// could be cross-attn, where V and Q/K hdim are different
naive_a
.
batch_q
=
batch
;
naive_a
.
batch_kv
=
batch
;
naive_a
.
batch_ratio_kv
=
1
;
// batch_q / batch_kv
naive_a
.
seqlen_q
=
seqlen_qs
[
0
];
naive_a
.
seqlen_kv
=
seqlen_ks
[
0
];
// if context_len_ptr is not nullptr, ignore this field
naive_a
.
nhead_q
=
nhead
;
naive_a
.
nhead_kv
=
nhead_k
;
naive_a
.
nhead_ratio_kv
=
naive_a
.
nhead_q
/
naive_a
.
nhead_kv
;
// nhead_q / nhead_kv
naive_a
.
page_size
=
0
;
// if paged, the seqlen-kv for each block
ck_tile
::
stream_config
naive_s
{};
naive_attention_fwd
(
naive_t
,
naive_a
,
naive_s
);
auto
o_naive_ref
=
o_naive_buf
.
ToHost
<
ODataType
>
();
o_buf
.
FromDevice
(
o_host
.
data
());
// TODO: ugly
auto
[
rtol_
,
atol_
]
=
get_elimit
<
DataTypeConfig
>
(
init_method
);
bool
pass_
=
ck_tile
::
check_err
(
o_host
,
o_naive_ref
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol_
,
atol_
);
std
::
cout
<<
", valid:"
<<
(
pass_
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
return
pass_
;
}
o_buf
.
FromDevice
(
o_host
.
data
());
o_buf
.
FromDevice
(
o_host
.
data
());
lse_buf
.
FromDevice
(
lse_host
.
data
());
lse_buf
.
FromDevice
(
lse_host
.
data
());
randval_buf
.
FromDevice
(
randval_host
.
data
());
randval_buf
.
FromDevice
(
randval_host
.
data
());
auto
p_compute_element_func
=
[
&
]()
{
auto
p_compute_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
if
constexpr
(
std
::
is_same_v
<
DataType
Config
,
ck_tile
::
fp8_t
>
)
return
ck_tile
::
scales
{
scale_p
};
return
ck_tile
::
scales
{
scale_p
};
else
else
return
ck_tile
::
identity
{};
return
ck_tile
::
identity
{};
}();
}();
auto
oacc_element_func
=
[
&
]()
{
auto
oacc_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
if
constexpr
(
std
::
is_same_v
<
DataType
Config
,
ck_tile
::
fp8_t
>
)
return
ck_tile
::
composes
(
ck_tile
::
saturates
<
ck_tile
::
fp8_t
>
{},
return
ck_tile
::
composes
(
ck_tile
::
saturates
<
ck_tile
::
fp8_t
>
{},
ck_tile
::
scales
{
scale_o
});
ck_tile
::
scales
{
scale_o
});
else
else
...
@@ -1168,7 +1240,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -1168,7 +1240,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
{
decltype
(
q_host_ref
)
q_host_ref_ro
(
q_host_ref
.
get_lengths
());
decltype
(
q_host_ref
)
q_host_ref_ro
(
q_host_ref
.
get_lengths
());
auto
[
rotary_cos_slice
,
rotary_sin_slice
]
=
auto
[
rotary_cos_slice
,
rotary_sin_slice
]
=
slice_rotary_cos_sin
(
rotary_cos_host
,
rotary_sin_host
,
cache_seqlen_ks
[
wb
],
real_seqlen_q
);
slice_rotary_cos_sin
(
rotary_cos_host
,
rotary_sin_host
,
cache_seqlen_ks
[
wb
],
real_seqlen_q
);
ck_tile
::
reference_batched_rotary_position_embedding
(
ck_tile
::
reference_batched_rotary_position_embedding
(
...
@@ -1184,13 +1256,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -1184,13 +1256,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
block_table_host
(
wb
,
i
[
1
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
1
]
%
page_block_size
,
i
[
2
]);
self
(
i
)
=
k_host
(
block_table_host
(
wb
,
i
[
1
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
1
]
%
page_block_size
,
i
[
2
]);
});
});
}
else
{
}
else
{
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
block_table_host
(
wb
,
i
[
1
]
/
page_block_size
),
i
[
1
]
%
page_block_size
,
i
[
0
]
/
nr
,
i
[
2
]);
self
(
i
)
=
k_host
(
block_table_host
(
wb
,
i
[
1
]
/
page_block_size
),
i
[
1
]
%
page_block_size
,
i
[
0
]
/
nr
,
i
[
2
]);
});
});
}
}
}
else
}
else
#endif
#endif
{
{
if
(
i_perm
)
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
1
]
+
key_offset
,
i
[
2
]);
});
if
(
i_perm
)
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
1
]
+
key_offset
,
i
[
2
]);
});
else
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
cache_b_idx
,
i
[
1
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
2
]);
});
else
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
cache_b_idx
,
i
[
1
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
2
]);
});
...
@@ -1211,7 +1283,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -1211,7 +1283,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
{
knew_host_ref_ro
.
emplace
(
knew_host_ref
.
get_lengths
());
knew_host_ref_ro
.
emplace
(
knew_host_ref
.
get_lengths
());
auto
[
rotary_cos_slice
,
rotary_sin_slice
]
=
auto
[
rotary_cos_slice
,
rotary_sin_slice
]
=
slice_rotary_cos_sin
(
rotary_cos_host
,
rotary_sin_host
,
cache_seqlen_ks
[
wb
],
seqlen_knew
);
slice_rotary_cos_sin
(
rotary_cos_host
,
rotary_sin_host
,
cache_seqlen_ks
[
wb
],
seqlen_knew
);
ck_tile
::
reference_batched_rotary_position_embedding
(
ck_tile
::
reference_batched_rotary_position_embedding
(
...
@@ -1233,19 +1305,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -1233,19 +1305,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
0
<
page_block_size
)
{
if
(
0
<
page_block_size
)
{
if
(
is_v_rowmajor
)
{
if
(
is_v_rowmajor
)
{
if
(
i_perm
)
{
if
(
i_perm
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
2
]
%
page_block_size
,
i
[
1
]);
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
2
]
%
page_block_size
,
i
[
1
]);
});
});
}
else
{
}
else
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
2
]
%
page_block_size
,
i
[
0
]
/
nr
,
i
[
1
]);
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
2
]
%
page_block_size
,
i
[
0
]
/
nr
,
i
[
1
]);
});
});
}
}
}
}
else
else
{
{
if
(
i_perm
)
{
if
(
i_perm
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]
%
page_block_size
);
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]
%
page_block_size
);
});
});
}
else
{
}
else
{
...
@@ -1440,7 +1512,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -1440,7 +1512,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else
o_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
o_host
(
b_idx
,
idx
[
1
]
+
query_offset
,
idx
[
0
],
idx
[
2
]);
});
else
o_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
o_host
(
b_idx
,
idx
[
1
]
+
query_offset
,
idx
[
0
],
idx
[
2
]);
});
// clang-format on
// clang-format on
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
init_method
);
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
Config
>
(
init_method
);
bool
cur_pass
=
ck_tile
::
check_err
(
bool
cur_pass
=
ck_tile
::
check_err
(
o_host_result
,
o_host_ref
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
o_host_result
,
o_host_ref
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
pass
&=
cur_pass
;
pass
&=
cur_pass
;
...
@@ -1497,15 +1569,15 @@ int main(int argc, char* argv[])
...
@@ -1497,15 +1569,15 @@ int main(int argc, char* argv[])
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
if
(
data_type
==
"fp16"
)
if
(
data_type
==
"fp16"
)
{
{
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
FmhaFwdFp16
>
(
arg_parser
)
?
0
:
-
2
;
}
}
else
if
(
data_type
==
"bf16"
)
else
if
(
data_type
==
"bf16"
)
{
{
return
run
<
ck_tile
::
b
f16
_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
FmhaFwdB
f16
>
(
arg_parser
)
?
0
:
-
2
;
}
}
else
if
(
data_type
==
"fp8"
)
else
if
(
data_type
==
"fp8"
)
{
{
return
run
<
ck_tile
::
fp8_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
FmhaFwdFp8
>
(
arg_parser
)
?
0
:
-
2
;
}
}
return
-
3
;
return
-
3
;
...
...
Prev
1
2
3
4
5
6
7
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment