Commit 07214d76 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent d0e2ace6
...@@ -939,9 +939,10 @@ struct logsoftmax ...@@ -939,9 +939,10 @@ struct logsoftmax
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
if (axis < 0 || axis >= inputs[0].lens().size()) if(axis < 0 || axis >= inputs[0].lens().size())
{ {
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) + " is out of range"); MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
" is out of range");
} }
return inputs.at(0); return inputs.at(0);
} }
......
...@@ -229,8 +229,9 @@ struct onnx_parser ...@@ -229,8 +229,9 @@ struct onnx_parser
return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s); return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
} }
instruction_ref instruction_ref parse_logsoftmax(const std::string&,
parse_logsoftmax(const std::string&, const attribute_map& attributes, std::vector<instruction_ref> args) const attribute_map& attributes,
std::vector<instruction_ref> args)
{ {
int axis = 1; int axis = 1;
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
......
...@@ -619,7 +619,7 @@ struct cpu_logsoftmax ...@@ -619,7 +619,7 @@ struct cpu_logsoftmax
std::string name() const { return "cpu::logsoftmax"; } std::string name() const { return "cpu::logsoftmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template<typename T> template <typename T>
std::size_t compute_batch_index(const T& idx, shape& batch_shape, int axis) const std::size_t compute_batch_index(const T& idx, shape& batch_shape, int axis) const
{ {
std::vector<std::size_t> batch_idx(idx.begin(), idx.begin() + axis); std::vector<std::size_t> batch_idx(idx.begin(), idx.begin() + axis);
...@@ -636,14 +636,15 @@ struct cpu_logsoftmax ...@@ -636,14 +636,15 @@ struct cpu_logsoftmax
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(), std::numeric_limits<value_type>::lowest()); std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) { shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis); auto index = compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end())); batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
}); });
shape_for_each(output_shape, [&](auto idx) { shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis); auto index = compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index]; output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index];
}); });
...@@ -653,14 +654,15 @@ struct cpu_logsoftmax ...@@ -653,14 +654,15 @@ struct cpu_logsoftmax
batch_sum[index] += std::exp(output(idx.begin(), idx.end())); batch_sum[index] += std::exp(output(idx.begin(), idx.end()));
}); });
for (std::size_t i = 0; i < batch_sum.size(); ++i) for(std::size_t i = 0; i < batch_sum.size(); ++i)
{ {
batch_sum[i] = std::log(batch_sum[i]); batch_sum[i] = std::log(batch_sum[i]);
} }
shape_for_each(output_shape, [&](auto idx) { shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis); auto index = compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index] - batch_sum[index]; output(idx.begin(), idx.end()) =
input(idx.begin(), idx.end()) - batch_max[index] - batch_sum[index];
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment