Unverified Commit ee0654c4 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

use cmakelists as the single source of truth for score_mod function definition

parent 61108fdf
...@@ -6,6 +6,9 @@ if(FMHA_FWD_ENABLE_APIS STREQUAL "all") ...@@ -6,6 +6,9 @@ if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS}) set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
endif() endif()
variable_watch(FMHA_SCORE_MOD_F)
set(FMHA_SCORE_MOD_F [[s + static_cast<decltype(s)>(q_idx - v_idx)]])
foreach(api ${FMHA_FWD_ENABLE_APIS}) foreach(api ${FMHA_FWD_ENABLE_APIS})
if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS) if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.") message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.")
...@@ -21,11 +24,12 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") ...@@ -21,11 +24,12 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
# generate a list of kernels, but not actually emit files at config sta # generate a list of kernels, but not actually emit files at config sta
execute_process( execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt --api ${FMHA_FWD_APIS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
RESULT_VARIABLE ret RESULT_VARIABLE ret
) )
if(ret AND NOT ret EQUAL 0) if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.") message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD kernels via Python.")
endif() endif()
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory # NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory
...@@ -35,7 +39,10 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS) ...@@ -35,7 +39,10 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
add_custom_command( add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS} OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} --api ${FMHA_FWD_APIS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
"--score_mod_expr=${FMHA_SCORE_MOD_F}"
VERBATIM
) )
set(EXAMPLE_FMHA_FWD "tile_example_flexattn_fwd") set(EXAMPLE_FMHA_FWD "tile_example_flexattn_fwd")
...@@ -80,6 +87,8 @@ endif() ...@@ -80,6 +87,8 @@ endif()
# Allow comparing floating points directly in order to check sentinel values # Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS "-DCK_TILE_SCORE_MOD_F=${FMHA_SCORE_MOD_F}")
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated # TODO: we have to turn off this global prop, otherwise the progress bar generated
......
...@@ -1375,9 +1375,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1375,9 +1375,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::identity{}, ck_tile::identity{},
ck_tile::identity{}); ck_tile::identity{});
#ifndef CK_TILE_SCORE_MOD_F
#error "must be defined"
#else
#define XSTR(x) STR(x)
#define STR(x) #x
#pragma message "host score_mod_f: " XSTR(CK_TILE_SCORE_MOD_F)
#endif
auto score_mod = [] (auto s, ck_tile::index_t b, ck_tile::index_t h, ck_tile::index_t q_idx, ck_tile::index_t v_idx) { auto score_mod = [] (auto s, ck_tile::index_t b, ck_tile::index_t h, ck_tile::index_t q_idx, ck_tile::index_t v_idx) {
(void) s; (void) b; (void) h; (void) q_idx; (void) v_idx; (void) s; (void) b; (void) h; (void) q_idx; (void) v_idx;
return s + static_cast<decltype(s)>(q_idx - v_idx); return CK_TILE_SCORE_MOD_F;
}; };
s_host_ref.ForEach([&](auto& self, auto i) { s_host_ref.ForEach([&](auto& self, auto i) {
......
...@@ -108,9 +108,7 @@ if __name__ == "__main__": ...@@ -108,9 +108,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--score_mod_expr", "--score_mod_expr",
# default="s", default="s",
# test with
default="s + static_cast<decltype(s)>(q_idx - v_idx)",
required=False, required=False,
help="flex attention's score mod function, a cpp expression with `s`, `b`, `h`, `q_idx`, and `v_idx` variables" help="flex attention's score mod function, a cpp expression with `s`, `b`, `h`, `q_idx`, and `v_idx` variables"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment