"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "e0b6ce021595c933b17f99853600762be1a1704f"
Commit 201c8182 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add to process the case of input axis being 0

parent aa521f17
...@@ -622,15 +622,30 @@ struct cpu_logsoftmax ...@@ -622,15 +622,30 @@ struct cpu_logsoftmax
template <typename T> template <typename T>
std::size_t compute_batch_index(const T& idx, shape& batch_shape, int axis) const std::size_t compute_batch_index(const T& idx, shape& batch_shape, int axis) const
{ {
std::vector<std::size_t> batch_idx(idx.begin(), idx.begin() + axis); if (axis == 0)
return batch_shape.index(batch_idx.begin(), batch_idx.end()); {
return 0;
}
else
{
std::vector<std::size_t> batch_idx(idx.begin(), idx.begin() + axis);
return batch_shape.index(batch_idx.begin(), batch_idx.end());
}
} }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto lens = output_shape.lens(); auto lens = output_shape.lens();
std::vector<std::size_t> batch_lens(lens.begin(), lens.begin() + op.axis); std::vector<std::size_t> batch_lens{};
if (op.axis == 0)
{
batch_lens.push_back(1);
}
else
{
batch_lens.insert(batch_lens.begin(), lens.begin(), lens.begin() + op.axis);
}
shape batch_shape{migraphx::shape::uint32_type, batch_lens}; shape batch_shape{migraphx::shape::uint32_type, batch_lens};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment