Commit 201c8182 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add to process the case of input axis being 0

parent aa521f17
...@@ -621,16 +621,31 @@ struct cpu_logsoftmax ...@@ -621,16 +621,31 @@ struct cpu_logsoftmax
template <typename T> template <typename T>
std::size_t compute_batch_index(const T& idx, shape& batch_shape, int axis) const std::size_t compute_batch_index(const T& idx, shape& batch_shape, int axis) const
{
if (axis == 0)
{
return 0;
}
else
{ {
std::vector<std::size_t> batch_idx(idx.begin(), idx.begin() + axis); std::vector<std::size_t> batch_idx(idx.begin(), idx.begin() + axis);
return batch_shape.index(batch_idx.begin(), batch_idx.end()); return batch_shape.index(batch_idx.begin(), batch_idx.end());
} }
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto lens = output_shape.lens(); auto lens = output_shape.lens();
std::vector<std::size_t> batch_lens(lens.begin(), lens.begin() + op.axis); std::vector<std::size_t> batch_lens{};
if (op.axis == 0)
{
batch_lens.push_back(1);
}
else
{
batch_lens.insert(batch_lens.begin(), lens.begin(), lens.begin() + op.axis);
}
shape batch_shape{migraphx::shape::uint32_type, batch_lens}; shape batch_shape{migraphx::shape::uint32_type, batch_lens};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment