Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b89a88b5
Commit
b89a88b5
authored
Sep 19, 2022
by
Adam Osewski
Browse files
Merge branch 'develop' into wavelet_model
parents
41d5fca7
43c898f6
Changes
261
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
15 deletions
+22
-15
test/softmax/test_softmax_util.hpp
test/softmax/test_softmax_util.hpp
+22
-15
No files found.
test/softmax/test_softmax_util.hpp
View file @
b89a88b5
...
...
@@ -9,7 +9,8 @@
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/host_tensor.hpp"
...
...
@@ -51,19 +52,23 @@ class TestSoftmax : public ::testing::Test
using
ReferenceInstance
=
tensor_operation
::
host
::
ReferenceSoftmax
<
InDataType
,
OutDataType
,
AccDataType
>
;
using
DeviceInstance
=
tensor_operation
::
device
::
DeviceSoftmax
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceInstance
=
tensor_operation
::
device
::
DeviceSoftmaxImpl
<
InDataType
,
AccDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
Rank
,
NumReduceDim
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
>
;
TestSoftmax
()
:
ref_instance_invoker_
(
ReferenceInstance
{}.
MakeInvoker
())
{}
...
...
@@ -97,7 +102,9 @@ class TestSoftmax : public ::testing::Test
&
alpha
,
&
beta
,
in_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
());
out_dev
.
GetDeviceBuffer
(),
PassThrough
{},
PassThrough
{});
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
...
...
Prev
1
…
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment