Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
41631747
Commit
41631747
authored
Nov 14, 2024
by
Ye Wang
Browse files
reproducer for dbias in shape 1hss
parent
0c9012fb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
31 deletions
+35
-31
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+2
-2
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+33
-29
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
41631747
...
@@ -501,7 +501,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -501,7 +501,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
continue
if
receipt
==
3
:
if
receipt
==
3
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
bias
in
[
'no'
,
'alibi'
]
cond
&=
bias
in
[
'no'
,
'bias'
,
'alibi'
]
cond
&=
dpad
==
dvpad
cond
&=
dpad
==
dvpad
cond
&=
deterministic
==
"f"
cond
&=
deterministic
==
"f"
if
not
cond
:
if
not
cond
:
...
...
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
41631747
...
@@ -295,7 +295,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -295,7 +295,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_v
));
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_v
));
ck_tile
::
HostTensor
<
BiasDataType
>
bias_host
(
ck_tile
::
HostTensor
<
BiasDataType
>
bias_host
(
bias
.
type
==
bias_enum
::
elementwise_bias
bias
.
type
==
bias_enum
::
elementwise_bias
?
get_lengths
(
i_perm
,
1
,
1
,
shape_seqlen_q
,
max_seqlen_k
)
?
get_lengths
(
i_perm
,
1
,
nhead
,
shape_seqlen_q
,
max_seqlen_k
)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
ck_tile
::
HostTensor
<
AccDataType
>
alibi_slope_host
(
ck_tile
::
HostTensor
<
AccDataType
>
alibi_slope_host
(
bias
.
type
==
bias_enum
::
alibi
bias
.
type
==
bias_enum
::
alibi
...
@@ -321,7 +321,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -321,7 +321,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
get_lengths
(
o_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_v
));
get_lengths
(
o_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_v
));
ck_tile
::
HostTensor
<
BiasGradDataType
>
dbias_host
(
ck_tile
::
HostTensor
<
BiasGradDataType
>
dbias_host
(
use_dbias
use_dbias
?
get_lengths
(
i_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
max_seqlen_k
)
?
get_lengths
(
i_perm
,
1
,
nhead
,
shape_seqlen_q
,
max_seqlen_k
)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
ck_tile
::
HostTensor
<
AccDataType
>
dq_acc_host
(
ck_tile
::
HostTensor
<
AccDataType
>
dq_acc_host
(
i_perm
i_perm
...
@@ -448,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -448,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
stride_q
=
(
i_perm
?
hdim_q
:
nhead
*
hdim_q
);
const
ck_tile
::
index_t
stride_q
=
(
i_perm
?
hdim_q
:
nhead
*
hdim_q
);
const
ck_tile
::
index_t
stride_k
=
(
i_perm
?
hdim_q
:
nhead_k
*
hdim_q
);
const
ck_tile
::
index_t
stride_k
=
(
i_perm
?
hdim_q
:
nhead_k
*
hdim_q
);
const
ck_tile
::
index_t
stride_v
=
(
i_perm
?
hdim_v
:
nhead_k
*
hdim_v
);
const
ck_tile
::
index_t
stride_v
=
(
i_perm
?
hdim_v
:
nhead_k
*
hdim_v
);
const
ck_tile
::
index_t
stride_bias
=
(
max
_seqlen_k
)
;
const
ck_tile
::
index_t
stride_bias
=
i_perm
?
shape_seqlen_k
:
nhead
*
shape
_seqlen_k
;
const
ck_tile
::
index_t
stride_o
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
);
const
ck_tile
::
index_t
stride_o
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
);
const
ck_tile
::
index_t
stride_randval
=
(
max_seqlen_k
);
const
ck_tile
::
index_t
stride_randval
=
(
max_seqlen_k
);
const
ck_tile
::
index_t
stride_do
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
);
const
ck_tile
::
index_t
stride_do
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
);
...
@@ -459,7 +459,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -459,7 +459,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
nhead_stride_q
=
(
i_perm
?
shape_seqlen_q
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_q
=
(
i_perm
?
shape_seqlen_q
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_k
=
(
i_perm
?
shape_seqlen_k
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_k
=
(
i_perm
?
shape_seqlen_k
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_v
=
(
i_perm
?
shape_seqlen_k
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_v
=
(
i_perm
?
shape_seqlen_k
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_bias
=
0
;
const
ck_tile
::
index_t
nhead_stride_bias
=
i_perm
?
shape_seqlen_q
*
shape_seqlen_k
:
shape_seqlen_k
;
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_randval
=
(
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_randval
=
(
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_do
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_do
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
...
@@ -477,7 +477,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -477,7 +477,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
batch_stride_lsed
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lsed
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_dk
=
(
nhead
*
shape_seqlen_k
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_dk
=
(
nhead
*
shape_seqlen_k
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_dv
=
(
nhead
*
shape_seqlen_k
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_dv
=
(
nhead
*
shape_seqlen_k
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_dbias
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
)
;
const
ck_tile
::
index_t
batch_stride_dbias
=
0
;
const
ck_tile
::
index_t
split_stride_dq_acc
=
const
ck_tile
::
index_t
split_stride_dq_acc
=
(
shape_batch
*
nhead
*
shape_seqlen_q
*
hdim_q
);
(
shape_batch
*
nhead
*
shape_seqlen_q
*
hdim_q
);
...
@@ -657,12 +657,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -657,12 +657,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
bias
.
type
==
bias_enum
::
elementwise_bias
)
if
(
bias
.
type
==
bias_enum
::
elementwise_bias
)
{
{
// elementwise bias
// elementwise bias
ck_tile
::
HostTensor
<
BiasDataType
>
bias_host_ref
({
1
,
real_seqlen_q
,
real_seqlen_k
});
ck_tile
::
HostTensor
<
BiasDataType
>
bias_host_ref
({
nhead
,
real_seqlen_q
,
real_seqlen_k
});
// clang-format off
// clang-format off
if
(
i_perm
)
if
(
i_perm
)
bias_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
bias_host
(
0
,
0
,
i
[
1
]
+
query_offset
,
i
[
2
]);
});
bias_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
bias_host
(
0
,
i
[
0
]
,
i
[
1
]
+
query_offset
,
i
[
2
]);
});
else
else
bias_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
bias_host
(
0
,
i
[
1
]
+
query_offset
,
0
,
i
[
2
]);
});
bias_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
bias_host
(
0
,
i
[
1
]
+
query_offset
,
i
[
0
]
,
i
[
2
]);
});
// clang-format on
// clang-format on
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
...
@@ -813,6 +813,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -813,6 +813,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
dv_buf
.
FromDevice
(
dv_host
.
data
());
dv_buf
.
FromDevice
(
dv_host
.
data
());
dbias_buf
.
FromDevice
(
dbias_host
.
data
());
dbias_buf
.
FromDevice
(
dbias_host
.
data
());
ck_tile
::
HostTensor
<
BiasGradDataType
>
dbias_host_result
(
{
nhead
,
max_seqlen_q
,
max_seqlen_k
});
// dbias_g_m_n
if
(
use_dbias
){
if
(
i_perm
)
dbias_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
dbias_host
(
0
,
idx
[
0
],
idx
[
1
],
idx
[
2
]);
});
else
dbias_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
dbias_host
(
0
,
idx
[
1
],
idx
[
0
],
idx
[
2
]);
});
}
ck_tile
::
HostTensor
<
BiasGradDataType
>
dbias_host_ref
(
{
nhead
,
max_seqlen_q
,
max_seqlen_k
});
// dbias_g_m_n
dbias_host_ref
.
SetZero
();
for
(
ck_tile
::
index_t
wb
=
0
;
wb
<
batch
;
++
wb
)
for
(
ck_tile
::
index_t
wb
=
0
;
wb
<
batch
;
++
wb
)
{
{
const
ck_tile
::
index_t
real_seqlen_q
=
seqstart_q_host
[
wb
+
1
]
-
seqstart_q_host
[
wb
];
const
ck_tile
::
index_t
real_seqlen_q
=
seqstart_q_host
[
wb
+
1
]
-
seqstart_q_host
[
wb
];
...
@@ -830,8 +840,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -830,8 +840,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
nhead
,
real_seqlen_q
,
real_seqlen_k
});
// ds_g_m_n low precision
{
nhead
,
real_seqlen_q
,
real_seqlen_k
});
// ds_g_m_n low precision
ck_tile
::
HostTensor
<
AccDataType
>
dp_hp_host_ref
(
ck_tile
::
HostTensor
<
AccDataType
>
dp_hp_host_ref
(
{
nhead
,
real_seqlen_q
,
real_seqlen_k
});
// dp_g_m_n high precision
{
nhead
,
real_seqlen_q
,
real_seqlen_k
});
// dp_g_m_n high precision
ck_tile
::
HostTensor
<
BiasGradDataType
>
dbias_host_ref
(
{
nhead
,
real_seqlen_q
,
real_seqlen_k
});
// dbias_g_m_n
ck_tile
::
HostTensor
<
QGradDataType
>
dq_host_ref
({
nhead
,
real_seqlen_q
,
hdim_q
});
// dq_g_m_k
ck_tile
::
HostTensor
<
QGradDataType
>
dq_host_ref
({
nhead
,
real_seqlen_q
,
hdim_q
});
// dq_g_m_k
ck_tile
::
HostTensor
<
KGradDataType
>
dk_host_ref
({
nhead
,
real_seqlen_k
,
hdim_q
});
// dk_g_n_k
ck_tile
::
HostTensor
<
KGradDataType
>
dk_host_ref
({
nhead
,
real_seqlen_k
,
hdim_q
});
// dk_g_n_k
ck_tile
::
HostTensor
<
VGradDataType
>
dv_host_ref
({
nhead
,
real_seqlen_k
,
hdim_v
});
// dv_g_n_o
ck_tile
::
HostTensor
<
VGradDataType
>
dv_host_ref
({
nhead
,
real_seqlen_k
,
hdim_v
});
// dv_g_n_o
...
@@ -870,7 +878,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -870,7 +878,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
use_dbias
)
if
(
use_dbias
)
{
{
ds_hp_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ds_hp_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
dbias_host_ref
(
idx
)
=
ck_tile
::
type_convert
<
BiasGradDataType
>
(
self
(
idx
));
dbias_host_ref
(
idx
)
+
=
ck_tile
::
type_convert
<
BiasGradDataType
>
(
self
(
idx
));
});
});
}
}
...
@@ -912,8 +920,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -912,8 +920,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
nhead
,
real_seqlen_k
,
hdim_q
});
// dk_g_n_k
{
nhead
,
real_seqlen_k
,
hdim_q
});
// dk_g_n_k
ck_tile
::
HostTensor
<
VGradDataType
>
dv_host_result
(
ck_tile
::
HostTensor
<
VGradDataType
>
dv_host_result
(
{
nhead
,
real_seqlen_k
,
hdim_v
});
// dv_g_n_o
{
nhead
,
real_seqlen_k
,
hdim_v
});
// dv_g_n_o
ck_tile
::
HostTensor
<
BiasGradDataType
>
dbias_host_result
(
{
nhead
,
real_seqlen_q
,
real_seqlen_k
});
// dbias_g_m_n
// clang-format off
// clang-format off
// permute
// permute
...
@@ -926,11 +932,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -926,11 +932,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
i_perm
)
dv_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
dv_host
(
b
,
idx
[
0
],
idx
[
1
]
+
key_offset
,
idx
[
2
]);
});
if
(
i_perm
)
dv_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
dv_host
(
b
,
idx
[
0
],
idx
[
1
]
+
key_offset
,
idx
[
2
]);
});
else
dv_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
dv_host
(
b
,
idx
[
1
]
+
key_offset
,
idx
[
0
],
idx
[
2
]);
});
else
dv_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
dv_host
(
b
,
idx
[
1
]
+
key_offset
,
idx
[
0
],
idx
[
2
]);
});
if
(
use_dbias
)
{
if
(
i_perm
)
dbias_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
dbias_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
]);
});
else
dbias_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
dbias_host
(
b
,
idx
[
1
]
+
query_offset
,
idx
[
0
],
idx
[
2
]);
});
}
// clang-format on
// clang-format on
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
hdim_q
,
hdim_v
);
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
hdim_q
,
hdim_v
);
...
@@ -950,17 +951,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -950,17 +951,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
rtol
,
rtol
,
atol
);
atol
);
bool
dbias_cur_pass
=
true
;
pass
&=
(
dq_cur_pass
&
dk_cur_pass
&
dv_cur_pass
);
if
(
use_dbias
)
if
(
!
(
dq_cur_pass
&
dk_cur_pass
&
dv_cur_pass
))
{
dbias_cur_pass
=
ck_tile
::
check_err
(
dbias_host_result
,
dbias_host_ref
,
std
::
string
(
"Error: BiasGrad Incorrect results!"
),
rtol
,
atol
);
}
pass
&=
(
dq_cur_pass
&
dk_cur_pass
&
dv_cur_pass
&
dbias_cur_pass
);
if
(
!
(
dq_cur_pass
&
dk_cur_pass
&
dv_cur_pass
&
dbias_cur_pass
))
{
{
std
::
cerr
<<
"mismatch found at batch: "
<<
wb
<<
std
::
endl
std
::
cerr
<<
"mismatch found at batch: "
<<
wb
<<
std
::
endl
<<
"
\t
seqlen_q: "
<<
real_seqlen_q
<<
std
::
endl
<<
"
\t
seqlen_q: "
<<
real_seqlen_q
<<
std
::
endl
...
@@ -971,6 +963,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -971,6 +963,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
break
;
break
;
}
}
}
}
bool
dbias_cur_pass
=
true
;
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
hdim_q
,
hdim_v
);
if
(
use_dbias
)
{
dbias_cur_pass
=
ck_tile
::
check_err
(
dbias_host_result
,
dbias_host_ref
,
std
::
string
(
"Error: BiasGrad Incorrect results!"
),
rtol
,
atol
);
}
pass
&=
dbias_cur_pass
;
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment