Unverified Commit 0325c1a4 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Clip update for onnx (#455)



* fix pad calc

* modify clip for more args

* formatting

* add test, flip order, revert to unary

* fix error msg

* add min and max args to clip

* formatting

* fixes to quantization

* formatting

* fix logic and add extra test

* formatting

* fix logic, add extra test

* formatting

* fix bug in test
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent c4e53a33
...@@ -18,29 +18,30 @@ namespace migraphx { ...@@ -18,29 +18,30 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct clip : unary<clip> struct clip
{ {
float max_val = std::numeric_limits<float>::max(); std::string name() const { return "clip"; }
float min_val = std::numeric_limits<float>::min();
clip() {} shape compute_shape(std::vector<shape> inputs) const
clip(float max, float min) : max_val(max), min_val(min) {}
auto apply() const
{ {
auto max = max_val; check_shapes{inputs}.has(3).same_type();
auto min = min_val; return inputs.front();
return [max, min](auto x) {
using type = decltype(x);
return std::min(std::max(type(min), x), type(max));
};
} }
template <class Self, class F> argument compute(const shape& output_shape, std::vector<argument> args) const
static auto reflect(Self& self, F f)
{ {
return pack(f(self.max_val, "max"), f(self.min_val, "min")); argument result{output_shape};
visit_all(result, args[0], args[1], args[2])(
[&](auto output, auto input, auto min_val, auto max_val) {
auto max = max_val.front();
auto min = min_val.front();
std::transform(input.begin(), input.end(), output.begin(), [max, min](auto x) {
using type = decltype(x);
return std::min(std::max(type(min), x), type(max));
});
});
return result;
} }
}; };
......
...@@ -330,16 +330,48 @@ struct onnx_parser ...@@ -330,16 +330,48 @@ struct onnx_parser
instruction_ref instruction_ref
parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args) parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
op::clip op; auto input_lens = args[0]->get_shape().lens();
if(contains(info.attributes, "max")) instruction_ref min_arg;
instruction_ref max_arg;
bool min_used = false;
bool max_used = false;
if(args.size() == 3)
{ {
op.max_val = parse_value(info.attributes.at("max")).at<float>(); min_arg = args[1];
max_arg = args[2];
min_used = true;
max_used = true;
} }
if(contains(info.attributes, "min")) else if(args.size() == 2)
{ {
op.min_val = parse_value(info.attributes.at("min")).at<float>(); min_arg = args[1];
min_used = true;
} }
return prog.add_instruction(op, std::move(args)); // if using previous opset for attributes
else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
{
float min_val = parse_value(info.attributes.at("min")).at<float>();
float max_val = parse_value(info.attributes.at("max")).at<float>();
min_arg = prog.add_literal(min_val);
max_arg = prog.add_literal(max_val);
min_used = true;
max_used = true;
}
if(min_used)
min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg);
if(max_used)
max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);
if(min_used and max_used)
return prog.add_instruction(op::clip{}, args[0], min_arg, max_arg);
if(min_used)
return prog.add_instruction(op::max{}, args[0], min_arg);
return prog.add_instruction(op::identity{}, args[0]);
} }
template <class Op> template <class Op>
......
...@@ -80,9 +80,14 @@ instruction_ref insert_quant_ins(program& prog, ...@@ -80,9 +80,14 @@ instruction_ref insert_quant_ins(program& prog,
shifted_ins = prog.insert_instruction(insert_loc, op::add{}, l_shift, float_ins); shifted_ins = prog.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
} }
auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, shifted_ins); auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, shifted_ins);
auto rounded_lens = rounded_ins->get_shape().lens();
auto max_clip = prog.add_literal(127.0f);
auto min_clip = prog.add_literal(-128.0f);
max_clip = prog.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, max_clip);
min_clip = prog.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, min_clip);
auto clipped_ins = auto clipped_ins =
prog.insert_instruction(insert_loc, op::clip{127.0f, -128.0f}, rounded_ins); prog.insert_instruction(insert_loc, op::clip{}, rounded_ins, min_clip, max_clip);
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, clipped_ins); quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, clipped_ins);
} }
else else
......
...@@ -14,7 +14,7 @@ shape hip_clip::compute_shape(std::vector<shape> inputs) const ...@@ -14,7 +14,7 @@ shape hip_clip::compute_shape(std::vector<shape> inputs) const
argument hip_clip::compute(context& ctx, const shape&, const std::vector<argument>& args) const argument hip_clip::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::clip(ctx.get_stream().get(), args.back(), args.front(), op.max_val, op.min_val); device::clip(ctx.get_stream().get(), args.back(), args.front(), args.at(1), args.at(2));
return args.back(); return args.back();
} }
......
...@@ -10,12 +10,12 @@ void add_clip(hipStream_t stream, ...@@ -10,12 +10,12 @@ void add_clip(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
const argument& arg2, const argument& arg2,
const float max, const argument& min_arg,
const float min) const argument& max_arg)
{ {
nary(stream, result, arg1, arg2)([max, min](auto x, auto y) __device__ { nary(stream, result, arg1, arg2, min_arg, max_arg)(
return ::min<decltype(x + y)>(::max<decltype(x)>(min, x + y), max); [](auto x, auto y, auto min, auto max)
}); __device__ { return ::min<decltype(x + y)>(::max<decltype(x)>(min, x + y), max); });
} }
void add_clip(hipStream_t stream, void add_clip(hipStream_t stream,
...@@ -23,12 +23,13 @@ void add_clip(hipStream_t stream, ...@@ -23,12 +23,13 @@ void add_clip(hipStream_t stream,
const argument& arg1, const argument& arg1,
const argument& arg2, const argument& arg2,
const argument& arg3, const argument& arg3,
const float max, const argument& min_arg,
const float min) const argument& max_arg)
{ {
nary(stream, result, arg1, arg2, arg3)([max, min](auto x, auto y, auto z) __device__ { nary(stream, result, arg1, arg2, arg3, min_arg, max_arg)(
return ::min<decltype(x + y + z)>(::max<decltype(x)>(min, x + y + z), max); [](auto x, auto y, auto z, auto min, auto max) __device__ {
}); return ::min<decltype(x + y + z)>(::max<decltype(x)>(min, x + y + z), max);
});
} }
} // namespace device } // namespace device
......
...@@ -9,10 +9,11 @@ namespace device { ...@@ -9,10 +9,11 @@ namespace device {
void clip(hipStream_t stream, void clip(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
const float max, const argument& min_val,
const float min) const argument& max_val)
{ {
nary(stream, result, arg1)([max, min](auto x) __device__ {
nary(stream, result, arg1, min_val, max_val)([](auto x, auto min, auto max) __device__ {
return ::min<decltype(x)>(::max<decltype(x)>(min, x), max); return ::min<decltype(x)>(::max<decltype(x)>(min, x), max);
}); });
} }
......
...@@ -184,29 +184,22 @@ struct hip_triadd ...@@ -184,29 +184,22 @@ struct hip_triadd
struct hip_triadd_clip struct hip_triadd_clip
{ {
op::clip op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return op::clip::reflect(self.op, f);
}
std::string name() const { return "hip::triadd_clip"; } std::string name() const { return "hip::triadd_clip"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(4); check_shapes{inputs, *this}.has(6);
return inputs.front(); return inputs.front();
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::add_clip(ctx.get_stream().get(), device::add_clip(ctx.get_stream().get(),
args.at(3), args.at(5),
args.at(0), args.at(0),
args.at(1), args.at(1),
args.at(2), args.at(2),
op.max_val, args.at(3),
op.min_val); args.at(4));
return args.at(3); return args.at(5);
} }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
...@@ -216,24 +209,17 @@ struct hip_triadd_clip ...@@ -216,24 +209,17 @@ struct hip_triadd_clip
struct hip_add_clip struct hip_add_clip
{ {
op::clip op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return op::clip::reflect(self.op, f);
}
std::string name() const { return "hip::add_clip"; } std::string name() const { return "hip::add_clip"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(5);
return inputs.front(); return inputs.front();
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::add_clip( device::add_clip(
ctx.get_stream().get(), args.at(2), args.at(0), args.at(1), op.max_val, op.min_val); ctx.get_stream().get(), args.at(4), args.at(0), args.at(1), args.at(2), args.at(3));
return args.at(2); return args.at(4);
} }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
...@@ -337,19 +323,20 @@ struct find_add_clip ...@@ -337,19 +323,20 @@ struct find_add_clip
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
auto&& op = any_cast<gpu::hip_clip>(ins->get_operator()).op; auto ins_args = ins->inputs();
auto args = add_ins->inputs(); auto add_args = add_ins->inputs();
move_standard_front(args); move_standard_front(add_args);
move_broadcasted_back(args); move_broadcasted_back(add_args);
// Use the allocation from the relu operator // Use the allocation from the clip operator
args.back() = ins->inputs().back(); add_args.pop_back();
add_args.insert(add_args.end(), std::next(ins_args.begin()), ins_args.end());
if(add_ins->name() == "gpu::add") if(add_ins->name() == "gpu::add")
p.replace_instruction(ins, hip_add_clip{op}, args); p.replace_instruction(ins, hip_add_clip{}, add_args);
else if(add_ins->name() == "hip::triadd") else if(add_ins->name() == "hip::triadd")
p.replace_instruction(ins, hip_triadd_clip{op}, args); p.replace_instruction(ins, hip_triadd_clip{}, add_args);
} }
}; };
......
...@@ -15,16 +15,16 @@ void add_clip(hipStream_t stream, ...@@ -15,16 +15,16 @@ void add_clip(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
const argument& arg2, const argument& arg2,
float max, const argument& min_arg,
float min); const argument& max_arg);
void add_clip(hipStream_t stream, void add_clip(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
const argument& arg2, const argument& arg2,
const argument& arg3, const argument& arg3,
float max, const argument& min_arg,
float min); const argument& max_arg);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -10,7 +10,11 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,11 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void clip(hipStream_t stream, const argument& result, const argument& arg1, float max, float min); void clip(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& min_val,
const argument& max_val);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -179,7 +179,7 @@ struct tf_parser ...@@ -179,7 +179,7 @@ struct tf_parser
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", op::identity{});
add_generic_op("LessEqual", op::identity{}); add_generic_op("LessEqual", op::identity{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Relu6", op::clip{6.0, 0.0}); // add_generic_op("Relu6", op::clip{6.0, 0.0});
add_generic_op("Rsqrt", op::rsqrt{}); add_generic_op("Rsqrt", op::rsqrt{});
add_generic_op("Tanh", op::tanh{}); add_generic_op("Tanh", op::tanh{});
add_generic_op("StopGradient", op::identity{}); add_generic_op("StopGradient", op::identity{});
...@@ -210,6 +210,7 @@ struct tf_parser ...@@ -210,6 +210,7 @@ struct tf_parser
add_mem_op("OneHot", &tf_parser::parse_onehot, false); add_mem_op("OneHot", &tf_parser::parse_onehot, false);
add_mem_op("Pack", &tf_parser::parse_pack, false); add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad); add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Relu6", &tf_parser::parse_relu6);
add_mem_op("Reshape", &tf_parser::parse_reshape, false); add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Shape", &tf_parser::parse_shape, false); add_mem_op("Shape", &tf_parser::parse_shape, false);
add_mem_op("Slice", &tf_parser::parse_slice, false); add_mem_op("Slice", &tf_parser::parse_slice, false);
...@@ -771,6 +772,18 @@ struct tf_parser ...@@ -771,6 +772,18 @@ struct tf_parser
return prog.add_instruction(op, l0); return prog.add_instruction(op, l0);
} }
instruction_ref
parse_relu6(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto input_lens = args[0]->get_shape().lens();
auto min_val = prog.add_literal(0.0f);
auto max_val = prog.add_literal(6.0f);
min_val = prog.add_instruction(op::multibroadcast{input_lens}, min_val);
max_val = prog.add_instruction(op::multibroadcast{input_lens}, max_val);
return prog.add_instruction(op::clip{}, args.front(), min_val, max_val);
}
instruction_ref instruction_ref
parse_reshape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_reshape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
......
...@@ -2043,11 +2043,12 @@ TEST_CASE(clip_test) ...@@ -2043,11 +2043,12 @@ TEST_CASE(clip_test)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = p.add_literal(migraphx::literal{s, {-1.0, 0.0, 10.0}}); auto l = p.add_literal(migraphx::literal{s, {-1.0, 0.0, 10.0}});
migraphx::op::clip op; auto min_val = p.add_literal(0.0f);
op.max_val = 6.0; auto max_val = p.add_literal(6.0f);
op.min_val = 0.0; min_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
p.add_instruction(op, l); max_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
p.add_instruction(migraphx::op::clip{}, l, min_val, max_val);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
......
...@@ -529,8 +529,12 @@ struct test_acosh : verify_program<test_acosh> ...@@ -529,8 +529,12 @@ struct test_acosh : verify_program<test_acosh>
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{migraphx::shape::float_type, {16}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto cx = p.add_instruction(migraphx::op::clip{100.0f, 1.1f}, x); auto min_val = p.add_literal(1.1f);
auto max_val = p.add_literal(100.0f);
min_val = p.add_instruction(migraphx::op::multibroadcast{{16}}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{{16}}, max_val);
auto cx = p.add_instruction(migraphx::op::clip{}, x, min_val, max_val);
p.add_instruction(migraphx::op::acosh{}, cx); p.add_instruction(migraphx::op::acosh{}, cx);
return p; return p;
} }
...@@ -542,8 +546,12 @@ struct test_atanh : verify_program<test_atanh> ...@@ -542,8 +546,12 @@ struct test_atanh : verify_program<test_atanh>
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::double_type, {16}}; migraphx::shape s{migraphx::shape::double_type, {16}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto cx = p.add_instruction(migraphx::op::clip{0.95f, -0.95f}, x); auto min_val = p.add_literal(-0.95);
auto max_val = p.add_literal(0.95);
min_val = p.add_instruction(migraphx::op::multibroadcast{{16}}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{{16}}, max_val);
auto cx = p.add_instruction(migraphx::op::clip{}, x, min_val, max_val);
p.add_instruction(migraphx::op::atanh{}, cx); p.add_instruction(migraphx::op::atanh{}, cx);
return p; return p;
} }
...@@ -931,6 +939,7 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu> ...@@ -931,6 +939,7 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
std::vector<size_t> input_lens{4, 3, 3, 3};
auto input = auto input =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights = auto weights =
...@@ -942,7 +951,11 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu> ...@@ -942,7 +951,11 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu>
auto bcast_add = auto bcast_add =
p.add_instruction(migraphx::op::broadcast{1, conv->get_shape().lens()}, bias); p.add_instruction(migraphx::op::broadcast{1, conv->get_shape().lens()}, bias);
auto bias_add = p.add_instruction(migraphx::op::add{}, conv, bcast_add); auto bias_add = p.add_instruction(migraphx::op::add{}, conv, bcast_add);
p.add_instruction(migraphx::op::clip{6.0f, 0.0f}, bias_add); auto min_val = p.add_literal(0.0f);
auto max_val = p.add_literal(6.0f);
min_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, max_val);
p.add_instruction(migraphx::op::clip{}, bias_add, min_val, max_val);
return p; return p;
} }
}; };
...@@ -1928,8 +1941,12 @@ struct test_clip : verify_program<test_clip> ...@@ -1928,8 +1941,12 @@ struct test_clip : verify_program<test_clip>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3}}); auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::clip{6.0, 0.0}, x); auto min_val = p.add_literal(0.0f);
auto max_val = p.add_literal(6.0f);
min_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
p.add_instruction(migraphx::op::clip{}, x, min_val, max_val);
return p; return p;
} }
}; };
...@@ -4387,9 +4404,14 @@ struct test_rsqrt : verify_program<test_rsqrt> ...@@ -4387,9 +4404,14 @@ struct test_rsqrt : verify_program<test_rsqrt>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}}; std::vector<size_t> input_lens{1, 3, 16, 16};
auto x = p.add_parameter("x", s); migraphx::shape s{migraphx::shape::float_type, input_lens};
auto l0 = p.add_instruction(migraphx::op::clip{std::numeric_limits<float>::max(), 1.0}, x); auto x = p.add_parameter("x", s);
auto min_val = p.add_literal(1.0f);
auto max_val = p.add_literal(std::numeric_limits<float>::max());
min_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, max_val);
auto l0 = p.add_instruction(migraphx::op::clip{}, x, min_val, max_val);
p.add_instruction(migraphx::op::rsqrt{}, l0); p.add_instruction(migraphx::op::rsqrt{}, l0);
return p; return p;
}; };
......
No preview for this file type
clip_test_op11_no_args:H
01"Clipclip_test_op11_no_argsZ
0

b
1

B
\ No newline at end of file
...@@ -268,6 +268,43 @@ def clip_test(): ...@@ -268,6 +268,43 @@ def clip_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def clip_test_op11():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
min_val = helper.make_tensor('min', TensorProto.FLOAT, [], [0.0])
max_val = helper.make_tensor('max', TensorProto.FLOAT, [], [6.0])
node = onnx.helper.make_node('Clip',
inputs=['0', 'min', 'max'],
outputs=['1'])
return ([node], [x], [y], [min_val, max_val])
@onnx_test
def clip_test_op11_min_only():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
min_val = helper.make_tensor('min', TensorProto.FLOAT, [], [0.0])
node = onnx.helper.make_node('Clip', inputs=['0', 'min'], outputs=['1'])
return ([node], [x], [y], [min_val])
@onnx_test
def clip_test_op11_no_args():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
node = onnx.helper.make_node('Clip', inputs=['0'], outputs=['1'])
return ([node], [x], [y])
@onnx_test @onnx_test
def concat_test(): def concat_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 4, 3]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 4, 3])
......
...@@ -222,13 +222,53 @@ TEST_CASE(ceil_test) ...@@ -222,13 +222,53 @@ TEST_CASE(ceil_test)
TEST_CASE(clip_test) TEST_CASE(clip_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0); auto min_val = p.add_literal(0.0f);
auto max_val = p.add_literal(6.0f);
min_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
p.add_instruction(migraphx::op::clip{}, l0, min_val, max_val);
auto prog = optimize_onnx("clip_test.onnx"); auto prog = optimize_onnx("clip_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(clip_test_op11)
{
migraphx::program p;
auto min_val = p.add_literal(0.0f);
auto max_val = p.add_literal(6.0f);
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
min_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
p.add_instruction(migraphx::op::clip{}, l0, min_val, max_val);
auto prog = optimize_onnx("clip_test_op11.onnx");
EXPECT(p == prog);
}
TEST_CASE(clip_test_op11_min_only)
{
migraphx::program p;
auto min_val = p.add_literal(0.0f);
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
min_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
p.add_instruction(migraphx::op::max{}, l0, min_val);
auto prog = optimize_onnx("clip_test_op11_min_only.onnx");
EXPECT(p == prog);
}
TEST_CASE(clip_test_op11_no_args)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::identity{}, l0);
auto prog = optimize_onnx("clip_test_op11_no_args.onnx");
EXPECT(p == prog);
}
TEST_CASE(concat_test) TEST_CASE(concat_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -14,6 +14,31 @@ ...@@ -14,6 +14,31 @@
#include "test.hpp" #include "test.hpp"
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
migraphx::instruction_ref
create_clip_op(migraphx::program& p, float max, float min, migraphx::instruction_ref input)
{
auto input_lens = input->get_shape().lens();
auto max_val = p.add_literal(max);
auto min_val = p.add_literal(min);
max_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, max_val);
min_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, min_val);
return p.add_instruction(migraphx::op::clip{}, input, min_val, max_val);
}
migraphx::instruction_ref create_clip_op(migraphx::instruction_ref insert_loc,
migraphx::program& p,
float max,
float min,
migraphx::instruction_ref input)
{
auto input_lens = input->get_shape().lens();
auto max_val = p.add_literal(max);
auto min_val = p.add_literal(min);
max_val = p.insert_instruction(insert_loc, migraphx::op::multibroadcast{input_lens}, max_val);
min_val = p.insert_instruction(insert_loc, migraphx::op::multibroadcast{input_lens}, min_val);
return p.insert_instruction(insert_loc, migraphx::op::clip{}, input, min_val, max_val);
}
TEST_CASE(param_add) TEST_CASE(param_add)
{ {
auto create_program_float = [](bool add_return = false) { auto create_program_float = [](bool add_return = false) {
...@@ -308,7 +333,7 @@ TEST_CASE(dot_float) ...@@ -308,7 +333,7 @@ TEST_CASE(dot_float)
auto fa = p.add_literal(migraphx::literal(sa, vfa)); auto fa = p.add_literal(migraphx::literal(sa, vfa));
auto ma = p.add_instruction(migraphx::op::mul{}, fa, pa); auto ma = p.add_instruction(migraphx::op::mul{}, fa, pa);
auto ra = p.add_instruction(migraphx::op::round{}, ma); auto ra = p.add_instruction(migraphx::op::round{}, ma);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra); auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca); auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
// quantize parameter b to int8 type // quantize parameter b to int8 type
...@@ -317,7 +342,7 @@ TEST_CASE(dot_float) ...@@ -317,7 +342,7 @@ TEST_CASE(dot_float)
auto fb = p.add_literal(migraphx::literal(sb, vfb)); auto fb = p.add_literal(migraphx::literal(sb, vfb));
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb); auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb); auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb); auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = auto qb =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb); p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
...@@ -372,7 +397,7 @@ TEST_CASE(dot_double_2args) ...@@ -372,7 +397,7 @@ TEST_CASE(dot_double_2args)
auto fa = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa)); auto fa = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
auto ma = p.add_instruction(migraphx::op::mul{}, fa, fpa); auto ma = p.add_instruction(migraphx::op::mul{}, fa, fpa);
auto ra = p.add_instruction(migraphx::op::round{}, ma); auto ra = p.add_instruction(migraphx::op::round{}, ma);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra); auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca); auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
// quantize parameter b to int8 type // quantize parameter b to int8 type
...@@ -383,7 +408,7 @@ TEST_CASE(dot_double_2args) ...@@ -383,7 +408,7 @@ TEST_CASE(dot_double_2args)
auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb)); auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, fpb); auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, fpb);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb); auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb); auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = auto qb =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb); p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
...@@ -438,7 +463,7 @@ TEST_CASE(dot_large_alpha_beta_float) ...@@ -438,7 +463,7 @@ TEST_CASE(dot_large_alpha_beta_float)
auto sfta = p.add_literal(migraphx::literal(sa, vsa)); auto sfta = p.add_literal(migraphx::literal(sa, vsa));
auto msa = p.add_instruction(migraphx::op::add{}, sfta, ma); auto msa = p.add_instruction(migraphx::op::add{}, sfta, ma);
auto ra = p.add_instruction(migraphx::op::round{}, msa); auto ra = p.add_instruction(migraphx::op::round{}, msa);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra); auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca); auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
// quantize parameter b to int8 type // quantize parameter b to int8 type
...@@ -447,7 +472,7 @@ TEST_CASE(dot_large_alpha_beta_float) ...@@ -447,7 +472,7 @@ TEST_CASE(dot_large_alpha_beta_float)
auto fb = p.add_literal(migraphx::literal(sb, vfb)); auto fb = p.add_literal(migraphx::literal(sb, vfb));
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb); auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb); auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb); auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = auto qb =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb); p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
...@@ -505,7 +530,7 @@ TEST_CASE(dot_large_alpha_beta_int32) ...@@ -505,7 +530,7 @@ TEST_CASE(dot_large_alpha_beta_int32)
auto sfta = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa)); auto sfta = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa));
auto msa = p.add_instruction(migraphx::op::add{}, sfta, ma); auto msa = p.add_instruction(migraphx::op::add{}, sfta, ma);
auto ra = p.add_instruction(migraphx::op::round{}, msa); auto ra = p.add_instruction(migraphx::op::round{}, msa);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra); auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca); auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
// quantize parameter b to int8 type // quantize parameter b to int8 type
...@@ -516,7 +541,7 @@ TEST_CASE(dot_large_alpha_beta_int32) ...@@ -516,7 +541,7 @@ TEST_CASE(dot_large_alpha_beta_int32)
insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb); insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb);
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, conv_b); auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, conv_b);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb); auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb); auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = auto qb =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb); p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
...@@ -557,7 +582,7 @@ TEST_CASE(dot_int32_one_arg) ...@@ -557,7 +582,7 @@ TEST_CASE(dot_int32_one_arg)
auto sfta = p.add_literal(migraphx::literal({migraphx::shape::float_type, s.lens()}, vsa)); auto sfta = p.add_literal(migraphx::literal({migraphx::shape::float_type, s.lens()}, vsa));
auto msa = p.add_instruction(migraphx::op::add{}, sfta, fpa); auto msa = p.add_instruction(migraphx::op::add{}, sfta, fpa);
auto ra = p.add_instruction(migraphx::op::round{}, msa); auto ra = p.add_instruction(migraphx::op::round{}, msa);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra); auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca); auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
auto q_dot = p.add_instruction(migraphx::op::quant_dot{1, 0}, qa, qa); auto q_dot = p.add_instruction(migraphx::op::quant_dot{1, 0}, qa, qa);
...@@ -617,7 +642,7 @@ TEST_CASE(dot_int32) ...@@ -617,7 +642,7 @@ TEST_CASE(dot_int32)
auto sfta = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa)); auto sfta = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa));
auto msa = p.add_instruction(migraphx::op::add{}, sfta, ma); auto msa = p.add_instruction(migraphx::op::add{}, sfta, ma);
auto ra = p.add_instruction(migraphx::op::round{}, msa); auto ra = p.add_instruction(migraphx::op::round{}, msa);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra); auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca); auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
// quantize parameter b to int8 type // quantize parameter b to int8 type
...@@ -628,7 +653,7 @@ TEST_CASE(dot_int32) ...@@ -628,7 +653,7 @@ TEST_CASE(dot_int32)
insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb); insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb);
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, conv_b); auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, conv_b);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb); auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb); auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = auto qb =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb); p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
...@@ -692,7 +717,7 @@ TEST_CASE(dot_float_convert) ...@@ -692,7 +717,7 @@ TEST_CASE(dot_float_convert)
auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb)); auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb); auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb); auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb); auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb = auto qb =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb); p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
...@@ -738,7 +763,7 @@ TEST_CASE(conv_float) ...@@ -738,7 +763,7 @@ TEST_CASE(conv_float)
auto fx = p.add_literal(migraphx::literal(sx, vfx)); auto fx = p.add_literal(migraphx::literal(sx, vfx));
auto mx = p.add_instruction(migraphx::op::mul{}, fx, px); auto mx = p.add_instruction(migraphx::op::mul{}, fx, px);
auto rx = p.add_instruction(migraphx::op::round{}, mx); auto rx = p.add_instruction(migraphx::op::round{}, mx);
auto cx = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, rx); auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
auto qx = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, cx); auto qx = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, cx);
// quantize parameter b to int8 type // quantize parameter b to int8 type
...@@ -747,7 +772,7 @@ TEST_CASE(conv_float) ...@@ -747,7 +772,7 @@ TEST_CASE(conv_float)
auto fw = p.add_literal(migraphx::literal(sw, vfw)); auto fw = p.add_literal(migraphx::literal(sw, vfw));
auto mw = p.insert_instruction(insert_loc, migraphx::op::mul{}, fw, pw); auto mw = p.insert_instruction(insert_loc, migraphx::op::mul{}, fw, pw);
auto rw = p.insert_instruction(insert_loc, migraphx::op::round{}, mw); auto rw = p.insert_instruction(insert_loc, migraphx::op::round{}, mw);
auto cw = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rw); auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
auto qw = auto qw =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cw); p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cw);
...@@ -793,7 +818,7 @@ TEST_CASE(conv_int32) ...@@ -793,7 +818,7 @@ TEST_CASE(conv_int32)
auto fx = p.add_literal(migraphx::literal(fpx->get_shape(), vfx)); auto fx = p.add_literal(migraphx::literal(fpx->get_shape(), vfx));
auto mx = p.add_instruction(migraphx::op::mul{}, fx, fpx); auto mx = p.add_instruction(migraphx::op::mul{}, fx, fpx);
auto rx = p.add_instruction(migraphx::op::round{}, mx); auto rx = p.add_instruction(migraphx::op::round{}, mx);
auto cx = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, rx); auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
auto qx = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, cx); auto qx = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, cx);
// quantize parameter b to int8 type // quantize parameter b to int8 type
...@@ -804,7 +829,7 @@ TEST_CASE(conv_int32) ...@@ -804,7 +829,7 @@ TEST_CASE(conv_int32)
auto fw = p.add_literal(migraphx::literal(fpw->get_shape(), vfw)); auto fw = p.add_literal(migraphx::literal(fpw->get_shape(), vfw));
auto mw = p.insert_instruction(insert_loc, migraphx::op::mul{}, fw, fpw); auto mw = p.insert_instruction(insert_loc, migraphx::op::mul{}, fw, fpw);
auto rw = p.insert_instruction(insert_loc, migraphx::op::round{}, mw); auto rw = p.insert_instruction(insert_loc, migraphx::op::round{}, mw);
auto cw = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rw); auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
auto qw = auto qw =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cw); p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cw);
...@@ -849,7 +874,7 @@ TEST_CASE(conv_half) ...@@ -849,7 +874,7 @@ TEST_CASE(conv_half)
auto fx = p.add_literal(migraphx::literal(fpx->get_shape(), vfx)); auto fx = p.add_literal(migraphx::literal(fpx->get_shape(), vfx));
auto mx = p.add_instruction(migraphx::op::mul{}, fx, fpx); auto mx = p.add_instruction(migraphx::op::mul{}, fx, fpx);
auto rx = p.add_instruction(migraphx::op::round{}, mx); auto rx = p.add_instruction(migraphx::op::round{}, mx);
auto cx = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, rx); auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
auto qx = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, cx); auto qx = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, cx);
// quantize parameter b to int8 type // quantize parameter b to int8 type
...@@ -860,7 +885,7 @@ TEST_CASE(conv_half) ...@@ -860,7 +885,7 @@ TEST_CASE(conv_half)
auto fw = p.add_literal(migraphx::literal(fpw->get_shape(), vfw)); auto fw = p.add_literal(migraphx::literal(fpw->get_shape(), vfw));
auto mw = p.insert_instruction(insert_loc, migraphx::op::mul{}, fw, fpw); auto mw = p.insert_instruction(insert_loc, migraphx::op::mul{}, fw, fpw);
auto rw = p.insert_instruction(insert_loc, migraphx::op::round{}, mw); auto rw = p.insert_instruction(insert_loc, migraphx::op::round{}, mw);
auto cw = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rw); auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
auto qw = auto qw =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cw); p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cw);
......
...@@ -443,8 +443,13 @@ TEST_CASE(relu_test) ...@@ -443,8 +443,13 @@ TEST_CASE(relu_test)
TEST_CASE(relu6_test) TEST_CASE(relu6_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); std::vector<size_t> input_lens{1, 3, 16, 16};
p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens});
auto min_val = p.add_literal(0.0f);
auto max_val = p.add_literal(6.0f);
min_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, max_val);
p.add_instruction(migraphx::op::clip{}, l0, min_val, max_val);
auto prog = optimize_tf("relu6_test.pb", false); auto prog = optimize_tf("relu6_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment