Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9ee2997d
Commit
9ee2997d
authored
Feb 16, 2023
by
guangzlu
Browse files
added bf16 fwd attn dropout verify
parent
067e71a8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
6 deletions
+17
-6
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_bf16.cpp
...softmax_gemm/grouped_multihead_attention_forward_bf16.cpp
+11
-3
library/include/ck/library/reference_tensor_operation/cpu/reference_dropout.hpp
...rary/reference_tensor_operation/cpu/reference_dropout.hpp
+6
-3
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_bf16.cpp
View file @
9ee2997d
...
...
@@ -27,12 +27,14 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
...
@@ -42,6 +44,7 @@ using B1DataType = BF16;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
BF16
;
using
ZDataType
=
U16
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
...
@@ -78,6 +81,7 @@ using DeviceGemmInstance =
B0DataType
,
B1DataType
,
CDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
...
...
@@ -98,8 +102,8 @@ using DeviceGemmInstance =
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
128
,
// Gemm1NPerBlock
64
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
...
...
@@ -107,7 +111,7 @@ using DeviceGemmInstance =
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -157,6 +161,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp
,
CElementOp
>
;
// Ref dropout
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ZDataType
,
ADataType
,
ADataType
>
;
#include "run_grouped_multihead_attention_forward.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
library/include/ck/library/reference_tensor_operation/cpu/reference_dropout.hpp
View file @
9ee2997d
...
...
@@ -31,14 +31,14 @@ struct ReferenceDropout : public device::BaseOperator
in_
(
in
),
out_
(
out
),
p_dropout_in_16bits_
(
p_dropout_in_16bits
),
rp_dropout_
(
ck
::
type_convert
<
OutDataType
>
(
rp_dropout
)
)
rp_dropout_
(
rp_dropout
)
{
}
const
Tensor
<
RefDataType
>&
ref_
;
const
Tensor
<
InDataType
>&
in_
;
Tensor
<
OutDataType
>&
out_
;
RefDataType
p_dropout_in_16bits_
;
OutDataType
rp_dropout_
;
float
rp_dropout_
;
};
// Invoker
...
...
@@ -48,7 +48,10 @@ struct ReferenceDropout : public device::BaseOperator
{
arg
.
out_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
arg
.
ref_
(
idx
)
<
arg
.
p_dropout_in_16bits_
?
arg
.
in_
(
idx
)
*
arg
.
rp_dropout_
:
0
;
arg
.
ref_
(
idx
)
<
arg
.
p_dropout_in_16bits_
?
ck
::
type_convert
<
OutDataType
>
(
ck
::
type_convert
<
float
>
(
arg
.
in_
(
idx
))
*
ck
::
type_convert
<
float
>
(
arg
.
rp_dropout_
))
:
0
;
});
return
0
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment