Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5d2a5a11
Commit
5d2a5a11
authored
Jul 30, 2024
by
danyao12
Browse files
more strides for fa integration
parent
fd28454d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
7 deletions
+31
-7
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+5
-0
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+14
-5
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+12
-2
No files found.
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
5d2a5a11
...
...
@@ -496,6 +496,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_randval
,
stride_do
,
stride_q
,
// stride_dq_acc
stride_q
,
// stride_dq
stride_dk
,
stride_dv
,
stride_dbias
,
...
...
@@ -508,6 +509,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_do
,
nhead_stride_lsed
,
nhead_stride_q
,
// nhead_stride_dq_acc
nhead_stride_q
,
// nhead_stride_dq
nhead_stride_k
,
// nhead_stride_dk
nhead_stride_v
,
// nhead_stride_dv
nhead_stride_dbias
,
batch_stride_q
,
batch_stride_k
,
...
...
@@ -518,6 +522,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_do
,
batch_stride_lsed
,
batch_stride_q
,
// batch_stride_dq_acc
batch_stride_q
,
// batch_stride_dq
batch_stride_dk
,
batch_stride_dv
,
batch_stride_dbias
,
...
...
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
5d2a5a11
...
...
@@ -99,6 +99,7 @@ struct fmha_bwd_args
ck_tile
::
index_t
stride_randval
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_dq_acc
;
ck_tile
::
index_t
stride_dq
;
ck_tile
::
index_t
stride_dk
;
ck_tile
::
index_t
stride_dv
;
ck_tile
::
index_t
stride_dbias
;
...
...
@@ -111,6 +112,9 @@ struct fmha_bwd_args
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_lsed
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dq
;
ck_tile
::
index_t
nhead_stride_dk
;
ck_tile
::
index_t
nhead_stride_dv
;
ck_tile
::
index_t
nhead_stride_dbias
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
...
...
@@ -121,6 +125,7 @@ struct fmha_bwd_args
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
batch_stride_dq_acc
;
ck_tile
::
index_t
batch_stride_dq
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dv
;
ck_tile
::
index_t
batch_stride_dbias
;
...
...
@@ -179,6 +184,8 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
nhead_stride_do
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dbias
,
args
.
batch_stride_lsed
,
args
.
split_stride_dq_acc
,
...
...
@@ -227,6 +234,8 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args
.
nhead_stride_do
,
args
.
nhead_stride_lsed
,
args
.
nhead_stride_dq_acc
,
args
.
nhead_stride_dk
,
args
.
nhead_stride_dv
,
args
.
nhead_stride_dbias
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
...
...
@@ -307,9 +316,9 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
hdim_q
,
args
.
stride_q
,
args
.
stride_
d
q
,
args
.
stride_dq_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_
d
q
,
args
.
nhead_stride_dq_acc
,
args
.
split_stride_dq_acc
);
}
...
...
@@ -320,11 +329,11 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
hdim_q
,
args
.
stride_q
,
args
.
stride_
d
q
,
args
.
stride_dq_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_
d
q
,
args
.
nhead_stride_dq_acc
,
args
.
batch_stride_q
,
args
.
batch_stride_
d
q
,
args
.
batch_stride_dq_acc
,
args
.
split_stride_dq_acc
);
}
...
...
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
5d2a5a11
...
...
@@ -147,6 +147,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_lsed
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dk
;
ck_tile
::
index_t
nhead_stride_dv
;
ck_tile
::
index_t
batch_stride_lsed
;
};
...
...
@@ -301,6 +303,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
...
...
@@ -350,6 +354,8 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_do
,
nhead_stride_lsed
,
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dv
,
batch_stride_lsed
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
...
...
@@ -452,6 +458,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
split_stride_dq_acc
,
...
...
@@ -491,6 +499,8 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_do
,
nhead_stride_lsed
,
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dv
,
batch_stride_lsed
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
...
...
@@ -687,10 +697,10 @@ struct FmhaBwdDQDKDVKernel
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_do
+
batch_offset_do
;
KGradDataType
*
dk_ptr
=
reinterpret_cast
<
KGradDataType
*>
(
kargs
.
dk_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_k
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_
d
k
+
batch_offset_dk
;
VGradDataType
*
dv_ptr
=
reinterpret_cast
<
VGradDataType
*>
(
kargs
.
dv_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_v
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_
d
v
+
batch_offset_dv
;
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment