Commit 2dafca1f authored by danyao12's avatar danyao12
Browse files

mqa/gqa support for atomic f16 cases

parent 2a4c2316
...@@ -224,6 +224,52 @@ struct __attribute__((packed)) fmha_bwd_asm_args ...@@ -224,6 +224,52 @@ struct __attribute__((packed)) fmha_bwd_asm_args
p3 _p15; p3 _p15;
}}; }};
struct __attribute__((packed)) fmha_bwd_xqa_asm_args
{{
void* ptr_dq;
p2 _p0;
void* ptr_dk;
p2 _p1;
void* ptr_dv;
p2 _p2;
const void* ptr_q;
p2 _p3;
const void* ptr_k;
p2 _p4;
const void* ptr_v;
p2 _p5;
const void* ptr_do;
p2 _p6;
const void* ptr_lse;
p2 _p7;
const void* ptr_d;
p2 _p8;
float scalar;
p3 _p9;
float log2e;
p3 _p10;
unsigned int seq_len;
p3 _p11;
unsigned int Ts;
p3 _p12;
unsigned int Hs;
p3 _p13;
unsigned int BAs;
p3 _p14;
unsigned int Seqs;
p3 _p15;
unsigned int ratio;
p3 _p16;
unsigned int Hs_kv;
p3 _p17;
unsigned int BAs_kv;
p3 _p18;
unsigned int Seqs_kv;
p3 _p19;
unsigned int Seqs_dkv;
p3 _p20;
}};
struct fmha_bwd_ext_traits struct fmha_bwd_ext_traits
{{ {{
int b; int b;
...@@ -278,6 +324,38 @@ class fmha_bwd_ext_kernel ...@@ -278,6 +324,38 @@ class fmha_bwd_ext_kernel
reinterpret_cast<void**>(&config))); reinterpret_cast<void**>(&config)));
}} }}
void
launch_kernel(fmha_bwd_ext_traits fmha_ext_traits, fmha_bwd_xqa_asm_args args, const ck_tile::stream_config& s) const
{{
size_t arg_size = sizeof(args);
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
&args,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
&arg_size,
HIP_LAUNCH_PARAM_END}};
int bdx = 256;
int gdx = fmha_ext_traits.s / fmha_ext_traits.ts_kv;
int gdy = fmha_ext_traits.h;
int gdz = fmha_ext_traits.b;
if(fmha_ext_traits.mask > 0)
{{
int num_tg = fmha_ext_traits.s / fmha_ext_traits.ts_kv;
gdx = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
}}
HIP_CALL(hipModuleLaunchKernel(kernel_func,
gdx,
gdy,
gdz,
bdx,
1,
1,
0,
s.stream_id_,
NULL,
reinterpret_cast<void**>(&config)));
}}
private: private:
hipModule_t module; hipModule_t module;
hipFunction_t kernel_func; hipFunction_t kernel_func;
...@@ -343,6 +421,69 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned ...@@ -343,6 +421,69 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
); );
}} }}
template <typename dot_do_o_trait_>
float fmha_ext_bwd_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name, bool io_perm)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_ext_name << std::flush;
fmha_bwd_xqa_asm_args args;
args.ptr_dq = a.dq_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
args.ptr_q = a.q_ptr;
args.ptr_k = a.k_ptr;
args.ptr_v = a.v_ptr;
args.ptr_do = a.do_ptr;
args.ptr_lse = a.lse_ptr;
args.ptr_d = a.d_ptr;
args.scalar = a.scale;
args.log2e = ck_tile::log2e_v<float>;
args.seq_len = a.seqlen_q;
int stride_tg = 128 * a.hdim_q * 2;
int stride_head = a.seqlen_q * a.hdim_q * 2;
int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen = a.hdim_q * 2;
int stride_head_kv = a.seqlen_q * a.hdim_q * 2;
int stride_batch_kv = a.nhead_k * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen_kv = a.hdim_q * 2;
int stride_seqlen_dkv = a.hdim_q * 2;
if(io_perm == 0) //BSHD
{{
stride_seqlen = a.nhead_q * a.hdim_q * 2;
stride_head = a.hdim_q * 2;
stride_seqlen_kv = a.nhead_k * a.hdim_q * 2;
stride_seqlen_dkv = a.nhead_q * a.hdim_q * 2;
stride_tg = 128 * stride_seqlen_kv;
stride_head_kv = a.hdim_q * 2;
}}
args.Ts = stride_tg;
args.Hs = stride_head;
args.BAs = stride_batch;
args.Seqs = stride_seqlen;
args.ratio = a.nhead_q / a.nhead_k;
args.Hs_kv = stride_head_kv;
args.BAs_kv = stride_batch_kv;
args.Seqs_kv = stride_seqlen_kv;
args.Seqs_dkv = stride_seqlen_dkv;
auto traits = fmha_bwd_ext_traits{{a.batch,
a.nhead_q,
a.seqlen_q,
a.hdim_q,
1,
a.mask_type,
32,
128}};
fmha_bwd_ext_kernel impl(HSA_KERNEL, bwd_ext_asm);
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
);
}}
template <typename dot_do_o_trait_, typename convert_dq_trait_> template <typename dot_do_o_trait_, typename convert_dq_trait_>
float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name, bool io_perm) float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name, bool io_perm)
{{ {{
...@@ -398,11 +539,11 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -398,11 +539,11 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if (t.uses_ext_asm == true){{ if (t.uses_ext_asm == true){{
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
(a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false) &&
(a.stride_q == a.stride_dq /*i_perm == o_perm*/) && (a.stride_k == a.stride_dk /*i_perm == o_perm*/) && (a.stride_q == a.stride_o /*i_perm == o_perm*/)) {{
(a.stride_v == a.stride_dv /*i_perm == o_perm*/) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)) {{
if(t.data_type.compare("fp16") == 0){{ if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{ if(t.mask_type == mask_enum::no_mask){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_asm_no_coex == true){{ if(t.is_asm_no_coex == true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
...@@ -420,16 +561,17 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -420,16 +561,17 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
return r; return r;
}} }}
}} }}
else if(t.is_asm_atomic_fp32 == false){{ else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_a16"; const std::string bwd_ext_name = "bwd_ext_fp16_a16";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_ext_name, io_perm); r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_ext_name, io_perm);
return r; return r;
}} }}
}} }}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_asm_no_coex == true){{ if(t.is_asm_no_coex == true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
...@@ -447,18 +589,19 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -447,18 +589,19 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
return r; return r;
}} }}
}} }}
else if(t.is_asm_atomic_fp32 == false){{ else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_causal_a16"; const std::string bwd_ext_name = "bwd_ext_fp16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_>(s, a, bwd_fp16_causal_a16, bwd_ext_name, io_perm); r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_causal_a16, bwd_ext_name, io_perm);
return r; return r;
}} }}
}} }}
}} }}
else if(t.data_type.compare("bf16") == 0){{ else if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{ if(t.mask_type == mask_enum::no_mask){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_asm_no_coex == true){{ if(t.is_asm_no_coex == true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
...@@ -476,16 +619,17 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -476,16 +619,17 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
return r; return r;
}} }}
}} }}
else if(t.is_asm_atomic_fp32 == false){{ else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_a16"; const std::string bwd_ext_name = "bwd_ext_bf16_a16";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_ext_name, io_perm); r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_ext_name, io_perm);
return r; return r;
}} }}
}} }}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_asm_no_coex == true){{ if(t.is_asm_no_coex == true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
...@@ -503,11 +647,11 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& ...@@ -503,11 +647,11 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
return r; return r;
}} }}
}} }}
else if(t.is_asm_atomic_fp32 == false){{ else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a16"; const std::string bwd_ext_name = "bwd_ext_bf16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q; bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_ext_name, io_perm); r = fmha_ext_bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_ext_name, io_perm);
return r; return r;
}} }}
}} }}
......
...@@ -15,8 +15,8 @@ for asm_atomic_fp32 in 0 1 ; do ...@@ -15,8 +15,8 @@ for asm_atomic_fp32 in 0 1 ; do
for asm_no_coex in 0 1 ; do for asm_no_coex in 0 1 ; do
for mask in 0 1 ; do for mask in 0 1 ; do
$EXE -prec=$prec -b=4 -h=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=$asm_atomic_fp32 -asm_no_coex=$asm_no_coex -v=1 -mode=0 -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=4 -h=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=$asm_atomic_fp32 -asm_no_coex=$asm_no_coex -mode=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=3 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=$asm_atomic_fp32 -asm_no_coex=$asm_no_coex -v=1 -mode=0 -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=1 -h=3 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=$asm_atomic_fp32 -asm_no_coex=$asm_no_coex -mode=0 -kname=$KNAME $COMMON_ARGS
done done
done done
......
#!/bin/sh
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
KNAME=1
export CK_WARMUP=0
export CK_REPEAT=1
COMMON_ARGS='-v=1'
set -x
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 128 ; do
for mask in 0 1 ; do
$EXE -prec=$prec -b=2 -h=4 -h_k=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=0 -mode=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=3 -h_k=1 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=0 -mode=0 -kname=$KNAME $COMMON_ARGS
done
done
done
done
set +x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment