Unverified Commit 52abc2f3 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Use double for all scaling values and float-point constant values at the Device Op API (#557)

* Use double as alpha/beta values type in reduce device op api

* Use double as alpha/beta values type in softmax device op api

* Use double as alpha/beta values type in multiple-reduce device op api

* Use double as epsilon value type in normalization/elementwise-normalization device op api
parent 1cfa8760
...@@ -47,8 +47,8 @@ int main(int argc, char* argv[]) ...@@ -47,8 +47,8 @@ int main(int argc, char* argv[])
ck::index_t num_elements = ck::index_t num_elements =
std::accumulate(in_lengths.begin(), in_lengths.end(), 1, std::multiplies<ck::index_t>()); std::accumulate(in_lengths.begin(), in_lengths.end(), 1, std::multiplies<ck::index_t>());
AccDataType alpha{2.0f}; double alpha{2.0};
AccDataType beta{2.0f}; double beta{2.0};
SimpleDeviceMem in(sizeof(InDataType) * num_elements); SimpleDeviceMem in(sizeof(InDataType) * num_elements);
SimpleDeviceMem out(sizeof(OutDataType) * num_elements); SimpleDeviceMem out(sizeof(OutDataType) * num_elements);
...@@ -82,8 +82,8 @@ int main(int argc, char* argv[]) ...@@ -82,8 +82,8 @@ int main(int argc, char* argv[])
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths, auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
in_strides, in_strides,
reduce_dims, reduce_dims,
&alpha, alpha,
&beta, beta,
in.GetDeviceBuffer(), in.GetDeviceBuffer(),
out.GetDeviceBuffer(), out.GetDeviceBuffer(),
PassThrough{}, PassThrough{},
...@@ -129,8 +129,8 @@ int main(int argc, char* argv[]) ...@@ -129,8 +129,8 @@ int main(int argc, char* argv[])
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths, auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
in_strides, in_strides,
reduce_dims, reduce_dims,
&alpha, alpha,
&beta, beta,
in.GetDeviceBuffer(), in.GetDeviceBuffer(),
out.GetDeviceBuffer(), out.GetDeviceBuffer(),
PassThrough{}, PassThrough{},
...@@ -147,4 +147,4 @@ int main(int argc, char* argv[]) ...@@ -147,4 +147,4 @@ int main(int argc, char* argv[])
} }
return 0; return 0;
} }
\ No newline at end of file
...@@ -61,8 +61,8 @@ int main(int argc, char* argv[]) ...@@ -61,8 +61,8 @@ int main(int argc, char* argv[])
for(auto dim : reduce_dims) for(auto dim : reduce_dims)
reduce_length *= in_lengths[dim]; reduce_length *= in_lengths[dim];
float alpha{1.0f}; double alpha{1.0};
float beta{0.0f}; double beta{0.0};
SimpleDeviceMem in(sizeof(InDataType) * num_in_elements); SimpleDeviceMem in(sizeof(InDataType) * num_in_elements);
SimpleDeviceMem out(sizeof(OutDataType) * num_out_elements); SimpleDeviceMem out(sizeof(OutDataType) * num_out_elements);
......
...@@ -267,8 +267,8 @@ int reduce_blockwise_impl(bool do_verification, ...@@ -267,8 +267,8 @@ int reduce_blockwise_impl(bool do_verification,
arrOutLengths, arrOutLengths,
arrOutStrides, arrOutStrides,
reduceDims, reduceDims,
alpha, static_cast<double>(alpha),
beta, static_cast<double>(beta),
in.mData.data(), in.mData.data(),
nullptr, nullptr,
out_ref.mData.data(), out_ref.mData.data(),
...@@ -295,8 +295,8 @@ int reduce_blockwise_impl(bool do_verification, ...@@ -295,8 +295,8 @@ int reduce_blockwise_impl(bool do_verification,
arrOutLengths, arrOutLengths,
arrOutStrides, arrOutStrides,
reduceDims, reduceDims,
alpha, static_cast<double>(alpha),
beta, static_cast<double>(beta),
in_dev.GetDeviceBuffer(), in_dev.GetDeviceBuffer(),
nullptr, nullptr,
out_dev.GetDeviceBuffer(), out_dev.GetDeviceBuffer(),
......
...@@ -226,8 +226,8 @@ int main(int argc, char* argv[]) ...@@ -226,8 +226,8 @@ int main(int argc, char* argv[])
arrOutLengths, arrOutLengths,
arrOutStrides, arrOutStrides,
reduceDims, reduceDims,
alpha, static_cast<double>(alpha),
beta, static_cast<double>(beta),
in_1.mData.data(), in_1.mData.data(),
nullptr, nullptr,
out_ref.mData.data(), out_ref.mData.data(),
...@@ -254,8 +254,8 @@ int main(int argc, char* argv[]) ...@@ -254,8 +254,8 @@ int main(int argc, char* argv[])
arrInLengths_2, arrInLengths_2,
arrInStrides_2, arrInStrides_2,
reduceDims_1, reduceDims_1,
1.0f, 1.0,
0.0f, 0.0,
in_1_dev.GetDeviceBuffer(), in_1_dev.GetDeviceBuffer(),
nullptr, nullptr,
in_2_dev.GetDeviceBuffer(), in_2_dev.GetDeviceBuffer(),
...@@ -278,8 +278,8 @@ int main(int argc, char* argv[]) ...@@ -278,8 +278,8 @@ int main(int argc, char* argv[])
arrOutLengths, arrOutLengths,
arrOutStrides, arrOutStrides,
reduceDims_2, reduceDims_2,
alpha, static_cast<double>(alpha),
beta, static_cast<double>(beta),
in_2_dev.GetDeviceBuffer(), in_2_dev.GetDeviceBuffer(),
nullptr, nullptr,
out_dev.GetDeviceBuffer(), out_dev.GetDeviceBuffer(),
......
...@@ -180,8 +180,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification, ...@@ -180,8 +180,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification,
arrOutLengths, arrOutLengths,
arrOutStrides, arrOutStrides,
reduceDims, reduceDims,
alpha, static_cast<double>(alpha),
beta, static_cast<double>(beta),
in.mData.data(), in.mData.data(),
nullptr, nullptr,
out_ref.mData.data(), out_ref.mData.data(),
...@@ -208,8 +208,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification, ...@@ -208,8 +208,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification,
arrOutLengths, arrOutLengths,
arrOutStrides, arrOutStrides,
reduceDims, reduceDims,
alpha, static_cast<double>(alpha),
beta, static_cast<double>(beta),
in_dev.GetDeviceBuffer(), in_dev.GetDeviceBuffer(),
nullptr, nullptr,
out_dev.GetDeviceBuffer(), out_dev.GetDeviceBuffer(),
......
...@@ -56,8 +56,8 @@ class SimpleAppArgs ...@@ -56,8 +56,8 @@ class SimpleAppArgs
int option_index = 0; int option_index = 0;
public: public:
std::vector<size_t> inLengths = {8, 128, 2048}; std::vector<size_t> inLengths = {8, 128, 2048};
std::vector<AccDataType> scales = {2.0f, 2.0f}; std::vector<double> scales = {2.0, 2.0};
bool do_verification = true; bool do_verification = true;
int init_method = 2; int init_method = 2;
...@@ -151,8 +151,8 @@ int main(int argc, char* argv[]) ...@@ -151,8 +151,8 @@ int main(int argc, char* argv[])
auto inStrides = in.mDesc.GetStrides(); auto inStrides = in.mDesc.GetStrides();
auto outStrides = out.mDesc.GetStrides(); auto outStrides = out.mDesc.GetStrides();
AccDataType alpha = args.scales[0]; double alpha = args.scales[0];
AccDataType beta = args.scales[1]; double beta = args.scales[1];
std::cout << "in: " << in.mDesc << std::endl; std::cout << "in: " << in.mDesc << std::endl;
std::cout << "out: " << out.mDesc << std::endl; std::cout << "out: " << out.mDesc << std::endl;
...@@ -221,8 +221,8 @@ int main(int argc, char* argv[]) ...@@ -221,8 +221,8 @@ int main(int argc, char* argv[])
auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths, auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths,
i_inStrides, i_inStrides,
reduceDims, reduceDims,
&alpha, alpha,
&beta, beta,
in_dev.GetDeviceBuffer(), in_dev.GetDeviceBuffer(),
out_dev.GetDeviceBuffer(), out_dev.GetDeviceBuffer(),
PassThrough{}, PassThrough{},
......
...@@ -217,8 +217,8 @@ int mean_meansquare_dual_reduce_test(size_t n, ...@@ -217,8 +217,8 @@ int mean_meansquare_dual_reduce_test(size_t n,
size_t invariant_total_length = n; size_t invariant_total_length = n;
size_t reduce_total_length = h * w * c; size_t reduce_total_length = h * w * c;
const AccDataType alpha = ck::type_convert<AccDataType>(1.0f); const double alpha = 1.0f;
const AccDataType beta = ck::type_convert<AccDataType>(0.0f); const double beta = 0.0f;
std::size_t num_thread = 1; std::size_t num_thread = 1;
...@@ -267,8 +267,8 @@ int mean_meansquare_dual_reduce_test(size_t n, ...@@ -267,8 +267,8 @@ int mean_meansquare_dual_reduce_test(size_t n,
i_outLengths, i_outLengths,
{i_outStrides, i_outStrides}, {i_outStrides, i_outStrides},
reduceDims, reduceDims,
{&alpha, &alpha}, {alpha, alpha},
{&beta, &beta}, {beta, beta},
in_dev.GetDeviceBuffer(), in_dev.GetDeviceBuffer(),
{mean_dev.GetDeviceBuffer(), meansquare_dev.GetDeviceBuffer()}, {mean_dev.GetDeviceBuffer(), meansquare_dev.GetDeviceBuffer()},
ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}), ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}),
......
...@@ -32,7 +32,7 @@ struct DeviceElementwiseNormalization : public BaseOperator ...@@ -32,7 +32,7 @@ struct DeviceElementwiseNormalization : public BaseOperator
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon, double epsilon,
const std::array<const void*, NumInput> in_dev_buffers, const std::array<const void*, NumInput> in_dev_buffers,
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
......
...@@ -32,8 +32,8 @@ struct DeviceMultipleReduce : public BaseOperator ...@@ -32,8 +32,8 @@ struct DeviceMultipleReduce : public BaseOperator
const std::array<index_t, NumOutputDim> outLengths, const std::array<index_t, NumOutputDim> outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStrides, const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
const std::array<const void*, NumReduction> alphas, const std::array<double, NumReduction> alphas,
const std::array<const void*, NumReduction> betas, const std::array<double, NumReduction> betas,
const void* in_dev, const void* in_dev,
const std::array<void*, NumReduction> out_dev_buffers, const std::array<void*, NumReduction> out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple,
......
...@@ -28,7 +28,7 @@ struct DeviceNormalization : public BaseOperator ...@@ -28,7 +28,7 @@ struct DeviceNormalization : public BaseOperator
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon, double epsilon,
const void* p_x, const void* p_x,
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
......
...@@ -33,8 +33,8 @@ struct DeviceReduce : public BaseOperator ...@@ -33,8 +33,8 @@ struct DeviceReduce : public BaseOperator
const std::array<index_t, NumOutDim> outLengths, const std::array<index_t, NumOutDim> outLengths,
const std::array<index_t, NumOutDim> outStrides, const std::array<index_t, NumOutDim> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, double alpha,
float beta, double beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev, const void* in_index_dev,
void* out_dev, void* out_dev,
......
...@@ -27,10 +27,8 @@ struct DeviceSoftmax : public BaseOperator ...@@ -27,10 +27,8 @@ struct DeviceSoftmax : public BaseOperator
// @param[in] inLengths Input tensor extent(s) from high to low dimension // @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension // @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied // @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling // @param[in] alpha double type value
// value as type AccDataType // @param[in] beta double type value
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// value as type AccDataType
// @param[in] in_dev Typeless const pointer in device memory storing the input // @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor // tensor
// @param out_dev Typeless pointer in device memory storing the output tensor // @param out_dev Typeless pointer in device memory storing the output tensor
...@@ -43,8 +41,8 @@ struct DeviceSoftmax : public BaseOperator ...@@ -43,8 +41,8 @@ struct DeviceSoftmax : public BaseOperator
MakeArgumentPointer(const std::vector<index_t> inLengths, MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
const void* alpha, double alpha,
const void* beta, double beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
InElementwiseOp in_elementwise_op, InElementwiseOp in_elementwise_op,
......
...@@ -270,18 +270,18 @@ struct DeviceElementwiseNormalizationImpl ...@@ -270,18 +270,18 @@ struct DeviceElementwiseNormalizationImpl
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
XElementwiseOperation x_elementwise_op, XElementwiseOperation x_elementwise_op,
YElementwiseOperation y_elementwise_op, YElementwiseOperation y_elementwise_op,
AccDataType epsilon, double epsilon,
const std::array<const void*, NumInput> in_dev_buffers, const std::array<const void*, NumInput> in_dev_buffers,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
YDataType* p_y) YDataType* p_y)
: epsilon_(epsilon), : p_gamma_(p_gamma),
p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_y_(p_y), p_y_(p_y),
x_elementwise_op_(x_elementwise_op), x_elementwise_op_(x_elementwise_op),
y_elementwise_op_(y_elementwise_op) y_elementwise_op_(y_elementwise_op)
{ {
epsilon_ = static_cast<AccDataType>(epsilon);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims); Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
for(int i = 0; i < NumInput; i++) for(int i = 0; i < NumInput; i++)
...@@ -543,7 +543,7 @@ struct DeviceElementwiseNormalizationImpl ...@@ -543,7 +543,7 @@ struct DeviceElementwiseNormalizationImpl
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon, double epsilon,
const std::array<const void*, NumInput> in_dev_buffers, const std::array<const void*, NumInput> in_dev_buffers,
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
......
...@@ -270,8 +270,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank, ...@@ -270,8 +270,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim>& outLengths, const std::array<index_t, NumOutputDim>& outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray, const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
const std::array<int, NumReduceDim>& reduceDims, const std::array<int, NumReduceDim>& reduceDims,
const std::array<const void*, NumReduction>& alphas, const std::array<double, NumReduction>& alphas,
const std::array<const void*, NumReduction>& betas, const std::array<double, NumReduction>& betas,
const void* in_dev, const void* in_dev,
const std::array<void*, NumReduction>& out_dev_buffers, const std::array<void*, NumReduction>& out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple,
...@@ -286,8 +286,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank, ...@@ -286,8 +286,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
for(size_t i = 0; i < NumReduction; i++) for(size_t i = 0; i < NumReduction; i++)
{ {
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]); alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]); beta_values_(i) = static_cast<AccDataType>(betas[i]);
}; };
in_dev_ = static_cast<const InDataType*>(in_dev); in_dev_ = static_cast<const InDataType*>(in_dev);
...@@ -547,8 +547,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank, ...@@ -547,8 +547,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim> outLengths, const std::array<index_t, NumOutputDim> outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray, const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
const std::array<const void*, NumReduction> alphas, const std::array<double, NumReduction> alphas,
const std::array<const void*, NumReduction> betas, const std::array<double, NumReduction> betas,
const void* in_dev, const void* in_dev,
const std::array<void*, NumReduction> out_dev_buffers, const std::array<void*, NumReduction> out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple,
......
...@@ -195,8 +195,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank, ...@@ -195,8 +195,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim>& outLengths, const std::array<index_t, NumOutputDim>& outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray, const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
const std::array<int, NumReduceDim>& reduceDims, const std::array<int, NumReduceDim>& reduceDims,
const std::array<const void*, NumReduction>& alphas, const std::array<double, NumReduction>& alphas,
const std::array<const void*, NumReduction>& betas, const std::array<double, NumReduction>& betas,
const void* in_dev, const void* in_dev,
const std::array<void*, NumReduction>& out_dev_buffers, const std::array<void*, NumReduction>& out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple,
...@@ -211,8 +211,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank, ...@@ -211,8 +211,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
for(size_t i = 0; i < NumReduction; i++) for(size_t i = 0; i < NumReduction; i++)
{ {
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]); alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]); beta_values_(i) = static_cast<AccDataType>(betas[i]);
}; };
in_dev_ = static_cast<const InDataType*>(in_dev); in_dev_ = static_cast<const InDataType*>(in_dev);
...@@ -374,8 +374,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank, ...@@ -374,8 +374,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim> outLengths, const std::array<index_t, NumOutputDim> outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray, const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
const std::array<const void*, NumReduction> alphas, const std::array<double, NumReduction> alphas,
const std::array<const void*, NumReduction> betas, const std::array<double, NumReduction> betas,
const void* in_dev, const void* in_dev,
const std::array<void*, NumReduction> out_dev_buffers, const std::array<void*, NumReduction> out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple,
......
...@@ -221,18 +221,19 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -221,18 +221,19 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccElementwiseOperation acc_elementwise_op, AccElementwiseOperation acc_elementwise_op,
AccDataType epsilon, double epsilon,
const XDataType* p_x, const XDataType* p_x,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
YDataType* p_y) YDataType* p_y)
: epsilon_(epsilon), : p_x_(p_x),
p_x_(p_x),
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_y_(p_y), p_y_(p_y),
acc_elementwise_op_(acc_elementwise_op) acc_elementwise_op_(acc_elementwise_op)
{ {
epsilon_ = static_cast<AccDataType>(epsilon);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims); Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims); xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims); yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
...@@ -421,7 +422,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -421,7 +422,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon, double epsilon,
const void* p_x, const void* p_x,
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
......
...@@ -217,8 +217,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InDataType, ...@@ -217,8 +217,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
const std::array<index_t, NumDstDim> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, double alpha,
float beta, double beta,
const InDataType* in_dev, const InDataType* in_dev,
const IndexDataType* in_index_dev, const IndexDataType* in_index_dev,
OutDataType* out_dev, OutDataType* out_dev,
...@@ -502,8 +502,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InDataType, ...@@ -502,8 +502,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
const std::array<index_t, NumDstDim> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, double alpha,
float beta, double beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev, const void* in_index_dev,
void* out_dev, void* out_dev,
......
...@@ -165,8 +165,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType, ...@@ -165,8 +165,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
const std::array<index_t, NumDstDim> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, double alpha,
float beta, double beta,
const InDataType* in_dev, const InDataType* in_dev,
OutDataType* out_dev, OutDataType* out_dev,
IndexDataType* out_index_dev, IndexDataType* out_index_dev,
...@@ -341,8 +341,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType, ...@@ -341,8 +341,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
const std::array<index_t, NumDstDim> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, double alpha,
float beta, double beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev, const void* in_index_dev,
void* out_dev, void* out_dev,
......
...@@ -156,19 +156,20 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -156,19 +156,20 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
Argument(const std::vector<index_t> inLengths, Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType alpha, double alpha,
AccDataType beta, double beta,
const InDataType* in_dev, const InDataType* in_dev,
OutDataType* out_dev, OutDataType* out_dev,
InElementwiseOp in_elementwise_op, InElementwiseOp in_elementwise_op,
AccElementwiseOp acc_elementwise_op) AccElementwiseOp acc_elementwise_op)
: alpha_{alpha}, : in_dev_{in_dev},
beta_{beta},
in_dev_{in_dev},
out_dev_{out_dev}, out_dev_{out_dev},
in_elementwise_op_{in_elementwise_op}, in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op} acc_elementwise_op_{acc_elementwise_op}
{ {
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<AccDataType>(beta);
if(Rank != inLengths.size() || Rank != inStrides.size() || if(Rank != inLengths.size() || Rank != inStrides.size() ||
NumReduceDim != reduceDims.size()) NumReduceDim != reduceDims.size())
{ {
...@@ -336,8 +337,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -336,8 +337,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
static auto MakeArgument(const std::vector<index_t> inLengths, static auto MakeArgument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
const AccDataType alpha, double alpha,
const AccDataType beta, double beta,
const InDataType* in_dev, const InDataType* in_dev,
OutDataType* out_dev, OutDataType* out_dev,
InElementwiseOp in_elementwise_op, InElementwiseOp in_elementwise_op,
...@@ -375,8 +376,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -375,8 +376,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths, std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
const void* alpha, double alpha,
const void* beta, double beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
InElementwiseOp in_elementwise_op, InElementwiseOp in_elementwise_op,
...@@ -385,8 +386,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -385,8 +386,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
reduceDims, reduceDims,
*static_cast<const AccDataType*>(alpha), alpha,
*static_cast<const AccDataType*>(beta), beta,
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev), static_cast<OutDataType*>(out_dev),
in_elementwise_op, in_elementwise_op,
......
...@@ -56,8 +56,8 @@ struct ReferenceReduce : public device::DeviceReduce<InDataType, ...@@ -56,8 +56,8 @@ struct ReferenceReduce : public device::DeviceReduce<InDataType,
const std::array<index_t, NumDstDim> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, double alpha,
float beta, double beta,
const InDataType* in_host, const InDataType* in_host,
OutDataType* out_host, OutDataType* out_host,
IndexDataType* out_index_host, IndexDataType* out_index_host,
...@@ -388,8 +388,8 @@ struct ReferenceReduce : public device::DeviceReduce<InDataType, ...@@ -388,8 +388,8 @@ struct ReferenceReduce : public device::DeviceReduce<InDataType,
const std::array<index_t, NumDstDim> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, double alpha,
float beta, double beta,
const void* in_host, const void* in_host,
const void* in_index_host, const void* in_index_host,
void* out_host, void* out_host,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment