Unverified Commit ee0654c4 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

use cmakelists as the single source of truth for score_mod function definition

parent 61108fdf
......@@ -6,6 +6,9 @@ if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
endif()
variable_watch(FMHA_SCORE_MOD_F)
set(FMHA_SCORE_MOD_F [[s + static_cast<decltype(s)>(q_idx - v_idx)]])
foreach(api ${FMHA_FWD_ENABLE_APIS})
if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.")
......@@ -21,11 +24,12 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
--api ${FMHA_FWD_APIS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.")
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD kernels via Python.")
endif()
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory
......@@ -35,7 +39,10 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
--api ${FMHA_FWD_APIS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
"--score_mod_expr=${FMHA_SCORE_MOD_F}"
VERBATIM
)
set(EXAMPLE_FMHA_FWD "tile_example_flexattn_fwd")
......@@ -80,6 +87,8 @@ endif()
# Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS "-DCK_TILE_SCORE_MOD_F=${FMHA_SCORE_MOD_F}")
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
......
......@@ -1375,9 +1375,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::identity{},
ck_tile::identity{});
#ifndef CK_TILE_SCORE_MOD_F
#error "must be defined"
#else
#define XSTR(x) STR(x)
#define STR(x) #x
#pragma message "host score_mod_f: " XSTR(CK_TILE_SCORE_MOD_F)
#endif
auto score_mod = [] (auto s, ck_tile::index_t b, ck_tile::index_t h, ck_tile::index_t q_idx, ck_tile::index_t v_idx) {
(void) s; (void) b; (void) h; (void) q_idx; (void) v_idx;
return s + static_cast<decltype(s)>(q_idx - v_idx);
return CK_TILE_SCORE_MOD_F;
};
s_host_ref.ForEach([&](auto& self, auto i) {
......
......@@ -108,9 +108,7 @@ if __name__ == "__main__":
parser.add_argument(
"--score_mod_expr",
# default="s",
# test with
default="s + static_cast<decltype(s)>(q_idx - v_idx)",
default="s",
required=False,
help="flex attention's score mod function, a cpp expression with `s`, `b`, `h`, `q_idx`, and `v_idx` variables"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment