Commit bfdba340 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the attribute s in the scalar operator from shape type to size_t vector type.

parent da77ce0b
...@@ -18,7 +18,14 @@ namespace op { ...@@ -18,7 +18,14 @@ namespace op {
struct scalar struct scalar
{ {
shape scalar_bcast; std::vector<std::size_t> scalar_bcast_lens;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.scalar_bcast_lens, "scalar_bcst_dims"));
}
std::string name() const { return "scalar"; } std::string name() const { return "scalar"; }
...@@ -26,8 +33,8 @@ struct scalar ...@@ -26,8 +33,8 @@ struct scalar
{ {
assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1); assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1);
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
std::vector<std::size_t> strides(scalar_bcast.lens().size(), 0); std::vector<std::size_t> strides(scalar_bcast_lens.size(), 0);
return {t, scalar_bcast.lens(), strides}; return {t, scalar_bcast_lens, strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
...@@ -671,16 +671,16 @@ struct onnx_parser ...@@ -671,16 +671,16 @@ struct onnx_parser
auto&& bias_floats = attributes["bias"].floats(); auto&& bias_floats = attributes["bias"].floats();
bias = std::vector<float>(bias_floats.begin(), bias_floats.end()); bias = std::vector<float>(bias_floats.begin(), bias_floats.end());
} }
auto input_shape = args.front()->get_shape(); auto input_lens = args.front()->get_shape().lens();
auto scale_val = prog.add_literal(scale); auto scale_val = prog.add_literal(scale);
auto bias_vals = prog.add_literal( auto bias_vals = prog.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias});
auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_shape}, scale_val); auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor); auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
auto bias_bcast = auto bias_bcast =
prog.add_instruction(migraphx::op::broadcast{1, input_shape.lens()}, bias_vals); prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
} }
......
...@@ -809,7 +809,7 @@ TEST_CASE(imagescaler_test) ...@@ -809,7 +809,7 @@ TEST_CASE(imagescaler_test)
0.35, 0.35,
0.45}}); 0.45}});
auto scale_val = p.add_literal(2.f); auto scale_val = p.add_literal(2.f);
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val); auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, img, scaled_tensor); auto img_scaled = p.add_instruction(migraphx::op::mul{}, img, scaled_tensor);
auto bias_vals = p.add_literal( auto bias_vals = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
......
...@@ -371,7 +371,7 @@ struct test_scale : verify_program<test_scale> ...@@ -371,7 +371,7 @@ struct test_scale : verify_program<test_scale>
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", migraphx::shape::float_type); auto y = p.add_parameter("y", migraphx::shape::float_type);
auto scale = p.add_instruction(migraphx::op::scalar{s}, y); auto scale = p.add_instruction(migraphx::op::scalar{s.lens()}, y);
p.add_instruction(migraphx::op::mul{}, x, scale); p.add_instruction(migraphx::op::mul{}, x, scale);
return p; return p;
} }
......
...@@ -108,7 +108,7 @@ TEST_CASE(imagescaler_test) ...@@ -108,7 +108,7 @@ TEST_CASE(imagescaler_test)
auto scale_val = p.add_literal(0.5f); auto scale_val = p.add_literal(0.5f);
auto bias_vals = p.add_literal( auto bias_vals = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val); auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor); auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals); auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment