Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3185cbf9
Commit
3185cbf9
authored
Jan 12, 2023
by
fsx950223
Browse files
add verification
parent
fe6ee651
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
141 additions
and
207 deletions
+141
-207
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_fp16.cpp
...oftmax_gemm/grouped_multihead_attention_backward_fp16.cpp
+128
-169
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle.hpp
...ice_grouped_multihead_attention_backward_xdl_cshuffle.hpp
+13
-38
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_fp16.cpp
View file @
3185cbf9
...
...
@@ -244,7 +244,7 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float
K
=
64
;
float
K
=
128
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
bool
input_permute
=
false
;
...
...
@@ -297,6 +297,12 @@ int run(int argc, char* argv[])
std
::
vector
<
DataType
*>
p_vgrad
;
std
::
vector
<
const
DataType
*>
p_ygrad
;
std
::
vector
<
Tensor
<
DataType
>>
q_g_m_ks
;
std
::
vector
<
Tensor
<
DataType
>>
k_g_n_ks
;
std
::
vector
<
Tensor
<
DataType
>>
v_g_n_os
;
std
::
vector
<
Tensor
<
AccDataType
>>
s_g_m_ns
;
std
::
vector
<
Tensor
<
DataType
>>
p_g_m_ns
;
std
::
vector
<
Tensor
<
DataType
>>
y_g_m_os
;
std
::
vector
<
Tensor
<
DataType
>>
q_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
k_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
v_tensors
;
...
...
@@ -486,6 +492,12 @@ int run(int argc, char* argv[])
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
q_g_m_ks
.
push_back
(
q_g_m_k
);
k_g_n_ks
.
push_back
(
k_g_n_k
);
v_g_n_os
.
push_back
(
v_g_n_o
);
s_g_m_ns
.
push_back
(
s_g_m_n
);
p_g_m_ns
.
push_back
(
p_g_m_n
);
y_g_m_os
.
push_back
(
y_g_m_o
);
q_tensors
.
push_back
(
q_gs_ms_ks
);
k_tensors
.
push_back
(
k_gs_ns_ks
);
v_tensors
.
push_back
(
v_gs_os_ns
);
...
...
@@ -566,172 +578,119 @@ int run(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
// bool pass = true;
// if(do_verification)
// {
// kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
// vgrad_device_buf.SetZero();
// invoker.Run(argument, StreamConfig{nullptr, false});
// Tensor<DataType> qgrad_g_m_k({BatchCount, M, K});
// Tensor<DataType> kgrad_g_n_k({BatchCount, N, K});
// Tensor<DataType> vgrad_g_n_o({BatchCount, N, O});
// Tensor<DataType> sgrad_g_m_n({BatchCount, M, N});
// Tensor<DataType> pgrad_g_m_n({BatchCount, M, N});
// Tensor<DataType> ygrad_g_m_o({BatchCount, M, O});
// Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M});
// ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
// ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
// });
// #if PRINT_HOST
// {
// std::cout << "q_g_m_k ref:\n" << q_g_m_k;
// std::cout << "k_g_n_k ref:\n" << k_g_n_k;
// std::cout << "v_g_n_o ref:\n" << v_g_n_o;
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
// }
// #endif
// // Gradients
// auto ref_gemm_grad = ReferenceGemmGradInstance{};
// auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker();
// using RefGemmGradArg = ReferenceGemmGradInstance::Argument;
// // dP = dY * V^T
// auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1});
// ref_gemm_grad_invoker.Run(RefGemmGradArg{
// ygrad_g_m_o, v_g_o_n, pgrad_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
// #if PRINT_HOST
// {
// std::cout << "===== dP = dY * V^T\n";
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
// std::cout << "v_g_o_n ref:\n" << v_g_o_n;
// std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
// }
// #endif
// // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
// float ygrad_dot_y = 0;
// for(int o = 0; o < O; o++)
// {
// auto idx_gmo = idx_gmn;
// idx_gmo[2] = o;
// ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_o(idx_gmo);
// }
// self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y);
// });
// #if PRINT_HOST
// {
// std::cout << "===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)\n";
// std::cout << "p_g_m_n ref:\n" << p_g_m_n;
// std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
// std::cout << "y_g_m_o ref:\n" << y_g_m_o;
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
// std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
// }
// #endif
// // dV = P^T * dY
// auto p_g_n_m = p_g_m_n.Transpose({0, 2, 1});
// ref_gemm_grad_invoker.Run(RefGemmGradArg{
// p_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.f}});
// #if PRINT_HOST
// {
// std::cout << "===== dV = P^T * dY\n";
// std::cout << "p_g_n_m ref:\n" << p_g_n_m;
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
// std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o;
// }
// #endif
// // dQ = alpha * dS * K
// ref_gemm_grad_invoker.Run(RefGemmGradArg{
// sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
// #if PRINT_HOST
// {
// std::cout << "===== dQ = alpha * dS * K\n";
// std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
// std::cout << "k_g_n_k ref:\n" << k_g_n_k;
// std::cout << "qgrad_g_m_k ref:\n" << qgrad_g_m_k;
// }
// #endif
// // dK = alpha * dS^T * Q
// auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
// ref_gemm_grad_invoker.Run(RefGemmGradArg{
// sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
// #if PRINT_HOST
// {
// std::cout << "===== dK = alpha * dS^T * Q\n";
// std::cout << "sgrad_g_n_m ref:\n" << sgrad_g_n_m;
// std::cout << "q_g_m_k ref:\n" << q_g_m_k;
// std::cout << "kgrad_g_n_k ref:\n" << kgrad_g_n_k;
// }
// #endif
// Tensor<DataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
// Tensor<DataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
// Tensor<DataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
// Tensor<DataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
// Tensor<DataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
// Tensor<DataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
// qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
// kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
// vgrad_device_buf.FromDevice(vgrad_gs_os_ns_device_result.mData.data());
// // permute
// qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
// const size_t& g0 = idx[0];
// const size_t& g1 = idx[1];
// const size_t g = g0 * G1 + g1;
// self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
// });
// kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
// const size_t& g0 = idx[0];
// const size_t& g1 = idx[1];
// const size_t g = g0 * G1 + g1;
// self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
// });
// vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
// const size_t& g0 = idx[0];
// const size_t& g1 = idx[1];
// const size_t g = g0 * G1 + g1;
// self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
// });
// std::cout << "Checking qgrad:\n";
// pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
// qgrad_gs_ms_ks_host_result.mData,
// "error",
// 1e-2,
// 1e-2);
// std::cout << "Checking kgrad:\n";
// pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
// kgrad_gs_ns_ks_host_result.mData,
// "error",
// 1e-2,
// 1e-2);
// std::cout << "Checking vgrad:\n";
// pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
// vgrad_gs_os_ns_host_result.mData,
// "error",
// 1e-2,
// 1e-2);
// }
// return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
return
0
;
bool
pass
=
true
;
if
(
do_verification
)
{
for
(
int
i
=
0
;
i
<
group_count
;
i
++
){
qgrad_tensors_device
[
i
]
->
SetZero
();
kgrad_tensors_device
[
i
]
->
SetZero
();
vgrad_tensors_device
[
i
]
->
SetZero
();
}
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
){
int
G0
=
v_tensors
[
i
].
GetLengths
()[
0
];
int
G1
=
v_tensors
[
i
].
GetLengths
()[
1
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
BatchCount
=
G0
*
G1
;
Tensor
<
DataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
auto
ref_gemm_grad
=
ReferenceGemmGradInstance
{};
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
// dP = dY * V^T
auto
v_g_o_n
=
v_g_n_os
[
i
].
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
float
ygrad_dot_y
=
0
;
for
(
int
o
=
0
;
o
<
O
;
o
++
)
{
auto
idx_gmo
=
idx_gmn
;
idx_gmo
[
2
]
=
o
;
ygrad_dot_y
+=
ygrad_g_m_o
(
idx_gmo
)
*
y_g_m_os
[
i
](
idx_gmo
);
}
self
(
idx_gmn
)
=
p_g_m_ns
[
i
](
idx_gmn
)
*
(
pgrad_g_m_n
(
idx_gmn
)
-
ygrad_dot_y
);
});
auto
p_g_n_m
=
p_g_m_ns
[
i
].
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
p_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
sgrad_g_m_n
,
k_g_n_ks
[
i
],
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
sgrad_g_n_m
,
q_g_m_ks
[
i
],
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
Tensor
<
DataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
std
::
cout
<<
"Checking qgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
qgrad_gs_ms_ks_device_result
.
mData
,
qgrad_gs_ms_ks_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
std
::
cout
<<
"Checking kgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
kgrad_gs_ns_ks_device_result
.
mData
,
kgrad_gs_ns_ks_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
std
::
cout
<<
"Checking vgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
vgrad_gs_os_ns_device_result
.
mData
,
vgrad_gs_os_ns_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
}
}
return
pass
?
((
void
)(
std
::
cout
<<
"pass
\n
"
),
0
)
:
((
void
)(
std
::
cout
<<
"fail
\n
"
),
1
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle.hpp
View file @
3185cbf9
...
...
@@ -844,31 +844,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
// ignore = acc1_biases_gs_ms_gemm1ns_strides;
}
// void Print() const
// {
// std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
// << a_grid_desc_g_m_k_.GetLength(I1) << ", "
// << a_grid_desc_g_m_k_.GetLength(I2) << '\n';
// // a_grid_desc_g_m_k_.Print();
// std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
// << b_grid_desc_g_n_k_.GetLength(I1) << ", "
// << b_grid_desc_g_n_k_.GetLength(I2) << '\n';
// // b_grid_desc_g_n_k_.Print();
// std::cout << "b1_grid_desc_g_o_n_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
// << b1_grid_desc_g_n_k_.GetLength(I1) << ", "
// << b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
// // b1_grid_desc_g_n_k_.Print();
// std::cout << "c_grid_desc_g_m_o_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
// << c_grid_desc_g_m_n_.GetLength(I1) << ", "
// << c_grid_desc_g_m_n_.GetLength(I2) << '\n';
// // c_grid_desc_g_m_n_.Print();
// std::cout << "vgrad_grid_desc_n_o_: " << vgrad_grid_desc_n_o_.GetLength(I0) << ", "
// << vgrad_grid_desc_n_o_.GetLength(I1) << '\n';
// std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0)
// << ", " << ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", "
// << ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n';
// }
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
...
...
@@ -914,8 +889,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
const
auto
kernel
=
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
GroupKernelArg
,
AElementwiseOperation
,
BElementwiseOperation
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment