Commit bfdba340 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the attribute s in the scalar operator from shape type to size_t vector type.

parent da77ce0b
......@@ -18,7 +18,14 @@ namespace op {
struct scalar
{
shape scalar_bcast;
std::vector<std::size_t> scalar_bcast_lens;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.scalar_bcast_lens, "scalar_bcst_dims"));
}
std::string name() const { return "scalar"; }
......@@ -26,8 +33,8 @@ struct scalar
{
assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1);
auto t = inputs.at(0).type();
std::vector<std::size_t> strides(scalar_bcast.lens().size(), 0);
return {t, scalar_bcast.lens(), strides};
std::vector<std::size_t> strides(scalar_bcast_lens.size(), 0);
return {t, scalar_bcast_lens, strides};
}
argument compute(shape output_shape, std::vector<argument> args) const
......
......@@ -671,16 +671,16 @@ struct onnx_parser
auto&& bias_floats = attributes["bias"].floats();
bias = std::vector<float>(bias_floats.begin(), bias_floats.end());
}
auto input_shape = args.front()->get_shape();
auto input_lens = args.front()->get_shape().lens();
auto scale_val = prog.add_literal(scale);
auto bias_vals = prog.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias});
auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_shape}, scale_val);
auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
auto bias_bcast =
prog.add_instruction(migraphx::op::broadcast{1, input_shape.lens()}, bias_vals);
prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
}
......
......@@ -809,7 +809,7 @@ TEST_CASE(imagescaler_test)
0.35,
0.45}});
auto scale_val = p.add_literal(2.f);
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val);
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, img, scaled_tensor);
auto bias_vals = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
......
......@@ -371,7 +371,7 @@ struct test_scale : verify_program<test_scale>
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", migraphx::shape::float_type);
auto scale = p.add_instruction(migraphx::op::scalar{s}, y);
auto scale = p.add_instruction(migraphx::op::scalar{s.lens()}, y);
p.add_instruction(migraphx::op::mul{}, x, scale);
return p;
}
......
......@@ -108,7 +108,7 @@ TEST_CASE(imagescaler_test)
auto scale_val = p.add_literal(0.5f);
auto bias_vals = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val);
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment