Commit c8c016dd authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of https://github.com/ROCm/composable_kernel into update_cka8w8

parents e8ca3daf 4e731776
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[]) int run(int argc, char* argv[])
{ {
...@@ -156,7 +156,7 @@ int run(int argc, char* argv[]) ...@@ -156,7 +156,7 @@ int run(int argc, char* argv[])
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break; break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[]) int run(int argc, char* argv[])
{ {
...@@ -173,7 +173,7 @@ int run(int argc, char* argv[]) ...@@ -173,7 +173,7 @@ int run(int argc, char* argv[])
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break; break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
struct ProblemSize final struct ProblemSize final
...@@ -66,8 +69,8 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -66,8 +69,8 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
......
...@@ -377,7 +377,7 @@ int main(int argc, char* argv[]) ...@@ -377,7 +377,7 @@ int main(int argc, char* argv[])
break; break;
default: default:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1}); a0_g_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
d00_g_m_n.GenerateTensorValue(GeneratorTensor_1<D00DataType>{1}); d00_g_m_n.GenerateTensorValue(GeneratorTensor_1<D00DataType>{1});
d01_g_m_n.GenerateTensorValue(GeneratorTensor_1<D01DataType>{1}); d01_g_m_n.GenerateTensorValue(GeneratorTensor_1<D01DataType>{1});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -41,7 +41,7 @@ struct ExecutionConfig final ...@@ -41,7 +41,7 @@ struct ExecutionConfig final
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = true; bool time_kernel = false;
}; };
#define DefaultConvParams \ #define DefaultConvParams \
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <vector> #include <vector>
...@@ -248,7 +248,7 @@ int main(int argc, char* argv[]) ...@@ -248,7 +248,7 @@ int main(int argc, char* argv[])
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
break; break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -194,9 +194,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -194,9 +194,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b1_tensors[i].GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5}); b1_tensors[i].GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break; break;
default: default:
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<A0DataType, 0>{});
b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<B1DataType, 1>{});
} }
d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5}); d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
......
...@@ -184,9 +184,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -184,9 +184,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
break; break;
default: default:
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<A0DataType, 0>{});
a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<A1DataType, 0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
} }
d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5}); d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
......
...@@ -172,12 +172,13 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -172,12 +172,13 @@ bool run_grouped_conv_fwd(bool do_verification,
{ {
case 0: break; case 0: break;
case 1: case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); // values generated: -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 6});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1.0, 1.0});
break; break;
default: default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1.0, 1.0});
} }
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
......
...@@ -205,7 +205,6 @@ int main(int argc, char* argv[]) ...@@ -205,7 +205,6 @@ int main(int argc, char* argv[])
a1_device_buf.ToDevice(a1_m_k.mData.data()); a1_device_buf.ToDevice(a1_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data()); b0_device_buf.ToDevice(b0_k_n.mData.data());
b1_device_buf.ToDevice(b1_k_n.mData.data()); b1_device_buf.ToDevice(b1_k_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
...@@ -253,8 +252,6 @@ int main(int argc, char* argv[]) ...@@ -253,8 +252,6 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
Tensor<AccDataType> c_m_n({M, N}); Tensor<AccDataType> c_m_n({M, N});
......
...@@ -54,6 +54,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -54,6 +54,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
list(REMOVE_ITEM FILE_NAME "${source}") list(REMOVE_ITEM FILE_NAME "${source}")
endif() endif()
endforeach() endforeach()
#Do not build any DPP examples if DL_KERNELS not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dpp")
message("removing dpp example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any XDL examples if gfx9 targets are not on the list #Do not build any XDL examples if gfx9 targets are not on the list
foreach(source IN LISTS FILE_NAME) foreach(source IN LISTS FILE_NAME)
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
......
[Back to the main page](../README.md)
# Composable Kernel examples
\ No newline at end of file
...@@ -2,10 +2,17 @@ ...@@ -2,10 +2,17 @@
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation # generate kernel instances to speed up compilation
DTYPE_MAP = { FWD_DTYPE_MAP = {
"fp16": "ck_tile::fp16_t", "fp16" : "FmhaFwdFp16",
"bf16": "ck_tile::bf16_t", "bf16" : "FmhaFwdBf16",
"fp8" : "ck_tile::fp8_t" "fp8" : "FmhaFwdFp8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16"
}
BWD_DTYPE_MAP = {
"fp16": "FmhaBwdFp16",
"bf16": "FmhaBwdBf16"
} }
MASK_IMPL = { MASK_IMPL = {
......
...@@ -283,7 +283,7 @@ class FmhaBwdApiPool: ...@@ -283,7 +283,7 @@ class FmhaBwdApiPool:
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic]) F_deterministic=BOOL_MAP[trait.deterministic])
...@@ -360,7 +360,7 @@ class FmhaBwdDQDKDVKernel: ...@@ -360,7 +360,7 @@ class FmhaBwdDQDKDVKernel:
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0, F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0, F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0, F_bk0 = self.F_tile.F_bk0,
...@@ -469,7 +469,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -469,7 +469,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen = list() gen = list()
api_pool = FmhaBwdApiPool(mask_impl) api_pool = FmhaBwdApiPool(mask_impl)
for dtype in DTYPE_MAP.keys(): for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
...@@ -585,7 +585,7 @@ class FmhaBwdOGradDotOKernel: ...@@ -585,7 +585,7 @@ class FmhaBwdOGradDotOKernel:
FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_spad = BOOL_MAP[self.F_spad], F_spad = BOOL_MAP[self.F_spad],
F_dvpad = BOOL_MAP[self.F_dvpad], F_dvpad = BOOL_MAP[self.F_dvpad],
F_mode = MODE_MAP[self.F_mode], F_mode = MODE_MAP[self.F_mode],
...@@ -616,7 +616,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: ...@@ -616,7 +616,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
gen = list() gen = list()
for dtype in DTYPE_MAP.keys(): for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
...@@ -716,7 +716,7 @@ class FmhaBwdConvertQGradKernel: ...@@ -716,7 +716,7 @@ class FmhaBwdConvertQGradKernel:
FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format( FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_bm0, F_bm0 = self.F_bm0,
F_bn0 = self.F_bn0, F_bn0 = self.F_bn0,
F_spad = BOOL_MAP[self.F_spad], F_spad = BOOL_MAP[self.F_spad],
...@@ -751,7 +751,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]: ...@@ -751,7 +751,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
gen = list() gen = list()
for dtype in DTYPE_MAP.keys(): for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
......
...@@ -282,7 +282,7 @@ class FmhaFwdApiPool: ...@@ -282,7 +282,7 @@ class FmhaFwdApiPool:
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
...@@ -301,7 +301,7 @@ class FmhaFwdTileSize: ...@@ -301,7 +301,7 @@ class FmhaFwdTileSize:
F_bk1 : int # tile size along kv gemm unroll F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used) F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v F_rn1 : int # number of warps for gemm1 along head dim v
...@@ -339,7 +339,7 @@ class FmhaFwdKernel: ...@@ -339,7 +339,7 @@ class FmhaFwdKernel:
FMHA_FWD_KERNEL_BODY.format( FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0, F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0, F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0, F_bk0 = self.F_tile.F_bk0,
...@@ -462,6 +462,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm ...@@ -462,6 +462,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
# no need lse/dropout kernels # no need lse/dropout kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else: else:
assert False assert False
return pipelines return pipelines
...@@ -469,7 +472,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm ...@@ -469,7 +472,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
gen = list() gen = list()
api_pool = FmhaFwdApiPool(mask_impl) api_pool = FmhaFwdApiPool(mask_impl)
for dtype in DTYPE_MAP.keys(): for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype) d = get_fmha_fwd_tile_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
......
...@@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool: ...@@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool:
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout], inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
...@@ -216,7 +216,7 @@ class FmhaFwdAppendKVKernel: ...@@ -216,7 +216,7 @@ class FmhaFwdAppendKVKernel:
FMHA_FWD_APPENDKV_KERNEL_BODY.format( FMHA_FWD_APPENDKV_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bs = self.F_tile.F_bs, F_bs = self.F_tile.F_bs,
F_bsk = self.F_tile.F_bsk, F_bsk = self.F_tile.F_bsk,
F_bd = self.F_tile.F_bd, F_bd = self.F_tile.F_bd,
...@@ -301,6 +301,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -301,6 +301,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8', 'bf8']:
# rope/paged-kv is not supported # rope/paged-kv is not supported
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f')) pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else: else:
assert False assert False
return pipelines return pipelines
...@@ -308,7 +311,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -308,7 +311,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen = list() gen = list()
api_pool = FmhaFwdAppendKVApiPool(mask_impl) api_pool = FmhaFwdAppendKVApiPool(mask_impl)
for dtype in DTYPE_MAP.keys(): for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype) d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
......
...@@ -112,7 +112,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) ...@@ -112,7 +112,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
}} }}
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>; {F_dvpad}>;
#include <iostream> #include <iostream>
...@@ -161,7 +161,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< ...@@ -161,7 +161,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
{F_hdim}, {F_hdim},
{F_bm0}, {F_bm0},
{F_bn1}, {F_bn1},
{F_mode}, {F_mode},
fmha_trait>; fmha_trait>;
...@@ -231,11 +231,11 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a ...@@ -231,11 +231,11 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout std::cout
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>() << ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>() << ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
<< std::flush; << std::flush;
return ck_tile::launch_kernel(s, return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }}, [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }} [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
); );
}} }}
...@@ -247,12 +247,22 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ...@@ -247,12 +247,22 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
}} }}
""" """
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, ck_tile::fp8_t>) {{
return fmha_fwd_splitkv_<traits_, traits2_>(s, a); return -1;
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, true, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, false, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}} }}
""" """
...@@ -421,11 +431,11 @@ class FmhaFwdSplitKVApiPool: ...@@ -421,11 +431,11 @@ class FmhaFwdSplitKVApiPool:
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
...@@ -462,7 +472,7 @@ class FmhaFwdSplitKVKernel: ...@@ -462,7 +472,7 @@ class FmhaFwdSplitKVKernel:
FMHA_FWD_SPLITKV_KERNEL_BODY.format( FMHA_FWD_SPLITKV_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0, F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0, F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0, F_bk0 = self.F_tile.F_bk0,
...@@ -482,7 +492,7 @@ class FmhaFwdSplitKVKernel: ...@@ -482,7 +492,7 @@ class FmhaFwdSplitKVKernel:
F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_squant = BOOL_MAP[self.F_pipeline.F_squant],
...@@ -542,7 +552,7 @@ class FmhaFwdSplitKVCombineKernel: ...@@ -542,7 +552,7 @@ class FmhaFwdSplitKVCombineKernel:
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format( FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0, F_bm0 = self.F_tile.F_bm0,
F_bn1 = self.F_tile.F_bn1, F_bn1 = self.F_tile.F_bn1,
F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_spad = BOOL_MAP[self.F_pipeline.F_spad],
...@@ -614,27 +624,29 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -614,27 +624,29 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant = 't' if dtype == 'fp8' else 'f' squant = 't' if dtype == 'fp8' else 'f'
pipelines = [] pipelines = []
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
# TODO: use async pipeline when compiler is more stable # TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
# if True: # if True:
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
else: else:
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask))
if receipt == 1: if receipt == 1:
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8', 'bf8']:
# no need lse/paged-kv kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, 'f', mask)) pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else: else:
assert False assert False
return pipelines return pipelines
...@@ -642,7 +654,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -642,7 +654,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen = list() gen = list()
api_pool = FmhaFwdSplitKVApiPool(mask_impl) api_pool = FmhaFwdSplitKVApiPool(mask_impl)
for dtype in DTYPE_MAP.keys(): for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype) d = get_fmha_fwd_tile_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
...@@ -655,9 +667,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -655,9 +667,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if pipeline.F_spad != 't' or pipeline.F_skpad != 't': if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue continue
if pipeline.F_pagedkv == 't':
# we only use batch mode kernels to handle (paged-) kvcache problems
continue
k = Kernel(F_idx=0, k = Kernel(F_idx=0,
F_hdim=hdim, F_hdim=hdim,
F_dtype=dtype, F_dtype=dtype,
...@@ -705,7 +714,7 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis ...@@ -705,7 +714,7 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
gen = list() gen = list()
for dtype in DTYPE_MAP.keys(): for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype) d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
......
...@@ -101,7 +101,7 @@ auto create_args(int argc, char* argv[]) ...@@ -101,7 +101,7 @@ auto create_args(int argc, char* argv[])
} }
// different threshold for different dtype // different threshold for different dtype
template <typename DataType> template <typename DataTypeConfig>
auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
{ {
double rtol = 1e-2; double rtol = 1e-2;
...@@ -110,7 +110,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) ...@@ -110,7 +110,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
} }
template <> template <>
auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) auto get_elimit<FmhaBwdBf16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
{ {
double rtol = 1e-2; double rtol = 1e-2;
double atol = 1e-2; double atol = 1e-2;
...@@ -122,7 +122,7 @@ auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_ ...@@ -122,7 +122,7 @@ auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_
return ck_tile::make_tuple(rtol, atol); return ck_tile::make_tuple(rtol, atol);
} }
template <typename DataType> template <typename DataTypeConfig>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
...@@ -209,7 +209,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -209,7 +209,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k);
using TypeConfig = FmhaBwdTypeConfig<DataType>; using TypeConfig = FmhaBwdTypeConfig<DataTypeConfig>;
using QDataType = typename TypeConfig::QDataType; using QDataType = typename TypeConfig::QDataType;
using KDataType = typename TypeConfig::KDataType; using KDataType = typename TypeConfig::KDataType;
...@@ -933,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -933,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
// clang-format on // clang-format on
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v); auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
bool dq_cur_pass = ck_tile::check_err(dq_host_result, bool dq_cur_pass = ck_tile::check_err(dq_host_result,
dq_host_ref, dq_host_ref,
std::string("Error: QGrad Incorrect results!"), std::string("Error: QGrad Incorrect results!"),
...@@ -986,11 +986,11 @@ int main(int argc, char* argv[]) ...@@ -986,11 +986,11 @@ int main(int argc, char* argv[])
const std::string data_type = arg_parser.get_str("prec"); const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16") if(data_type == "fp16")
{ {
return run<ck_tile::half_t>(arg_parser) ? 0 : -2; return run<FmhaBwdFp16>(arg_parser) ? 0 : -2;
} }
else if(data_type == "bf16") else if(data_type == "bf16")
{ {
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2; return run<FmhaBwdBf16>(arg_parser) ? 0 : -2;
} }
return -3; return -3;
......
...@@ -14,11 +14,19 @@ ...@@ -14,11 +14,19 @@
#include <utility> #include <utility>
#include <variant> #include <variant>
struct FmhaBwdFp16
{
};
struct FmhaBwdBf16
{
};
template <typename DataType> template <typename DataType>
struct FmhaBwdTypeConfig; struct FmhaBwdTypeConfig;
template <> template <>
struct FmhaBwdTypeConfig<ck_tile::half_t> struct FmhaBwdTypeConfig<FmhaBwdFp16>
{ {
using QDataType = ck_tile::half_t; using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t; using KDataType = ck_tile::half_t;
...@@ -38,7 +46,7 @@ struct FmhaBwdTypeConfig<ck_tile::half_t> ...@@ -38,7 +46,7 @@ struct FmhaBwdTypeConfig<ck_tile::half_t>
}; };
template <> template <>
struct FmhaBwdTypeConfig<ck_tile::bf16_t> struct FmhaBwdTypeConfig<FmhaBwdBf16>
{ {
using QDataType = ck_tile::bf16_t; using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t;
...@@ -150,113 +158,113 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -150,113 +158,113 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
// create group mode kernel arguments // create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{ {
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.lse_ptr, args.lse_ptr,
args.do_ptr, args.do_ptr,
args.d_ptr, args.d_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.dk_ptr, args.dk_ptr,
args.dv_ptr, args.dv_ptr,
args.dbias_ptr, args.dbias_ptr,
args.dq_acc_ptr, args.dq_acc_ptr,
args.seqstart_q_ptr, args.seqstart_q_ptr,
args.seqstart_k_ptr, args.seqstart_k_ptr,
args.seqlen_k_ptr, args.seqlen_k_ptr,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q, args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.scale, args.scale,
args.stride_q, args.stride_q,
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_do, args.stride_do,
args.stride_dq_acc, args.stride_dq_acc,
args.stride_dk, args.stride_dk,
args.stride_dv, args.stride_dv,
args.stride_dbias, args.stride_dbias,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_lsed, args.nhead_stride_lsed,
args.nhead_stride_dq_acc, args.nhead_stride_dq_acc,
args.nhead_stride_dk, args.nhead_stride_dk,
args.nhead_stride_dv, args.nhead_stride_dv,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.split_stride_dq_acc, args.split_stride_dq_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.drop_seed_offset); args.drop_seed_offset);
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.lse_ptr, args.lse_ptr,
args.do_ptr, args.do_ptr,
args.d_ptr, args.d_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.dk_ptr, args.dk_ptr,
args.dv_ptr, args.dv_ptr,
args.dbias_ptr, args.dbias_ptr,
args.dq_acc_ptr, args.dq_acc_ptr,
args.seqlen_q, args.seqlen_q,
args.seqlen_k, args.seqlen_k,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q, args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.scale, args.scale,
args.stride_q, args.stride_q,
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_do, args.stride_do,
args.stride_dq_acc, args.stride_dq_acc,
args.stride_dk, args.stride_dk,
args.stride_dv, args.stride_dv,
args.stride_dbias, args.stride_dbias,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_lsed, args.nhead_stride_lsed,
args.nhead_stride_dq_acc, args.nhead_stride_dq_acc,
args.nhead_stride_dk, args.nhead_stride_dk,
args.nhead_stride_dv, args.nhead_stride_dv,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.batch_stride_q, args.batch_stride_q,
args.batch_stride_k, args.batch_stride_k,
args.batch_stride_v, args.batch_stride_v,
args.batch_stride_bias, args.batch_stride_bias,
args.batch_stride_randval, args.batch_stride_randval,
args.batch_stride_do, args.batch_stride_do,
args.batch_stride_lsed, args.batch_stride_lsed,
args.batch_stride_dq_acc, args.batch_stride_dq_acc,
args.batch_stride_dk, args.batch_stride_dk,
args.batch_stride_dv, args.batch_stride_dv,
args.batch_stride_dbias, args.batch_stride_dbias,
args.split_stride_dq_acc, args.split_stride_dq_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.drop_seed_offset); args.drop_seed_offset);
} }
}(); }();
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "fmha_fwd.hpp" #include "fmha_fwd.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "ck_tile/ref/naive_attention.hpp"
#include "mask.hpp" #include "mask.hpp"
#include "rotary.hpp" #include "rotary.hpp"
#include "utils.hpp" #include "utils.hpp"
...@@ -41,7 +42,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) ...@@ -41,7 +42,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not") arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)")
.insert("mode", "0", "kernel mode. 0:batch, 1:group") .insert("mode", "0", "kernel mode. 0:batch, 1:group")
.insert("b", "2", "batch size") .insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q") .insert("h", "8", "num of head, for q")
...@@ -62,7 +63,7 @@ auto create_args(int argc, char* argv[]) ...@@ -62,7 +63,7 @@ auto create_args(int argc, char* argv[])
"-1 to choose s_knew in [1, s] randomly.") "-1 to choose s_knew in [1, s] randomly.")
.insert("s_kpad", .insert("s_kpad",
"-1", "-1",
"seqlen_k stride between 2 tokens, currently used in group-mode only\n" "seqlen_k stride between 2 batches, currently used in group-mode only\n"
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n" "for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
"along seqlen, instead of packed. same as xformer kv_padding") "along seqlen, instead of packed. same as xformer kv_padding")
.insert("d", "128", "head dim for q, k") .insert("d", "128", "head dim for q, k")
...@@ -142,7 +143,7 @@ auto create_args(int argc, char* argv[]) ...@@ -142,7 +143,7 @@ auto create_args(int argc, char* argv[])
} }
// different threshold for different dtype // different threshold for different dtype
template <typename DataType> template <typename DataTypeConfig>
auto get_elimit(std::string /*init_method*/) auto get_elimit(std::string /*init_method*/)
{ {
double rtol = 1e-3; double rtol = 1e-3;
...@@ -151,7 +152,7 @@ auto get_elimit(std::string /*init_method*/) ...@@ -151,7 +152,7 @@ auto get_elimit(std::string /*init_method*/)
} }
template <> template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/) auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
{ {
double rtol = 1e-2; double rtol = 1e-2;
double atol = 1e-2; double atol = 1e-2;
...@@ -159,7 +160,7 @@ auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/) ...@@ -159,7 +160,7 @@ auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
} }
template <> template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method) auto get_elimit<FmhaFwdFp8>(std::string init_method)
{ {
if(init_method == "ui" || init_method == "ni") if(init_method == "ui" || init_method == "ni")
{ {
...@@ -261,7 +262,7 @@ int override_num_splits_if_necessary( ...@@ -261,7 +262,7 @@ int override_num_splits_if_necessary(
return num_splits; return num_splits;
} }
template <typename DataType> template <typename DataTypeConfig>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
...@@ -294,7 +295,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -294,7 +295,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
#if !CK_TILE_FMHA_FWD_APPENDKV_API #if !CK_TILE_FMHA_FWD_APPENDKV_API
if(seqlen_knew != 0) if(seqlen_knew != 0)
{ {
std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl; std::cerr << "fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option"
<< std::endl;
seqlen_knew = 0; seqlen_knew = 0;
} }
#endif #endif
...@@ -304,8 +306,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -304,8 +306,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
if constexpr(!(std::is_same_v<DataType, ck_tile::fp16_t> || if constexpr(!(std::is_same_v<DataTypeConfig, FmhaFwdFp16> ||
std::is_same_v<DataType, ck_tile::bf16_t>)) std::is_same_v<DataTypeConfig, FmhaFwdBf16>))
{ {
if(0 < rotary_dim) if(0 < rotary_dim)
{ {
...@@ -321,6 +323,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -321,6 +323,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
rotary_dim = 0; rotary_dim = 0;
} }
#endif #endif
// to use fmha_fwd_appendkv(), make sure it's in batch mode
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
if(need_append_kvcache && mode == mode_enum::group)
{
std::cerr << "fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option" << std::endl;
mode = mode_enum::batch;
}
if(!(rotary_dim <= hdim_q)) if(!(rotary_dim <= hdim_q))
{ {
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
...@@ -356,22 +365,26 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -356,22 +365,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< std::endl; << std::endl;
use_cache_batch_idx = false; use_cache_batch_idx = false;
} }
#endif #else
if(0 < page_block_size && use_cache_batch_idx) if(use_cache_batch_idx)
{ {
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the " if(0 < page_block_size)
"'cache_batch_idx' option" {
<< std::endl; std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
use_cache_batch_idx = false; "'cache_batch_idx' option"
<< std::endl;
use_cache_batch_idx = false;
}
else if(mode == mode_enum::group)
{
std::cerr << "group mode will not use cache_batch_idx. ignoring the "
"'cache_batch_idx' option"
<< std::endl;
use_cache_batch_idx = false;
}
} }
// the input tensor layout for kvcache is same as batch mode #endif
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size);
if(use_kvcache && mode != mode_enum::batch)
{
std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl;
mode = mode_enum::batch;
}
auto [seqlen_qs, seqlen_ks, seqlen_kpads] = auto [seqlen_qs, seqlen_ks, seqlen_kpads] =
decode_seqlen(mode, decode_seqlen(mode,
...@@ -380,7 +393,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -380,7 +393,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
arg_parser.get_str("s_k"), arg_parser.get_str("s_k"),
arg_parser.get_str("s_kpad"), arg_parser.get_str("s_kpad"),
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
use_kvcache); need_append_kvcache);
// compute kvcache seqlen_k (before appending knew/vnew) // compute kvcache seqlen_k (before appending knew/vnew)
auto cache_seqlen_ks = seqlen_ks; auto cache_seqlen_ks = seqlen_ks;
std::transform(cache_seqlen_ks.begin(), std::transform(cache_seqlen_ks.begin(),
...@@ -416,25 +429,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -416,25 +429,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
return atoi(squant_str.c_str()) != 0 ? true : false; return atoi(squant_str.c_str()) != 0 ? true : false;
}(); }();
float range_q = arg_parser.get_float("range_q");
float range_k = arg_parser.get_float("range_k");
float range_v = arg_parser.get_float("range_v");
float range_p = arg_parser.get_float("range_p");
float range_o = arg_parser.get_float("range_o");
float dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<DataType>::max());
float scale_p = 1.f;
float scale_o = 1.f;
if(squant)
{
scale_s = scale_s * (range_q / dtype_max) * (range_k / dtype_max);
scale_p = dtype_max / range_p;
// scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)]
scale_o = range_p * range_v / range_o / dtype_max;
}
std::string vlayout = arg_parser.get_str("vlayout"); std::string vlayout = arg_parser.get_str("vlayout");
bool lse = arg_parser.get_bool("lse"); bool lse = arg_parser.get_bool("lse");
...@@ -454,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -454,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
bool s_randval = false; bool s_randval = false;
if(p_drop > 0.0f && do_validation) if(p_drop > 0.0f && do_validation != 0)
{ {
s_randval = true; s_randval = true;
} }
...@@ -487,7 +481,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -487,7 +481,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
using TypeConfig = FmhaFwdTypeConfig<DataType>; using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
using QDataType = typename TypeConfig::QDataType; using QDataType = typename TypeConfig::QDataType;
using KDataType = typename TypeConfig::KDataType; using KDataType = typename TypeConfig::KDataType;
...@@ -501,6 +495,28 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -501,6 +495,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
using OaccDataType = typename TypeConfig::OaccDataType; using OaccDataType = typename TypeConfig::OaccDataType;
using ODataType = typename TypeConfig::ODataType; using ODataType = typename TypeConfig::ODataType;
float range_q = arg_parser.get_float("range_q");
float range_k = arg_parser.get_float("range_k");
float range_v = arg_parser.get_float("range_v");
float range_p = arg_parser.get_float("range_p");
float range_o = arg_parser.get_float("range_o");
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
float p_dtype_max = v_dtype_max; // assume p and v is the same type
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
float scale_p = 1.f;
float scale_o = 1.f;
if(squant)
{
scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max);
scale_p = p_dtype_max / range_p;
scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max);
}
// accumulation numbers for performance evaluation // accumulation numbers for performance evaluation
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
auto max_seqlen_q = auto max_seqlen_q =
...@@ -697,14 +713,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -697,14 +713,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
else if(init_method == "ufq" || init_method == "uf:q" || else if(init_method == "ufq" || init_method == "uf:q" ||
init_method == "3") // suitable for fp8 quantization init_method == "3") // suitable for fp8 quantization
{ {
ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host); ck_tile::FillUniformDistribution<QDataType>{-q_dtype_max, q_dtype_max, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host); ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, seed}(k_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(knew_host); ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, seed}(knew_host);
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(v_host); ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, seed}(v_host);
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(vnew_host); ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, seed}(vnew_host);
// bias_fp8 = qscale_bias * bias_fp32 // bias_fp8 = qscale_bias * bias_fp32
float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k); float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k);
// Assume bias is in [-1.f, 1.f] in original fp32 // Assume bias is in [-1.f, 1.f] in original fp32
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host); ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
} }
...@@ -741,8 +757,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -741,8 +757,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_buf( ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0); 0 <= seqlen_kpads[0]
? seqlen_ks.size() * sizeof(int32_t)
: 0);
ck_tile::DeviceMem cache_seqlen_k_buf( ck_tile::DeviceMem cache_seqlen_k_buf(
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
...@@ -763,7 +781,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -763,7 +781,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
: seqstart_k_with_padding_host.data()); : seqstart_k_with_padding_host.data());
seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0]
? seqlen_ks.data()
: nullptr);
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
rotary_cos_buf.ToDevice(rotary_cos_host.data()); rotary_cos_buf.ToDevice(rotary_cos_host.data());
rotary_sin_buf.ToDevice(rotary_sin_host.data()); rotary_sin_buf.ToDevice(rotary_sin_host.data());
...@@ -976,8 +996,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -976,8 +996,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr = args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr = args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
(use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr); ? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled)
args.max_seqlen_q = max_seqlen_q; args.max_seqlen_q = max_seqlen_q;
...@@ -1029,6 +1050,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1029,6 +1050,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
args.batch_stride_block_table = batch_stride_block_table; args.batch_stride_block_table = batch_stride_block_table;
args.page_block_size = page_block_size; args.page_block_size = page_block_size;
args.is_gappy = false; // use 'false' for flash-attention integration
args.cache_batch_idx = args.cache_batch_idx =
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
...@@ -1100,25 +1122,75 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1100,25 +1122,75 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
<< " GB/s" << std::flush; << " GB/s" << std::flush;
if(!do_validation) if(do_validation == 0)
{ {
std::cout << std::flush << std::endl; std::cout << std::flush << std::endl;
return true; return true;
} }
if(do_validation == 2)
{
// NOTE: use gpu to do validation
ck_tile::naive_attention_fwd_traits naive_t;
naive_t.q_type = data_type;
naive_t.k_type = data_type;
naive_t.v_type = data_type;
naive_t.o_type = data_type;
naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd";
naive_t.variation = 0; // TODO?
ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes());
ck_tile::naive_attention_fwd_args naive_a;
naive_a.q_ptr = q_buf.GetDeviceBuffer();
naive_a.k_ptr = k_buf.GetDeviceBuffer();
naive_a.v_ptr = v_buf.GetDeviceBuffer();
naive_a.o_ptr = o_naive_buf.GetDeviceBuffer();
naive_a.scale_s = scale_s;
naive_a.context_len_ptr = nullptr; // used when seqlen kv come from a pointer
naive_a.page_table_ptr =
nullptr; // [batch, num_blocks] seqlen_kv is in different block(paged attn)
naive_a.hdim = hdim_q;
naive_a.hdim_v = hdim_v; // could be cross-attn, where V and Q/K hdim are different
naive_a.batch_q = batch;
naive_a.batch_kv = batch;
naive_a.batch_ratio_kv = 1; // batch_q / batch_kv
naive_a.seqlen_q = seqlen_qs[0];
naive_a.seqlen_kv = seqlen_ks[0]; // if context_len_ptr is not nullptr, ignore this field
naive_a.nhead_q = nhead;
naive_a.nhead_kv = nhead_k;
naive_a.nhead_ratio_kv = naive_a.nhead_q / naive_a.nhead_kv; // nhead_q / nhead_kv
naive_a.page_size = 0; // if paged, the seqlen-kv for each block
ck_tile::stream_config naive_s{};
naive_attention_fwd(naive_t, naive_a, naive_s);
auto o_naive_ref = o_naive_buf.ToHost<ODataType>();
o_buf.FromDevice(o_host.data()); // TODO: ugly
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
bool pass_ = ck_tile::check_err(
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
std::cout << ", valid:" << (pass_ ? "y" : "n") << std::flush << std::endl;
return pass_;
}
o_buf.FromDevice(o_host.data()); o_buf.FromDevice(o_host.data());
lse_buf.FromDevice(lse_host.data()); lse_buf.FromDevice(lse_host.data());
randval_buf.FromDevice(randval_host.data()); randval_buf.FromDevice(randval_host.data());
auto p_compute_element_func = [&]() { auto p_compute_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>) if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
return ck_tile::scales{scale_p}; return ck_tile::scales{scale_p};
else else
return ck_tile::identity{}; return ck_tile::identity{};
}(); }();
auto oacc_element_func = [&]() { auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>) if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{}, return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o}); ck_tile::scales{scale_o});
else else
...@@ -1168,7 +1240,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1168,7 +1240,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
auto [rotary_cos_slice, rotary_sin_slice] = auto [rotary_cos_slice, rotary_sin_slice] =
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q); slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q);
ck_tile::reference_batched_rotary_position_embedding( ck_tile::reference_batched_rotary_position_embedding(
...@@ -1184,13 +1256,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1184,13 +1256,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
k_host_ref.ForEach([&](auto& self, auto i) { k_host_ref.ForEach([&](auto& self, auto i) {
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]);
}); });
} else { } else {
k_host_ref.ForEach([&](auto& self, auto i) { k_host_ref.ForEach([&](auto& self, auto i) {
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]);
}); });
} }
} else } else
#endif #endif
{ {
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
...@@ -1211,7 +1283,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1211,7 +1283,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
knew_host_ref_ro.emplace(knew_host_ref.get_lengths()); knew_host_ref_ro.emplace(knew_host_ref.get_lengths());
auto [rotary_cos_slice, rotary_sin_slice] = auto [rotary_cos_slice, rotary_sin_slice] =
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew); slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew);
ck_tile::reference_batched_rotary_position_embedding( ck_tile::reference_batched_rotary_position_embedding(
...@@ -1233,19 +1305,19 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1233,19 +1305,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(0 < page_block_size) { if(0 < page_block_size) {
if(is_v_rowmajor) { if(is_v_rowmajor) {
if(i_perm) { if(i_perm) {
v_host_ref.ForEach([&](auto& self, auto i) { v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]);
}); });
} else { } else {
v_host_ref.ForEach([&](auto& self, auto i) { v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]); self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]);
}); });
} }
} }
else else
{ {
if(i_perm) { if(i_perm) {
v_host_ref.ForEach([&](auto& self, auto i) { v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size);
}); });
} else { } else {
...@@ -1440,7 +1512,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1440,7 +1512,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
// clang-format on // clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method); auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
bool cur_pass = ck_tile::check_err( bool cur_pass = ck_tile::check_err(
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
pass &= cur_pass; pass &= cur_pass;
...@@ -1497,15 +1569,15 @@ int main(int argc, char* argv[]) ...@@ -1497,15 +1569,15 @@ int main(int argc, char* argv[])
const std::string data_type = arg_parser.get_str("prec"); const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16") if(data_type == "fp16")
{ {
return run<ck_tile::half_t>(arg_parser) ? 0 : -2; return run<FmhaFwdFp16>(arg_parser) ? 0 : -2;
} }
else if(data_type == "bf16") else if(data_type == "bf16")
{ {
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2; return run<FmhaFwdBf16>(arg_parser) ? 0 : -2;
} }
else if(data_type == "fp8") else if(data_type == "fp8")
{ {
return run<ck_tile::fp8_t>(arg_parser) ? 0 : -2; return run<FmhaFwdFp8>(arg_parser) ? 0 : -2;
} }
return -3; return -3;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment