Commit ffbc0918 authored by Paul's avatar Paul
Browse files

Compile-time fixes

parent 2851a6e9
...@@ -196,7 +196,7 @@ struct block ...@@ -196,7 +196,7 @@ struct block
{ {
return [=](auto f) { return [=](auto f) {
// TODO: Assert same elements // TODO: Assert same elements
idx.local_stride(x.elements(), [&](auto j) { f(x[j], xs[j]...); }); idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
}; };
} }
}; };
......
...@@ -9,6 +9,7 @@ namespace migraphx { ...@@ -9,6 +9,7 @@ namespace migraphx {
template <class Lens, class Strides> template <class Lens, class Strides>
struct shape struct shape
{ {
using shape_type = shape;
using index_array = typename Lens::base_array; using index_array = typename Lens::base_array;
Lens lens = {}; Lens lens = {};
Strides strides = {}; Strides strides = {};
......
...@@ -2,18 +2,18 @@ ...@@ -2,18 +2,18 @@
#define MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP #define MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#include <migraphx/kernels/reduce.hpp> #include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/basic_ops.hpp> #include <migraphx/kernels/ops.hpp>
namespace migraphx { namespace migraphx {
template <index_int Axis, class Input, class Output> template <index_int Axis, class Input, class Output>
void softmax(Input input, Output output) __device__ void softmax(Input input, Output output)
{ {
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input); auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input);
auto batch_sum = auto batch_sum =
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input); r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input);
r.outer(output, r.inner(output,
input)([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; }); input)([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; });
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment