Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f2540aa5
Commit
f2540aa5
authored
Apr 13, 2022
by
rocking
Browse files
Add exponential
parent
c8b4ac22
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
9 deletions
+13
-9
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
+13
-9
No files found.
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
View file @
f2540aa5
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <cstdlib>
#include <cstdlib>
#include <stdlib.h>
#include <stdlib.h>
#include <half.hpp>
#include <half.hpp>
#include <math.h>
#include "config.hpp"
#include "config.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
...
@@ -116,16 +117,19 @@ using DeviceReduceInstance =
...
@@ -116,16 +117,19 @@ using DeviceReduceInstance =
1
,
1
,
1
>
;
1
>
;
struct
Sub
struct
Sub
_Exp
{
{
__host__
__device__
constexpr
void
operator
()(
CDataType
&
dst
,
const
CDataType
&
src1
,
const
CDataType
&
src2
)
const
__host__
__device__
constexpr
void
operator
()(
CDataType
&
dst
,
const
CDataType
&
src1
,
const
CDataType
&
src2
)
const
{
{
dst
=
src1
-
src2
;
dst
=
src1
-
src2
;
// FIXME - use float16 exponential
float
dst_f32
=
static_cast
<
float
>
(
dst
);
dst
=
static_cast
<
CDataType
>
(
exp
(
dst_f32
));
}
}
};
};
using
DeviceElementwiseInstance
=
ck
::
tensor_operation
::
device
::
using
DeviceElementwiseInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
Sub
,
16
,
16
,
8
,
8
,
1
,
1
,
1
,
1
,
1
>
;
DeviceElementwise_2D
<
CDataType
,
CDataType
,
CDataType
,
Sub
_Exp
,
16
,
16
,
8
,
8
,
1
,
1
,
1
,
1
,
1
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
...
@@ -289,25 +293,25 @@ int main(int argc, char* argv[])
...
@@ -289,25 +293,25 @@ int main(int argc, char* argv[])
reduce_max_invoker_ptr
->
Run
(
reduce_max_argument_ptr
.
get
(),
nrepeat
);
reduce_max_invoker_ptr
->
Run
(
reduce_max_argument_ptr
.
get
(),
nrepeat
);
// do broadcast sub
// do broadcast sub
auto
broadcastSub
=
DeviceElementwiseInstance
{};
auto
broadcastSub
Exp
=
DeviceElementwiseInstance
{};
auto
broadcastSub_argument_ptr
=
auto
broadcastSub
Exp
_argument_ptr
=
broadcastSub
.
MakeArgumentPointer
(
c_m_n_device_buf
.
GetDeviceBuffer
(),
broadcastSub
Exp
.
MakeArgumentPointer
(
c_m_n_device_buf
.
GetDeviceBuffer
(),
c_m_n_max_device_buf
.
GetDeviceBuffer
(),
c_m_n_max_device_buf
.
GetDeviceBuffer
(),
d_m_n_device_buf
.
GetDeviceBuffer
(),
d_m_n_device_buf
.
GetDeviceBuffer
(),
{
M
,
N
},
{
M
,
N
},
{
StrideC
,
1
},
{
StrideC
,
1
},
{
0
,
1
},
{
0
,
1
},
{
StrideC
,
1
},
{
StrideC
,
1
},
Sub
{});
Sub
_Exp
{});
if
(
!
broadcastSub
.
IsSupportedArgument
(
broadcastSub_argument_ptr
.
get
()))
if
(
!
broadcastSub
Exp
.
IsSupportedArgument
(
broadcastSub
Exp
_argument_ptr
.
get
()))
{
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"DeviceElementwise_2D instance, exiting!"
);
"DeviceElementwise_2D instance, exiting!"
);
};
};
auto
broadcastSub_invoker_ptr
=
broadcastSub
.
MakeInvokerPointer
();
auto
broadcastSub
Exp
_invoker_ptr
=
broadcastSub
Exp
.
MakeInvokerPointer
();
broadcastSub_invoker_ptr
->
Run
(
broadcastSub_argument_ptr
.
get
(),
nrepeat
);
broadcastSub
Exp
_invoker_ptr
->
Run
(
broadcastSub
Exp
_argument_ptr
.
get
(),
nrepeat
);
// TODO - Need BroadcastSub + exponential + ReduceSum + BroadcastDiv
// TODO - Need BroadcastSub + exponential + ReduceSum + BroadcastDiv
// TODO = do_verification
// TODO = do_verification
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment