Unverified Commit 52abc2f3 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Use double for all scaling values and float-point constant values at the Device Op API (#557)

* Use double as alpha/beta values type in reduce device op api

* Use double as alpha/beta values type in softmax device op api

* Use double as alpha/beta values type in multiple-reduce device op api

* Use double as epsilon value type in normalization/elementwise-normalization device op api
parent 1cfa8760
......@@ -24,11 +24,14 @@ struct ReferenceSoftmax : public device::BaseOperator
{
Argument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
AccDataType alpha,
AccDataType beta,
double alpha,
double beta,
const std::vector<index_t> sm_reduce_dims)
: in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims)
: in_(in), out_(out), sm_reduce_dims_(sm_reduce_dims)
{
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<AccDataType>(beta);
// std::cout << "debug: scalar dims: ";
for(size_t i = 0; i < in.mDesc.GetNumOfDimension(); i++)
{
......@@ -143,8 +146,8 @@ struct ReferenceSoftmax : public device::BaseOperator
static auto MakeArgument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
AccDataType alpha,
AccDataType beta,
double alpha,
double beta,
const std::vector<index_t> sm_reduce_dims)
{
return Argument{in, out, alpha, beta, sm_reduce_dims};
......
......@@ -332,8 +332,8 @@ bool profile_reduce_impl_impl(bool do_verification,
arrOutLengths,
arrOutStrides,
reduceDims,
alpha,
beta,
static_cast<double>(alpha),
static_cast<double>(beta),
in.mData.data(),
nullptr,
out_ref.mData.data(),
......@@ -361,8 +361,8 @@ bool profile_reduce_impl_impl(bool do_verification,
arrOutLengths,
arrOutStrides,
reduceDims,
alpha,
beta,
static_cast<double>(alpha),
static_cast<double>(beta),
in_dev.GetDeviceBuffer(),
nullptr,
out_dev.GetDeviceBuffer(),
......
......@@ -48,8 +48,8 @@ bool profile_softmax_impl(int do_verification,
std::vector<index_t> in_length,
std::vector<index_t> in_strides,
std::vector<index_t> reduce_dims,
AccDataType alpha,
AccDataType beta)
double alpha,
double beta)
{
if(Rank != in_length.size())
{
......@@ -122,8 +122,8 @@ bool profile_softmax_impl(int do_verification,
auto argument_ptr = inst_ptr->MakeArgumentPointer(in_tensor_lengths,
in_tensor_strides,
reduce_dims,
&alpha,
&beta,
alpha,
beta,
in_dev.GetDeviceBuffer(),
out_dev.GetDeviceBuffer(),
PassThrough{},
......
......@@ -99,8 +99,8 @@ int profile_softmax(int argc, char* argv[])
length,
stride,
reduce,
float(alpha),
float(beta));
double(alpha),
double(beta));
}
else if(data_type == SoftmaxDataType::F32_F32)
{
......@@ -111,8 +111,8 @@ int profile_softmax(int argc, char* argv[])
length,
stride,
reduce,
float(alpha),
float(beta));
double(alpha),
double(beta));
}
else
{
......@@ -131,8 +131,8 @@ int profile_softmax(int argc, char* argv[])
length,
stride,
reduce,
float(alpha),
float(beta));
double(alpha),
double(beta));
}
else if(data_type == SoftmaxDataType::F32_F32)
{
......@@ -143,8 +143,8 @@ int profile_softmax(int argc, char* argv[])
length,
stride,
reduce,
float(alpha),
float(beta));
double(alpha),
double(beta));
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment