Commit 224a7b02 authored by danyao12's avatar danyao12
Browse files

dq_acc stride

parent 99ed2c1a
...@@ -495,6 +495,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -495,6 +495,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_o, stride_o,
stride_randval, stride_randval,
stride_do, stride_do,
stride_q, // stride_dq_acc
stride_dk, stride_dk,
stride_dv, stride_dv,
stride_dbias, stride_dbias,
...@@ -506,6 +507,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -506,6 +507,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_randval, nhead_stride_randval,
nhead_stride_do, nhead_stride_do,
nhead_stride_lsed, nhead_stride_lsed,
nhead_stride_q, // nhead_stride_dq_acc
nhead_stride_dbias, nhead_stride_dbias,
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
...@@ -515,6 +517,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -515,6 +517,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_randval, batch_stride_randval,
batch_stride_do, batch_stride_do,
batch_stride_lsed, batch_stride_lsed,
batch_stride_q, // batch_stride_dq_acc
batch_stride_dk, batch_stride_dk,
batch_stride_dv, batch_stride_dv,
batch_stride_dbias, batch_stride_dbias,
......
...@@ -98,6 +98,7 @@ struct fmha_bwd_args ...@@ -98,6 +98,7 @@ struct fmha_bwd_args
ck_tile::index_t stride_o; ck_tile::index_t stride_o;
ck_tile::index_t stride_randval; ck_tile::index_t stride_randval;
ck_tile::index_t stride_do; ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dk; ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv; ck_tile::index_t stride_dv;
ck_tile::index_t stride_dbias; ck_tile::index_t stride_dbias;
...@@ -109,6 +110,7 @@ struct fmha_bwd_args ...@@ -109,6 +110,7 @@ struct fmha_bwd_args
ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed; ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dbias; ck_tile::index_t nhead_stride_dbias;
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
...@@ -118,6 +120,7 @@ struct fmha_bwd_args ...@@ -118,6 +120,7 @@ struct fmha_bwd_args
ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed; ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv; ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias; ck_tile::index_t batch_stride_dbias;
...@@ -164,6 +167,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -164,6 +167,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_do, args.stride_do,
args.stride_dq_acc,
args.stride_dk, args.stride_dk,
args.stride_dv, args.stride_dv,
args.stride_dbias, args.stride_dbias,
...@@ -174,6 +178,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -174,6 +178,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_lsed, args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.batch_stride_lsed, args.batch_stride_lsed,
args.split_stride_dq_acc, args.split_stride_dq_acc,
...@@ -210,6 +215,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -210,6 +215,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_do, args.stride_do,
args.stride_dq_acc,
args.stride_dk, args.stride_dk,
args.stride_dv, args.stride_dv,
args.stride_dbias, args.stride_dbias,
...@@ -220,6 +226,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -220,6 +226,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_lsed, args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.batch_stride_q, args.batch_stride_q,
args.batch_stride_k, args.batch_stride_k,
...@@ -228,6 +235,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -228,6 +235,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.batch_stride_randval, args.batch_stride_randval,
args.batch_stride_do, args.batch_stride_do,
args.batch_stride_lsed, args.batch_stride_lsed,
args.batch_stride_dq_acc,
args.batch_stride_dk, args.batch_stride_dk,
args.batch_stride_dv, args.batch_stride_dv,
args.batch_stride_dbias, args.batch_stride_dbias,
...@@ -300,7 +308,9 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) ...@@ -300,7 +308,9 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args.seqstart_k_ptr, args.seqstart_k_ptr,
args.hdim_q, args.hdim_q,
args.stride_q, args.stride_q,
args.stride_dq_acc,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_dq_acc,
args.split_stride_dq_acc); args.split_stride_dq_acc);
} }
else else
...@@ -311,8 +321,11 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) ...@@ -311,8 +321,11 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args.seqlen_k, args.seqlen_k,
args.hdim_q, args.hdim_q,
args.stride_q, args.stride_q,
args.stride_dq_acc,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_dq_acc,
args.batch_stride_q, args.batch_stride_q,
args.batch_stride_dq_acc,
args.split_stride_dq_acc); args.split_stride_dq_acc);
} }
}(); }();
......
...@@ -11,7 +11,7 @@ COMMON_ARGS='-v=1' ...@@ -11,7 +11,7 @@ COMMON_ARGS='-v=1'
set -x set -x
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do for perm in 0 1 ; do
for hdim in 32 64 128 256 ; do for hdim in 64 ; do
for mode in 0 1 ; do for mode in 0 1 ; do
for bias in "n" "e" "a"; do for bias in "n" "e" "a"; do
for dbias in 0 1 ; do for dbias in 0 1 ; do
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment