Commit cd4d4629 authored by danyao12's avatar danyao12
Browse files

Merge branch 'develop' into ck_tile/fa_bwd_v3

parents 21d12bb7 888317e6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_dynamic_unary_common.hpp"
#include "../run_convnd_activ_dynamic_example.inc"
int main(int argc, char* argv[])
{
ck::tensor_operation::element_wise::PassThrough out_element_op;
return !run_convnd_example(argc, argv, out_element_op);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_dynamic_unary_common.hpp"
#include "../run_convnd_activ_dynamic_example.inc"
int main(int argc, char* argv[])
{
ck::tensor_operation::element_wise::Power out_element_op(4.f, 1.f, 2.f);
return !run_convnd_example(argc, argv, out_element_op);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_dynamic_unary_common.hpp"
#include "../run_convnd_activ_dynamic_example.inc"
int main(int argc, char* argv[])
{
ck::tensor_operation::element_wise::Relu out_element_op;
return !run_convnd_example(argc, argv, out_element_op);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_dynamic_unary_common.hpp"
#include "../run_convnd_activ_dynamic_example.inc"
int main(int argc, char* argv[])
{
ck::tensor_operation::element_wise::Sigmoid out_element_op;
return !run_convnd_example(argc, argv, out_element_op);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_dynamic_unary_common.hpp"
#include "../run_convnd_activ_dynamic_example.inc"
int main(int argc, char* argv[])
{
ck::tensor_operation::element_wise::SoftRelu out_element_op;
return !run_convnd_example(argc, argv, out_element_op);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_dynamic_unary_common.hpp"
#include "../run_convnd_activ_dynamic_example.inc"
int main(int argc, char* argv[])
{
ck::tensor_operation::element_wise::Swish out_element_op(1.0f);
return !run_convnd_example(argc, argv, out_element_op);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_dynamic_unary_common.hpp"
#include "../run_convnd_activ_dynamic_example.inc"
int main(int argc, char* argv[])
{
ck::tensor_operation::element_wise::TanH out_element_op;
return !run_convnd_example(argc, argv, out_element_op);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
void print_helper_msg()
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
template <typename OutElementOp>
bool run_convnd_example(int argc, char* argv[], const OutElementOp& out_element_op)
{
print_helper_msg();
bool do_verification = true;
// Use floats for SoftRelu by default to avoid overflow after e^x.
int init_method =
std::is_same_v<OutElementOp, ck::tensor_operation::element_wise::SoftRelu> ? 2 : 1;
bool time_kernel = false;
// Following shapes are selected to avoid overflow. Expect inf in case of
// size increase for some elementwise ops.
ck::utils::conv::ConvParam conv_param{
3, 2, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}};
if(argc == 1)
{
// use default
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
}
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto run = [&]() {
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
return run_grouped_conv<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceGroupedConvNDActivInstance>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op);
};
if(conv_param.num_dim_spatial_ == 3)
{
return run();
}
return false;
}
add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
\ No newline at end of file
...@@ -205,7 +205,6 @@ int main(int argc, char* argv[]) ...@@ -205,7 +205,6 @@ int main(int argc, char* argv[])
a1_device_buf.ToDevice(a1_m_k.mData.data()); a1_device_buf.ToDevice(a1_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data()); b0_device_buf.ToDevice(b0_k_n.mData.data());
b1_device_buf.ToDevice(b1_k_n.mData.data()); b1_device_buf.ToDevice(b1_k_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
...@@ -253,8 +252,6 @@ int main(int argc, char* argv[]) ...@@ -253,8 +252,6 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
Tensor<AccDataType> c_m_n({M, N}); Tensor<AccDataType> c_m_n({M, N});
......
...@@ -54,6 +54,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -54,6 +54,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
list(REMOVE_ITEM FILE_NAME "${source}") list(REMOVE_ITEM FILE_NAME "${source}")
endif() endif()
endforeach() endforeach()
#Do not build any DPP examples if DL_KERNELS not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dpp")
message("removing dpp example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any XDL examples if gfx9 targets are not on the list #Do not build any XDL examples if gfx9 targets are not on the list
foreach(source IN LISTS FILE_NAME) foreach(source IN LISTS FILE_NAME)
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
...@@ -85,9 +92,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -85,9 +92,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl") if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
elseif(FILE_NAME MATCHES "_wmma") elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif() endif()
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
...@@ -169,9 +176,9 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) ...@@ -169,9 +176,9 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl") if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
elseif(FILE_NAME MATCHES "_wmma") elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif() endif()
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
......
[Back to the main page](../README.md)
# Composable Kernel examples
\ No newline at end of file
...@@ -15,8 +15,7 @@ This will result in an executable `build/bin/tile_example_fmha_fwd` ...@@ -15,8 +15,7 @@ This will result in an executable `build/bin/tile_example_fmha_fwd`
## kernel ## kernel
The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template.
There are 3 template parameters for this kernel template. There are 2 template parameters for this kernel template.
* `TilePartitioner` is used to map the workgroup to corresponding tile, `fmha_fwd_tile_partitioner.hpp` in this folder served as this purpose.
* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). * `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)).
* `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support. * `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support.
......
...@@ -2,10 +2,17 @@ ...@@ -2,10 +2,17 @@
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation # generate kernel instances to speed up compilation
DTYPE_MAP = { FWD_DTYPE_MAP = {
"fp16": "ck_tile::fp16_t", "fp16" : "FmhaFwdFp16",
"bf16": "ck_tile::bf16_t", "bf16" : "FmhaFwdBf16",
"fp8" : "ck_tile::fp8_t" "fp8" : "FmhaFwdFp8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16"
}
BWD_DTYPE_MAP = {
"fp16": "FmhaBwdFp16",
"bf16": "FmhaBwdBf16"
} }
MASK_IMPL = { MASK_IMPL = {
...@@ -112,6 +119,7 @@ PIPELINE_MAP = { ...@@ -112,6 +119,7 @@ PIPELINE_MAP = {
PIPELINE_ENUM_MAP = { PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
} }
BOOL_MAP = { BOOL_MAP = {
......
...@@ -1038,7 +1038,7 @@ class FmhaBwdApiPool: ...@@ -1038,7 +1038,7 @@ class FmhaBwdApiPool:
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic]) F_deterministic=BOOL_MAP[trait.deterministic])
...@@ -1115,7 +1115,7 @@ class FmhaBwdDQDKDVKernel: ...@@ -1115,7 +1115,7 @@ class FmhaBwdDQDKDVKernel:
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0, F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0, F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0, F_bk0 = self.F_tile.F_bk0,
...@@ -1224,7 +1224,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -1224,7 +1224,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen = list() gen = list()
api_pool = FmhaBwdApiPool(mask_impl) api_pool = FmhaBwdApiPool(mask_impl)
for dtype in DTYPE_MAP.keys(): for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
...@@ -1340,7 +1340,7 @@ class FmhaBwdOGradDotOKernel: ...@@ -1340,7 +1340,7 @@ class FmhaBwdOGradDotOKernel:
FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_spad = BOOL_MAP[self.F_spad], F_spad = BOOL_MAP[self.F_spad],
F_dvpad = BOOL_MAP[self.F_dvpad], F_dvpad = BOOL_MAP[self.F_dvpad],
F_mode = MODE_MAP[self.F_mode], F_mode = MODE_MAP[self.F_mode],
...@@ -1371,7 +1371,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: ...@@ -1371,7 +1371,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
gen = list() gen = list()
for dtype in DTYPE_MAP.keys(): for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
...@@ -1471,7 +1471,7 @@ class FmhaBwdConvertQGradKernel: ...@@ -1471,7 +1471,7 @@ class FmhaBwdConvertQGradKernel:
FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format( FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_bm0, F_bm0 = self.F_bm0,
F_bn0 = self.F_bn0, F_bn0 = self.F_bn0,
F_spad = BOOL_MAP[self.F_spad], F_spad = BOOL_MAP[self.F_spad],
...@@ -1506,7 +1506,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]: ...@@ -1506,7 +1506,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
gen = list() gen = list()
for dtype in DTYPE_MAP.keys(): for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
......
...@@ -46,9 +46,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl ...@@ -46,9 +46,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline< using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
fmha_pipeline_problem_{F_idx}>; fmha_pipeline_problem_{F_idx}>;
using fmha_kernel_{F_idx} = using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel<fmha_pipeline_{F_idx}>;
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
...@@ -181,7 +179,7 @@ class FmhaFwdAppendKVApiPool: ...@@ -181,7 +179,7 @@ class FmhaFwdAppendKVApiPool:
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout], inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
...@@ -216,7 +214,7 @@ class FmhaFwdAppendKVKernel: ...@@ -216,7 +214,7 @@ class FmhaFwdAppendKVKernel:
FMHA_FWD_APPENDKV_KERNEL_BODY.format( FMHA_FWD_APPENDKV_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bs = self.F_tile.F_bs, F_bs = self.F_tile.F_bs,
F_bsk = self.F_tile.F_bsk, F_bsk = self.F_tile.F_bsk,
F_bd = self.F_tile.F_bd, F_bd = self.F_tile.F_bd,
...@@ -301,6 +299,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -301,6 +299,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8', 'bf8']:
# rope/paged-kv is not supported # rope/paged-kv is not supported
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f')) pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else: else:
assert False assert False
return pipelines return pipelines
...@@ -308,7 +309,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -308,7 +309,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen = list() gen = list()
api_pool = FmhaFwdAppendKVApiPool(mask_impl) api_pool = FmhaFwdAppendKVApiPool(mask_impl)
for dtype in DTYPE_MAP.keys(): for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype) d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
...@@ -352,4 +353,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im ...@@ -352,4 +353,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels: for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment