Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9b4c780a
Commit
9b4c780a
authored
Jul 14, 2023
by
danyao12
Browse files
recovery kloop v1
parent
fb445cb6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
9 deletions
+19
-9
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
...u/grid/gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
+19
-9
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
View file @
9b4c780a
...
@@ -1945,10 +1945,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
...
@@ -1945,10 +1945,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
bool
masked_flag
=
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n_global
);
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n_global
))
s_element_op
(
s_slash_p_thread_buf
(
i
),
{
masked_flag
?
-
ck
::
NumericLimits
<
float
>::
Infinity
()
s_slash_p_thread_buf
(
i
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
:
s_slash_p_thread_buf
[
i
]);
}
else
{
s_element_op
(
s_slash_p_thread_buf
(
i
),
s_slash_p_thread_buf
[
i
]);
}
});
});
}
}
else
else
...
@@ -2011,11 +2015,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
...
@@ -2011,11 +2015,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
constexpr
auto
m
=
constexpr
auto
m
=
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I0
];
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I0
];
// dS and P has same thread buf layout
// dS and P has same thread buf layout
bool
undropped_flag
=
s_slash_p_thread_buf
[
i
]
>=
0
;
if
(
s_slash_p_thread_buf
[
i
]
>=
0
)
sgrad_thread_buf
(
i
)
=
{
s_slash_p_thread_buf
[
i
]
*
sgrad_thread_buf
(
i
)
=
(
undropped_flag
?
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}])
s_slash_p_thread_buf
[
i
]
*
:
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
}
else
{
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}];
}
});
});
// gemm dQ
// gemm dQ
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment