Unverified Commit b0ece214 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Make clip a pointwise op (#1043)

Make clip a pointwise op
parent fc42d852
......@@ -7,7 +7,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
......@@ -21,25 +21,26 @@ struct clip
{
std::string name() const { return "clip"; }
value attributes() const
{
return {{"pointwise", true},
{"point_op", "${function:min}(${function:max}(${1}, ${0}), ${2})"}};
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).same_type();
check_shapes{inputs, *this}.has(3).same_type().same_dims();
return inputs.front();
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1], args[2])([&](auto output, auto x, auto min, auto max) {
par_for(output_shape.elements(),
[&](auto i) { output[i] = std::min(std::max(min[i], x[i]), max[i]); });
});
visit_all(result, args[0], args[1], args[2])(
[&](auto output, auto input, auto min_val, auto max_val) {
auto max = max_val.front();
auto min = min_val.front();
std::transform(input.begin(), input.end(), output.begin(), [max, min](auto x) {
using type = decltype(x);
return std::min(std::max(type(min), x), type(max));
});
});
return result;
}
};
......
......@@ -28,9 +28,9 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu>
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val);
migraphx::make_op("multibroadcast", {{"out_lens", conv->get_shape().lens()}}), min_val);
max_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val);
migraphx::make_op("multibroadcast", {{"out_lens", conv->get_shape().lens()}}), max_val);
mm->add_instruction(migraphx::make_op("clip"), bias_add, min_val, max_val);
return p;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment