Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b1e544e2
"...composable_kernel_rocm.git" did not exist on "836b7e557d028cc2d7c6b341352253fd81003e54"
Commit
b1e544e2
authored
Nov 16, 2022
by
Anthony Chang
Browse files
ready to plug in kernel
parent
4f6d52c1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
182 additions
and
106 deletions
+182
-106
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+182
-106
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
b1e544e2
...
...
@@ -15,6 +15,8 @@ Outputs:
*/
#define PRINT_HOST 1
#include <iostream>
#include <numeric>
#include <initializer_list>
...
...
@@ -97,38 +99,60 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe
PassThrough
,
PassThrough
,
Scale
>
;
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorV
,
typename
TensorS
,
typename
TensorP
,
typename
TensorY
>
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
const
TensorK
&
k_g_n_k
,
const
TensorV
&
v_g_n_o
,
const
float
alpha
,
TensorS
&
s_g_m_n
,
TensorP
&
p_g_m_n
,
TensorY
&
y_g_m_o
)
{
// S = alpha * Q * K^T
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
q_g_m_k
,
k_g_k_n
,
s_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
});
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
#if 0
// Ref Gemm dP: dP = dY * V^T
// fp16 in, fp16 out
using ReferenceGemmPGradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType,
DataType,
DataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
// Ref Gemm dQ: dQ = alpha * dS * K
// fp16 in, fp16 out
using ReferenceGemmQGradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType,
DataType,
DataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
// Ref Gemm dK: dK = alpha * dS^T * Q
// fp16 in, fp16 out
using ReferenceGemmKGradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType,
DataType,
DataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
#endif
// P = Softmax(S)
// >>> scipy.special.softmax(numpy.eye(4), 1)
// array([[0.47536689, 0.1748777 , 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.47536689, 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.47536689, 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]])
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
s_g_m_n
,
p_g_m_n
,
1
,
0
,
{
2
});
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// Y = P * V
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
p_g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
}
int
run
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
...
...
@@ -136,9 +160,9 @@ int run(int argc, char* argv[])
bool
time_kernel
=
false
;
// Overall QKV matrices shape
//
Y
_g_m_o = Softmax(Q_g_m_k * K_g_k_n) * V_g_n_o
//
Y
_g0_g1_m_o = reshape(
Y
_g_m_o, [G0, G1, M, O])
//
Y
_g0_m_g1_o = permute(
Y
_g0_g1_m_o, [0, 2, 1, 3])
//
y
_g_m_o = Softmax(
alpha *
Q_g_m_k * K_g_k_n) * V_g_n_o
//
y
_g0_g1_m_o = reshape(
y
_g_m_o, [G0, G1, M, O])
//
y
_g0_m_g1_o = permute(
y
_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
4
;
ck
::
index_t
N
=
4
;
ck
::
index_t
K
=
4
;
...
...
@@ -219,20 +243,13 @@ int run(int argc, char* argv[])
Tensor
<
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
// Tensor<DataType> y_gs_ms_os_device_result(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor
<
DataType
>
qgrad_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
kgrad_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
vgrad_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
// Tensor<DataType> qgrad_gs_ms_ks_device(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
// Tensor<DataType> kgrad_gs_ns_ks_device(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
// Tensor<DataType> vgrad_gs_os_ns_device(v_gs_os_ns_lengths, v_gs_os_ns_strides);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -262,15 +279,38 @@ int run(int argc, char* argv[])
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
10
});
}
#if 0
// calculate y beforehand
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
q_gs_ms_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
);
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
#endif
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
// TODO ANT: attention backward kernel
#if 0
...
...
@@ -280,7 +320,7 @@ int run(int argc, char* argv[])
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y
grad
_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
...
...
@@ -323,44 +363,31 @@ int run(int argc, char* argv[])
<< gemm.GetTypeString() << std::endl;
#endif
bool
pass
=
true
;
if
(
do_verification
)
{
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
Tensor
<
DataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object in bwd pass
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object in bwd pass
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
// scratch object in bwd pass
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
// permute
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
// TODO ANT: os_ns -> ns_os ?
});
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
){
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
*
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
std
::
cout
<<
"q_g_m_k ref:
\n
"
<<
q_g_m_k
;
std
::
cout
<<
"k_g_n_k ref:
\n
"
<<
k_g_n_k
;
std
::
cout
<<
"v_g_n_o ref:
\n
"
<<
v_g_n_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
if
(
PRINT_HOST
)
{
std
::
cout
<<
"q_g_m_k ref:
\n
"
<<
q_g_m_k
;
std
::
cout
<<
"k_g_n_k ref:
\n
"
<<
k_g_n_k
;
std
::
cout
<<
"v_g_n_o ref:
\n
"
<<
v_g_n_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
}
// S = alpha * Q * K^T
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
...
...
@@ -400,77 +427,126 @@ int run(int argc, char* argv[])
// Gradients
auto
ref_gemm_grad
=
ReferenceGemmGradInstance
{};
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
// dP = dY * V^T
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
std
::
cout
<<
"===== dP = dY * V^T
\n
"
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"v_g_o_n ref:
\n
"
<<
v_g_o_n
;
std
::
cout
<<
"pgrad_g_m_n ref:
\n
"
<<
pgrad_g_m_n
;
if
(
PRINT_HOST
)
{
std
::
cout
<<
"===== dP = dY * V^T
\n
"
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"v_g_o_n ref:
\n
"
<<
v_g_o_n
;
std
::
cout
<<
"pgrad_g_m_n ref:
\n
"
<<
pgrad_g_m_n
;
}
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
){
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
float
ygrad_dot_y
=
0
;
for
(
int
o
=
0
;
o
<
O
;
o
++
)
for
(
int
o
=
0
;
o
<
O
;
o
++
)
{
auto
idx_gmo
=
idx_gmn
;
idx_gmo
[
2
]
=
o
;
idx_gmo
[
2
]
=
o
;
ygrad_dot_y
+=
ygrad_g_m_o
(
idx_gmo
)
*
y_g_m_o
(
idx_gmo
);
}
self
(
idx_gmn
)
=
p_g_m_n
(
idx_gmn
)
*
(
pgrad_g_m_n
(
idx_gmn
)
-
ygrad_dot_y
);
});
std
::
cout
<<
"===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
\n
"
;
std
::
cout
<<
"p_g_m_n ref:
\n
"
<<
p_g_m_n
;
std
::
cout
<<
"pgrad_g_m_n ref:
\n
"
<<
pgrad_g_m_n
;
std
::
cout
<<
"y_g_m_o ref:
\n
"
<<
y_g_m_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
if
(
PRINT_HOST
)
{
std
::
cout
<<
"===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
\n
"
;
std
::
cout
<<
"p_g_m_n ref:
\n
"
<<
p_g_m_n
;
std
::
cout
<<
"pgrad_g_m_n ref:
\n
"
<<
pgrad_g_m_n
;
std
::
cout
<<
"y_g_m_o ref:
\n
"
<<
y_g_m_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
}
// dV = P^T * dY
auto
p_g_n_m
=
p_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
p_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
std
::
cout
<<
"===== dV = P^T * dY
\n
"
;
std
::
cout
<<
"p_g_n_m ref:
\n
"
<<
p_g_n_m
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"vgrad_g_n_o ref:
\n
"
<<
vgrad_g_n_o
;
if
(
PRINT_HOST
)
{
std
::
cout
<<
"===== dV = P^T * dY
\n
"
;
std
::
cout
<<
"p_g_n_m ref:
\n
"
<<
p_g_n_m
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"vgrad_g_n_o ref:
\n
"
<<
vgrad_g_n_o
;
}
// dQ = alpha * dS * K
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
sgrad_g_m_n
,
k_g_n_k
,
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
std
::
cout
<<
"===== dQ = alpha * dS * K
\n
"
;
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
std
::
cout
<<
"k_g_n_k ref:
\n
"
<<
k_g_n_k
;
std
::
cout
<<
"qgrad_g_m_k ref:
\n
"
<<
qgrad_g_m_k
;
if
(
PRINT_HOST
)
{
std
::
cout
<<
"===== dQ = alpha * dS * K
\n
"
;
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
std
::
cout
<<
"k_g_n_k ref:
\n
"
<<
k_g_n_k
;
std
::
cout
<<
"qgrad_g_m_k ref:
\n
"
<<
qgrad_g_m_k
;
}
// dK = alpha * dS^T * Q
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
sgrad_g_n_m
,
q_g_m_k
,
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
std
::
cout
<<
"===== dK = alpha * dS^T * Q
\n
"
;
std
::
cout
<<
"sgrad_g_n_m ref:
\n
"
<<
sgrad_g_n_m
;
std
::
cout
<<
"q_g_m_k ref:
\n
"
<<
q_g_m_k
;
std
::
cout
<<
"kgrad_g_n_k ref:
\n
"
<<
kgrad_g_n_k
;
if
(
PRINT_HOST
)
{
std
::
cout
<<
"===== dK = alpha * dS^T * Q
\n
"
;
std
::
cout
<<
"sgrad_g_n_m ref:
\n
"
<<
sgrad_g_n_m
;
std
::
cout
<<
"q_g_m_k ref:
\n
"
<<
q_g_m_k
;
std
::
cout
<<
"kgrad_g_n_k ref:
\n
"
<<
kgrad_g_n_k
;
}
Tensor
<
DataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
vgrad_device_buf
.
FromDevice
(
vgrad_gs_os_ns_device_result
.
mData
.
data
());
// permute
// y_gs_ms_os.ForEach([&](auto& self, auto idx) {
// const size_t& g0 = idx[0];
// const size_t& g1 = idx[1];
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
//
const size_t g = g0 * G1 + g1;
const
size_t
g
=
g0
*
G1
+
g1
;
// self(idx) = y_g_m_o(g, idx[2], idx[3]);
// });
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
// return ck::utils::check_err(y_gs_ms_os_device_result.mData, y_gs_ms_os.mData)
// ? 0
// : 1;
std
::
cout
<<
"Checking qgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
qgrad_gs_ms_ks_device_result
.
mData
,
qgrad_gs_ms_ks_host_result
.
mData
);
std
::
cout
<<
"Checking kgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
kgrad_gs_ns_ks_device_result
.
mData
,
kgrad_gs_ns_ks_host_result
.
mData
);
std
::
cout
<<
"Checking vgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
vgrad_gs_os_ns_device_result
.
mData
,
vgrad_gs_os_ns_host_result
.
mData
);
}
return
0
;
return
pass
?
0
:
1
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment