Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bc9e2f25
Commit
bc9e2f25
authored
Jan 29, 2023
by
ltqin
Browse files
host dV correct
parent
4e79cc4b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
173 additions
and
16 deletions
+173
-16
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16_dropout.cpp
...emm/batched_multihead_attention_backward_fp16_dropout.cpp
+79
-16
library/include/ck/library/reference_tensor_operation/cpu/reference_dropout.hpp
...rary/reference_tensor_operation/cpu/reference_dropout.hpp
+94
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16_dropout.cpp
View file @
bc9e2f25
...
@@ -44,6 +44,7 @@ Kernel outputs:
...
@@ -44,6 +44,7 @@ Kernel outputs:
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -187,12 +188,16 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe
...
@@ -187,12 +188,16 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
>
;
Scale
>
;
// Ref dropout
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
DataType
,
DataType
>
;
template
<
typename
TensorQ
,
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorK
,
typename
TensorV
,
typename
TensorV
,
typename
TensorS
,
typename
TensorS
,
typename
TensorP
,
typename
TensorP
,
typename
TensorZ
,
typename
TensorY
,
typename
TensorY
,
typename
TensorLSE
=
TensorP
>
typename
TensorLSE
=
TensorP
>
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
...
@@ -202,7 +207,10 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -202,7 +207,10 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorS
&
s_g_m_n
,
TensorS
&
s_g_m_n
,
TensorP
&
p_g_m_n
,
TensorP
&
p_g_m_n
,
TensorY
&
y_g_m_o
,
TensorY
&
y_g_m_o
,
TensorLSE
&
lse_g_m
)
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ushort
p_dropout
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
...
@@ -230,11 +238,18 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -230,11 +238,18 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// Y = P * V
// P_dropout
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// std::cout << "p_drop_g_m_n ref:\n" << p_drop_g_m_n;
// Y = P_dropout * V
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
p_g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
p_
drop_
g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
}
}
...
@@ -263,6 +278,7 @@ int run(int argc, char* argv[])
...
@@ -263,6 +278,7 @@ int run(int argc, char* argv[])
float
p_drop
=
0.2
;
float
p_drop
=
0.2
;
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
@@ -433,6 +449,7 @@ int run(int argc, char* argv[])
...
@@ -433,6 +449,7 @@ int run(int argc, char* argv[])
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
...
@@ -447,7 +464,17 @@ int run(int argc, char* argv[])
...
@@ -447,7 +464,17 @@ int run(int argc, char* argv[])
lse_gs_ms
.
ForEach
(
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
);
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_16bits
);
y_gs_ms_os
.
ForEach
(
y_gs_ms_os
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
...
@@ -551,6 +578,23 @@ int run(int argc, char* argv[])
...
@@ -551,6 +578,23 @@ int run(int argc, char* argv[])
bool
pass
=
true
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
// run fowad again for y, cause z_g_m_n update
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_16bits
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
vgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
...
@@ -560,6 +604,7 @@ int run(int argc, char* argv[])
...
@@ -560,6 +604,7 @@ int run(int argc, char* argv[])
Tensor
<
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
...
@@ -581,18 +626,24 @@ int run(int argc, char* argv[])
...
@@ -581,18 +626,24 @@ int run(int argc, char* argv[])
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
// dP = dY * V^T
// dP
_dropout
= dY * V^T
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
ygrad_g_m_o
,
v_g_o_n
,
pgrad_
drop_
g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
#if PRINT_HOST
#if PRINT_HOST
{
{
std
::
cout
<<
"===== dP = dY * V^T
\n
"
;
std
::
cout
<<
"===== dP = dY * V^T
\n
"
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
y
grad_g_m_
o
;
std
::
cout
<<
"ygrad_
drop_
g_m_o ref:
\n
"
<<
p
grad_
drop_
g_m_
n
;
std
::
cout
<<
"v_g_o_n ref:
\n
"
<<
v_g_o_n
;
std
::
cout
<<
"v_g_o_n ref:
\n
"
<<
v_g_o_n
;
std
::
cout
<<
"pgrad_g_m_n ref:
\n
"
<<
pgrad_g_m_n
;
std
::
cout
<<
"pgrad_g_m_n ref:
\n
"
<<
pgrad_g_m_n
;
}
}
#endif
#endif
// dP = dP_dropout . Z
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
...
@@ -616,22 +667,30 @@ int run(int argc, char* argv[])
...
@@ -616,22 +667,30 @@ int run(int argc, char* argv[])
}
}
#endif
#endif
// dV = P^T * dY
// dV = rp_dropout * P_drop^T * dY
auto
p_g_n_m
=
p_g_m_n
.
Transpose
({
0
,
2
,
1
});
auto
pdrop_g_n_m
=
p_drop_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
pdrop_g_n_m
,
p_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
rp_dropout
}});
#if PRINT_HOST
#if PRINT_HOST
{
{
std
::
cout
<<
"===== dV = P^T * dY
\n
"
;
std
::
cout
<<
"===== dV = P^T * dY
\n
"
;
std
::
cout
<<
"p_g_n_m ref:
\n
"
<<
p_g_n_m
;
std
::
cout
<<
"p
drop
_g_n_m ref:
\n
"
<<
pdro
p_g_n_m
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"vgrad_g_n_o ref:
\n
"
<<
vgrad_g_n_o
;
std
::
cout
<<
"vgrad_g_n_o ref:
\n
"
<<
vgrad_g_n_o
;
}
}
#endif
#endif
// dQ = alpha * dS * K
// dQ = alpha * dS * K
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
sgrad_g_m_n
,
sgrad_g_m_n
,
k_g_n_k
,
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
k_g_n_k
,
qgrad_g_m_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
*
rp_dropout
}});
#if PRINT_HOST
#if PRINT_HOST
{
{
std
::
cout
<<
"===== dQ = alpha * dS * K
\n
"
;
std
::
cout
<<
"===== dQ = alpha * dS * K
\n
"
;
...
@@ -643,8 +702,12 @@ int run(int argc, char* argv[])
...
@@ -643,8 +702,12 @@ int run(int argc, char* argv[])
// dK = alpha * dS^T * Q
// dK = alpha * dS^T * Q
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
auto
sgrad_g_n_m
=
sgrad_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
sgrad_g_n_m
,
sgrad_g_n_m
,
q_g_m_k
,
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
q_g_m_k
,
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
*
rp_dropout
}});
#if PRINT_HOST
#if PRINT_HOST
{
{
std
::
cout
<<
"===== dK = alpha * dS^T * Q
\n
"
;
std
::
cout
<<
"===== dK = alpha * dS^T * Q
\n
"
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_dropout.hpp
0 → 100644
View file @
bc9e2f25
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
RefDataType
,
typename
InDataType
,
typename
OutDataType
>
struct
ReferenceDropout
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
RefDataType
>&
ref
,
const
Tensor
<
InDataType
>&
in
,
Tensor
<
OutDataType
>&
out
,
RefDataType
p_dropout
)
:
ref_
(
ref
),
in_
(
in
),
out_
(
out
),
p_dropout_
(
p_dropout
)
{
}
const
Tensor
<
RefDataType
>&
ref_
;
const
Tensor
<
InDataType
>&
in_
;
Tensor
<
OutDataType
>&
out_
;
RefDataType
p_dropout_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
float
Run
(
const
Argument
&
arg
)
{
arg
.
out_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
arg
.
ref_
(
idx
)
<
arg
.
p_dropout_
?
arg
.
in_
(
idx
)
:
0
;
});
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
RefDataType
>&
ref
,
const
Tensor
<
InDataType
>&
in
,
Tensor
<
OutDataType
>&
out
,
RefDataType
p_dropout
)
{
return
Argument
{
ref
,
in
,
out
,
p_dropout
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceDropout"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment