Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2dafca1f
Commit
2dafca1f
authored
Sep 27, 2024
by
danyao12
Browse files
mqa/gqa support for atomic f16 cases
parent
2a4c2316
Changes
7
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1480 additions
and
917 deletions
+1480
-917
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+158
-14
example/ck_tile/01_fmha/hsaco/bwd_bf16_a16.cpp
example/ck_tile/01_fmha/hsaco/bwd_bf16_a16.cpp
+291
-192
example/ck_tile/01_fmha/hsaco/bwd_bf16_causal_a16.cpp
example/ck_tile/01_fmha/hsaco/bwd_bf16_causal_a16.cpp
+380
-281
example/ck_tile/01_fmha/hsaco/bwd_fp16_a16.cpp
example/ck_tile/01_fmha/hsaco/bwd_fp16_a16.cpp
+285
-186
example/ck_tile/01_fmha/hsaco/bwd_fp16_causal_a16.cpp
example/ck_tile/01_fmha/hsaco/bwd_fp16_causal_a16.cpp
+341
-242
example/ck_tile/01_fmha/script/smoke_test_bwd_ext.sh
example/ck_tile/01_fmha/script/smoke_test_bwd_ext.sh
+2
-2
example/ck_tile/01_fmha/script/smoke_test_bwd_xqa_ext.sh
example/ck_tile/01_fmha/script/smoke_test_bwd_xqa_ext.sh
+23
-0
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
2dafca1f
...
...
@@ -224,6 +224,52 @@ struct __attribute__((packed)) fmha_bwd_asm_args
p3 _p15;
}};
struct __attribute__((packed)) fmha_bwd_xqa_asm_args
{{
void* ptr_dq;
p2 _p0;
void* ptr_dk;
p2 _p1;
void* ptr_dv;
p2 _p2;
const void* ptr_q;
p2 _p3;
const void* ptr_k;
p2 _p4;
const void* ptr_v;
p2 _p5;
const void* ptr_do;
p2 _p6;
const void* ptr_lse;
p2 _p7;
const void* ptr_d;
p2 _p8;
float scalar;
p3 _p9;
float log2e;
p3 _p10;
unsigned int seq_len;
p3 _p11;
unsigned int Ts;
p3 _p12;
unsigned int Hs;
p3 _p13;
unsigned int BAs;
p3 _p14;
unsigned int Seqs;
p3 _p15;
unsigned int ratio;
p3 _p16;
unsigned int Hs_kv;
p3 _p17;
unsigned int BAs_kv;
p3 _p18;
unsigned int Seqs_kv;
p3 _p19;
unsigned int Seqs_dkv;
p3 _p20;
}};
struct fmha_bwd_ext_traits
{{
int b;
...
...
@@ -278,6 +324,38 @@ class fmha_bwd_ext_kernel
reinterpret_cast<void**>(&config)));
}}
void
launch_kernel(fmha_bwd_ext_traits fmha_ext_traits, fmha_bwd_xqa_asm_args args, const ck_tile::stream_config& s) const
{{
size_t arg_size = sizeof(args);
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
&args,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
&arg_size,
HIP_LAUNCH_PARAM_END}};
int bdx = 256;
int gdx = fmha_ext_traits.s / fmha_ext_traits.ts_kv;
int gdy = fmha_ext_traits.h;
int gdz = fmha_ext_traits.b;
if(fmha_ext_traits.mask > 0)
{{
int num_tg = fmha_ext_traits.s / fmha_ext_traits.ts_kv;
gdx = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
}}
HIP_CALL(hipModuleLaunchKernel(kernel_func,
gdx,
gdy,
gdz,
bdx,
1,
1,
0,
s.stream_id_,
NULL,
reinterpret_cast<void**>(&config)));
}}
private:
hipModule_t module;
hipFunction_t kernel_func;
...
...
@@ -343,6 +421,69 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
);
}}
template <typename dot_do_o_trait_>
float fmha_ext_bwd_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name, bool io_perm)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_ext_name << std::flush;
fmha_bwd_xqa_asm_args args;
args.ptr_dq = a.dq_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
args.ptr_q = a.q_ptr;
args.ptr_k = a.k_ptr;
args.ptr_v = a.v_ptr;
args.ptr_do = a.do_ptr;
args.ptr_lse = a.lse_ptr;
args.ptr_d = a.d_ptr;
args.scalar = a.scale;
args.log2e = ck_tile::log2e_v<float>;
args.seq_len = a.seqlen_q;
int stride_tg = 128 * a.hdim_q * 2;
int stride_head = a.seqlen_q * a.hdim_q * 2;
int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen = a.hdim_q * 2;
int stride_head_kv = a.seqlen_q * a.hdim_q * 2;
int stride_batch_kv = a.nhead_k * a.seqlen_q * a.hdim_q * 2;
int stride_seqlen_kv = a.hdim_q * 2;
int stride_seqlen_dkv = a.hdim_q * 2;
if(io_perm == 0) //BSHD
{{
stride_seqlen = a.nhead_q * a.hdim_q * 2;
stride_head = a.hdim_q * 2;
stride_seqlen_kv = a.nhead_k * a.hdim_q * 2;
stride_seqlen_dkv = a.nhead_q * a.hdim_q * 2;
stride_tg = 128 * stride_seqlen_kv;
stride_head_kv = a.hdim_q * 2;
}}
args.Ts = stride_tg;
args.Hs = stride_head;
args.BAs = stride_batch;
args.Seqs = stride_seqlen;
args.ratio = a.nhead_q / a.nhead_k;
args.Hs_kv = stride_head_kv;
args.BAs_kv = stride_batch_kv;
args.Seqs_kv = stride_seqlen_kv;
args.Seqs_dkv = stride_seqlen_dkv;
auto traits = fmha_bwd_ext_traits{{a.batch,
a.nhead_q,
a.seqlen_q,
a.hdim_q,
1,
a.mask_type,
32,
128}};
fmha_bwd_ext_kernel impl(HSA_KERNEL, bwd_ext_asm);
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
);
}}
template <typename dot_do_o_trait_, typename convert_dq_trait_>
float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name, bool io_perm)
{{
...
...
@@ -398,11 +539,11 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if (t.uses_ext_asm == true){{
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
(a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false) &&
(a.stride_q == a.stride_dq /*i_perm == o_perm*/) && (a.stride_k == a.stride_dk /*i_perm == o_perm*/) &&
(a.stride_v == a.stride_dv /*i_perm == o_perm*/) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)) {{
(a.stride_q == a.stride_o /*i_perm == o_perm*/)) {{
if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_asm_no_coex == true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
...
...
@@ -420,16 +561,17 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
return r;
}}
}}
else if(t.is_asm_atomic_fp32 == false){{
else if(
(
t.is_asm_atomic_fp32 == false)
&& (a.nhead_q % a.nhead_k == 0))
{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_ext_name, io_perm);
r = fmha_ext_bwd_
xqa_
<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_ext_name, io_perm);
return r;
}}
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_asm_no_coex == true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
...
...
@@ -447,18 +589,19 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
return r;
}}
}}
else if(t.is_asm_atomic_fp32 == false){{
else if(
(
t.is_asm_atomic_fp32 == false)
&& (a.nhead_q % a.nhead_k == 0))
{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_>(s, a, bwd_fp16_causal_a16, bwd_ext_name, io_perm);
r = fmha_ext_bwd_
xqa_
<dot_do_o_trait_>(s, a, bwd_fp16_causal_a16, bwd_ext_name, io_perm);
return r;
}}
}}
}}
else if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_asm_no_coex == true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
...
...
@@ -476,16 +619,17 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
return r;
}}
}}
else if(t.is_asm_atomic_fp32 == false){{
else if(
(
t.is_asm_atomic_fp32 == false)
&& (a.nhead_q % a.nhead_k == 0))
{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_ext_name, io_perm);
r = fmha_ext_bwd_
xqa_
<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_ext_name, io_perm);
return r;
}}
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_asm_no_coex == true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
...
...
@@ -503,11 +647,11 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
return r;
}}
}}
else if(t.is_asm_atomic_fp32 == false){{
else if(
(
t.is_asm_atomic_fp32 == false)
&& (a.nhead_q % a.nhead_k == 0))
{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_ext_name, io_perm);
r = fmha_ext_bwd_
xqa_
<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_ext_name, io_perm);
return r;
}}
}}
...
...
example/ck_tile/01_fmha/hsaco/bwd_bf16_a16.cpp
View file @
2dafca1f
This diff is collapsed.
Click to expand it.
example/ck_tile/01_fmha/hsaco/bwd_bf16_causal_a16.cpp
View file @
2dafca1f
This diff is collapsed.
Click to expand it.
example/ck_tile/01_fmha/hsaco/bwd_fp16_a16.cpp
View file @
2dafca1f
This diff is collapsed.
Click to expand it.
example/ck_tile/01_fmha/hsaco/bwd_fp16_causal_a16.cpp
View file @
2dafca1f
This diff is collapsed.
Click to expand it.
example/ck_tile/01_fmha/script/smoke_test_bwd_ext.sh
View file @
2dafca1f
...
...
@@ -15,8 +15,8 @@ for asm_atomic_fp32 in 0 1 ; do
for
asm_no_coex
in
0 1
;
do
for
mask
in
0 1
;
do
$EXE
-prec
=
$prec
-b
=
4
-h
=
2
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-ext_asm
=
1
-asm_atomic_fp32
=
$asm_atomic_fp32
-asm_no_coex
=
$asm_no_coex
-v
=
1
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
768
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-ext_asm
=
1
-asm_atomic_fp32
=
$asm_atomic_fp32
-asm_no_coex
=
$asm_no_coex
-v
=
1
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
4
-h
=
2
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-ext_asm
=
1
-asm_atomic_fp32
=
$asm_atomic_fp32
-asm_no_coex
=
$asm_no_coex
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
768
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-ext_asm
=
1
-asm_atomic_fp32
=
$asm_atomic_fp32
-asm_no_coex
=
$asm_no_coex
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
done
done
...
...
example/ck_tile/01_fmha/script/smoke_test_bwd_xqa_ext.sh
0 → 100644
View file @
2dafca1f
#!/bin/sh
# TODO: run this script from CK root or build directory
EXE
=
"
$(
find
.
-name
tile_example_fmha_bwd
-type
f |
head
-n
1
)
"
KNAME
=
1
export
CK_WARMUP
=
0
export
CK_REPEAT
=
1
COMMON_ARGS
=
'-v=1'
set
-x
for
prec
in
"fp16"
"bf16"
;
do
for
perm
in
0 1
;
do
for
hdim
in
128
;
do
for
mask
in
0 1
;
do
$EXE
-prec
=
$prec
-b
=
2
-h
=
4
-h_k
=
2
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-ext_asm
=
1
-asm_atomic_fp32
=
0
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
3
-h_k
=
1
-d
=
$hdim
-s
=
768
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-ext_asm
=
1
-asm_atomic_fp32
=
0
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
done
done
done
done
set
+x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment