Commit ffbc0918 authored by Paul's avatar Paul
Browse files

Compile-time fixes

parent 2851a6e9
......@@ -196,7 +196,7 @@ struct block
{
return [=](auto f) {
// TODO: Assert same elements
idx.local_stride(x.elements(), [&](auto j) { f(x[j], xs[j]...); });
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
};
}
};
......
......@@ -9,6 +9,7 @@ namespace migraphx {
template <class Lens, class Strides>
struct shape
{
using shape_type = shape;
using index_array = typename Lens::base_array;
Lens lens = {};
Strides strides = {};
......
......@@ -2,18 +2,18 @@
#define MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/ops.hpp>
namespace migraphx {
template <index_int Axis, class Input, class Output>
void softmax(Input input, Output output)
__device__ void softmax(Input input, Output output)
{
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input);
auto batch_sum =
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input);
r.outer(output,
r.inner(output,
input)([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; });
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment