Commit 7b12d9b7 authored by danyao12's avatar danyao12
Browse files

some kernels and related api update

parent d4de8495
......@@ -303,12 +303,12 @@ class fmha_bwd_v3_kernel
HIP_LAUNCH_PARAM_END}};
int bdx = 256;
int gdx = fmha_v3_traits.s / fmha_v3_traits.ts_kv;
int gdx = (fmha_v3_traits.s + fmha_v3_traits.ts_kv - 1) / fmha_v3_traits.ts_kv;
int gdy = fmha_v3_traits.h;
int gdz = fmha_v3_traits.b;
if(fmha_v3_traits.mask > 0)
{{
int num_tg = fmha_v3_traits.s / fmha_v3_traits.ts_kv;
int num_tg = (fmha_v3_traits.s + fmha_v3_traits.ts_kv - 1) / fmha_v3_traits.ts_kv;
gdx = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
}}
HIP_CALL(hipModuleLaunchKernel(kernel_func,
......@@ -335,12 +335,12 @@ class fmha_bwd_v3_kernel
HIP_LAUNCH_PARAM_END}};
int bdx = 256;
int gdx = fmha_v3_traits.s / fmha_v3_traits.ts_kv;
int gdx = (fmha_v3_traits.s + fmha_v3_traits.ts_kv - 1) / fmha_v3_traits.ts_kv;
int gdy = fmha_v3_traits.h;
int gdz = fmha_v3_traits.b;
if(fmha_v3_traits.mask > 0)
{{
int num_tg = fmha_v3_traits.s / fmha_v3_traits.ts_kv;
int num_tg = (fmha_v3_traits.s + fmha_v3_traits.ts_kv - 1) / fmha_v3_traits.ts_kv;
gdx = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
}}
HIP_CALL(hipModuleLaunchKernel(kernel_func,
......@@ -374,11 +374,11 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
}}
template <typename dot_do_o_trait_>
float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_v3_buf[], const std::string& bwd_v3_name, bool io_perm)
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_v3_buf[], const std::string& bwd_v3_name, bool io_perm, int ts_qo, int ts_kv)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_v3_name << std::flush;
fmha_bwd_v3_args args;
fmha_bwd_xqa_v3_args args;
args.ptr_dq = a.dq_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
......@@ -392,28 +392,43 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned c
args.log2e = ck_tile::log2e_v<float>;
args.seq_len = a.seqlen_q;
int stride_tg = 128 * a.hdim_q * 2;
int stride_tg = ts_kv * a.hdim_q * 2;
int stride_head = a.seqlen_q * a.hdim_q * 2;
int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen = a.hdim_q * 2;
int stride_head_kv = a.seqlen_q * a.hdim_q * 2;
int stride_batch_kv = a.nhead_k * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen_kv = a.hdim_q * 2;
int stride_seqlen_dkv = a.hdim_q * 2;
if(io_perm == 0) //BSHD
{{
stride_seqlen = a.nhead_q * a.hdim_q * 2;
stride_tg = 128 * stride_seqlen;
stride_head = a.hdim_q * 2;
stride_seqlen_kv = a.nhead_k * a.hdim_q * 2;
stride_seqlen_dkv = a.nhead_q * a.hdim_q * 2;
stride_tg = ts_kv * stride_seqlen_kv;
stride_head_kv = a.hdim_q * 2;
}}
args.Ts = stride_tg;
args.Hs = stride_head;
args.BAs = stride_batch;
args.Seqs = stride_seqlen;
args.ratio = a.nhead_q / a.nhead_k;
args.Hs_kv = stride_head_kv;
args.BAs_kv = stride_batch_kv;
args.Seqs_kv = stride_seqlen_kv;
args.Seqs_dkv = stride_seqlen_dkv;
auto traits = fmha_bwd_v3_traits{{a.batch,
a.nhead_q,
a.seqlen_q,
a.hdim_q,
1,
a.mask_type,
32,
128}};
ts_qo,
ts_kv}};
static fmha_bwd_v3_kernel impl(HSA_KERNEL, bwd_v3_buf); // static here is for thread safety.
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
......@@ -421,13 +436,13 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned c
);
}}
template <typename dot_do_o_trait_>
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_v3_buf[], const std::string& bwd_v3_name, bool io_perm)
template <typename dot_do_o_trait_, typename convert_dq_trait_>
float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_v3_buf[], const std::string& bwd_v3_name, bool io_perm, int ts_qo, int ts_kv)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_v3_name << std::flush;
fmha_bwd_xqa_v3_args args;
args.ptr_dq = a.dq_ptr;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
fmha_bwd_v3_args args;
args.ptr_dq = a.dq_acc_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
args.ptr_q = a.q_ptr;
......@@ -440,56 +455,42 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsign
args.log2e = ck_tile::log2e_v<float>;
args.seq_len = a.seqlen_q;
int stride_tg = 128 * a.hdim_q * 2;
int stride_tg = ts_kv * a.hdim_q * 2;
int stride_head = a.seqlen_q * a.hdim_q * 2;
int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen = a.hdim_q * 2;
int stride_head_kv = a.seqlen_q * a.hdim_q * 2;
int stride_batch_kv = a.nhead_k * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen_kv = a.hdim_q * 2;
int stride_seqlen_dkv = a.hdim_q * 2;
if(io_perm == 0) //BSHD
{{
stride_seqlen = a.nhead_q * a.hdim_q * 2;
stride_tg = ts_kv * stride_seqlen;
stride_head = a.hdim_q * 2;
stride_seqlen_kv = a.nhead_k * a.hdim_q * 2;
stride_seqlen_dkv = a.nhead_q * a.hdim_q * 2;
stride_tg = 128 * stride_seqlen_kv;
stride_head_kv = a.hdim_q * 2;
}}
args.Ts = stride_tg;
args.Hs = stride_head;
args.BAs = stride_batch;
args.Seqs = stride_seqlen;
args.ratio = a.nhead_q / a.nhead_k;
args.Hs_kv = stride_head_kv;
args.BAs_kv = stride_batch_kv;
args.Seqs_kv = stride_seqlen_kv;
args.Seqs_dkv = stride_seqlen_dkv;
auto traits = fmha_bwd_v3_traits{{a.batch,
a.nhead_q,
a.seqlen_q,
a.hdim_q,
1,
a.mask_type,
32,
128}};
ts_qo,
ts_kv}};
static fmha_bwd_v3_kernel impl(HSA_KERNEL, bwd_v3_buf); // static here is for thread safety.
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
);
}}
template <typename dot_do_o_trait_, typename convert_dq_trait_>
float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_v3_buf[], const std::string& bwd_v3_name, bool io_perm)
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_v3_buf[], const std::string& bwd_v3_name, bool io_perm, int ts_qo, int ts_kv)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
fmha_bwd_v3_args args;
fmha_bwd_xqa_v3_args args;
args.ptr_dq = a.dq_acc_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
......@@ -503,28 +504,43 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned c
args.log2e = ck_tile::log2e_v<float>;
args.seq_len = a.seqlen_q;
int stride_tg = 128 * a.hdim_q * 2;
int stride_tg = ts_kv * a.hdim_q * 2;
int stride_head = a.seqlen_q * a.hdim_q * 2;
int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen = a.hdim_q * 2;
int stride_head_kv = a.seqlen_q * a.hdim_q * 2;
int stride_batch_kv = a.nhead_k * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen_kv = a.hdim_q * 2;
int stride_seqlen_dkv = a.hdim_q * 2;
if(io_perm == 0) //BSHD
{{
stride_seqlen = a.nhead_q * a.hdim_q * 2;
stride_tg = 128 * stride_seqlen;
stride_head = a.hdim_q * 2;
stride_seqlen_kv = a.nhead_k * a.hdim_q * 2;
stride_seqlen_dkv = a.nhead_q * a.hdim_q * 2;
stride_tg = ts_kv * stride_seqlen_kv;
stride_head_kv = a.hdim_q * 2;
}}
args.Ts = stride_tg;
args.Hs = stride_head;
args.BAs = stride_batch;
args.Seqs = stride_seqlen;
args.ratio = a.nhead_q / a.nhead_k;
args.Hs_kv = stride_head_kv;
args.BAs_kv = stride_batch_kv;
args.Seqs_kv = stride_seqlen_kv;
args.Seqs_dkv = stride_seqlen_dkv;
auto traits = fmha_bwd_v3_traits{{a.batch,
a.nhead_q,
a.seqlen_q,
a.hdim_q,
1,
a.mask_type,
32,
128}};
ts_qo,
ts_kv}};
static fmha_bwd_v3_kernel impl(HSA_KERNEL, bwd_v3_buf); // static here is for thread safety.
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
......@@ -542,22 +558,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
(a.stride_q == a.stride_o /*i_perm == o_perm*/)) {{
if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_v3_spec == true){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_spec_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_a32, bwd_v3_name, io_perm);
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_a32, bwd_v3_name, io_perm, 32, 128);
return r;
}}
else{{
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_v3_name, io_perm);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
......@@ -565,27 +580,26 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_v3_name, io_perm);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_v3_spec == true){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_spec_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_causal_a32, bwd_v3_name, io_perm);
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_spec_causal_a32, bwd_v3_name, io_perm, 32, 128);
return r;
}}
else{{
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_v3_name, io_perm);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
......@@ -593,29 +607,28 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_causal_a16, bwd_v3_name, io_perm);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_causal_a16, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
}}
else if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_v3_spec == true){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_spec_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_a32, bwd_v3_name, io_perm);
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_a32, bwd_v3_name, io_perm, 32, 128);
return r;
}}
else{{
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32, bwd_v3_name, io_perm);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
......@@ -624,35 +637,34 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtz, bwd_v3_name, io_perm);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtz, bwd_v3_name, io_perm, 32, 128);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_v3_name, io_perm);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_v3_spec == true){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_v3_spec == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_spec_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_causal_a32, bwd_v3_name, io_perm);
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_spec_causal_a32, bwd_v3_name, io_perm, 32, 128);
return r;
}}
else{{
else if((t.is_v3_spec == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32, bwd_v3_name, io_perm);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
......@@ -661,14 +673,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtz, bwd_v3_name, io_perm);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtz, bwd_v3_name, io_perm, 32, 128);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_v3_name, io_perm);
r = fmha_bwd_v3_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_v3_name, io_perm, 16, 192);
return r;
}}
}}
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -11,13 +11,17 @@ set -x
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 128 ; do
for v3_atomic_fp32 in 0 1 ; do
for v3_rtz_cvt in 0 1 ; do
for mask in 0 1 ; do
$EXE -prec=$prec -b=2 -h=4 -h_k=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=0 -v3_rtz_cvt=1 -mode=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=3 -h_k=1 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=0 -v3_rtz_cvt=1 -mode=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=4 -h_k=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -v3_rtz_cvt=$v3_rtz_cvt -mode=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=3 -h_k=1 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -v3_rtz_cvt=$v3_rtz_cvt -mode=0 -kname=$KNAME $COMMON_ARGS
done
done
done
done
done
done
set +x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment