Commit 82fee1e7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

temp code backup

parent 420d2363
...@@ -28,9 +28,14 @@ struct binary : op_name<Derived> ...@@ -28,9 +28,14 @@ struct binary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { auto s1 = args[0].get_shape();
if(input1.get_shape().packed() and input2.get_shape().packed()) auto s2 = args[1].get_shape();
if(s1 == s2 and s1.packed())
{ {
shape std_shape{s1.type(), s1.lens()};
auto input1 = make_view(std_shape, args[0].data());
auto input2 = make_view(std_shape, args[1].data());
auto output = make_view(std_shape, result.data());
std::transform(input1.begin(), std::transform(input1.begin(),
input1.end(), input1.end(),
input2.begin(), input2.begin(),
...@@ -39,12 +44,13 @@ struct binary : op_name<Derived> ...@@ -39,12 +44,13 @@ struct binary : op_name<Derived>
} }
else else
{ {
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()( output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end())); input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
}); });
}
}); });
}
return result; return result;
} }
......
...@@ -26,7 +26,9 @@ struct convert : unary<convert> ...@@ -26,7 +26,9 @@ struct convert : unary<convert>
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.target_type, "target_type")); return pack(f(self.target_type, "target_type"),
f(self.scale, "scale"),
f(self.shift, "shift"));
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
...@@ -27,18 +27,22 @@ struct unary : op_name<Derived> ...@@ -27,18 +27,22 @@ struct unary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
result.visit([&](auto output) { auto in_shape = args[0].get_shape();
args[0].visit([&](auto input) { if (in_shape.packed())
if(input.get_shape().packed())
{ {
shape std_in_shape{in_shape.type(), in_shape.lens()};
shape std_out_shape{output_shape.type(), output_shape.lens()};
auto input = make_view(std_in_shape, args[0].cast());
auto output = make_view(std_out_shape, result.cast());
std::transform(input.begin(), std::transform(input.begin(),
input.end(), input.end(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
return result;
} }
else
{
result.visit([&](auto output) {
args[0].visit([&](auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = output(idx.begin(), idx.end()) =
static_cast<const Derived&>(*this).apply()(input(idx.begin(), idx.end())); static_cast<const Derived&>(*this).apply()(input(idx.begin(), idx.end()));
...@@ -47,6 +51,7 @@ struct unary : op_name<Derived> ...@@ -47,6 +51,7 @@ struct unary : op_name<Derived>
return result; return result;
}); });
}); });
}
return result; return result;
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp> #include <migraphx/op/convert.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
...@@ -197,7 +198,17 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -197,7 +198,17 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
} }
auto op = ins->get_operator(); auto op = ins->get_operator();
auto ins_shape = compute_shape(op, converted_inputs); shape ins_shape{};
// just to compute the output shape
if (ins->name() == "dot")
{
ins_shape = compute_shape(op::quant_dot{}, converted_inputs);
}
else
{
ins_shape = compute_shape(op::quant_convolution{}, converted_inputs);
}
if(ins_shape.type() != orig_type) if(ins_shape.type() != orig_type)
{ {
// check the dead code case to avoid assert // check the dead code case to avoid assert
...@@ -239,17 +250,17 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -239,17 +250,17 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
ins, ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group}, op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs); converted_inputs);
auto conv_lens = conv_res->get_shape().lens(); auto conv_s = conv_res->get_shape();
auto fl = prog.add_literal(literal(adjust_factor)); std::vector<float> vec_fact(conv_s.elements(), adjust_factor);
auto adj_fact = prog.insert_instruction(ins, op::multibroadcast{conv_lens}, fl);
prog.replace_instruction(ins, adj_fact); auto fl = prog.add_literal(literal{conv_s, vec_fact});
auto ad_res = prog.insert_instruction(ins, op::mul{}, conv_res, fl);
prog.replace_instruction(ins, ad_res);
} }
else else
{ {
MIGRAPHX_THROW("INT8_QUANTIZE: does not support operator" + ins->name()); MIGRAPHX_THROW("INT8_QUANTIZE: does not support operator" + ins->name());
} }
prog.replace_instruction(ins, op, converted_inputs);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment