Unverified Commit 52abc2f3 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Use double for all scaling values and float-point constant values at the Device Op API (#557)

* Use double as alpha/beta values type in reduce device op api

* Use double as alpha/beta values type in softmax device op api

* Use double as alpha/beta values type in multiple-reduce device op api

* Use double as epsilon value type in normalization/elementwise-normalization device op api
parent 1cfa8760
...@@ -24,11 +24,14 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -24,11 +24,14 @@ struct ReferenceSoftmax : public device::BaseOperator
{ {
Argument(const Tensor<InDataType>& in, Argument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out, Tensor<OutDataType>& out,
AccDataType alpha, double alpha,
AccDataType beta, double beta,
const std::vector<index_t> sm_reduce_dims) const std::vector<index_t> sm_reduce_dims)
: in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims) : in_(in), out_(out), sm_reduce_dims_(sm_reduce_dims)
{ {
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<AccDataType>(beta);
// std::cout << "debug: scalar dims: "; // std::cout << "debug: scalar dims: ";
for(size_t i = 0; i < in.mDesc.GetNumOfDimension(); i++) for(size_t i = 0; i < in.mDesc.GetNumOfDimension(); i++)
{ {
...@@ -143,8 +146,8 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -143,8 +146,8 @@ struct ReferenceSoftmax : public device::BaseOperator
static auto MakeArgument(const Tensor<InDataType>& in, static auto MakeArgument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out, Tensor<OutDataType>& out,
AccDataType alpha, double alpha,
AccDataType beta, double beta,
const std::vector<index_t> sm_reduce_dims) const std::vector<index_t> sm_reduce_dims)
{ {
return Argument{in, out, alpha, beta, sm_reduce_dims}; return Argument{in, out, alpha, beta, sm_reduce_dims};
......
...@@ -332,8 +332,8 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -332,8 +332,8 @@ bool profile_reduce_impl_impl(bool do_verification,
arrOutLengths, arrOutLengths,
arrOutStrides, arrOutStrides,
reduceDims, reduceDims,
alpha, static_cast<double>(alpha),
beta, static_cast<double>(beta),
in.mData.data(), in.mData.data(),
nullptr, nullptr,
out_ref.mData.data(), out_ref.mData.data(),
...@@ -361,8 +361,8 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -361,8 +361,8 @@ bool profile_reduce_impl_impl(bool do_verification,
arrOutLengths, arrOutLengths,
arrOutStrides, arrOutStrides,
reduceDims, reduceDims,
alpha, static_cast<double>(alpha),
beta, static_cast<double>(beta),
in_dev.GetDeviceBuffer(), in_dev.GetDeviceBuffer(),
nullptr, nullptr,
out_dev.GetDeviceBuffer(), out_dev.GetDeviceBuffer(),
......
...@@ -48,8 +48,8 @@ bool profile_softmax_impl(int do_verification, ...@@ -48,8 +48,8 @@ bool profile_softmax_impl(int do_verification,
std::vector<index_t> in_length, std::vector<index_t> in_length,
std::vector<index_t> in_strides, std::vector<index_t> in_strides,
std::vector<index_t> reduce_dims, std::vector<index_t> reduce_dims,
AccDataType alpha, double alpha,
AccDataType beta) double beta)
{ {
if(Rank != in_length.size()) if(Rank != in_length.size())
{ {
...@@ -122,8 +122,8 @@ bool profile_softmax_impl(int do_verification, ...@@ -122,8 +122,8 @@ bool profile_softmax_impl(int do_verification,
auto argument_ptr = inst_ptr->MakeArgumentPointer(in_tensor_lengths, auto argument_ptr = inst_ptr->MakeArgumentPointer(in_tensor_lengths,
in_tensor_strides, in_tensor_strides,
reduce_dims, reduce_dims,
&alpha, alpha,
&beta, beta,
in_dev.GetDeviceBuffer(), in_dev.GetDeviceBuffer(),
out_dev.GetDeviceBuffer(), out_dev.GetDeviceBuffer(),
PassThrough{}, PassThrough{},
......
...@@ -99,8 +99,8 @@ int profile_softmax(int argc, char* argv[]) ...@@ -99,8 +99,8 @@ int profile_softmax(int argc, char* argv[])
length, length,
stride, stride,
reduce, reduce,
float(alpha), double(alpha),
float(beta)); double(beta));
} }
else if(data_type == SoftmaxDataType::F32_F32) else if(data_type == SoftmaxDataType::F32_F32)
{ {
...@@ -111,8 +111,8 @@ int profile_softmax(int argc, char* argv[]) ...@@ -111,8 +111,8 @@ int profile_softmax(int argc, char* argv[])
length, length,
stride, stride,
reduce, reduce,
float(alpha), double(alpha),
float(beta)); double(beta));
} }
else else
{ {
...@@ -131,8 +131,8 @@ int profile_softmax(int argc, char* argv[]) ...@@ -131,8 +131,8 @@ int profile_softmax(int argc, char* argv[])
length, length,
stride, stride,
reduce, reduce,
float(alpha), double(alpha),
float(beta)); double(beta));
} }
else if(data_type == SoftmaxDataType::F32_F32) else if(data_type == SoftmaxDataType::F32_F32)
{ {
...@@ -143,8 +143,8 @@ int profile_softmax(int argc, char* argv[]) ...@@ -143,8 +143,8 @@ int profile_softmax(int argc, char* argv[])
length, length,
stride, stride,
reduce, reduce,
float(alpha), double(alpha),
float(beta)); double(beta));
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment