Commit 970ac115 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 82fee1e7
...@@ -37,12 +37,12 @@ struct binary : op_name<Derived> ...@@ -37,12 +37,12 @@ struct binary : op_name<Derived>
auto input2 = make_view(std_shape, args[1].data()); auto input2 = make_view(std_shape, args[1].data());
auto output = make_view(std_shape, result.data()); auto output = make_view(std_shape, result.data());
std::transform(input1.begin(), std::transform(input1.begin(),
input1.end(), input1.end(),
input2.begin(), input2.begin(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
} }
else else
{ {
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
......
...@@ -26,9 +26,8 @@ struct convert : unary<convert> ...@@ -26,9 +26,8 @@ struct convert : unary<convert>
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.target_type, "target_type"), return pack(
f(self.scale, "scale"), f(self.target_type, "target_type"), f(self.scale, "scale"), f(self.shift, "shift"));
f(self.shift, "shift"));
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
...@@ -28,11 +28,11 @@ struct unary : op_name<Derived> ...@@ -28,11 +28,11 @@ struct unary : op_name<Derived>
{ {
argument result{output_shape}; argument result{output_shape};
auto in_shape = args[0].get_shape(); auto in_shape = args[0].get_shape();
if (in_shape.packed()) if(in_shape.packed())
{ {
shape std_in_shape{in_shape.type(), in_shape.lens()}; shape std_in_shape{in_shape.type(), in_shape.lens()};
shape std_out_shape{output_shape.type(), output_shape.lens()}; shape std_out_shape{output_shape.type(), output_shape.lens()};
auto input = make_view(std_in_shape, args[0].cast()); auto input = make_view(std_in_shape, args[0].cast());
auto output = make_view(std_out_shape, result.cast()); auto output = make_view(std_out_shape, result.cast());
std::transform(input.begin(), std::transform(input.begin(),
input.end(), input.end(),
...@@ -44,8 +44,8 @@ struct unary : op_name<Derived> ...@@ -44,8 +44,8 @@ struct unary : op_name<Derived>
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
static_cast<const Derived&>(*this).apply()(input(idx.begin(), idx.end())); input(idx.begin(), idx.end()));
}); });
return result; return result;
......
...@@ -197,18 +197,18 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -197,18 +197,18 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
continue; continue;
} }
auto op = ins->get_operator(); auto op = ins->get_operator();
shape ins_shape{}; shape ins_shape{};
// just to compute the output shape // just to compute the output shape
if (ins->name() == "dot") if(ins->name() == "dot")
{ {
ins_shape = compute_shape(op::quant_dot{}, converted_inputs); ins_shape = compute_shape(op::quant_dot{}, converted_inputs);
} }
else else
{ {
ins_shape = compute_shape(op::quant_convolution{}, converted_inputs); ins_shape = compute_shape(op::quant_convolution{}, converted_inputs);
} }
if(ins_shape.type() != orig_type) if(ins_shape.type() != orig_type)
{ {
// check the dead code case to avoid assert // check the dead code case to avoid assert
...@@ -253,7 +253,7 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -253,7 +253,7 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
auto conv_s = conv_res->get_shape(); auto conv_s = conv_res->get_shape();
std::vector<float> vec_fact(conv_s.elements(), adjust_factor); std::vector<float> vec_fact(conv_s.elements(), adjust_factor);
auto fl = prog.add_literal(literal{conv_s, vec_fact}); auto fl = prog.add_literal(literal{conv_s, vec_fact});
auto ad_res = prog.insert_instruction(ins, op::mul{}, conv_res, fl); auto ad_res = prog.insert_instruction(ins, op::mul{}, conv_res, fl);
prog.replace_instruction(ins, ad_res); prog.replace_instruction(ins, ad_res);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment