"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "907fd91ce9532751c43c4cfd853e1d5d95bb929b"
Commit f2540aa5 authored by rocking's avatar rocking
Browse files

Add exponential

parent c8b4ac22
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> #include <half.hpp>
#include <math.h>
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -116,16 +117,19 @@ using DeviceReduceInstance = ...@@ -116,16 +117,19 @@ using DeviceReduceInstance =
1, 1,
1>; 1>;
struct Sub struct Sub_Exp
{ {
__host__ __device__ constexpr void operator()(CDataType& dst, const CDataType& src1, const CDataType& src2) const __host__ __device__ constexpr void operator()(CDataType& dst, const CDataType& src1, const CDataType& src2) const
{ {
dst = src1 - src2; dst = src1 - src2;
// FIXME - use float16 exponential
float dst_f32 = static_cast<float>(dst);
dst = static_cast<CDataType>(exp(dst_f32));
} }
}; };
using DeviceElementwiseInstance = ck::tensor_operation::device:: using DeviceElementwiseInstance = ck::tensor_operation::device::
DeviceElementwise_2D<CDataType, CDataType, CDataType, Sub, 16, 16, 8, 8, 1, 1, 1, 1, 1>; DeviceElementwise_2D<CDataType, CDataType, CDataType, Sub_Exp, 16, 16, 8, 8, 1, 1, 1, 1, 1>;
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>; ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
...@@ -289,25 +293,25 @@ int main(int argc, char* argv[]) ...@@ -289,25 +293,25 @@ int main(int argc, char* argv[])
reduce_max_invoker_ptr->Run(reduce_max_argument_ptr.get(), nrepeat); reduce_max_invoker_ptr->Run(reduce_max_argument_ptr.get(), nrepeat);
// do broadcast sub // do broadcast sub
auto broadcastSub = DeviceElementwiseInstance{}; auto broadcastSubExp = DeviceElementwiseInstance{};
auto broadcastSub_argument_ptr = auto broadcastSubExp_argument_ptr =
broadcastSub.MakeArgumentPointer(c_m_n_device_buf.GetDeviceBuffer(), broadcastSubExp.MakeArgumentPointer(c_m_n_device_buf.GetDeviceBuffer(),
c_m_n_max_device_buf.GetDeviceBuffer(), c_m_n_max_device_buf.GetDeviceBuffer(),
d_m_n_device_buf.GetDeviceBuffer(), d_m_n_device_buf.GetDeviceBuffer(),
{M, N}, {M, N},
{StrideC, 1}, {StrideC, 1},
{0, 1}, {0, 1},
{StrideC, 1}, {StrideC, 1},
Sub{}); Sub_Exp{});
if(!broadcastSub.IsSupportedArgument(broadcastSub_argument_ptr.get())) if(!broadcastSubExp.IsSupportedArgument(broadcastSubExp_argument_ptr.get()))
{ {
throw std::runtime_error("The runtime parameters seems not supported by the " throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceElementwise_2D instance, exiting!"); "DeviceElementwise_2D instance, exiting!");
}; };
auto broadcastSub_invoker_ptr = broadcastSub.MakeInvokerPointer(); auto broadcastSubExp_invoker_ptr = broadcastSubExp.MakeInvokerPointer();
broadcastSub_invoker_ptr->Run(broadcastSub_argument_ptr.get(), nrepeat); broadcastSubExp_invoker_ptr->Run(broadcastSubExp_argument_ptr.get(), nrepeat);
// TODO - Need BroadcastSub + exponential + ReduceSum + BroadcastDiv // TODO - Need BroadcastSub + exponential + ReduceSum + BroadcastDiv
// TODO = do_verification // TODO = do_verification
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment