Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3185cbf9
Commit
3185cbf9
authored
Jan 12, 2023
by
fsx950223
Browse files
add verification
parent
fe6ee651
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
141 additions
and
207 deletions
+141
-207
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_fp16.cpp
...oftmax_gemm/grouped_multihead_attention_backward_fp16.cpp
+128
-169
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle.hpp
...ice_grouped_multihead_attention_backward_xdl_cshuffle.hpp
+13
-38
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_fp16.cpp
View file @
3185cbf9
...
@@ -244,7 +244,7 @@ int run(int argc, char* argv[])
...
@@ -244,7 +244,7 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float
K
=
64
;
float
K
=
128
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
bool
input_permute
=
false
;
bool
input_permute
=
false
;
...
@@ -297,6 +297,12 @@ int run(int argc, char* argv[])
...
@@ -297,6 +297,12 @@ int run(int argc, char* argv[])
std
::
vector
<
DataType
*>
p_vgrad
;
std
::
vector
<
DataType
*>
p_vgrad
;
std
::
vector
<
const
DataType
*>
p_ygrad
;
std
::
vector
<
const
DataType
*>
p_ygrad
;
std
::
vector
<
Tensor
<
DataType
>>
q_g_m_ks
;
std
::
vector
<
Tensor
<
DataType
>>
k_g_n_ks
;
std
::
vector
<
Tensor
<
DataType
>>
v_g_n_os
;
std
::
vector
<
Tensor
<
AccDataType
>>
s_g_m_ns
;
std
::
vector
<
Tensor
<
DataType
>>
p_g_m_ns
;
std
::
vector
<
Tensor
<
DataType
>>
y_g_m_os
;
std
::
vector
<
Tensor
<
DataType
>>
q_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
q_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
k_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
k_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
v_tensors
;
std
::
vector
<
Tensor
<
DataType
>>
v_tensors
;
...
@@ -478,14 +484,20 @@ int run(int argc, char* argv[])
...
@@ -478,14 +484,20 @@ int run(int argc, char* argv[])
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
lse_gs_ms
.
ForEach
(
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
);
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
);
y_gs_ms_os
.
ForEach
(
y_gs_ms_os
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms
.
ForEach
(
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
q_g_m_ks
.
push_back
(
q_g_m_k
);
k_g_n_ks
.
push_back
(
k_g_n_k
);
v_g_n_os
.
push_back
(
v_g_n_o
);
s_g_m_ns
.
push_back
(
s_g_m_n
);
p_g_m_ns
.
push_back
(
p_g_m_n
);
y_g_m_os
.
push_back
(
y_g_m_o
);
q_tensors
.
push_back
(
q_gs_ms_ks
);
q_tensors
.
push_back
(
q_gs_ms_ks
);
k_tensors
.
push_back
(
k_gs_ns_ks
);
k_tensors
.
push_back
(
k_gs_ns_ks
);
v_tensors
.
push_back
(
v_gs_os_ns
);
v_tensors
.
push_back
(
v_gs_os_ns
);
...
@@ -566,172 +578,119 @@ int run(int argc, char* argv[])
...
@@ -566,172 +578,119 @@ int run(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
// bool pass = true;
bool
pass
=
true
;
// if(do_verification)
if
(
do_verification
)
// {
{
// kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
for
(
int
i
=
0
;
i
<
group_count
;
i
++
){
// vgrad_device_buf.SetZero();
qgrad_tensors_device
[
i
]
->
SetZero
();
// invoker.Run(argument, StreamConfig{nullptr, false});
kgrad_tensors_device
[
i
]
->
SetZero
();
vgrad_tensors_device
[
i
]
->
SetZero
();
// Tensor<DataType> qgrad_g_m_k({BatchCount, M, K});
}
// Tensor<DataType> kgrad_g_n_k({BatchCount, N, K});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
// Tensor<DataType> vgrad_g_n_o({BatchCount, N, O});
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
){
// Tensor<DataType> sgrad_g_m_n({BatchCount, M, N});
// Tensor<DataType> pgrad_g_m_n({BatchCount, M, N});
int
G0
=
v_tensors
[
i
].
GetLengths
()[
0
];
// Tensor<DataType> ygrad_g_m_o({BatchCount, M, O});
int
G1
=
v_tensors
[
i
].
GetLengths
()[
1
];
// Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M});
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
// ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
// ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
// });
int
BatchCount
=
G0
*
G1
;
Tensor
<
DataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
// #if PRINT_HOST
Tensor
<
DataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
// {
Tensor
<
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
// std::cout << "q_g_m_k ref:\n" << q_g_m_k;
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
// std::cout << "k_g_n_k ref:\n" << k_g_n_k;
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
// std::cout << "v_g_n_o ref:\n" << v_g_n_o;
Tensor
<
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
// }
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
// #endif
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// // Gradients
auto
ref_gemm_grad
=
ReferenceGemmGradInstance
{};
// auto ref_gemm_grad = ReferenceGemmGradInstance{};
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
// auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker();
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
// using RefGemmGradArg = ReferenceGemmGradInstance::Argument;
// dP = dY * V^T
auto
v_g_o_n
=
v_g_n_os
[
i
].
Transpose
({
0
,
2
,
1
});
// // dP = dY * V^T
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
// auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1});
ygrad_g_m_o
,
v_g_o_n
,
pgrad_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
// ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
// ygrad_g_m_o, v_g_o_n, pgrad_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
float
ygrad_dot_y
=
0
;
// #if PRINT_HOST
for
(
int
o
=
0
;
o
<
O
;
o
++
)
// {
{
// std::cout << "===== dP = dY * V^T\n";
auto
idx_gmo
=
idx_gmn
;
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
idx_gmo
[
2
]
=
o
;
// std::cout << "v_g_o_n ref:\n" << v_g_o_n;
ygrad_dot_y
+=
ygrad_g_m_o
(
idx_gmo
)
*
y_g_m_os
[
i
](
idx_gmo
);
// std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
}
// }
self
(
idx_gmn
)
=
p_g_m_ns
[
i
](
idx_gmn
)
*
(
pgrad_g_m_n
(
idx_gmn
)
-
ygrad_dot_y
);
// #endif
});
auto
p_g_n_m
=
p_g_m_ns
[
i
].
Transpose
({
0
,
2
,
1
});
// // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
// sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
p_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
// float ygrad_dot_y = 0;
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
// for(int o = 0; o < O; o++)
sgrad_g_m_n
,
k_g_n_ks
[
i
],
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
// {
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
// auto idx_gmo = idx_gmn;
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
// idx_gmo[2] = o;
sgrad_g_n_m
,
q_g_m_ks
[
i
],
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
// ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_o(idx_gmo);
// }
Tensor
<
DataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
// self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y);
Tensor
<
DataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
// });
Tensor
<
DataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
// #if PRINT_HOST
// {
Tensor
<
DataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
// std::cout << "===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)\n";
Tensor
<
DataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
// std::cout << "p_g_m_n ref:\n" << p_g_m_n;
Tensor
<
DataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
// std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
// std::cout << "y_g_m_o ref:\n" << y_g_m_o;
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
// std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
// }
// permute
// #endif
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
// // dV = P^T * dY
const
size_t
&
g1
=
idx
[
1
];
// auto p_g_n_m = p_g_m_n.Transpose({0, 2, 1});
// ref_gemm_grad_invoker.Run(RefGemmGradArg{
const
size_t
g
=
g0
*
G1
+
g1
;
// p_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.f}});
// #if PRINT_HOST
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
// {
});
// std::cout << "===== dV = P^T * dY\n";
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
// std::cout << "p_g_n_m ref:\n" << p_g_n_m;
const
size_t
&
g0
=
idx
[
0
];
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
const
size_t
&
g1
=
idx
[
1
];
// std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o;
// }
const
size_t
g
=
g0
*
G1
+
g1
;
// #endif
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
// // dQ = alpha * dS * K
});
// ref_gemm_grad_invoker.Run(RefGemmGradArg{
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
// sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
const
size_t
&
g0
=
idx
[
0
];
// #if PRINT_HOST
const
size_t
&
g1
=
idx
[
1
];
// {
// std::cout << "===== dQ = alpha * dS * K\n";
const
size_t
g
=
g0
*
G1
+
g1
;
// std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
// std::cout << "k_g_n_k ref:\n" << k_g_n_k;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
// std::cout << "qgrad_g_m_k ref:\n" << qgrad_g_m_k;
});
// }
// #endif
std
::
cout
<<
"Checking qgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
qgrad_gs_ms_ks_device_result
.
mData
,
// // dK = alpha * dS^T * Q
qgrad_gs_ms_ks_host_result
.
mData
,
// auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
"error"
,
// ref_gemm_grad_invoker.Run(RefGemmGradArg{
1e-2
,
// sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
1e-2
);
// #if PRINT_HOST
std
::
cout
<<
"Checking kgrad:
\n
"
;
// {
pass
&=
ck
::
utils
::
check_err
(
kgrad_gs_ns_ks_device_result
.
mData
,
// std::cout << "===== dK = alpha * dS^T * Q\n";
kgrad_gs_ns_ks_host_result
.
mData
,
// std::cout << "sgrad_g_n_m ref:\n" << sgrad_g_n_m;
"error"
,
// std::cout << "q_g_m_k ref:\n" << q_g_m_k;
1e-2
,
// std::cout << "kgrad_g_n_k ref:\n" << kgrad_g_n_k;
1e-2
);
// }
std
::
cout
<<
"Checking vgrad:
\n
"
;
// #endif
pass
&=
ck
::
utils
::
check_err
(
vgrad_gs_os_ns_device_result
.
mData
,
vgrad_gs_os_ns_host_result
.
mData
,
// Tensor<DataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
"error"
,
// Tensor<DataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
1e-2
,
// Tensor<DataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
1e-2
);
}
// Tensor<DataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
}
// Tensor<DataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
// Tensor<DataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
return
pass
?
((
void
)(
std
::
cout
<<
"pass
\n
"
),
0
)
:
((
void
)(
std
::
cout
<<
"fail
\n
"
),
1
);
// qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
// kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
// vgrad_device_buf.FromDevice(vgrad_gs_os_ns_device_result.mData.data());
// // permute
// qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
// const size_t& g0 = idx[0];
// const size_t& g1 = idx[1];
// const size_t g = g0 * G1 + g1;
// self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
// });
// kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
// const size_t& g0 = idx[0];
// const size_t& g1 = idx[1];
// const size_t g = g0 * G1 + g1;
// self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
// });
// vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
// const size_t& g0 = idx[0];
// const size_t& g1 = idx[1];
// const size_t g = g0 * G1 + g1;
// self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
// });
// std::cout << "Checking qgrad:\n";
// pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
// qgrad_gs_ms_ks_host_result.mData,
// "error",
// 1e-2,
// 1e-2);
// std::cout << "Checking kgrad:\n";
// pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
// kgrad_gs_ns_ks_host_result.mData,
// "error",
// 1e-2,
// 1e-2);
// std::cout << "Checking vgrad:\n";
// pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
// vgrad_gs_os_ns_host_result.mData,
// "error",
// 1e-2,
// 1e-2);
// }
// return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
return
0
;
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle.hpp
View file @
3185cbf9
...
@@ -50,7 +50,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...
@@ -50,7 +50,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
cast_pointer_to_generic_address_space
(
group_kernel_args
));
cast_pointer_to_generic_address_space
(
group_kernel_args
));
index_t
left
=
0
;
index_t
left
=
0
;
...
@@ -718,9 +718,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -718,9 +718,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
const
auto
p_c_grid
=
static_cast
<
const
DataType
*>
(
p_Cs
[
i
]);
const
auto
p_c_grid
=
static_cast
<
const
DataType
*>
(
p_Cs
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
const
LSEDataType
*>
(
p_LSEs
[
i
]);
const
auto
p_lse_grid
=
static_cast
<
const
LSEDataType
*>
(
p_LSEs
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
DataType
*>
(
p_Ygrads
[
i
]);
const
auto
p_ygrad_grid
=
static_cast
<
const
DataType
*>
(
p_Ygrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
DataType
*>
(
p_Qgrads
[
i
]);
auto
p_qgrad_grid
=
static_cast
<
DataType
*>
(
p_Qgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
DataType
*>
(
p_Kgrads
[
i
]);
auto
p_kgrad_grid
=
static_cast
<
DataType
*>
(
p_Kgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
DataType
*>
(
p_Vgrads
[
i
]);
auto
p_vgrad_grid
=
static_cast
<
DataType
*>
(
p_Vgrads
[
i
]);
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
const
auto
&
problem_desc
=
problem_desc_vec
[
i
];
...
@@ -844,31 +844,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -844,31 +844,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
// ignore = acc1_biases_gs_ms_gemm1ns_strides;
// ignore = acc1_biases_gs_ms_gemm1ns_strides;
}
}
// void Print() const
// {
// std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
// << a_grid_desc_g_m_k_.GetLength(I1) << ", "
// << a_grid_desc_g_m_k_.GetLength(I2) << '\n';
// // a_grid_desc_g_m_k_.Print();
// std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
// << b_grid_desc_g_n_k_.GetLength(I1) << ", "
// << b_grid_desc_g_n_k_.GetLength(I2) << '\n';
// // b_grid_desc_g_n_k_.Print();
// std::cout << "b1_grid_desc_g_o_n_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
// << b1_grid_desc_g_n_k_.GetLength(I1) << ", "
// << b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
// // b1_grid_desc_g_n_k_.Print();
// std::cout << "c_grid_desc_g_m_o_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
// << c_grid_desc_g_m_n_.GetLength(I1) << ", "
// << c_grid_desc_g_m_n_.GetLength(I2) << '\n';
// // c_grid_desc_g_m_n_.Print();
// std::cout << "vgrad_grid_desc_n_o_: " << vgrad_grid_desc_n_o_.GetLength(I0) << ", "
// << vgrad_grid_desc_n_o_.GetLength(I1) << '\n';
// std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0)
// << ", " << ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", "
// << ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n';
// }
// element-wise op
// element-wise op
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
...
@@ -914,15 +889,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -914,15 +889,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1
<
const
auto
kernel
=
GridwiseGemm
,
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
GroupKernelArg
,
GroupKernelArg
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment