Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
39fc3d4b
Commit
39fc3d4b
authored
Jul 11, 2024
by
danyao12
Browse files
fix group deterministic bugs
parent
8c967d76
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
10 deletions
+6
-10
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+0
-5
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+1
-1
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+5
-4
No files found.
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
39fc3d4b
...
@@ -132,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -132,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
index_t
hdim_v
=
arg_parser
.
get_int
(
"d_v"
);
ck_tile
::
index_t
hdim_v
=
arg_parser
.
get_int
(
"d_v"
);
if
(
hdim_v
<
0
)
if
(
hdim_v
<
0
)
hdim_v
=
hdim_q
;
hdim_v
=
hdim_q
;
if
(
hdim_q
%
2
!=
0
||
hdim_v
%
2
!=
0
)
{
std
::
cerr
<<
"FMHA Bwd kernel currently only supports even headdim"
<<
std
::
endl
;
return
false
;
}
bool
i_perm
=
arg_parser
.
get_bool
(
"iperm"
);
// if true, will be batch * nhead * seqlen * hdim
bool
i_perm
=
arg_parser
.
get_bool
(
"iperm"
);
// if true, will be batch * nhead * seqlen * hdim
bool
o_perm
=
arg_parser
.
get_bool
(
"operm"
);
// if false, will be batch * seqlen * nhead * hdim
bool
o_perm
=
arg_parser
.
get_bool
(
"operm"
);
// if false, will be batch * seqlen * nhead * hdim
...
...
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
39fc3d4b
...
@@ -297,7 +297,7 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
...
@@ -297,7 +297,7 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
return
FmhaBwdConvertQGradKernel
::
MakeKargs
(
args
.
dq_acc_ptr
,
return
FmhaBwdConvertQGradKernel
::
MakeKargs
(
args
.
dq_acc_ptr
,
args
.
dq_ptr
,
args
.
dq_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_q_ptr
,
args
.
seq
len
_k_ptr
,
args
.
seq
start
_k_ptr
,
args
.
hdim_q
,
args
.
hdim_q
,
args
.
stride_q
,
args
.
stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
...
...
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
39fc3d4b
...
@@ -1375,7 +1375,7 @@ struct FmhaBwdConvertQGradKernel
...
@@ -1375,7 +1375,7 @@ struct FmhaBwdConvertQGradKernel
FmhaBwdConvertQGradEmptyKargs
<
0
>>
FmhaBwdConvertQGradEmptyKargs
<
0
>>
{
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seq
len
_k_ptr
;
const
int32_t
*
seq
start
_k_ptr
;
};
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
...
@@ -1411,7 +1411,7 @@ struct FmhaBwdConvertQGradKernel
...
@@ -1411,7 +1411,7 @@ struct FmhaBwdConvertQGradKernel
MakeKargs
(
const
void
*
dq_acc_ptr
,
MakeKargs
(
const
void
*
dq_acc_ptr
,
void
*
dq_ptr
,
void
*
dq_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seq
len
_k_ptr
,
const
void
*
seq
start
_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
nhead_stride_dq
,
...
@@ -1426,7 +1426,7 @@ struct FmhaBwdConvertQGradKernel
...
@@ -1426,7 +1426,7 @@ struct FmhaBwdConvertQGradKernel
nhead_stride_dq
},
nhead_stride_dq
},
{},
{},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seq
len
_k_ptr
)};
reinterpret_cast
<
const
int32_t
*>
(
seq
start
_k_ptr
)};
if
constexpr
(
kIsDeterministic
)
if
constexpr
(
kIsDeterministic
)
{
{
...
@@ -1463,7 +1463,8 @@ struct FmhaBwdConvertQGradKernel
...
@@ -1463,7 +1463,8 @@ struct FmhaBwdConvertQGradKernel
// get real # queries & # keys under group mode
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
kargs
.
seqlen_k
=
kargs
.
seqlen_k_ptr
[
i_batch
];
const
auto
adjusted_seqstart_k_ptr
=
kargs
.
seqstart_k_ptr
+
i_batch
;
kargs
.
seqlen_k
=
adjusted_seqstart_k_ptr
[
1
]
-
adjusted_seqstart_k_ptr
[
0
];
// # of required blocks is different in each groups, terminate unnecessary blocks
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
// earlier
if
(
kargs
.
seqlen_q
<=
i_m0
)
if
(
kargs
.
seqlen_q
<=
i_m0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment