Commit 07214d76 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent d0e2ace6
......@@ -939,9 +939,10 @@ struct logsoftmax
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
if (axis < 0 || axis >= inputs[0].lens().size())
if(axis < 0 || axis >= inputs[0].lens().size())
{
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) + " is out of range");
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
" is out of range");
}
return inputs.at(0);
}
......
......@@ -229,8 +229,9 @@ struct onnx_parser
return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
}
instruction_ref
parse_logsoftmax(const std::string&, const attribute_map& attributes, std::vector<instruction_ref> args)
instruction_ref parse_logsoftmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = 1;
if(contains(attributes, "axis"))
......
......@@ -619,7 +619,7 @@ struct cpu_logsoftmax
std::string name() const { return "cpu::logsoftmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template<typename T>
template <typename T>
std::size_t compute_batch_index(const T& idx, shape& batch_shape, int axis) const
{
std::vector<std::size_t> batch_idx(idx.begin(), idx.begin() + axis);
......@@ -636,14 +636,15 @@ struct cpu_logsoftmax
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(), std::numeric_limits<value_type>::lowest());
std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis);
auto index = compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
});
shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis);
auto index = compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index];
});
......@@ -653,14 +654,15 @@ struct cpu_logsoftmax
batch_sum[index] += std::exp(output(idx.begin(), idx.end()));
});
for (std::size_t i = 0; i < batch_sum.size(); ++i)
for(std::size_t i = 0; i < batch_sum.size(); ++i)
{
batch_sum[i] = std::log(batch_sum[i]);
}
shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index] - batch_sum[index];
output(idx.begin(), idx.end()) =
input(idx.begin(), idx.end()) - batch_max[index] - batch_sum[index];
});
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment