Commit 89dfaaff authored by Khalique's avatar Khalique
Browse files

added parse onnx for imagescaler

parent 57444235
...@@ -306,10 +306,6 @@ struct contiguous ...@@ -306,10 +306,6 @@ struct contiguous
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
auto lens = inputs.at(0).lens(); auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
// if(lens.size() < 2)
// {
// MIGRAPH_THROW("Number of dimensions should exceed 1");
// }
return {t, lens}; return {t, lens};
} }
}; };
......
...@@ -56,7 +56,7 @@ struct onnx_parser ...@@ -56,7 +56,7 @@ struct onnx_parser
add_generic_op("Sub", op::sub{}); add_generic_op("Sub", op::sub{});
add_generic_op("Sum", op::add{}); add_generic_op("Sum", op::add{});
// add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Constant", &onnx_parser::parse_constant); add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv); add_mem_op("Conv", &onnx_parser::parse_conv);
...@@ -325,10 +325,33 @@ struct onnx_parser ...@@ -325,10 +325,33 @@ struct onnx_parser
return prog.add_instruction(op, args.front()); return prog.add_instruction(op, args.front());
} }
// instruction_ref parse_imagescaler(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) instruction_ref parse_imagescaler(const std::string&,
// { attribute_map attributes,
std::vector<instruction_ref> args)
{
float scale = 1.0;
std::vector<float> bias{};
if(contains(attributes, "scale"))
{
scale = parse_value(attributes.at("scale")).at<float>();
}
if(contains(attributes, "bias"))
{
auto&& bias_floats = attributes["bias"].floats();
bias = std::vector<float>(bias_floats.begin(), bias_floats.end());
}
auto input_shape = args.front()->get_shape();
// } auto scale_val = prog.add_literal(scale);
auto bias_vals = prog.add_literal(
migraph::literal{migraph::shape{migraph::shape::float_type, {bias.size()}}, bias});
auto scale_tensor = prog.add_instruction(migraph::op::scalar{input_shape}, scale_val);
auto img_scaled = prog.add_instruction(migraph::op::mul{}, args.front(), scale_tensor);
auto bias_bcast = prog.add_instruction(migraph::op::broadcast{1, input_shape}, bias_vals);
return prog.add_instruction(migraph::op::add{}, img_scaled, bias_bcast);
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
......
...@@ -5,9 +5,18 @@ namespace migraph { ...@@ -5,9 +5,18 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
void mul(const argument& result, const argument& arg1, const argument& arg2) void mul(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
nary(result, arg1, arg2)([](auto x, auto y) { return x * y; }); nary(stream, result, arg1, arg2)([](auto x, auto y) { return x * y; });
}
void mul(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z) { return x * y * z; });
} }
} // namespace device } // namespace device
......
...@@ -3,12 +3,19 @@ ...@@ -3,12 +3,19 @@
#define MIGRAPH_GUARD_RTGLIB_DEVICE_MUL_HPP #define MIGRAPH_GUARD_RTGLIB_DEVICE_MUL_HPP
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <hip/hip_runtime_api.h>
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
void mul(const argument& result, const argument& arg1, const argument& arg2); void mul(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2);
void mul(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -14,9 +14,9 @@ shape hip_mul::compute_shape(const std::vector<shape>& inputs) const ...@@ -14,9 +14,9 @@ shape hip_mul::compute_shape(const std::vector<shape>& inputs) const
return inputs.at(0); return inputs.at(0);
} }
argument hip_mul::compute(context&, const shape&, const std::vector<argument>& args) const argument hip_mul::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::mul(args[2], args[0], args[1]); device::mul(ctx.get_stream().get(), args[2], args[0], args[1]);
return args[2]; return args[2];
} }
......
...@@ -100,6 +100,23 @@ void leaky_relu_test() ...@@ -100,6 +100,23 @@ void leaky_relu_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void imagescaler_test()
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {1, 3, 16, 16}};
auto l0 = p.add_parameter("0", s);
auto scale_val = p.add_literal(0.5f);
auto bias_vals = p.add_literal(migraph::literal{migraph::shape{migraph::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto scaled_tensor = p.add_instruction(migraph::op::scalar{s}, scale_val);
auto img_scaled = p.add_instruction(migraph::op::mul{}, l0, scaled_tensor);
auto bias_bcast = p.add_instruction(migraph::op::broadcast{1, s}, bias_vals);
p.add_instruction(migraph::op::add{}, img_scaled, bias_bcast);
auto prog = migraph::parse_onnx("imagescaler_test.onnx");
EXPECT(p == prog);
}
int main() int main()
{ {
pytorch_conv_bias_test(); pytorch_conv_bias_test();
...@@ -107,4 +124,5 @@ int main() ...@@ -107,4 +124,5 @@ int main()
pytorch_conv_bn_relu_maxpool(); pytorch_conv_bn_relu_maxpool();
pytorch_conv_relu_maxpool_x2(); pytorch_conv_relu_maxpool_x2();
leaky_relu_test(); leaky_relu_test();
imagescaler_test();
} }
...@@ -93,7 +93,7 @@ void contiguous_shape() ...@@ -93,7 +93,7 @@ void contiguous_shape()
throws_shape(migraph::op::contiguous{}, input, input); throws_shape(migraph::op::contiguous{}, input, input);
migraph::shape single{migraph::shape::float_type, {2}}; migraph::shape single{migraph::shape::float_type, {2}};
throws_shape(migraph::op::contiguous{}, single); expect_shape(single, migraph::op::contiguous{}, single);
} }
void reshape_shape() void reshape_shape()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment