Commit 82fee1e7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

temp code backup

parent 420d2363
......@@ -28,23 +28,29 @@ struct binary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().packed() and input2.get_shape().packed())
{
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
auto s1 = args[0].get_shape();
auto s2 = args[1].get_shape();
if(s1 == s2 and s1.packed())
{
shape std_shape{s1.type(), s1.lens()};
auto input1 = make_view(std_shape, args[0].data());
auto input2 = make_view(std_shape, args[1].data());
auto output = make_view(std_shape, result.data());
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
}
else
{
{
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
});
}
});
});
}
return result;
}
......
......@@ -26,7 +26,9 @@ struct convert : unary<convert>
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.target_type, "target_type"));
return pack(f(self.target_type, "target_type"),
f(self.scale, "scale"),
f(self.shift, "shift"));
}
shape compute_shape(std::vector<shape> inputs) const
......
......@@ -27,26 +27,31 @@ struct unary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
if(input.get_shape().packed())
{
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
auto in_shape = args[0].get_shape();
if (in_shape.packed())
{
shape std_in_shape{in_shape.type(), in_shape.lens()};
shape std_out_shape{output_shape.type(), output_shape.lens()};
auto input = make_view(std_in_shape, args[0].cast());
auto output = make_view(std_out_shape, result.cast());
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
}
else
{
result.visit([&](auto output) {
args[0].visit([&](auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) =
static_cast<const Derived&>(*this).apply()(input(idx.begin(), idx.end()));
});
return result;
}
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) =
static_cast<const Derived&>(*this).apply()(input(idx.begin(), idx.end()));
});
return result;
});
});
}
return result;
}
......
......@@ -4,6 +4,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
......@@ -197,7 +198,17 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
}
auto op = ins->get_operator();
auto ins_shape = compute_shape(op, converted_inputs);
shape ins_shape{};
// just to compute the output shape
if (ins->name() == "dot")
{
ins_shape = compute_shape(op::quant_dot{}, converted_inputs);
}
else
{
ins_shape = compute_shape(op::quant_convolution{}, converted_inputs);
}
if(ins_shape.type() != orig_type)
{
// check the dead code case to avoid assert
......@@ -239,17 +250,17 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs);
auto conv_lens = conv_res->get_shape().lens();
auto fl = prog.add_literal(literal(adjust_factor));
auto adj_fact = prog.insert_instruction(ins, op::multibroadcast{conv_lens}, fl);
prog.replace_instruction(ins, adj_fact);
auto conv_s = conv_res->get_shape();
std::vector<float> vec_fact(conv_s.elements(), adjust_factor);
auto fl = prog.add_literal(literal{conv_s, vec_fact});
auto ad_res = prog.insert_instruction(ins, op::mul{}, conv_res, fl);
prog.replace_instruction(ins, ad_res);
}
else
{
MIGRAPHX_THROW("INT8_QUANTIZE: does not support operator" + ins->name());
}
prog.replace_instruction(ins, op, converted_inputs);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment