Commit a40c5f4a authored by Khalique's avatar Khalique
Browse files

fix edge case for dims=1

parent 7f7cbbc0
......@@ -533,7 +533,7 @@ struct cpu_softmax
template <typename T>
std::size_t compute_batch_index(T idx, shape& batch_shape, int axis) const
{
idx.erase(idx.begin() + axis);
idx[axis] = 0;
return batch_shape.index(idx);
}
......@@ -541,7 +541,7 @@ struct cpu_softmax
{
argument result{output_shape};
auto batch_lens = output_shape.lens();
batch_lens.erase(batch_lens.begin() + op.axis);
batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment