Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0c359ae8
Commit
0c359ae8
authored
Apr 14, 2023
by
danyao12
Browse files
ZDataType can be set to U16/INT32 in fwd&bwd&train examples
parent
3b57967f
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
104 additions
and
55 deletions
+104
-55
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
...ale_softmax_gemm/batched_multihead_attention_backward.cpp
+12
-11
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward.cpp
...cale_softmax_gemm/batched_multihead_attention_forward.cpp
+6
-5
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
..._scale_softmax_gemm/batched_multihead_attention_train.cpp
+15
-14
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward.cpp
...ale_softmax_gemm/grouped_multihead_attention_backward.cpp
+11
-10
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward.cpp
...cale_softmax_gemm/grouped_multihead_attention_forward.cpp
+6
-5
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train.cpp
..._scale_softmax_gemm/grouped_multihead_attention_train.cpp
+11
-10
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+43
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
View file @
0c359ae8
...
@@ -55,6 +55,7 @@ using F16 = ck::half_t;
...
@@ -55,6 +55,7 @@ using F16 = ck::half_t;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
U16
=
unsigned
short
;
using
INT32
=
int32_t
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
...
@@ -68,7 +69,7 @@ using GemmDataType = BF16;
...
@@ -68,7 +69,7 @@ using GemmDataType = BF16;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
using
ZDataType
=
U16
;
// INT32
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
@@ -422,7 +423,7 @@ using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedG
...
@@ -422,7 +423,7 @@ using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedG
// Ref dropout
// Ref dropout
using
ReferenceDropoutInstance
=
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
InputDataType
,
InputDataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ZDataType
,
InputDataType
,
InputDataType
>
;
template
<
typename
TensorQ
,
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorK
,
...
@@ -442,7 +443,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -442,7 +443,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
TensorZ
&
z_g_m_n
,
ushort
p_dropout_in_16bits
,
ZDataType
p_dropout_in_16bits
,
float
rp_dropout
)
float
rp_dropout
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
...
@@ -549,7 +550,7 @@ int run(int argc, char* argv[])
...
@@ -549,7 +550,7 @@ int run(int argc, char* argv[])
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
ZDataType
p_dropout_in_16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward.cpp
View file @
0c359ae8
...
@@ -38,6 +38,7 @@ using F16 = ck::half_t;
...
@@ -38,6 +38,7 @@ using F16 = ck::half_t;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
U16
=
unsigned
short
;
using
INT32
=
int32_t
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
@@ -49,7 +50,7 @@ using B1DataType = DataType;
...
@@ -49,7 +50,7 @@ using B1DataType = DataType;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
DataType
;
using
CDataType
=
DataType
;
using
ZDataType
=
U16
;
using
ZDataType
=
U16
;
// INT32
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
View file @
0c359ae8
...
@@ -64,6 +64,7 @@ using F16 = ck::half_t;
...
@@ -64,6 +64,7 @@ using F16 = ck::half_t;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
U16
=
unsigned
short
;
using
INT32
=
int32_t
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
...
@@ -71,13 +72,13 @@ using Scale = ck::tensor_operation::element_wise::Scale;
...
@@ -71,13 +72,13 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
QKVElementOp
=
PassThrough
;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
InputDataType
=
B
F16
;
using
InputDataType
=
F16
;
using
OutputDataType
=
F
32
;
using
OutputDataType
=
F
16
;
using
GemmDataType
=
B
F16
;
using
GemmDataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
using
ZDataType
=
INT32
;
// INT32
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
@@ -641,7 +642,7 @@ using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedG
...
@@ -641,7 +642,7 @@ using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedG
// Ref dropout
// Ref dropout
using
ReferenceDropoutInstance
=
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
InputDataType
,
InputDataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ZDataType
,
InputDataType
,
InputDataType
>
;
template
<
typename
TensorQ
,
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorK
,
...
@@ -661,7 +662,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -661,7 +662,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
TensorZ
&
z_g_m_n
,
ushort
p_dropout_in_16bits
,
ZDataType
p_dropout_in_16bits
,
float
rp_dropout
)
float
rp_dropout
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
...
@@ -768,7 +769,7 @@ int run(int argc, char* argv[])
...
@@ -768,7 +769,7 @@ int run(int argc, char* argv[])
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
ZDataType
p_dropout_in_16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward.cpp
View file @
0c359ae8
...
@@ -54,6 +54,7 @@ using F16 = ck::half_t;
...
@@ -54,6 +54,7 @@ using F16 = ck::half_t;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
U16
=
unsigned
short
;
using
INT32
=
int32_t
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
...
@@ -67,7 +68,7 @@ using GemmDataType = BF16;
...
@@ -67,7 +68,7 @@ using GemmDataType = BF16;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
using
ZDataType
=
INT32
;
//
U16
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
@@ -421,7 +422,7 @@ using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedG
...
@@ -421,7 +422,7 @@ using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedG
// Ref dropout
// Ref dropout
using
ReferenceDropoutInstance
=
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
InputDataType
,
InputDataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ZDataType
,
InputDataType
,
InputDataType
>
;
template
<
typename
TensorQ
,
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorK
,
...
@@ -441,7 +442,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -441,7 +442,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
TensorZ
&
z_g_m_n
,
ushort
p_dropout_in_16bits
,
ZDataType
p_dropout_in_16bits
,
float
rp_dropout
)
float
rp_dropout
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
...
@@ -536,7 +537,7 @@ int run(int argc, char* argv[])
...
@@ -536,7 +537,7 @@ int run(int argc, char* argv[])
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
ZDataType
p_dropout_in_16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward.cpp
View file @
0c359ae8
...
@@ -38,6 +38,7 @@ using F16 = ck::half_t;
...
@@ -38,6 +38,7 @@ using F16 = ck::half_t;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
U16
=
unsigned
short
;
using
INT32
=
int32_t
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
@@ -49,7 +50,7 @@ using B1DataType = DataType;
...
@@ -49,7 +50,7 @@ using B1DataType = DataType;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
DataType
;
using
CDataType
=
DataType
;
using
ZDataType
=
U16
;
using
ZDataType
=
INT32
;
//
U16
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train.cpp
View file @
0c359ae8
...
@@ -63,6 +63,7 @@ using F16 = ck::half_t;
...
@@ -63,6 +63,7 @@ using F16 = ck::half_t;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
U16
=
unsigned
short
;
using
INT32
=
int32_t
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
...
@@ -76,7 +77,7 @@ using GemmDataType = BF16;
...
@@ -76,7 +77,7 @@ using GemmDataType = BF16;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
using
ZDataType
=
INT32
;
//
U16
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
@@ -640,7 +641,7 @@ using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedG
...
@@ -640,7 +641,7 @@ using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedG
// Ref dropout
// Ref dropout
using
ReferenceDropoutInstance
=
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
InputDataType
,
InputDataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ZDataType
,
InputDataType
,
InputDataType
>
;
template
<
typename
TensorQ
,
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorK
,
...
@@ -660,7 +661,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -660,7 +661,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE
&
lse_g_m
,
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
TensorZ
&
z_g_m_n
,
ushort
p_dropout_in_16bits
,
ZDataType
p_dropout_in_16bits
,
float
rp_dropout
)
float
rp_dropout
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
...
@@ -755,7 +756,7 @@ int run(int argc, char* argv[])
...
@@ -755,7 +756,7 @@ int run(int argc, char* argv[])
}
}
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
ZDataType
p_dropout_in_16bits
=
ZDataType
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
auto
gemm_fwd
=
DeviceGemmInstanceFWD
{};
auto
gemm_fwd
=
DeviceGemmInstanceFWD
{};
...
...
library/include/ck/library/utility/check_err.hpp
View file @
0c359ae8
...
@@ -257,5 +257,48 @@ check_err(const Range& out, const RefRange& ref, unsigned short atol = 1)
...
@@ -257,5 +257,48 @@ check_err(const Range& out, const RefRange& ref, unsigned short atol = 1)
return
res
;
return
res
;
}
}
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
int32_t
>
,
bool
>::
type
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
int32_t
atol
=
1
)
{
const
std
::
string
&
msg
=
"Error: Incorrect U16 results!"
;
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
bool
res
{
true
};
int
err_count
=
0
;
int32_t
err
=
0
;
int32_t
max_err
=
std
::
numeric_limits
<
int32_t
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
int32_t
o
=
*
std
::
next
(
std
::
begin
(
out
),
i
);
const
int32_t
r
=
*
std
::
next
(
std
::
begin
(
ref
),
i
);
err
=
(
o
>
r
)
?
o
-
r
:
r
-
o
;
if
(
err
>
atol
)
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
return
res
;
}
}
// namespace utils
}
// namespace utils
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment