Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ac3c1563
Commit
ac3c1563
authored
Dec 30, 2022
by
ltqin
Browse files
change if to #if
parent
61f4a7ee
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
6 deletions
+12
-6
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+12
-6
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
ac3c1563
...
@@ -523,13 +523,14 @@ int run(int argc, char* argv[])
...
@@ -523,13 +523,14 @@ int run(int argc, char* argv[])
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
if
(
PRINT_HOST
)
#
if
PRINT_HOST
{
{
std
::
cout
<<
"q_g_m_k ref:
\n
"
<<
q_g_m_k
;
std
::
cout
<<
"q_g_m_k ref:
\n
"
<<
q_g_m_k
;
std
::
cout
<<
"k_g_n_k ref:
\n
"
<<
k_g_n_k
;
std
::
cout
<<
"k_g_n_k ref:
\n
"
<<
k_g_n_k
;
std
::
cout
<<
"v_g_n_o ref:
\n
"
<<
v_g_n_o
;
std
::
cout
<<
"v_g_n_o ref:
\n
"
<<
v_g_n_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
}
}
#endif
// Gradients
// Gradients
auto
ref_gemm_grad
=
ReferenceGemmGradInstance
{};
auto
ref_gemm_grad
=
ReferenceGemmGradInstance
{};
...
@@ -540,13 +541,14 @@ int run(int argc, char* argv[])
...
@@ -540,13 +541,14 @@ int run(int argc, char* argv[])
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
ygrad_g_m_o
,
v_g_o_n
,
pgrad_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
if
(
PRINT_HOST
)
#
if
PRINT_HOST
{
{
std
::
cout
<<
"===== dP = dY * V^T
\n
"
;
std
::
cout
<<
"===== dP = dY * V^T
\n
"
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"v_g_o_n ref:
\n
"
<<
v_g_o_n
;
std
::
cout
<<
"v_g_o_n ref:
\n
"
<<
v_g_o_n
;
std
::
cout
<<
"pgrad_g_m_n ref:
\n
"
<<
pgrad_g_m_n
;
std
::
cout
<<
"pgrad_g_m_n ref:
\n
"
<<
pgrad_g_m_n
;
}
}
#endif
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
...
@@ -559,7 +561,7 @@ int run(int argc, char* argv[])
...
@@ -559,7 +561,7 @@ int run(int argc, char* argv[])
}
}
self
(
idx_gmn
)
=
p_g_m_n
(
idx_gmn
)
*
(
pgrad_g_m_n
(
idx_gmn
)
-
ygrad_dot_y
);
self
(
idx_gmn
)
=
p_g_m_n
(
idx_gmn
)
*
(
pgrad_g_m_n
(
idx_gmn
)
-
ygrad_dot_y
);
});
});
if
(
PRINT_HOST
)
#
if
PRINT_HOST
{
{
std
::
cout
<<
"===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
\n
"
;
std
::
cout
<<
"===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
\n
"
;
std
::
cout
<<
"p_g_m_n ref:
\n
"
<<
p_g_m_n
;
std
::
cout
<<
"p_g_m_n ref:
\n
"
<<
p_g_m_n
;
...
@@ -568,41 +570,45 @@ int run(int argc, char* argv[])
...
@@ -568,41 +570,45 @@ int run(int argc, char* argv[])
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
}
}
#endif
// dV = P^T * dY
// dV = P^T * dY
auto
p_g_n_m
=
p_g_m_n
.
Transpose
({
0
,
2
,
1
});
auto
p_g_n_m
=
p_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
p_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
p_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
if
(
PRINT_HOST
)
#
if
PRINT_HOST
{
{
std
::
cout
<<
"===== dV = P^T * dY
\n
"
;
std
::
cout
<<
"===== dV = P^T * dY
\n
"
;
std
::
cout
<<
"p_g_n_m ref:
\n
"
<<
p_g_n_m
;
std
::
cout
<<
"p_g_n_m ref:
\n
"
<<
p_g_n_m
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"vgrad_g_n_o ref:
\n
"
<<
vgrad_g_n_o
;
std
::
cout
<<
"vgrad_g_n_o ref:
\n
"
<<
vgrad_g_n_o
;
}
}
#endif
// dQ = alpha * dS * K
// dQ = alpha * dS * K
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
sgrad_g_m_n
,
k_g_n_k
,
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
sgrad_g_m_n
,
k_g_n_k
,
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
if
(
PRINT_HOST
)
#
if
PRINT_HOST
{
{
std
::
cout
<<
"===== dQ = alpha * dS * K
\n
"
;
std
::
cout
<<
"===== dQ = alpha * dS * K
\n
"
;
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
std
::
cout
<<
"k_g_n_k ref:
\n
"
<<
k_g_n_k
;
std
::
cout
<<
"k_g_n_k ref:
\n
"
<<
k_g_n_k
;
std
::
cout
<<
"qgrad_g_m_k ref:
\n
"
<<
qgrad_g_m_k
;
std
::
cout
<<
"qgrad_g_m_k ref:
\n
"
<<
qgrad_g_m_k
;
}
}
#endif
// dK = alpha * dS^T * Q
// dK = alpha * dS^T * Q
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
sgrad_g_n_m
,
q_g_m_k
,
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
sgrad_g_n_m
,
q_g_m_k
,
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
if
(
PRINT_HOST
)
#
if
PRINT_HOST
{
{
std
::
cout
<<
"===== dK = alpha * dS^T * Q
\n
"
;
std
::
cout
<<
"===== dK = alpha * dS^T * Q
\n
"
;
std
::
cout
<<
"sgrad_g_n_m ref:
\n
"
<<
sgrad_g_n_m
;
std
::
cout
<<
"sgrad_g_n_m ref:
\n
"
<<
sgrad_g_n_m
;
std
::
cout
<<
"q_g_m_k ref:
\n
"
<<
q_g_m_k
;
std
::
cout
<<
"q_g_m_k ref:
\n
"
<<
q_g_m_k
;
std
::
cout
<<
"kgrad_g_n_k ref:
\n
"
<<
kgrad_g_n_k
;
std
::
cout
<<
"kgrad_g_n_k ref:
\n
"
<<
kgrad_g_n_k
;
}
}
#endif
Tensor
<
DataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment