Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2464edd0
Commit
2464edd0
authored
Aug 29, 2023
by
letaoqin
Browse files
add comments for grouped host code
parent
28459058
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
2 deletions
+5
-2
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
...ten_bias/grouped_multihead_attention_bias_backward_v2.cpp
+5
-2
No files found.
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
View file @
2464edd0
...
@@ -767,12 +767,13 @@ int run(int argc, char* argv[])
...
@@ -767,12 +767,13 @@ int run(int argc, char* argv[])
auto
v_g_o_n
=
v_g_n_os
[
i
].
Transpose
({
0
,
2
,
1
});
auto
v_g_o_n
=
v_g_n_os
[
i
].
Transpose
({
0
,
2
,
1
});
ref_gemm0_grad_invoker
.
Run
(
RefGemm0GradArg
{
ref_gemm0_grad_invoker
.
Run
(
RefGemm0GradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_drop_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
ygrad_g_m_o
,
v_g_o_n
,
pgrad_drop_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
// dP = dP_dropout x Z
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
z_g_m_ns
[
i
],
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
float
ygrad_dot_y
=
0
;
float
ygrad_dot_y
=
0
;
for
(
int
o
=
0
;
o
<
O
;
o
++
)
for
(
int
o
=
0
;
o
<
O
;
o
++
)
...
@@ -786,12 +787,14 @@ int run(int argc, char* argv[])
...
@@ -786,12 +787,14 @@ int run(int argc, char* argv[])
ck
::
type_convert
<
AccDataType
>
(
p_g_m_ns
[
i
](
idx_gmn
))
*
ck
::
type_convert
<
AccDataType
>
(
p_g_m_ns
[
i
](
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
});
});
// dV = P_drop^T * dY
auto
p_drop_g_n_m
=
p_drop_g_m_ns
[
i
].
Transpose
({
0
,
2
,
1
});
auto
p_drop_g_n_m
=
p_drop_g_m_ns
[
i
].
Transpose
({
0
,
2
,
1
});
ref_gemm1_grad_invoker
.
Run
(
RefGemm1GradArg
{
ref_gemm1_grad_invoker
.
Run
(
RefGemm1GradArg
{
p_drop_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.0
f
}});
p_drop_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.0
f
}});
// dQ = alpha * dS * K
ref_gemm1_grad_invoker
.
Run
(
RefGemm1GradArg
{
ref_gemm1_grad_invoker
.
Run
(
RefGemm1GradArg
{
sgrad_g_m_n
,
k_g_n_ks
[
i
],
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
sgrad_g_m_n
,
k_g_n_ks
[
i
],
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
// dK = alpha * dS^T * Q
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm1_grad_invoker
.
Run
(
RefGemm1GradArg
{
ref_gemm1_grad_invoker
.
Run
(
RefGemm1GradArg
{
sgrad_g_n_m
,
q_g_m_ks
[
i
],
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
sgrad_g_n_m
,
q_g_m_ks
[
i
],
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment