Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
224a7b02
Commit
224a7b02
authored
Jul 27, 2024
by
danyao12
Browse files
dq_acc stride
parent
99ed2c1a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
1 deletion
+17
-1
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+3
-0
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+13
-0
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
+1
-1
No files found.
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
224a7b02
...
@@ -495,6 +495,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -495,6 +495,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_o
,
stride_o
,
stride_randval
,
stride_randval
,
stride_do
,
stride_do
,
stride_q
,
// stride_dq_acc
stride_dk
,
stride_dk
,
stride_dv
,
stride_dv
,
stride_dbias
,
stride_dbias
,
...
@@ -506,6 +507,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -506,6 +507,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_randval
,
nhead_stride_randval
,
nhead_stride_do
,
nhead_stride_do
,
nhead_stride_lsed
,
nhead_stride_lsed
,
nhead_stride_q
,
// nhead_stride_dq_acc
nhead_stride_dbias
,
nhead_stride_dbias
,
batch_stride_q
,
batch_stride_q
,
batch_stride_k
,
batch_stride_k
,
...
@@ -515,6 +517,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -515,6 +517,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_randval
,
batch_stride_randval
,
batch_stride_do
,
batch_stride_do
,
batch_stride_lsed
,
batch_stride_lsed
,
batch_stride_q
,
// batch_stride_dq_acc
batch_stride_dk
,
batch_stride_dk
,
batch_stride_dv
,
batch_stride_dv
,
batch_stride_dbias
,
batch_stride_dbias
,
...
...
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
224a7b02
...
@@ -98,6 +98,7 @@ struct fmha_bwd_args
...
@@ -98,6 +98,7 @@ struct fmha_bwd_args
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
stride_randval
;
ck_tile
::
index_t
stride_randval
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_dq_acc
;
ck_tile
::
index_t
stride_dk
;
ck_tile
::
index_t
stride_dk
;
ck_tile
::
index_t
stride_dv
;
ck_tile
::
index_t
stride_dv
;
ck_tile
::
index_t
stride_dbias
;
ck_tile
::
index_t
stride_dbias
;
...
@@ -109,6 +110,7 @@ struct fmha_bwd_args
...
@@ -109,6 +110,7 @@ struct fmha_bwd_args
ck_tile
::
index_t
nhead_stride_randval
;
ck_tile
::
index_t
nhead_stride_randval
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_lsed
;
ck_tile
::
index_t
nhead_stride_lsed
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dbias
;
ck_tile
::
index_t
nhead_stride_dbias
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_k
;
...
@@ -118,6 +120,7 @@ struct fmha_bwd_args
...
@@ -118,6 +120,7 @@ struct fmha_bwd_args
ck_tile
::
index_t
batch_stride_randval
;
ck_tile
::
index_t
batch_stride_randval
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
batch_stride_dq_acc
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dv
;
ck_tile
::
index_t
batch_stride_dv
;
ck_tile
::
index_t
batch_stride_dbias
;
ck_tile
::
index_t
batch_stride_dbias
;
...
@@ -164,6 +167,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
...
@@ -164,6 +167,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
stride_bias
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_randval
,
args
.
stride_do
,
args
.
stride_do
,
args
.
stride_dq_acc
,
args
.
stride_dk
,
args
.
stride_dk
,
args
.
stride_dv
,
args
.
stride_dv
,
args
.
stride_dbias
,
args
.
stride_dbias
,
...
@@ -174,6 +178,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
...
@@ -174,6 +178,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
nhead_stride_randval
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_do
,
args
.
nhead_stride_do
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dbias
,
args
.
nhead_stride_dbias
,
args
.
batch_stride_lsed
,
args
.
batch_stride_lsed
,
args
.
split_stride_dq_acc
,
args
.
split_stride_dq_acc
,
...
@@ -210,6 +215,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
...
@@ -210,6 +215,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
stride_bias
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_randval
,
args
.
stride_do
,
args
.
stride_do
,
args
.
stride_dq_acc
,
args
.
stride_dk
,
args
.
stride_dk
,
args
.
stride_dv
,
args
.
stride_dv
,
args
.
stride_dbias
,
args
.
stride_dbias
,
...
@@ -220,6 +226,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
...
@@ -220,6 +226,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
nhead_stride_randval
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_do
,
args
.
nhead_stride_do
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dbias
,
args
.
nhead_stride_dbias
,
args
.
batch_stride_q
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_k
,
...
@@ -228,6 +235,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
...
@@ -228,6 +235,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
batch_stride_randval
,
args
.
batch_stride_randval
,
args
.
batch_stride_do
,
args
.
batch_stride_do
,
args
.
batch_stride_lsed
,
args
.
batch_stride_lsed
,
args
.
batch_stride_dq_acc
,
args
.
batch_stride_dk
,
args
.
batch_stride_dk
,
args
.
batch_stride_dv
,
args
.
batch_stride_dv
,
args
.
batch_stride_dbias
,
args
.
batch_stride_dbias
,
...
@@ -300,7 +308,9 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
...
@@ -300,7 +308,9 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args
.
seqstart_k_ptr
,
args
.
seqstart_k_ptr
,
args
.
hdim_q
,
args
.
hdim_q
,
args
.
stride_q
,
args
.
stride_q
,
args
.
stride_dq_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_dq_acc
,
args
.
split_stride_dq_acc
);
args
.
split_stride_dq_acc
);
}
}
else
else
...
@@ -311,8 +321,11 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
...
@@ -311,8 +321,11 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args
.
seqlen_k
,
args
.
seqlen_k
,
args
.
hdim_q
,
args
.
hdim_q
,
args
.
stride_q
,
args
.
stride_q
,
args
.
stride_dq_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_dq_acc
,
args
.
batch_stride_q
,
args
.
batch_stride_q
,
args
.
batch_stride_dq_acc
,
args
.
split_stride_dq_acc
);
args
.
split_stride_dq_acc
);
}
}
}();
}();
...
...
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
View file @
224a7b02
...
@@ -11,7 +11,7 @@ COMMON_ARGS='-v=1'
...
@@ -11,7 +11,7 @@ COMMON_ARGS='-v=1'
set
-x
set
-x
for
prec
in
"fp16"
"bf16"
;
do
for
prec
in
"fp16"
"bf16"
;
do
for
perm
in
0 1
;
do
for
perm
in
0 1
;
do
for
hdim
in
32 64 128 25
6
;
do
for
hdim
in
6
4
;
do
for
mode
in
0 1
;
do
for
mode
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
dbias
in
0 1
;
do
for
dbias
in
0 1
;
do
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment