Unverified Commit 0325c1a4 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Clip update for onnx (#455)



* fix pad calc

* modify clip for more args

* formatting

* add test, flip order, revert to unary

* fix error msg

* add min and max args to clip

* formatting

* fixes to quantization

* formatting

* fix logic and add extra test

* formatting

* fix logic, add extra test

* formatting

* fix bug in test
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent c4e53a33
......@@ -18,29 +18,30 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct clip : unary<clip>
struct clip
{
float max_val = std::numeric_limits<float>::max();
float min_val = std::numeric_limits<float>::min();
std::string name() const { return "clip"; }
clip() {}
clip(float max, float min) : max_val(max), min_val(min) {}
auto apply() const
shape compute_shape(std::vector<shape> inputs) const
{
auto max = max_val;
auto min = min_val;
return [max, min](auto x) {
using type = decltype(x);
return std::min(std::max(type(min), x), type(max));
};
check_shapes{inputs}.has(3).same_type();
return inputs.front();
}
template <class Self, class F>
static auto reflect(Self& self, F f)
argument compute(const shape& output_shape, std::vector<argument> args) const
{
return pack(f(self.max_val, "max"), f(self.min_val, "min"));
argument result{output_shape};
visit_all(result, args[0], args[1], args[2])(
[&](auto output, auto input, auto min_val, auto max_val) {
auto max = max_val.front();
auto min = min_val.front();
std::transform(input.begin(), input.end(), output.begin(), [max, min](auto x) {
using type = decltype(x);
return std::min(std::max(type(min), x), type(max));
});
});
return result;
}
};
......
......@@ -330,16 +330,48 @@ struct onnx_parser
instruction_ref
parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args)
{
op::clip op;
if(contains(info.attributes, "max"))
auto input_lens = args[0]->get_shape().lens();
instruction_ref min_arg;
instruction_ref max_arg;
bool min_used = false;
bool max_used = false;
if(args.size() == 3)
{
op.max_val = parse_value(info.attributes.at("max")).at<float>();
min_arg = args[1];
max_arg = args[2];
min_used = true;
max_used = true;
}
if(contains(info.attributes, "min"))
else if(args.size() == 2)
{
op.min_val = parse_value(info.attributes.at("min")).at<float>();
min_arg = args[1];
min_used = true;
}
return prog.add_instruction(op, std::move(args));
// if using previous opset for attributes
else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
{
float min_val = parse_value(info.attributes.at("min")).at<float>();
float max_val = parse_value(info.attributes.at("max")).at<float>();
min_arg = prog.add_literal(min_val);
max_arg = prog.add_literal(max_val);
min_used = true;
max_used = true;
}
if(min_used)
min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg);
if(max_used)
max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);
if(min_used and max_used)
return prog.add_instruction(op::clip{}, args[0], min_arg, max_arg);
if(min_used)
return prog.add_instruction(op::max{}, args[0], min_arg);
return prog.add_instruction(op::identity{}, args[0]);
}
template <class Op>
......
......@@ -80,9 +80,14 @@ instruction_ref insert_quant_ins(program& prog,
shifted_ins = prog.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
}
auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, shifted_ins);
auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, shifted_ins);
auto rounded_lens = rounded_ins->get_shape().lens();
auto max_clip = prog.add_literal(127.0f);
auto min_clip = prog.add_literal(-128.0f);
max_clip = prog.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, max_clip);
min_clip = prog.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, min_clip);
auto clipped_ins =
prog.insert_instruction(insert_loc, op::clip{127.0f, -128.0f}, rounded_ins);
prog.insert_instruction(insert_loc, op::clip{}, rounded_ins, min_clip, max_clip);
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, clipped_ins);
}
else
......
......@@ -14,7 +14,7 @@ shape hip_clip::compute_shape(std::vector<shape> inputs) const
argument hip_clip::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::clip(ctx.get_stream().get(), args.back(), args.front(), op.max_val, op.min_val);
device::clip(ctx.get_stream().get(), args.back(), args.front(), args.at(1), args.at(2));
return args.back();
}
......
......@@ -10,12 +10,12 @@ void add_clip(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const float max,
const float min)
const argument& min_arg,
const argument& max_arg)
{
nary(stream, result, arg1, arg2)([max, min](auto x, auto y) __device__ {
return ::min<decltype(x + y)>(::max<decltype(x)>(min, x + y), max);
});
nary(stream, result, arg1, arg2, min_arg, max_arg)(
[](auto x, auto y, auto min, auto max)
__device__ { return ::min<decltype(x + y)>(::max<decltype(x)>(min, x + y), max); });
}
void add_clip(hipStream_t stream,
......@@ -23,12 +23,13 @@ void add_clip(hipStream_t stream,
const argument& arg1,
const argument& arg2,
const argument& arg3,
const float max,
const float min)
const argument& min_arg,
const argument& max_arg)
{
nary(stream, result, arg1, arg2, arg3)([max, min](auto x, auto y, auto z) __device__ {
return ::min<decltype(x + y + z)>(::max<decltype(x)>(min, x + y + z), max);
});
nary(stream, result, arg1, arg2, arg3, min_arg, max_arg)(
[](auto x, auto y, auto z, auto min, auto max) __device__ {
return ::min<decltype(x + y + z)>(::max<decltype(x)>(min, x + y + z), max);
});
}
} // namespace device
......
......@@ -9,10 +9,11 @@ namespace device {
void clip(hipStream_t stream,
const argument& result,
const argument& arg1,
const float max,
const float min)
const argument& min_val,
const argument& max_val)
{
nary(stream, result, arg1)([max, min](auto x) __device__ {
nary(stream, result, arg1, min_val, max_val)([](auto x, auto min, auto max) __device__ {
return ::min<decltype(x)>(::max<decltype(x)>(min, x), max);
});
}
......
......@@ -184,29 +184,22 @@ struct hip_triadd
struct hip_triadd_clip
{
op::clip op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return op::clip::reflect(self.op, f);
}
std::string name() const { return "hip::triadd_clip"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
check_shapes{inputs, *this}.has(6);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::add_clip(ctx.get_stream().get(),
args.at(3),
args.at(5),
args.at(0),
args.at(1),
args.at(2),
op.max_val,
op.min_val);
return args.at(3);
args.at(3),
args.at(4));
return args.at(5);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
......@@ -216,24 +209,17 @@ struct hip_triadd_clip
struct hip_add_clip
{
op::clip op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return op::clip::reflect(self.op, f);
}
std::string name() const { return "hip::add_clip"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3);
check_shapes{inputs, *this}.has(5);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::add_clip(
ctx.get_stream().get(), args.at(2), args.at(0), args.at(1), op.max_val, op.min_val);
return args.at(2);
ctx.get_stream().get(), args.at(4), args.at(0), args.at(1), args.at(2), args.at(3));
return args.at(4);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
......@@ -337,19 +323,20 @@ struct find_add_clip
void apply(program& p, match::matcher_result r) const
{
auto add_ins = r.instructions["add"];
auto ins = r.result;
auto&& op = any_cast<gpu::hip_clip>(ins->get_operator()).op;
auto args = add_ins->inputs();
move_standard_front(args);
move_broadcasted_back(args);
auto add_ins = r.instructions["add"];
auto ins = r.result;
auto ins_args = ins->inputs();
auto add_args = add_ins->inputs();
move_standard_front(add_args);
move_broadcasted_back(add_args);
// Use the allocation from the relu operator
args.back() = ins->inputs().back();
// Use the allocation from the clip operator
add_args.pop_back();
add_args.insert(add_args.end(), std::next(ins_args.begin()), ins_args.end());
if(add_ins->name() == "gpu::add")
p.replace_instruction(ins, hip_add_clip{op}, args);
p.replace_instruction(ins, hip_add_clip{}, add_args);
else if(add_ins->name() == "hip::triadd")
p.replace_instruction(ins, hip_triadd_clip{op}, args);
p.replace_instruction(ins, hip_triadd_clip{}, add_args);
}
};
......
......@@ -15,16 +15,16 @@ void add_clip(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
float max,
float min);
const argument& min_arg,
const argument& max_arg);
void add_clip(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3,
float max,
float min);
const argument& min_arg,
const argument& max_arg);
} // namespace device
} // namespace gpu
......
......@@ -10,7 +10,11 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void clip(hipStream_t stream, const argument& result, const argument& arg1, float max, float min);
void clip(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& min_val,
const argument& max_val);
} // namespace device
} // namespace gpu
......
......@@ -179,7 +179,7 @@ struct tf_parser
add_generic_op("Identity", op::identity{});
add_generic_op("LessEqual", op::identity{});
add_generic_op("Relu", op::relu{});
add_generic_op("Relu6", op::clip{6.0, 0.0});
// add_generic_op("Relu6", op::clip{6.0, 0.0});
add_generic_op("Rsqrt", op::rsqrt{});
add_generic_op("Tanh", op::tanh{});
add_generic_op("StopGradient", op::identity{});
......@@ -210,6 +210,7 @@ struct tf_parser
add_mem_op("OneHot", &tf_parser::parse_onehot, false);
add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Relu6", &tf_parser::parse_relu6);
add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Shape", &tf_parser::parse_shape, false);
add_mem_op("Slice", &tf_parser::parse_slice, false);
......@@ -771,6 +772,18 @@ struct tf_parser
return prog.add_instruction(op, l0);
}
instruction_ref
parse_relu6(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto input_lens = args[0]->get_shape().lens();
auto min_val = prog.add_literal(0.0f);
auto max_val = prog.add_literal(6.0f);
min_val = prog.add_instruction(op::multibroadcast{input_lens}, min_val);
max_val = prog.add_instruction(op::multibroadcast{input_lens}, max_val);
return prog.add_instruction(op::clip{}, args.front(), min_val, max_val);
}
instruction_ref
parse_reshape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
......
......@@ -2043,11 +2043,12 @@ TEST_CASE(clip_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l = p.add_literal(migraphx::literal{s, {-1.0, 0.0, 10.0}});
migraphx::op::clip op;
op.max_val = 6.0;
op.min_val = 0.0;
p.add_instruction(op, l);
auto l = p.add_literal(migraphx::literal{s, {-1.0, 0.0, 10.0}});
auto min_val = p.add_literal(0.0f);
auto max_val = p.add_literal(6.0f);
min_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
p.add_instruction(migraphx::op::clip{}, l, min_val, max_val);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
......
......@@ -529,8 +529,12 @@ struct test_acosh : verify_program<test_acosh>
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {16}};
auto x = p.add_parameter("x", s);
auto cx = p.add_instruction(migraphx::op::clip{100.0f, 1.1f}, x);
auto x = p.add_parameter("x", s);
auto min_val = p.add_literal(1.1f);
auto max_val = p.add_literal(100.0f);
min_val = p.add_instruction(migraphx::op::multibroadcast{{16}}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{{16}}, max_val);
auto cx = p.add_instruction(migraphx::op::clip{}, x, min_val, max_val);
p.add_instruction(migraphx::op::acosh{}, cx);
return p;
}
......@@ -542,8 +546,12 @@ struct test_atanh : verify_program<test_atanh>
{
migraphx::program p;
migraphx::shape s{migraphx::shape::double_type, {16}};
auto x = p.add_parameter("x", s);
auto cx = p.add_instruction(migraphx::op::clip{0.95f, -0.95f}, x);
auto x = p.add_parameter("x", s);
auto min_val = p.add_literal(-0.95);
auto max_val = p.add_literal(0.95);
min_val = p.add_instruction(migraphx::op::multibroadcast{{16}}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{{16}}, max_val);
auto cx = p.add_instruction(migraphx::op::clip{}, x, min_val, max_val);
p.add_instruction(migraphx::op::atanh{}, cx);
return p;
}
......@@ -931,6 +939,7 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu>
migraphx::program create_program() const
{
migraphx::program p;
std::vector<size_t> input_lens{4, 3, 3, 3};
auto input =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights =
......@@ -942,7 +951,11 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu>
auto bcast_add =
p.add_instruction(migraphx::op::broadcast{1, conv->get_shape().lens()}, bias);
auto bias_add = p.add_instruction(migraphx::op::add{}, conv, bcast_add);
p.add_instruction(migraphx::op::clip{6.0f, 0.0f}, bias_add);
auto min_val = p.add_literal(0.0f);
auto max_val = p.add_literal(6.0f);
min_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, max_val);
p.add_instruction(migraphx::op::clip{}, bias_add, min_val, max_val);
return p;
}
};
......@@ -1928,8 +1941,12 @@ struct test_clip : verify_program<test_clip>
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::clip{6.0, 0.0}, x);
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3}});
auto min_val = p.add_literal(0.0f);
auto max_val = p.add_literal(6.0f);
min_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
p.add_instruction(migraphx::op::clip{}, x, min_val, max_val);
return p;
}
};
......@@ -4387,9 +4404,14 @@ struct test_rsqrt : verify_program<test_rsqrt>
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}};
auto x = p.add_parameter("x", s);
auto l0 = p.add_instruction(migraphx::op::clip{std::numeric_limits<float>::max(), 1.0}, x);
std::vector<size_t> input_lens{1, 3, 16, 16};
migraphx::shape s{migraphx::shape::float_type, input_lens};
auto x = p.add_parameter("x", s);
auto min_val = p.add_literal(1.0f);
auto max_val = p.add_literal(std::numeric_limits<float>::max());
min_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, max_val);
auto l0 = p.add_instruction(migraphx::op::clip{}, x, min_val, max_val);
p.add_instruction(migraphx::op::rsqrt{}, l0);
return p;
};
......
No preview for this file type
clip_test_op11_no_args:H
01"Clipclip_test_op11_no_argsZ
0

b
1

B
\ No newline at end of file
......@@ -268,6 +268,43 @@ def clip_test():
return ([node], [x], [y])
@onnx_test
def clip_test_op11():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
min_val = helper.make_tensor('min', TensorProto.FLOAT, [], [0.0])
max_val = helper.make_tensor('max', TensorProto.FLOAT, [], [6.0])
node = onnx.helper.make_node('Clip',
inputs=['0', 'min', 'max'],
outputs=['1'])
return ([node], [x], [y], [min_val, max_val])
@onnx_test
def clip_test_op11_min_only():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
min_val = helper.make_tensor('min', TensorProto.FLOAT, [], [0.0])
node = onnx.helper.make_node('Clip', inputs=['0', 'min'], outputs=['1'])
return ([node], [x], [y], [min_val])
@onnx_test
def clip_test_op11_no_args():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
node = onnx.helper.make_node('Clip', inputs=['0'], outputs=['1'])
return ([node], [x], [y])
@onnx_test
def concat_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 4, 3])
......
......@@ -222,13 +222,53 @@ TEST_CASE(ceil_test)
TEST_CASE(clip_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0);
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
auto min_val = p.add_literal(0.0f);
auto max_val = p.add_literal(6.0f);
min_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
p.add_instruction(migraphx::op::clip{}, l0, min_val, max_val);
auto prog = optimize_onnx("clip_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(clip_test_op11)
{
migraphx::program p;
auto min_val = p.add_literal(0.0f);
auto max_val = p.add_literal(6.0f);
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
min_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
p.add_instruction(migraphx::op::clip{}, l0, min_val, max_val);
auto prog = optimize_onnx("clip_test_op11.onnx");
EXPECT(p == prog);
}
TEST_CASE(clip_test_op11_min_only)
{
migraphx::program p;
auto min_val = p.add_literal(0.0f);
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
min_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
p.add_instruction(migraphx::op::max{}, l0, min_val);
auto prog = optimize_onnx("clip_test_op11_min_only.onnx");
EXPECT(p == prog);
}
TEST_CASE(clip_test_op11_no_args)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::identity{}, l0);
auto prog = optimize_onnx("clip_test_op11_no_args.onnx");
EXPECT(p == prog);
}
TEST_CASE(concat_test)
{
migraphx::program p;
......
......@@ -14,6 +14,31 @@
#include "test.hpp"
#include <migraphx/half.hpp>
migraphx::instruction_ref
create_clip_op(migraphx::program& p, float max, float min, migraphx::instruction_ref input)
{
auto input_lens = input->get_shape().lens();
auto max_val = p.add_literal(max);
auto min_val = p.add_literal(min);
max_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, max_val);
min_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, min_val);
return p.add_instruction(migraphx::op::clip{}, input, min_val, max_val);
}
migraphx::instruction_ref create_clip_op(migraphx::instruction_ref insert_loc,
migraphx::program& p,
float max,
float min,
migraphx::instruction_ref input)
{
auto input_lens = input->get_shape().lens();
auto max_val = p.add_literal(max);
auto min_val = p.add_literal(min);
max_val = p.insert_instruction(insert_loc, migraphx::op::multibroadcast{input_lens}, max_val);
min_val = p.insert_instruction(insert_loc, migraphx::op::multibroadcast{input_lens}, min_val);
return p.insert_instruction(insert_loc, migraphx::op::clip{}, input, min_val, max_val);
}
TEST_CASE(param_add)
{
auto create_program_float = [](bool add_return = false) {
......@@ -308,7 +333,7 @@ TEST_CASE(dot_float)
auto fa = p.add_literal(migraphx::literal(sa, vfa));
auto ma = p.add_instruction(migraphx::op::mul{}, fa, pa);
auto ra = p.add_instruction(migraphx::op::round{}, ma);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
// quantize parameter b to int8 type
......@@ -317,7 +342,7 @@ TEST_CASE(dot_float)
auto fb = p.add_literal(migraphx::literal(sb, vfb));
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
......@@ -372,7 +397,7 @@ TEST_CASE(dot_double_2args)
auto fa = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
auto ma = p.add_instruction(migraphx::op::mul{}, fa, fpa);
auto ra = p.add_instruction(migraphx::op::round{}, ma);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
// quantize parameter b to int8 type
......@@ -383,7 +408,7 @@ TEST_CASE(dot_double_2args)
auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, fpb);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
......@@ -438,7 +463,7 @@ TEST_CASE(dot_large_alpha_beta_float)
auto sfta = p.add_literal(migraphx::literal(sa, vsa));
auto msa = p.add_instruction(migraphx::op::add{}, sfta, ma);
auto ra = p.add_instruction(migraphx::op::round{}, msa);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
// quantize parameter b to int8 type
......@@ -447,7 +472,7 @@ TEST_CASE(dot_large_alpha_beta_float)
auto fb = p.add_literal(migraphx::literal(sb, vfb));
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
......@@ -505,7 +530,7 @@ TEST_CASE(dot_large_alpha_beta_int32)
auto sfta = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa));
auto msa = p.add_instruction(migraphx::op::add{}, sfta, ma);
auto ra = p.add_instruction(migraphx::op::round{}, msa);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
// quantize parameter b to int8 type
......@@ -516,7 +541,7 @@ TEST_CASE(dot_large_alpha_beta_int32)
insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb);
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, conv_b);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
......@@ -557,7 +582,7 @@ TEST_CASE(dot_int32_one_arg)
auto sfta = p.add_literal(migraphx::literal({migraphx::shape::float_type, s.lens()}, vsa));
auto msa = p.add_instruction(migraphx::op::add{}, sfta, fpa);
auto ra = p.add_instruction(migraphx::op::round{}, msa);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
auto q_dot = p.add_instruction(migraphx::op::quant_dot{1, 0}, qa, qa);
......@@ -617,7 +642,7 @@ TEST_CASE(dot_int32)
auto sfta = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa));
auto msa = p.add_instruction(migraphx::op::add{}, sfta, ma);
auto ra = p.add_instruction(migraphx::op::round{}, msa);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra);
auto ca = create_clip_op(p, 127.0f, -128.0f, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
// quantize parameter b to int8 type
......@@ -628,7 +653,7 @@ TEST_CASE(dot_int32)
insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb);
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, conv_b);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
......@@ -692,7 +717,7 @@ TEST_CASE(dot_float_convert)
auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb);
auto cb = create_clip_op(insert_loc, p, 127.0f, -128.0f, rb);
auto qb =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
......@@ -738,7 +763,7 @@ TEST_CASE(conv_float)
auto fx = p.add_literal(migraphx::literal(sx, vfx));
auto mx = p.add_instruction(migraphx::op::mul{}, fx, px);
auto rx = p.add_instruction(migraphx::op::round{}, mx);
auto cx = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, rx);
auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
auto qx = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, cx);
// quantize parameter b to int8 type
......@@ -747,7 +772,7 @@ TEST_CASE(conv_float)
auto fw = p.add_literal(migraphx::literal(sw, vfw));
auto mw = p.insert_instruction(insert_loc, migraphx::op::mul{}, fw, pw);
auto rw = p.insert_instruction(insert_loc, migraphx::op::round{}, mw);
auto cw = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rw);
auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
auto qw =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cw);
......@@ -793,7 +818,7 @@ TEST_CASE(conv_int32)
auto fx = p.add_literal(migraphx::literal(fpx->get_shape(), vfx));
auto mx = p.add_instruction(migraphx::op::mul{}, fx, fpx);
auto rx = p.add_instruction(migraphx::op::round{}, mx);
auto cx = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, rx);
auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
auto qx = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, cx);
// quantize parameter b to int8 type
......@@ -804,7 +829,7 @@ TEST_CASE(conv_int32)
auto fw = p.add_literal(migraphx::literal(fpw->get_shape(), vfw));
auto mw = p.insert_instruction(insert_loc, migraphx::op::mul{}, fw, fpw);
auto rw = p.insert_instruction(insert_loc, migraphx::op::round{}, mw);
auto cw = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rw);
auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
auto qw =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cw);
......@@ -849,7 +874,7 @@ TEST_CASE(conv_half)
auto fx = p.add_literal(migraphx::literal(fpx->get_shape(), vfx));
auto mx = p.add_instruction(migraphx::op::mul{}, fx, fpx);
auto rx = p.add_instruction(migraphx::op::round{}, mx);
auto cx = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, rx);
auto cx = create_clip_op(p, 127.0f, -128.0f, rx);
auto qx = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, cx);
// quantize parameter b to int8 type
......@@ -860,7 +885,7 @@ TEST_CASE(conv_half)
auto fw = p.add_literal(migraphx::literal(fpw->get_shape(), vfw));
auto mw = p.insert_instruction(insert_loc, migraphx::op::mul{}, fw, fpw);
auto rw = p.insert_instruction(insert_loc, migraphx::op::round{}, mw);
auto cw = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rw);
auto cw = create_clip_op(insert_loc, p, 127.0f, -128.0f, rw);
auto qw =
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cw);
......
......@@ -443,8 +443,13 @@ TEST_CASE(relu_test)
TEST_CASE(relu6_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0);
std::vector<size_t> input_lens{1, 3, 16, 16};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens});
auto min_val = p.add_literal(0.0f);
auto max_val = p.add_literal(6.0f);
min_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, max_val);
p.add_instruction(migraphx::op::clip{}, l0, min_val, max_val);
auto prog = optimize_tf("relu6_test.pb", false);
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment