Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
469509ba
Commit
469509ba
authored
Nov 15, 2024
by
Ye Wang
Browse files
debug 11HS dbias for TE
parent
8ef8a994
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
28 deletions
+30
-28
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+2
-2
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+28
-26
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
469509ba
...
...
@@ -501,7 +501,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
if
receipt
==
3
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
bias
in
[
'no'
,
'alibi'
]
cond
&=
bias
in
[
'no'
,
'bias'
,
'alibi'
]
cond
&=
dpad
==
dvpad
cond
&=
deterministic
==
"f"
if
not
cond
:
...
...
@@ -801,4 +801,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_
,
kernels
=
get_bwd_dq_dk_dv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_BWD_API_FILENAME
)
+
"
\n
"
)
\ No newline at end of file
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_BWD_API_FILENAME
)
+
"
\n
"
)
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
469509ba
...
...
@@ -321,7 +321,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
get_lengths
(
o_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_v
));
ck_tile
::
HostTensor
<
BiasGradDataType
>
dbias_host
(
use_dbias
?
get_lengths
(
i_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
max_seqlen_k
)
?
get_lengths
(
i_perm
,
1
,
1
,
shape_seqlen_q
,
max_seqlen_k
)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
ck_tile
::
HostTensor
<
AccDataType
>
dq_acc_host
(
i_perm
...
...
@@ -454,7 +454,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
stride_do
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
);
const
ck_tile
::
index_t
stride_dk
=
(
i_perm
?
hdim_q
:
nhead
*
hdim_q
);
const
ck_tile
::
index_t
stride_dv
=
(
i_perm
?
hdim_v
:
nhead
*
hdim_v
);
const
ck_tile
::
index_t
stride_dbias
=
(
i_perm
?
max_seqlen_k
:
nhead
*
max_seqlen_k
);
const
ck_tile
::
index_t
stride_dbias
=
(
max_seqlen_k
);
// setup nhead_stride_* arguments
const
ck_tile
::
index_t
nhead_stride_q
=
(
i_perm
?
shape_seqlen_q
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_k
=
(
i_perm
?
shape_seqlen_k
*
hdim_q
:
hdim_q
);
...
...
@@ -464,8 +464,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
nhead_stride_randval
=
(
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_do
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_lsed
=
shape_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_dbias
=
(
i_perm
?
shape_seqlen_q
*
max_seqlen_k
:
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_dbias
=
0
;
// setup batch_stride_* arguments
const
ck_tile
::
index_t
batch_stride_q
=
(
nhead
*
shape_seqlen_q
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_k
=
(
nhead_k
*
shape_seqlen_k
*
hdim_q
);
...
...
@@ -477,7 +476,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
batch_stride_lsed
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_dk
=
(
nhead
*
shape_seqlen_k
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_dv
=
(
nhead
*
shape_seqlen_k
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_dbias
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
)
;
const
ck_tile
::
index_t
batch_stride_dbias
=
0
;
const
ck_tile
::
index_t
split_stride_dq_acc
=
(
shape_batch
*
nhead
*
shape_seqlen_q
*
hdim_q
);
...
...
@@ -813,6 +812,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
dv_buf
.
FromDevice
(
dv_host
.
data
());
dbias_buf
.
FromDevice
(
dbias_host
.
data
());
ck_tile
::
HostTensor
<
BiasGradDataType
>
dbias_host_result
(
{
1
,
1
,
max_seqlen_q
,
max_seqlen_k
});
// dbias_g_m_n
if
(
use_dbias
){
if
(
i_perm
)
dbias_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
dbias_host
(
idx
[
0
],
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
else
dbias_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
dbias_host
(
idx
[
0
],
idx
[
2
],
idx
[
1
],
idx
[
3
]);
});
}
ck_tile
::
HostTensor
<
BiasGradDataType
>
dbias_host_ref
(
{
1
,
1
,
max_seqlen_q
,
max_seqlen_k
});
// dbias_g_m_n
dbias_host_ref
.
SetZero
();
for
(
ck_tile
::
index_t
wb
=
0
;
wb
<
batch
;
++
wb
)
{
const
ck_tile
::
index_t
real_seqlen_q
=
seqstart_q_host
[
wb
+
1
]
-
seqstart_q_host
[
wb
];
...
...
@@ -830,8 +840,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
nhead
,
real_seqlen_q
,
real_seqlen_k
});
// ds_g_m_n low precision
ck_tile
::
HostTensor
<
AccDataType
>
dp_hp_host_ref
(
{
nhead
,
real_seqlen_q
,
real_seqlen_k
});
// dp_g_m_n high precision
ck_tile
::
HostTensor
<
BiasGradDataType
>
dbias_host_ref
(
{
nhead
,
real_seqlen_q
,
real_seqlen_k
});
// dbias_g_m_n
ck_tile
::
HostTensor
<
QGradDataType
>
dq_host_ref
({
nhead
,
real_seqlen_q
,
hdim_q
});
// dq_g_m_k
ck_tile
::
HostTensor
<
KGradDataType
>
dk_host_ref
({
nhead
,
real_seqlen_k
,
hdim_q
});
// dk_g_n_k
ck_tile
::
HostTensor
<
VGradDataType
>
dv_host_ref
({
nhead
,
real_seqlen_k
,
hdim_v
});
// dv_g_n_o
...
...
@@ -870,7 +878,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
use_dbias
)
{
ds_hp_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
dbias_host_ref
(
idx
)
=
ck_tile
::
type_convert
<
BiasGradDataType
>
(
self
(
idx
));
dbias_host_ref
(
0
,
0
,
idx
[
1
],
idx
[
2
]
)
+
=
ck_tile
::
type_convert
<
BiasGradDataType
>
(
self
(
idx
));
});
}
...
...
@@ -912,8 +920,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
nhead
,
real_seqlen_k
,
hdim_q
});
// dk_g_n_k
ck_tile
::
HostTensor
<
VGradDataType
>
dv_host_result
(
{
nhead
,
real_seqlen_k
,
hdim_v
});
// dv_g_n_o
ck_tile
::
HostTensor
<
BiasGradDataType
>
dbias_host_result
(
{
nhead
,
real_seqlen_q
,
real_seqlen_k
});
// dbias_g_m_n
// clang-format off
// permute
...
...
@@ -926,11 +932,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
i_perm
)
dv_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
dv_host
(
b
,
idx
[
0
],
idx
[
1
]
+
key_offset
,
idx
[
2
]);
});
else
dv_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
dv_host
(
b
,
idx
[
1
]
+
key_offset
,
idx
[
0
],
idx
[
2
]);
});
if
(
use_dbias
)
{
if
(
i_perm
)
dbias_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
dbias_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
]);
});
else
dbias_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
dbias_host
(
b
,
idx
[
1
]
+
query_offset
,
idx
[
0
],
idx
[
2
]);
});
}
// clang-format on
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
hdim_q
,
hdim_v
);
...
...
@@ -950,17 +951,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
rtol
,
atol
);
bool
dbias_cur_pass
=
true
;
if
(
use_dbias
)
{
dbias_cur_pass
=
ck_tile
::
check_err
(
dbias_host_result
,
dbias_host_ref
,
std
::
string
(
"Error: BiasGrad Incorrect results!"
),
rtol
,
atol
);
}
pass
&=
(
dq_cur_pass
&
dk_cur_pass
&
dv_cur_pass
&
dbias_cur_pass
);
if
(
!
(
dq_cur_pass
&
dk_cur_pass
&
dv_cur_pass
&
dbias_cur_pass
))
pass
&=
(
dq_cur_pass
&
dk_cur_pass
&
dv_cur_pass
);
if
(
!
(
dq_cur_pass
&
dk_cur_pass
&
dv_cur_pass
))
{
std
::
cerr
<<
"mismatch found at batch: "
<<
wb
<<
std
::
endl
<<
"
\t
seqlen_q: "
<<
real_seqlen_q
<<
std
::
endl
...
...
@@ -972,6 +964,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
bool
dbias_cur_pass
=
true
;
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
hdim_q
,
hdim_v
);
if
(
use_dbias
){
dbias_cur_pass
=
ck_tile
::
check_err
(
dbias_host_result
,
dbias_host_ref
,
std
::
string
(
"Error: BiasGrad Incorrect results!"
),
rtol
,
atol
);
}
pass
&=
dbias_cur_pass
;
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
return
pass
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment