Commit 871c7556 authored by danyao12's avatar danyao12
Browse files

add bf16+a16 rtz

parent 2dafca1f
...@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") ...@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# to be included in "make all/install/check" # to be included in "make all/install/check"
message("adding example ${EXAMPLE_FMHA_BWD}") message("adding example ${EXAMPLE_FMHA_BWD}")
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16.cpp hsaco/bwd_bf16_a32.cpp hsaco/bwd_bf16_causal_a16.cpp hsaco/bwd_bf16_causal_a32.cpp hsaco/bwd_bf16_nocoex_a32.cpp hsaco/bwd_bf16_nocoex_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_nocoex_a32.cpp hsaco/bwd_fp16_nocoex_causal_a32.cpp fmha_bwd.cpp) add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32.cpp hsaco/bwd_bf16_causal_a16.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32.cpp hsaco/bwd_bf16_nocoex_a32.cpp hsaco/bwd_bf16_nocoex_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_nocoex_a32.cpp hsaco/bwd_fp16_nocoex_causal_a32.cpp fmha_bwd.cpp)
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
......
...@@ -620,11 +620,20 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -620,11 +620,20 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
}} }}
}} }}
else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; if(t.is_asm_rtz_cvt == true){{
const std::string bwd_ext_name = "bwd_ext_bf16_a16"; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
bool io_perm = a.nhead_stride_q > a.stride_q; const std::string bwd_ext_name = "bwd_ext_bf16_a16_rtz";
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_ext_name, io_perm); bool io_perm = a.nhead_stride_q > a.stride_q;
return r; r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtz, bwd_ext_name, io_perm);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_ext_name, io_perm);
return r;
}}
}} }}
}} }}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
...@@ -648,11 +657,20 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -648,11 +657,20 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
}} }}
}} }}
else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; if(t.is_asm_rtz_cvt == true){{
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a16"; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
bool io_perm = a.nhead_stride_q > a.stride_q; const std::string bwd_ext_name = "bwd_ext_bf16_causal_a16_rtz";
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_ext_name, io_perm); bool io_perm = a.nhead_stride_q > a.stride_q;
return r; r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtz, bwd_ext_name, io_perm);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_ext_name, io_perm);
return r;
}}
}} }}
}} }}
}} }}
......
...@@ -99,7 +99,10 @@ auto create_args(int argc, char* argv[]) ...@@ -99,7 +99,10 @@ auto create_args(int argc, char* argv[])
"if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) when ext_asm is set to 1") "if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) when ext_asm is set to 1")
.insert("asm_no_coex", .insert("asm_no_coex",
"0", "0",
"if set to 1 will use non-coexectuion kernel when ext_asm is set to 1"); "if set to 1 will use non-coexectuion kernel when ext_asm is set to 1")
.insert("asm_rtz_cvt",
"0",
"if set to 1 will use float to bf16 RTZ convert when ext_asm is set to 1");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
...@@ -191,6 +194,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -191,6 +194,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
bool ext_asm = arg_parser.get_bool("ext_asm"); bool ext_asm = arg_parser.get_bool("ext_asm");
bool asm_atomic_fp32 = arg_parser.get_bool("asm_atomic_fp32"); bool asm_atomic_fp32 = arg_parser.get_bool("asm_atomic_fp32");
bool asm_no_coex = arg_parser.get_bool("asm_no_coex"); bool asm_no_coex = arg_parser.get_bool("asm_no_coex");
bool asm_rtz_cvt = arg_parser.get_bool("asm_rtz_cvt");
ck_tile::stream_config stream_config{nullptr, ck_tile::stream_config stream_config{nullptr,
true, true,
...@@ -428,7 +432,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -428,7 +432,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
deterministic, deterministic,
ext_asm, ext_asm,
asm_atomic_fp32, asm_atomic_fp32,
asm_no_coex}; asm_no_coex,
asm_rtz_cvt};
auto fmha_args = [&]() { auto fmha_args = [&]() {
assert(nhead % nhead_k == 0); assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
......
...@@ -441,6 +441,7 @@ struct fmha_bwd_traits ...@@ -441,6 +441,7 @@ struct fmha_bwd_traits
bool uses_ext_asm; bool uses_ext_asm;
bool is_asm_atomic_fp32; bool is_asm_atomic_fp32;
bool is_asm_no_coex; bool is_asm_no_coex;
bool is_asm_rtz_cvt;
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -4,8 +4,10 @@ ...@@ -4,8 +4,10 @@
#pragma once #pragma once
extern unsigned char bwd_bf16_a16[]; extern unsigned char bwd_bf16_a16[];
extern unsigned char bwd_bf16_a16_rtz[];
extern unsigned char bwd_bf16_a32[]; extern unsigned char bwd_bf16_a32[];
extern unsigned char bwd_bf16_causal_a16[]; extern unsigned char bwd_bf16_causal_a16[];
extern unsigned char bwd_bf16_causal_a16_rtz[];
extern unsigned char bwd_bf16_causal_a32[]; extern unsigned char bwd_bf16_causal_a32[];
extern unsigned char bwd_bf16_nocoex_a32[]; extern unsigned char bwd_bf16_nocoex_a32[];
extern unsigned char bwd_bf16_nocoex_causal_a32[]; extern unsigned char bwd_bf16_nocoex_causal_a32[];
......
...@@ -13,8 +13,8 @@ for perm in 0 1 ; do ...@@ -13,8 +13,8 @@ for perm in 0 1 ; do
for hdim in 128 ; do for hdim in 128 ; do
for mask in 0 1 ; do for mask in 0 1 ; do
$EXE -prec=$prec -b=2 -h=4 -h_k=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=0 -mode=0 -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=2 -h=4 -h_k=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=0 -asm_rtz_cvt=1 -mode=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=3 -h_k=1 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=0 -mode=0 -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=1 -h=3 -h_k=1 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=0 -asm_rtz_cvt=1 -mode=0 -kname=$KNAME $COMMON_ARGS
done done
done done
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment