Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fe659502
Commit
fe659502
authored
Apr 15, 2022
by
rocking
Browse files
Add verication of softmax
parent
dba65b1c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
110 additions
and
12 deletions
+110
-12
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
+110
-12
No files found.
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
View file @
fe659502
...
@@ -5,11 +5,14 @@
...
@@ -5,11 +5,14 @@
#include <stdlib.h>
#include <stdlib.h>
#include <half.hpp>
#include <half.hpp>
#include <math.h>
#include <math.h>
#include "check_err.hpp"
#include "config.hpp"
#include "config.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "host_reduce_util.hpp"
#include "host_reduce_util.hpp"
#include "host_reduction.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
...
@@ -89,9 +92,7 @@ constexpr int Rank = 2;
...
@@ -89,9 +92,7 @@ constexpr int Rank = 2;
constexpr
int
NumReduceDim
=
1
;
constexpr
int
NumReduceDim
=
1
;
constexpr
ck
::
ReduceTensorOp
ReduceMaxId
=
ck
::
ReduceTensorOp
::
MAX
;
constexpr
ck
::
ReduceTensorOp
ReduceMaxId
=
ck
::
ReduceTensorOp
::
MAX
;
constexpr
ck
::
ReduceTensorOp
ReduceSumId
=
ck
::
ReduceTensorOp
::
ADD
;
constexpr
ck
::
ReduceTensorOp
ReduceSumId
=
ck
::
ReduceTensorOp
::
ADD
;
constexpr
ck
::
NanPropagation
NanOpt
=
ck
::
NanPropagation
::
PROPAGATE_NAN
;
constexpr
bool
ReducePropagateNan
=
false
;
constexpr
bool
PropagateNan
=
(
NanOpt
==
ck
::
NanPropagation
::
NOT_PROPAGATE_NAN
)
?
false
:
true
;
// constexpr ck::ReduceTensorIndices_t IndicesOpt = ck::ReduceTensorIndices_t::NO_INDICES;
using
ReduceMaxOp
=
typename
ck
::
reduce_binary_operator
<
CDataType
,
ReduceMaxId
>::
opType
;
using
ReduceMaxOp
=
typename
ck
::
reduce_binary_operator
<
CDataType
,
ReduceMaxId
>::
opType
;
using
ReduceSumOp
=
typename
ck
::
reduce_binary_operator
<
CDataType
,
ReduceSumId
>::
opType
;
using
ReduceSumOp
=
typename
ck
::
reduce_binary_operator
<
CDataType
,
ReduceSumId
>::
opType
;
using
ReduceMaxInElementwiseOperation
=
using
ReduceMaxInElementwiseOperation
=
...
@@ -112,7 +113,7 @@ using DeviceReduceMaxInstance =
...
@@ -112,7 +113,7 @@ using DeviceReduceMaxInstance =
ReduceMaxOp
,
ReduceMaxOp
,
ReduceMaxInElementwiseOperation
,
ReduceMaxInElementwiseOperation
,
ReduceMaxAccElementwiseOperation
,
ReduceMaxAccElementwiseOperation
,
PropagateNan
,
Reduce
PropagateNan
,
false
,
false
,
256
,
256
,
4
,
4
,
...
@@ -132,7 +133,7 @@ using DeviceReduceSumInstance =
...
@@ -132,7 +133,7 @@ using DeviceReduceSumInstance =
ReduceSumOp
,
ReduceSumOp
,
ReduceSumInElementwiseOperation
,
ReduceSumInElementwiseOperation
,
ReduceSumAccElementwiseOperation
,
ReduceSumAccElementwiseOperation
,
PropagateNan
,
Reduce
PropagateNan
,
false
,
false
,
256
,
256
,
4
,
4
,
...
@@ -170,9 +171,47 @@ using DeviceElementwiseSubExpInstance = ck::tensor_operation::device::
...
@@ -170,9 +171,47 @@ using DeviceElementwiseSubExpInstance = ck::tensor_operation::device::
using
DeviceElementwiseDivInstance
=
ck
::
tensor_operation
::
device
::
using
DeviceElementwiseDivInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
Div
,
256
,
32
,
8
>
;
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
Div
,
256
,
32
,
8
>
;
using
Reference
GemmInstance
=
ck
::
tensor_operation
::
host
::
using
Host
GemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
using
HostReduceMaxInstance
=
ReductionHost
<
CDataType
,
CDataType
,
CDataType
,
ReduceMaxId
,
Rank
,
NumReduceDim
,
ReducePropagateNan
,
false
>
;
using
HostReduceSumInstance
=
ReductionHost
<
CDataType
,
CDataType
,
CDataType
,
ReduceSumId
,
Rank
,
NumReduceDim
,
ReducePropagateNan
,
false
>
;
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
HostTensorC
,
typename
Functor
,
int
broadcastDim
>
void
host_broadcast2D
(
HostTensorC
&
C
,
const
HostTensorA
&
A
,
const
HostTensorB
&
B
,
int
M
,
int
N
,
Functor
functor
)
{
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
if
constexpr
(
broadcastDim
==
1
)
functor
(
C
(
m
,
n
),
A
(
m
,
n
),
B
(
n
));
else
functor
(
C
(
m
,
n
),
A
(
m
,
n
),
B
(
m
));
}
}
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
0
;
bool
do_verification
=
0
;
...
@@ -189,7 +228,6 @@ int main(int argc, char* argv[])
...
@@ -189,7 +228,6 @@ int main(int argc, char* argv[])
ck
::
index_t
StrideC
=
4096
;
ck
::
index_t
StrideC
=
4096
;
const
std
::
vector
<
int
>
reduceDims
{
0
};
const
std
::
vector
<
int
>
reduceDims
{
0
};
const
std
::
vector
<
int
>
reduceInvariantDims
{
1
};
if
(
argc
==
4
)
if
(
argc
==
4
)
{
{
...
@@ -237,7 +275,7 @@ int main(int argc, char* argv[])
...
@@ -237,7 +275,7 @@ int main(int argc, char* argv[])
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
int
>
c_n_max
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
N
)}),
Tensor
<
CDataType
>
c_n_max
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
N
)}),
std
::
vector
<
std
::
size_t
>
({
1
}));
std
::
vector
<
std
::
size_t
>
({
1
}));
Tensor
<
CDataType
>
exp_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
exp_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
exp_n_sum
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
N
)}),
Tensor
<
CDataType
>
exp_n_sum
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
N
)}),
...
@@ -370,8 +408,8 @@ int main(int argc, char* argv[])
...
@@ -370,8 +408,8 @@ int main(int argc, char* argv[])
reduce_n_shape
,
reduce_n_shape
,
reduce_n_stride
,
reduce_n_stride
,
reduceDims
,
reduceDims
,
1
,
1
,
// alpha
0
,
0
,
// beta
exp_m_n_device_buf
.
GetDeviceBuffer
(),
exp_m_n_device_buf
.
GetDeviceBuffer
(),
exp_n_sum_device_buf
.
GetDeviceBuffer
(),
exp_n_sum_device_buf
.
GetDeviceBuffer
(),
indices_device_buf
.
GetDeviceBuffer
(),
indices_device_buf
.
GetDeviceBuffer
(),
...
@@ -410,6 +448,66 @@ int main(int argc, char* argv[])
...
@@ -410,6 +448,66 @@ int main(int argc, char* argv[])
broadcastDiv_invoker_ptr
->
Run
(
broadcastDiv_argument_ptr
.
get
(),
nrepeat
);
broadcastDiv_invoker_ptr
->
Run
(
broadcastDiv_argument_ptr
.
get
(),
nrepeat
);
// TODO = do_verification
// TODO = do_verification
(
void
)
do_verification
;
if
(
do_verification
)
{
std
::
cout
<<
"verification..."
<<
std
::
endl
;
const
std
::
vector
<
int
>
reduceInvariantDims
{
1
};
Tensor
<
CDataType
>
host_c_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
host_c_n_max
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
N
)}),
std
::
vector
<
std
::
size_t
>
({
1
}));
Tensor
<
int
>
host_indices
(
host_c_n_max
.
mDesc
.
GetLengths
());
Tensor
<
CDataType
>
host_exp_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
host_exp_n_sum
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
N
)}),
std
::
vector
<
std
::
size_t
>
({
1
}));
Tensor
<
CDataType
>
host_softmax_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
auto
host_gemm
=
HostGemmInstance
{};
auto
host_gemm_invoker
=
host_gemm
.
MakeInvoker
();
auto
host_gemm_argument
=
host_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
host_c_m_n
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
auto
host_reduce_max
=
HostReduceMaxInstance
{
host_c_m_n
.
mDesc
,
host_c_n_max
.
mDesc
,
reduceInvariantDims
,
reduceDims
};
auto
host_reduce_sum
=
HostReduceSumInstance
{
host_exp_m_n
.
mDesc
,
host_exp_n_sum
.
mDesc
,
reduceInvariantDims
,
reduceDims
};
host_gemm_invoker
.
Run
(
host_gemm_argument
);
host_reduce_max
.
Run
(
1
,
// alpha
reinterpret_cast
<
const
CDataType
*>
(
host_c_m_n
.
mData
.
data
()),
0
,
// beta
reinterpret_cast
<
CDataType
*>
(
host_c_n_max
.
mData
.
data
()),
host_indices
.
mData
.
data
());
host_broadcast2D
<
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Sub_Exp
,
1
>
(
host_exp_m_n
,
host_c_m_n
,
host_c_n_max
,
M
,
N
,
Sub_Exp
{});
host_reduce_sum
.
Run
(
1
,
// alpha
reinterpret_cast
<
const
CDataType
*>
(
host_exp_m_n
.
mData
.
data
()),
0
,
// beta
reinterpret_cast
<
CDataType
*>
(
host_exp_n_sum
.
mData
.
data
()),
host_indices
.
mData
.
data
());
host_broadcast2D
<
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Tensor
<
CDataType
>
,
Div
,
1
>
(
host_softmax_m_n
,
host_exp_m_n
,
host_exp_n_sum
,
M
,
N
,
Div
{});
c_m_n_device_buf
.
FromDevice
(
c_m_n
.
mData
.
data
());
c_n_max_device_buf
.
FromDevice
(
c_n_max
.
mData
.
data
());
exp_m_n_device_buf
.
FromDevice
(
exp_m_n
.
mData
.
data
());
exp_n_sum_device_buf
.
FromDevice
(
exp_n_sum
.
mData
.
data
());
softmax_m_n_device_buf
.
FromDevice
(
softmax_m_n
.
mData
.
data
());
bool
result
=
true
;
if
(
result
&=
ck
::
utils
::
check_err
(
c_m_n
.
mData
,
host_c_m_n
.
mData
))
std
::
cout
<<
"[PASS] - c_m_n"
<<
std
::
endl
;
if
(
result
&=
ck
::
utils
::
check_err
(
c_n_max
.
mData
,
host_c_n_max
.
mData
))
std
::
cout
<<
"[PASS] - c_n_max"
<<
std
::
endl
;
if
(
result
&=
ck
::
utils
::
check_err
(
exp_m_n
.
mData
,
host_exp_m_n
.
mData
))
std
::
cout
<<
"[PASS] - exp_m_n"
<<
std
::
endl
;
if
(
result
&=
ck
::
utils
::
check_err
(
exp_n_sum
.
mData
,
host_exp_n_sum
.
mData
))
std
::
cout
<<
"[PASS] - exp_n_sum"
<<
std
::
endl
;
if
(
result
&=
ck
::
utils
::
check_err
(
softmax_m_n
.
mData
,
host_softmax_m_n
.
mData
))
std
::
cout
<<
"[PASS] - softmax_m_n"
<<
std
::
endl
;
}
return
0
;
return
0
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment