Commit c9bc461c authored by Paul's avatar Paul
Browse files

Slice inner

parent b9b761eb
......@@ -52,10 +52,12 @@ struct softmax_compiler : compiler<softmax_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
auto axis = v.at("axis").to<int64_t>();
auto block_size = compute_block_size(inputs[0].lens()[axis], 256);
auto relements = inputs[0].lens()[axis];
auto nelements = inputs.back().elements() / relements;
auto block_size = compute_block_size(relements, 256);
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements(), block_size), 256);
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "softmax_kernel";
......
......@@ -191,13 +191,12 @@ struct block
f();
}
template <class T, class... Ts>
__device__ auto inner(T x, Ts... xs) const
template <class F>
__device__ auto inner(F f) const
{
return [=](auto f) {
// TODO: Assert same elements
return sliced(slicer, [=](auto x, auto... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
};
});
}
};
......@@ -246,15 +245,15 @@ struct lane
f();
}
template <class T, class... Ts>
__device__ auto inner(T x, Ts... xs) const
template <class F>
__device__ auto inner(F f) const
{
return [=](auto f) {
return sliced(slicer, [=](auto x, auto... xs) {
for(index_int j = 0; j < x.get_shape().elements(); j++)
{
f(x[j], xs[j]...);
}
};
});
}
};
......
......@@ -13,8 +13,8 @@ __device__ void softmax(Input input, Output output)
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input);
auto batch_sum =
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input);
r.inner(output,
input)([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; });
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; })(output,
input);
});
}
......
......@@ -186,7 +186,6 @@ struct miopen_apply
add_extend_op("rnn_var_sl_shift_output");
add_extend_op("rnn_var_sl_shift_sequence");
add_extend_op("scatter_none");
add_extend_op("softmax");
add_extend_op("topk");
add_batch_norm_inference_op();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment