Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
98df59c6
Commit
98df59c6
authored
Aug 16, 2023
by
letaoqin
Browse files
bias data type convert
parent
64c9f790
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
6 deletions
+7
-6
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
...ten_bias/run_batched_multihead_attention_bias_forward.inc
+2
-2
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward.inc
...ten_bias/run_grouped_multihead_attention_bias_forward.inc
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+3
-2
No files found.
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
View file @
98df59c6
...
@@ -322,12 +322,12 @@ int run(int argc, char* argv[])
...
@@ -322,12 +322,12 @@ int run(int argc, char* argv[])
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// bias
// bias
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
d_g_m_n
(
idx
);
});
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
ck
::
type_convert
<
AccDataType
>
(
d_g_m_n
(
idx
)
)
;
});
// masking
// masking
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
self
(
idx
)
=
-
ck
::
NumericLimits
<
AccDataType
>::
Infinity
();
});
});
// softmax
// softmax
...
...
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward.inc
View file @
98df59c6
...
@@ -396,12 +396,12 @@ int run(int argc, char* argv[])
...
@@ -396,12 +396,12 @@ int run(int argc, char* argv[])
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// bias
// bias
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
d_g_m_n
(
idx
);
});
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
ck
::
type_convert
<
AccDataType
>
(
d_g_m_n
(
idx
)
)
;
});
// masking
// masking
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
self
(
idx
)
=
-
ck
::
NumericLimits
<
AccDataType
>::
Infinity
();
});
});
// softmax
// softmax
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
98df59c6
...
@@ -1307,8 +1307,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1307,8 +1307,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_thread_buf
);
d0_thread_buf
);
// acc add bias
// acc add bias
static_for
<
0
,
m0
*
n0
*
n2
*
n4
,
1
>
{}(
static_for
<
0
,
m0
*
n0
*
n2
*
n4
,
1
>
{}([
&
](
auto
i
)
{
[
&
](
auto
i
)
{
acc_thread_buf
(
i
)
+=
d0_thread_buf
[
i
];
});
acc_thread_buf
(
i
)
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
});
d0_threadwise_copy
.
MoveSrcSliceWindow
(
d0_threadwise_copy
.
MoveSrcSliceWindow
(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment