Commit f2540aa5 authored by rocking's avatar rocking
Browse files

Add exponential

parent c8b4ac22
......@@ -4,6 +4,7 @@
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include <math.h>
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
......@@ -116,16 +117,19 @@ using DeviceReduceInstance =
1,
1>;
struct Sub
struct Sub_Exp
{
__host__ __device__ constexpr void operator()(CDataType& dst, const CDataType& src1, const CDataType& src2) const
{
dst = src1 - src2;
// FIXME - use float16 exponential
float dst_f32 = static_cast<float>(dst);
dst = static_cast<CDataType>(exp(dst_f32));
}
};
using DeviceElementwiseInstance = ck::tensor_operation::device::
DeviceElementwise_2D<CDataType, CDataType, CDataType, Sub, 16, 16, 8, 8, 1, 1, 1, 1, 1>;
DeviceElementwise_2D<CDataType, CDataType, CDataType, Sub_Exp, 16, 16, 8, 8, 1, 1, 1, 1, 1>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
......@@ -289,25 +293,25 @@ int main(int argc, char* argv[])
reduce_max_invoker_ptr->Run(reduce_max_argument_ptr.get(), nrepeat);
// do broadcast sub
auto broadcastSub = DeviceElementwiseInstance{};
auto broadcastSub_argument_ptr =
broadcastSub.MakeArgumentPointer(c_m_n_device_buf.GetDeviceBuffer(),
auto broadcastSubExp = DeviceElementwiseInstance{};
auto broadcastSubExp_argument_ptr =
broadcastSubExp.MakeArgumentPointer(c_m_n_device_buf.GetDeviceBuffer(),
c_m_n_max_device_buf.GetDeviceBuffer(),
d_m_n_device_buf.GetDeviceBuffer(),
{M, N},
{StrideC, 1},
{0, 1},
{StrideC, 1},
Sub{});
Sub_Exp{});
if(!broadcastSub.IsSupportedArgument(broadcastSub_argument_ptr.get()))
if(!broadcastSubExp.IsSupportedArgument(broadcastSubExp_argument_ptr.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceElementwise_2D instance, exiting!");
};
auto broadcastSub_invoker_ptr = broadcastSub.MakeInvokerPointer();
broadcastSub_invoker_ptr->Run(broadcastSub_argument_ptr.get(), nrepeat);
auto broadcastSubExp_invoker_ptr = broadcastSubExp.MakeInvokerPointer();
broadcastSubExp_invoker_ptr->Run(broadcastSubExp_argument_ptr.get(), nrepeat);
// TODO - Need BroadcastSub + exponential + ReduceSum + BroadcastDiv
// TODO = do_verification
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment