"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "0c0779d667a511b14cc50590184c6e5181622cea"
Commit 67b160c5 authored by danyao12's avatar danyao12
Browse files

enable bwd_fp16_a16

parent c3b406d6
...@@ -293,6 +293,42 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) ...@@ -293,6 +293,42 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
); );
}} }}
template <typename dot_do_o_trait_>
float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_ext_name << std::flush;
fmha_bwd_asm_args args;
args.ptr_dq = a.dq_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
args.ptr_q = a.q_ptr;
args.ptr_k = a.k_ptr;
args.ptr_v = a.v_ptr;
args.ptr_do = a.do_ptr;
args.ptr_lse = a.lse_ptr;
args.ptr_d = a.d_ptr;
args.scalar = a.scale;
args.log2e = ck_tile::log2e_v<float>;
args.seq_len = a.seqlen_q;
args.Ts = 128 * a.hdim_q * 2;
args.Hs = a.seqlen_q * a.hdim_q * 2;
args.BAs = a.nhead_q * a.seqlen_q * a.hdim_q * 2;
auto traits = fmha_bwd_ext_traits{{a.batch,
a.nhead_q,
a.seqlen_q,
a.hdim_q,
1,
a.mask_type,
32,
128}};
fmha_bwd_ext_kernel impl(HSA_KERNEL, bwd_ext_asm);
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
);
}}
template <typename dot_do_o_trait_, typename convert_dq_trait_> template <typename dot_do_o_trait_, typename convert_dq_trait_>
float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name) float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name)
{{ {{
...@@ -338,11 +374,19 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -338,11 +374,19 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
(a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false)) {{ (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false)) {{
if(t.data_type.compare("fp16") == 0){{ if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{ if(t.mask_type == mask_enum::no_mask){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; if(t.is_asm_atomic_fp32 == true){{
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_a32"; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_ext_name); const std::string bwd_ext_name = "bwd_ext_fp16_a32";
return r; r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_ext_name);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_a16";
r = fmha_ext_bwd_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_ext_name);
return r;
}}
}} }}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment