Commit c9bc461c authored by Paul's avatar Paul
Browse files

Slice inner

parent b9b761eb
...@@ -52,10 +52,12 @@ struct softmax_compiler : compiler<softmax_compiler> ...@@ -52,10 +52,12 @@ struct softmax_compiler : compiler<softmax_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
auto axis = v.at("axis").to<int64_t>(); auto axis = v.at("axis").to<int64_t>();
auto block_size = compute_block_size(inputs[0].lens()[axis], 256); auto relements = inputs[0].lens()[axis];
auto nelements = inputs.back().elements() / relements;
auto block_size = compute_block_size(relements, 256);
hip_compile_options options; hip_compile_options options;
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements(), block_size), 256); v, compute_global_for(ctx, nelements * block_size, 256), block_size);
options.output = inputs.back(); options.output = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = "softmax_kernel"; options.kernel_name = "softmax_kernel";
......
...@@ -191,13 +191,12 @@ struct block ...@@ -191,13 +191,12 @@ struct block
f(); f();
} }
template <class T, class... Ts> template <class F>
__device__ auto inner(T x, Ts... xs) const __device__ auto inner(F f) const
{ {
return [=](auto f) { return sliced(slicer, [=](auto x, auto... xs) {
// TODO: Assert same elements
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); }); idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
}; });
} }
}; };
...@@ -246,15 +245,15 @@ struct lane ...@@ -246,15 +245,15 @@ struct lane
f(); f();
} }
template <class T, class... Ts> template <class F>
__device__ auto inner(T x, Ts... xs) const __device__ auto inner(F f) const
{ {
return [=](auto f) { return sliced(slicer, [=](auto x, auto... xs) {
for(index_int j = 0; j < x.get_shape().elements(); j++) for(index_int j = 0; j < x.get_shape().elements(); j++)
{ {
f(x[j], xs[j]...); f(x[j], xs[j]...);
} }
}; });
} }
}; };
......
...@@ -13,8 +13,8 @@ __device__ void softmax(Input input, Output output) ...@@ -13,8 +13,8 @@ __device__ void softmax(Input input, Output output)
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input); auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input);
auto batch_sum = auto batch_sum =
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input); r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input);
r.inner(output, r.inner([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; })(output,
input)([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; }); input);
}); });
} }
......
...@@ -186,7 +186,6 @@ struct miopen_apply ...@@ -186,7 +186,6 @@ struct miopen_apply
add_extend_op("rnn_var_sl_shift_output"); add_extend_op("rnn_var_sl_shift_output");
add_extend_op("rnn_var_sl_shift_sequence"); add_extend_op("rnn_var_sl_shift_sequence");
add_extend_op("scatter_none"); add_extend_op("scatter_none");
add_extend_op("softmax");
add_extend_op("topk"); add_extend_op("topk");
add_batch_norm_inference_op(); add_batch_norm_inference_op();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment