Commit 3a848f0d authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into doc2

parents 64e8e30a d1e945da
...@@ -87,13 +87,15 @@ struct program ...@@ -87,13 +87,15 @@ struct program
instruction_ref add_parameter(std::string name, shape s); instruction_ref add_parameter(std::string name, shape s);
instruction_ref add_return(std::vector<instruction_ref> args);
shape get_parameter_shape(std::string name) const; shape get_parameter_shape(std::string name) const;
instruction_ref get_parameter(std::string name) const; instruction_ref get_parameter(std::string name) const;
std::unordered_map<std::string, shape> get_parameter_shapes() const; std::unordered_map<std::string, shape> get_parameter_shapes() const;
argument eval(parameter_map params) const; std::vector<argument> eval(parameter_map params) const;
bool has_instruction(instruction_ref ins) const; bool has_instruction(instruction_ref ins) const;
...@@ -101,7 +103,7 @@ struct program ...@@ -101,7 +103,7 @@ struct program
instruction_ref begin() const; instruction_ref begin() const;
instruction_ref end() const; instruction_ref end() const;
shape get_shape() const; std::vector<shape> get_output_shapes() const;
context& get_context() const; context& get_context() const;
......
...@@ -69,11 +69,17 @@ struct schedule_model ...@@ -69,11 +69,17 @@ struct schedule_model
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
schedule_model& operator=(PrivateDetailTypeErasedT value) schedule_model& operator=(PrivateDetailTypeErasedT value)
{ {
if(private_detail_te_handle_mem_var.unique()) using std::swap;
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value); auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
else if(!private_detail_te_handle_mem_var) if(derived and private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>( {
std::forward<PrivateDetailTypeErasedT>(value)); *derived = std::forward<PrivateDetailTypeErasedT>(value);
}
else
{
schedule_model rhs(value);
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
}
return *this; return *this;
} }
...@@ -81,7 +87,7 @@ struct schedule_model ...@@ -81,7 +87,7 @@ struct schedule_model
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast() PrivateDetailTypeErasedT* any_cast()
{ {
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type< ? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>( typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle()) private_detail_te_get_handle())
...@@ -92,7 +98,7 @@ struct schedule_model ...@@ -92,7 +98,7 @@ struct schedule_model
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{ {
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type< ? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>( typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle()) private_detail_te_get_handle())
......
...@@ -115,11 +115,17 @@ struct target ...@@ -115,11 +115,17 @@ struct target
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
target& operator=(PrivateDetailTypeErasedT value) target& operator=(PrivateDetailTypeErasedT value)
{ {
if(private_detail_te_handle_mem_var.unique()) using std::swap;
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value); auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
else if(!private_detail_te_handle_mem_var) if(derived and private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>( {
std::forward<PrivateDetailTypeErasedT>(value)); *derived = std::forward<PrivateDetailTypeErasedT>(value);
}
else
{
target rhs(value);
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
}
return *this; return *this;
} }
...@@ -127,7 +133,7 @@ struct target ...@@ -127,7 +133,7 @@ struct target
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast() PrivateDetailTypeErasedT* any_cast()
{ {
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type< ? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>( typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle()) private_detail_te_get_handle())
...@@ -138,7 +144,7 @@ struct target ...@@ -138,7 +144,7 @@ struct target
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{ {
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type< ? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>( typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle()) private_detail_te_get_handle())
......
...@@ -7,8 +7,15 @@ ...@@ -7,8 +7,15 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/// struct to pass in tf options to parser
struct tf_options
{
bool is_nhwc = false;
unsigned int batch_size = 1;
};
/// Create a program from a tf pb file (default is nhwc format) /// Create a program from a tf pb file (default is nhwc format)
program parse_tf(const std::string& name, bool is_nhwc); program parse_tf(const std::string& name, tf_options = tf_options{});
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -22,6 +22,9 @@ void instruction::replace(const shape& r) ...@@ -22,6 +22,9 @@ void instruction::replace(const shape& r)
result = r; result = r;
for(auto&& ins : output) for(auto&& ins : output)
{ {
if(ins->name() == "@return")
continue;
assert(ins->name().front() != '@'); assert(ins->name().front() != '@');
ins->recompute_shape(); ins->recompute_shape();
} }
...@@ -70,6 +73,10 @@ bool instruction::valid() const ...@@ -70,6 +73,10 @@ bool instruction::valid() const
{ {
computed = result; computed = result;
} }
else if(op.name() == "@return")
{
computed = {};
}
else else
{ {
try try
...@@ -81,6 +88,7 @@ bool instruction::valid() const ...@@ -81,6 +88,7 @@ bool instruction::valid() const
return false; return false;
} }
} }
return result == computed && std::all_of(output.begin(), output.end(), [&](instruction_ref i) { return result == computed && std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end(); return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end();
}); });
......
...@@ -73,8 +73,9 @@ int main(int argc, char const* argv[]) ...@@ -73,8 +73,9 @@ int main(int argc, char const* argv[])
for(int i = 0; i < 10; i++) for(int i = 0; i < 10; i++)
{ {
std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> "; std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> ";
m["0"] = migraphx::gpu::to_gpu(migraphx::argument{s, &ptr[3072 * i]}); m["0"] = migraphx::gpu::to_gpu(migraphx::argument{s, &ptr[3072 * i]});
auto result = migraphx::gpu::from_gpu(prog.eval(m)); auto gpu_result = prog.eval(m).back();
auto result = migraphx::gpu::from_gpu(gpu_result);
std::vector<float> logits; std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax<float>(logits); std::vector<float> probs = softmax<float>(logits);
...@@ -95,7 +96,7 @@ int main(int argc, char const* argv[]) ...@@ -95,7 +96,7 @@ int main(int argc, char const* argv[])
{ {
std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> "; std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> ";
auto input3 = migraphx::argument{s, &ptr[3072 * i]}; auto input3 = migraphx::argument{s, &ptr[3072 * i]};
auto result = prog.eval({{"0", input3}}); auto result = prog.eval({{"0", input3}}).back();
std::vector<float> logits; std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax<float>(logits); std::vector<float> probs = softmax<float>(logits);
......
...@@ -130,8 +130,9 @@ int main(int argc, char const* argv[]) ...@@ -130,8 +130,9 @@ int main(int argc, char const* argv[])
for(int i = 0; i < 20; i++) for(int i = 0; i < 20; i++)
{ {
std::cout << "label: " << labels[i] << " ----> "; std::cout << "label: " << labels[i] << " ----> ";
m["0"] = migraphx::gpu::to_gpu(migraphx::argument{s, &ptr[784 * i]}); m["0"] = migraphx::gpu::to_gpu(migraphx::argument{s, &ptr[784 * i]});
auto result = migraphx::gpu::from_gpu(prog.eval(m)); auto results = prog.eval(m).back();
auto result = migraphx::gpu::from_gpu(results);
std::vector<float> logits; std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax(logits); std::vector<float> probs = softmax(logits);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/pad_calc.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -28,87 +29,102 @@ struct onnx_parser ...@@ -28,87 +29,102 @@ struct onnx_parser
std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>; std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
bool is_pytorch = false; bool is_pytorch = false;
unsigned int batch_size = 1;
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
std::unordered_map<std::string, operation> map_actv_funcs; std::unordered_map<std::string, operation> map_actv_funcs;
onnx_parser() onnx_parser()
{ {
add_generic_op("Relu", op::relu{}); // sort onnx operator alphabetically through name
add_generic_op("Sigmoid", op::sigmoid{});
add_generic_op("Abs", op::abs{}); add_generic_op("Abs", op::abs{});
add_generic_op("Exp", op::exp{}); add_generic_op("Acos", op::acos{});
add_generic_op("Acosh", op::acosh{});
add_generic_op("Asin", op::asin{});
add_generic_op("Asinh", op::asinh{});
add_generic_op("Atan", op::atan{});
add_generic_op("Atanh", op::atanh{});
add_generic_op("Ceil", op::ceil{});
add_generic_op("Cos", op::cos{});
add_generic_op("Cosh", op::cosh{});
add_generic_op("Erf", op::erf{}); add_generic_op("Erf", op::erf{});
add_generic_op("Log", op::log{}); add_generic_op("Exp", op::exp{});
// disable dropout for inference
add_generic_op("Dropout", op::identity{}); add_generic_op("Dropout", op::identity{});
add_generic_op("Log", op::log{});
add_generic_op("Floor", op::floor{});
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", op::identity{});
add_generic_op("Relu", op::relu{});
add_generic_op("Round", op::round{});
add_generic_op("Sigmoid", op::sigmoid{});
add_generic_op("Sign", op::sign{});
add_generic_op("Sin", op::sin{}); add_generic_op("Sin", op::sin{});
add_generic_op("Cos", op::cos{});
add_generic_op("Tan", op::tan{});
add_generic_op("Sinh", op::sinh{}); add_generic_op("Sinh", op::sinh{});
add_generic_op("Cosh", op::cosh{});
add_generic_op("Tanh", op::tanh{});
add_generic_op("Asin", op::asin{});
add_generic_op("Acos", op::acos{});
add_generic_op("Atan", op::atan{});
add_generic_op("Sqrt", op::sqrt{}); add_generic_op("Sqrt", op::sqrt{});
add_generic_op("Round", op::round{}); add_generic_op("Tan", op::tan{});
add_generic_op("Sign", op::sign{}); add_generic_op("Tanh", op::tanh{});
add_generic_op("Ceil", op::ceil{});
add_generic_op("Floor", op::floor{});
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Div", op::div{}); add_binary_op("Div", op::div{});
add_binary_op("Mul", op::mul{}); add_binary_op("Mul", op::mul{});
add_binary_op("Sub", op::sub{});
add_binary_op("Pow", op::pow{}); add_binary_op("Pow", op::pow{});
add_binary_op("PRelu", op::prelu{});
add_binary_op("Sub", op::sub{});
add_variadic_op("Sum", op::add{}); add_variadic_op("Sum", op::add{});
add_variadic_op("Max", op::max{}); add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{}); add_variadic_op("Min", op::min{});
add_mem_op("AveragePool", &onnx_parser::parse_pooling);
add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>); add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>); add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Cast", &onnx_parser::parse_cast); add_mem_op("Cast", &onnx_parser::parse_cast);
add_mem_op("Clip", &onnx_parser::parse_clip); add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn); add_mem_op("Concat", &onnx_parser::parse_concat);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
add_mem_op("Conv", &onnx_parser::parse_conv<op::convolution>);
add_mem_op("ConvInteger", &onnx_parser::parse_conv<op::quant_convolution>);
add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
add_mem_op("Elu", &onnx_parser::parse_elu); add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("Expand", &onnx_parser::parse_expand); add_mem_op("Expand", &onnx_parser::parse_expand);
add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("AveragePool", &onnx_parser::parse_pooling);
add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("Flatten", &onnx_parser::parse_flatten); add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("MatMul", &onnx_parser::parse_matmul); add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm); add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>); add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>); add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze); add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze); add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
add_mem_op("Slice", &onnx_parser::parse_slice); add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
add_mem_op("Concat", &onnx_parser::parse_concat); add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("Gather", &onnx_parser::parse_gather); add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
add_mem_op("ReduceLogSumExp", &onnx_parser::parse_reduce_log_sum_exp);
add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
add_mem_op("ReduceProd", &onnx_parser::parse_reduce_oper<op::reduce_prod>);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("Shape", &onnx_parser::parse_shape); add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill); add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape); add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn); add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm); add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
// init the activation function map // init the activation function map
init_actv_func(); init_actv_func();
...@@ -230,8 +246,15 @@ struct onnx_parser ...@@ -230,8 +246,15 @@ struct onnx_parser
auto s0 = arg0->get_shape().lens(); auto s0 = arg0->get_shape().lens();
auto s1 = arg1->get_shape().lens(); auto s1 = arg1->get_shape().lens();
auto out_lens = compute_broadcasted_lens(s0, s1); auto out_lens = compute_broadcasted_lens(s0, s1);
auto l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1); auto l0 = arg0;
if(arg0->get_shape().lens() != out_lens)
l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
auto l1 = arg1;
if(arg1->get_shape().lens() != out_lens)
l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
return prog.add_instruction(x, l0, l1); return prog.add_instruction(x, l0, l1);
} }
else else
...@@ -261,6 +284,43 @@ struct onnx_parser ...@@ -261,6 +284,43 @@ struct onnx_parser
}); });
} }
template <class T>
std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
{
std::vector<int64_t> output_vector(input_vector.begin(), input_vector.end());
return output_vector;
}
instruction_ref
add_bias(const std::vector<instruction_ref>& args, instruction_ref curr_ins, uint64_t axis)
{
if(args.size() == 3)
{
auto bias_bcast =
prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
return prog.add_instruction(op::add{}, curr_ins, bias_bcast);
}
return curr_ins;
}
template <class Op>
void check_asym_padding(instruction_ref& ins,
std::vector<int64_t>& padding,
Op& op,
float pad_val = 0)
{
if(padding[0] != padding[2] || padding[1] != padding[3])
{
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
ins = prog.add_instruction(op::pad{padding, pad_val}, ins);
}
else
{
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
}
instruction_ref parse_clip(const std::string&, instruction_ref parse_clip(const std::string&,
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
...@@ -282,7 +342,7 @@ struct onnx_parser ...@@ -282,7 +342,7 @@ struct onnx_parser
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
int axis = 1; int64_t axis = 1;
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(attributes.at("axis")).at<int>();
...@@ -319,11 +379,72 @@ struct onnx_parser ...@@ -319,11 +379,72 @@ struct onnx_parser
} }
} }
template <class Op>
instruction_ref process_auto_pad_attribute(instruction_ref ins,
attribute_map& attributes,
Op& op,
const std::vector<std::size_t>& in_lens)
{
if(!contains(attributes, "auto_pad"))
{
return ins;
}
auto auto_pad = attributes["auto_pad"].s();
if(auto_pad.find("SAME") != std::string::npos)
{
// calculate the padding
std::array<std::size_t, 2> out_lens;
out_lens[0] = (in_lens[2] + op.stride[0] - 1) / op.stride[0];
out_lens[1] = (in_lens[3] + op.stride[1] - 1) / op.stride[1];
std::array<std::size_t, 2> explicit_pads;
explicit_pads[0] = (out_lens[0] - 1) * op.stride[0] + op.lengths[0] - in_lens[2];
explicit_pads[1] = (out_lens[1] - 1) * op.stride[1] + op.lengths[1] - in_lens[3];
op.padding[0] = explicit_pads[0] / 2;
op.padding[1] = explicit_pads[1] / 2;
explicit_pads[0] -= 2 * op.padding[0];
explicit_pads[1] -= 2 * op.padding[1];
std::vector<std::int64_t> pads(8, 0);
if(explicit_pads[0] != 0 or explicit_pads[1] != 0)
{
if(auto_pad == "SAME_UPPER")
{
pads[6] = explicit_pads[0];
pads[7] = explicit_pads[1];
}
else if(auto_pad == "SAME_LOWER")
{
pads[2] = explicit_pads[0];
pads[3] = explicit_pads[1];
}
// MaxPool
if(op.mode == "max")
{
ins = prog.add_instruction(op::pad{pads, std::numeric_limits<float>::lowest()},
ins);
}
// AveragePool
else
{
ins = prog.add_instruction(op::pad{pads}, ins);
}
}
op.padding_mode = op::padding_mode_t::same;
}
return ins;
}
template <class Op>
instruction_ref instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
op::convolution op; Op op;
auto l0 = args[0]; auto l0 = args[0];
auto weights = args[1];
if(contains(attributes, "pads")) if(contains(attributes, "pads"))
{ {
if(contains(attributes, "auto_pad")) if(contains(attributes, "auto_pad"))
...@@ -340,11 +461,76 @@ struct onnx_parser ...@@ -340,11 +461,76 @@ struct onnx_parser
{ {
MIGRAPHX_THROW("padding should have 4 values"); MIGRAPHX_THROW("padding should have 4 values");
} }
check_asym_padding(l0, padding, op);
}
if(contains(attributes, "strides"))
{
copy(attributes["strides"].ints(), op.stride.begin());
}
if(contains(attributes, "dilations"))
{
copy(attributes["dilations"].ints(), op.dilation.begin());
}
if(contains(attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
if(s.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3];
auto input_dims = l0->get_shape().lens();
std::vector<int64_t> padding(input_dims.size());
calculate_padding(
0, padding, input_dims[2], op.stride[0], op.dilation[0], weight_h);
calculate_padding(
1, padding, input_dims[3], op.stride[1], op.dilation[1], weight_w);
check_asym_padding(l0, padding, op);
}
}
if(contains(attributes, "group"))
{
op.group = parse_value(attributes.at("group")).at<int>();
}
auto l1 = prog.add_instruction(op, l0, args[1]);
return add_bias(args, l1, 1);
}
instruction_ref parse_conv_transpose(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
op::deconvolution op;
auto l0 = args[0];
std::vector<std::int64_t> padding;
bool asymm_padding = false;
if(contains(attributes, "pads"))
{
if(contains(attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
}
copy(attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3]) if(padding[0] != padding[2] || padding[1] != padding[3])
{ {
// insert zeros for pad op (args[0] has 4 dims) asymm_padding = true;
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding}, l0);
} }
else else
{ {
...@@ -373,18 +559,55 @@ struct onnx_parser ...@@ -373,18 +559,55 @@ struct onnx_parser
op.padding_mode = op::padding_mode_t::same; op.padding_mode = op::padding_mode_t::same;
} }
} }
if(contains(attributes, "group")) if(contains(attributes, "group"))
{ {
op.group = parse_value(attributes.at("group")).at<int>(); op.group = parse_value(attributes.at("group")).at<int>();
} }
if(args.size() == 3)
auto l1 = prog.add_instruction(op, l0, args[1]);
std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
std::vector<int64_t> curr_shape{dims[2], dims[3]};
if(asymm_padding)
{ {
uint64_t axis = 1; op::slice slice_op;
auto l1 = prog.add_instruction(op, l0, args[1]); slice_op.axes = {0, 1, 2, 3};
auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape().lens()}, args[2]); slice_op.starts = {0, 0, 0 + padding[0], 0 + padding[1]};
return prog.add_instruction(op::add{}, l1, l2); slice_op.ends = {
dims[0], dims[1], curr_shape[0] - padding[2], curr_shape[1] - padding[3]};
l1 = prog.add_instruction(slice_op, l1);
}
if(contains(attributes, "output_padding"))
{
std::vector<int64_t> output_padding;
copy(attributes["output_padding"].ints(), std::back_inserter(output_padding));
output_padding = {0, 0, 0, 0, 0, 0, output_padding[0], output_padding[1]};
l1 = prog.add_instruction(op::pad{output_padding}, l1);
}
if(contains(attributes, "output_shape"))
{
std::vector<int64_t> output_shape;
copy(attributes["output_shape"].ints(), std::back_inserter(output_shape));
dims = to_int64_vector(l1->get_shape().lens());
curr_shape = {dims[2], dims[3]};
if(curr_shape != output_shape)
{
std::vector<int64_t> target_padding = {0,
0,
0,
0,
0,
0,
output_shape[0] - curr_shape[0],
output_shape[1] - curr_shape[1]};
l1 = prog.add_instruction(op::pad{target_padding}, l1);
}
} }
return prog.add_instruction(op, l0, args[1]);
return add_bias(args, l1, 1);
} }
instruction_ref parse_pooling(const std::string& name, instruction_ref parse_pooling(const std::string& name,
...@@ -398,27 +621,31 @@ struct onnx_parser ...@@ -398,27 +621,31 @@ struct onnx_parser
auto lens = args.front()->get_shape().lens(); auto lens = args.front()->get_shape().lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
} }
if(contains(attributes, "pads")) if(contains(attributes, "pads"))
{ {
if(contains(attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW(
"PARSE_POOLING: auto_pad and padding cannot be specified simultaneously");
}
}
std::vector<std::int64_t> padding; std::vector<std::int64_t> padding;
copy(attributes["pads"].ints(), std::back_inserter(padding)); copy(attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4) if(padding.size() != 4)
{ {
MIGRAPHX_THROW("padding should have 4 values"); MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
// insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding, std::numeric_limits<float>::lowest()},
l0);
}
else
{
op.padding[0] = padding[0];
op.padding[1] = padding[1];
} }
float pad_val = 0;
if(op.mode == "max")
pad_val = std::numeric_limits<float>::lowest();
check_asym_padding(l0, padding, op, pad_val);
} }
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
copy(attributes["strides"].ints(), op.stride.begin()); copy(attributes["strides"].ints(), op.stride.begin());
...@@ -427,14 +654,11 @@ struct onnx_parser ...@@ -427,14 +654,11 @@ struct onnx_parser
{ {
copy(attributes["kernel_shape"].ints(), op.lengths.begin()); copy(attributes["kernel_shape"].ints(), op.lengths.begin());
} }
if(contains(attributes, "auto_pad")) if(contains(attributes, "auto_pad"))
{ {
auto s = attributes["auto_pad"].s(); auto in_lens = args[0]->get_shape().lens();
if(s.find("SAME_UPPER") == std::string::npos) l0 = process_auto_pad_attribute(l0, attributes, op, in_lens);
{
MIGRAPHX_THROW("auto_pad only supports SAME_UPPER for pooling");
}
op.padding_mode = op::padding_mode_t::same;
} }
return prog.add_instruction(op, l0); return prog.add_instruction(op, l0);
...@@ -462,7 +686,7 @@ struct onnx_parser ...@@ -462,7 +686,7 @@ struct onnx_parser
instruction_ref instruction_ref
parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
uint64_t axis = 1; int64_t axis = 1;
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(attributes.at("axis")).at<int>();
...@@ -616,6 +840,7 @@ struct onnx_parser ...@@ -616,6 +840,7 @@ struct onnx_parser
return prog.add_instruction(op::dot{alpha, beta}, l1, l2); return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
} }
template <class Op>
instruction_ref instruction_ref
parse_matmul(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_matmul(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
...@@ -664,7 +889,7 @@ struct onnx_parser ...@@ -664,7 +889,7 @@ struct onnx_parser
} }
} }
auto dot_res = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1); auto dot_res = prog.add_instruction(Op{1, 0}, bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size()); int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended) if(is_a_prepended)
{ {
...@@ -703,6 +928,42 @@ struct onnx_parser ...@@ -703,6 +928,42 @@ struct onnx_parser
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
instruction_ref parse_instancenorm(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({H, W}, x)
// variance = reduce_mean({H, W}, (x - mean)^2)
float epsilon = 1e-5f;
if(contains(attributes, "epsilon"))
{
epsilon = parse_value(attributes.at("epsilon")).at<float>();
}
auto x = args[0];
auto scale = args[1];
auto bias = args[2];
auto dims = x->get_shape().lens();
auto mean = prog.add_instruction(op::reduce_mean{{2, 3}}, x);
auto mean_bcast = prog.add_instruction(op::multibroadcast{dims}, mean);
auto l0 = prog.add_instruction(op::sqdiff{}, x, mean_bcast);
auto variance = prog.add_instruction(op::reduce_mean{{2, 3}}, l0);
auto l1 = prog.add_instruction(op::sub{}, x, mean_bcast);
auto epsilon_literal = prog.add_literal(epsilon);
auto epsilon_bcast = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal);
auto variance_bcast = prog.add_instruction(op::multibroadcast{dims}, variance);
auto l2 = prog.add_instruction(op::add{}, variance_bcast, epsilon_bcast);
auto l3 = prog.add_instruction(op::rsqrt{}, l2);
auto l4 = prog.add_instruction(op::mul{}, l1, l3);
auto scale_bcast = prog.add_instruction(op::broadcast{1, dims}, scale);
;
auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias);
auto l5 = prog.add_instruction(op::mul{}, l4, scale_bcast);
return prog.add_instruction(op::add{}, l5, bias_bcast);
}
instruction_ref parse_leaky_relu(const std::string&, instruction_ref parse_leaky_relu(const std::string&,
attribute_map attributes, attribute_map attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
...@@ -763,11 +1024,12 @@ struct onnx_parser ...@@ -763,11 +1024,12 @@ struct onnx_parser
auto&& bias_floats = attributes["bias"].floats(); auto&& bias_floats = attributes["bias"].floats();
bias = std::vector<float>(bias_floats.begin(), bias_floats.end()); bias = std::vector<float>(bias_floats.begin(), bias_floats.end());
} }
auto input_lens = args.front()->get_shape().lens(); auto input_shape = args.front()->get_shape();
auto const& input_lens = input_shape.lens();
auto input_type = input_shape.type();
auto scale_val = prog.add_literal(scale); auto scale_val = prog.add_literal(literal{shape{input_type}, {scale}});
auto bias_vals = prog.add_literal( auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias});
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias});
auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val); auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor); auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
...@@ -1396,6 +1658,47 @@ struct onnx_parser ...@@ -1396,6 +1658,47 @@ struct onnx_parser
} }
} }
instruction_ref
parse_reduce_l1(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
auto abs_ins = prog.add_instruction(op::abs{}, args[0]);
return parse_reduce_oper<op::reduce_sum>({}, std::move(attributes), {abs_ins});
}
instruction_ref
parse_reduce_l2(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(attributes), {square_ins});
return prog.add_instruction(op::sqrt{}, sum_ins);
}
instruction_ref parse_reduce_log_sum(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
auto sum_ins =
parse_reduce_oper<op::reduce_sum>({}, std::move(attributes), std::move(args));
return prog.add_instruction(op::log{}, sum_ins);
}
instruction_ref parse_reduce_log_sum_exp(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
auto exp_ins = prog.add_instruction(op::exp{}, args[0]);
auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(attributes), {exp_ins});
return prog.add_instruction(op::log{}, sum_ins);
}
instruction_ref parse_reduce_sum_square(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
return parse_reduce_oper<op::reduce_sum>({}, std::move(attributes), {square_ins});
}
instruction_ref instruction_ref
parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -1425,6 +1728,22 @@ struct onnx_parser ...@@ -1425,6 +1728,22 @@ struct onnx_parser
} }
} }
void parse_from(const void* data, std::size_t size)
{
onnx::ModelProto model;
if(model.ParseFromArray(data, size))
{
if(model.has_graph())
{
this->parse_graph(model.graph());
}
}
else
{
MIGRAPHX_THROW("Failed reading onnx file.");
}
}
void parse_graph(const onnx::GraphProto& graph) void parse_graph(const onnx::GraphProto& graph)
{ {
nodes = get_nodes(graph); nodes = get_nodes(graph);
...@@ -1438,7 +1757,7 @@ struct onnx_parser ...@@ -1438,7 +1757,7 @@ struct onnx_parser
if(!contains(instructions, name)) if(!contains(instructions, name))
{ {
// TODO: Get shape of input parameter // TODO: Get shape of input parameter
shape s = parse_type(input.type()); shape s = parse_type(input.type(), batch_size);
instructions[name] = prog.add_parameter(name, s); instructions[name] = prog.add_parameter(name, s);
} }
} }
...@@ -1446,6 +1765,29 @@ struct onnx_parser ...@@ -1446,6 +1765,29 @@ struct onnx_parser
{ {
this->parse_node(output.name()); this->parse_node(output.name());
} }
// Find instructions corresponding to the output
auto prog_output = graph.output();
std::vector<std::string> all_output_names;
std::vector<std::string> prog_output_names;
std::transform(prog_output.begin(),
prog_output.end(),
std::back_inserter(all_output_names),
[](auto& node) { return node.name(); });
std::copy_if(
all_output_names.begin(),
all_output_names.end(),
std::back_inserter(prog_output_names),
[&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); });
std::vector<instruction_ref> output_ins;
std::transform(prog_output_names.begin(),
prog_output_names.end(),
std::back_inserter(output_ins),
[&](const auto& name) { return instructions[name]; });
// add the return instuction
prog.add_return(output_ins);
} }
void parse_undefined(const std::string& name) void parse_undefined(const std::string& name)
...@@ -1464,14 +1806,14 @@ struct onnx_parser ...@@ -1464,14 +1806,14 @@ struct onnx_parser
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
if(nodes.count(input) > 0) if(input.empty())
{ {
assert(name != input); this->parse_undefined(input);
this->parse_node(input);
} }
else if(input.empty()) else if(nodes.count(input) > 0)
{ {
this->parse_undefined(input); assert(name != input);
this->parse_node(input);
} }
args.push_back(instructions.at(input)); args.push_back(instructions.at(input));
} }
...@@ -1491,12 +1833,12 @@ struct onnx_parser ...@@ -1491,12 +1833,12 @@ struct onnx_parser
} }
else else
{ {
assert(node.output().size() >= result.size()); auto output_num = std::min<std::size_t>(node.output().size(), result.size());
std::transform(result.begin(), std::transform(node.output().begin(),
result.end(), node.output().begin() + output_num,
node.output().begin(), result.begin(),
std::inserter(instructions, instructions.end()), std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(y, x); }); [](auto&& x, auto&& y) { return std::make_pair(x, y); });
} }
} }
} }
...@@ -1572,6 +1914,8 @@ struct onnx_parser ...@@ -1572,6 +1914,8 @@ struct onnx_parser
case onnx::AttributeProto::STRING: case onnx::AttributeProto::STRING:
case onnx::AttributeProto::STRINGS: case onnx::AttributeProto::STRINGS:
case onnx::AttributeProto::TENSORS: case onnx::AttributeProto::TENSORS:
case onnx::AttributeProto::SPARSE_TENSOR:
case onnx::AttributeProto::SPARSE_TENSORS:
case onnx::AttributeProto::GRAPHS: return {}; case onnx::AttributeProto::GRAPHS: return {};
} }
MIGRAPHX_THROW("Invalid attribute type"); MIGRAPHX_THROW("Invalid attribute type");
...@@ -1658,7 +2002,7 @@ struct onnx_parser ...@@ -1658,7 +2002,7 @@ struct onnx_parser
return literal{{shape_type, dims}, data.begin(), data.end()}; return literal{{shape_type, dims}, data.begin(), data.end()};
} }
static shape parse_type(const onnx::TypeProto& t) static shape parse_type(const onnx::TypeProto& t, const unsigned int batch_size)
{ {
shape::type_t shape_type{}; shape::type_t shape_type{};
switch(t.tensor_type().elem_type()) switch(t.tensor_type().elem_type())
...@@ -1686,14 +2030,18 @@ struct onnx_parser ...@@ -1686,14 +2030,18 @@ struct onnx_parser
std::transform(tensor_dims.begin(), std::transform(tensor_dims.begin(),
tensor_dims.end(), tensor_dims.end(),
std::back_inserter(dims), std::back_inserter(dims),
[](auto&& d) -> std::size_t { [&](auto&& d) -> std::size_t {
if(not d.has_dim_value()) if(d.has_dim_value())
{ {
long default_batch_size = 1; // FIXME if(static_cast<int>(d.dim_value()) <= 0)
return default_batch_size; return batch_size;
return d.dim_value();
} }
return d.dim_value(); return batch_size;
}); });
if(dims.empty())
return {shape_type};
return {shape_type, dims}; return {shape_type, dims};
} }
...@@ -1728,15 +2076,16 @@ struct onnx_parser ...@@ -1728,15 +2076,16 @@ struct onnx_parser
} }
}; };
program parse_onnx(const std::string& name) template <class... Ts>
program parse_onnx_from(onnx_options options, Ts&&... xs)
{ {
std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
onnx_parser parser; onnx_parser parser;
parser.batch_size = options.batch_size;
#ifndef NDEBUG #ifndef NDEBUG
// Log the program when it can't be parsed // Log the program when it can't be parsed
try try
{ {
parser.parse_from(input); parser.parse_from(std::forward<Ts>(xs)...);
} }
catch(...) catch(...)
{ {
...@@ -1744,10 +2093,26 @@ program parse_onnx(const std::string& name) ...@@ -1744,10 +2093,26 @@ program parse_onnx(const std::string& name)
throw; throw;
} }
#else #else
parser.parse_from(input); parser.parse_from(std::forward<Ts>(xs)...);
#endif #endif
return std::move(parser.prog); return std::move(parser.prog);
} }
program parse_onnx(const std::string& name, onnx_options options)
{
std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
return parse_onnx_from(options, input);
}
program parse_onnx_buffer(const std::string& buffer, onnx_options options)
{
return parse_onnx_from(options, buffer.data(), buffer.size());
}
program parse_onnx_buffer(const void* data, std::size_t size, onnx_options options)
{
return parse_onnx_from(options, data, size);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -3,24 +3,42 @@ ...@@ -3,24 +3,42 @@
// //
// Copyright (c) Facebook Inc. and Microsoft Corporation. // Copyright (c) ONNX Project Contributors.
// Licensed under the MIT license. // Licensed under the MIT license.
syntax = "proto2"; syntax = "proto2";
package onnx; package onnx;
// Note [Release] // Overview
//
// ONNX is an open specification that is comprised of the following components:
//
// 1) A definition of an extensible computation graph model.
// 2) Definitions of standard data types.
// 3) Definitions of built-in operators.
//
// This document describes the syntax of models and their computation graphs,
// as well as the standard data types. Together, they are referred to as the ONNX
// Intermediate Representation, or 'IR' for short.
//
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
// Definitions of the built-in neural network operators may be found in docs/Operators.md.
// Notes
//
// Release
//
// We are still in the very early stage of defining ONNX. The current // We are still in the very early stage of defining ONNX. The current
// version of ONNX is a starting point. While we are actively working // version of ONNX is a starting point. While we are actively working
// towards a complete spec, we would like to get the community involved // towards a complete spec, we would like to get the community involved
// by sharing our working version of ONNX. // by sharing our working version of ONNX.
//
// Note [Protobuf compatibility] // Protobuf compatibility
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
// Based on experience working with downstream vendors, we generally can't // To simplify framework compatibility, ONNX is defined using the subset of protobuf
// assume recent versions of protobufs. This means that we do not use any // that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in proto3. // protobuf features that are only available in one of the two versions.
// //
// Here are the most notable contortions we have to carry out to work around // Here are the most notable contortions we have to carry out to work around
// these limitations: // these limitations:
...@@ -29,30 +47,11 @@ package onnx; ...@@ -29,30 +47,11 @@ package onnx;
// of key-value pairs, where order does not matter and duplicates // of key-value pairs, where order does not matter and duplicates
// are not allowed. // are not allowed.
// Note [Namespaces]
// ~~~~~~~~~~~~~~~~~ // Versioning
// ONNX gives explicit names to graphs, intermediate values and
// serialized tensors. To make it easier to generate names, we organize
// these into separate namespaces (so, e.g., a graph can have the same
// name as a serialized tensor.) The namespaces are as follows:
//
// - Node: These names identify specific nodes in the graph (but not, necessarily
// any particular input or output of the node.
// - Graph: These names identify graphs in the protobuf.
// - Attribute: These names identify attribute names for extra attributes that
// are passed to operators.
// - Operator: These names identify particular operators.
// - Value: These names identify intermediate values (typically tensors) flowing through
// the computation of a graph.
// - Shape: These names represent parameters for unknown shape dimensions.
// //
// We specify the namespace of a name in ONNX as comments in the form // ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
// of "namespace {Node,Graph,Operator,Attribute,Value,Shape}". Framework is responsible
// for supporting the namespaces.
// //
// Naming things is hard. Every element with a name has an optional doc_string associated
// with it, providing a human-readable description in text markdown.
// To be compatible with both proto2 and proto3, we will use a version number // To be compatible with both proto2 and proto3, we will use a version number
// that is not defined by the default value but an explicit enum number. // that is not defined by the default value but an explicit enum number.
enum Version { enum Version {
...@@ -61,26 +60,53 @@ enum Version { ...@@ -61,26 +60,53 @@ enum Version {
_START_VERSION = 0; _START_VERSION = 0;
// The version field is always serialized and we will use it to store the // The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version // version that the graph is generated from. This helps us set up version
// control. We should use version as // control.
// xx(major) - xx(minor) - xxxx(bugfix) // For the IR, we are using simple numbers starting with 0x00000001,
// and we are starting with 0x00000001 (0.0.1), which was the // which was the version we published on Oct 10, 2017.
// version we published on Oct 10, 2017. IR_VERSION_2017_10_10 = 0x0000000000000001;
IR_VERSION_2017_10_10 = 0x00000001;
// IR_VERSION 0.0.2 published on Oct 30, 2017 // IR_VERSION 2 published on Oct 30, 2017
// - Added type discriminator to AttributeProto to support proto3 users // - Added type discriminator to AttributeProto to support proto3 users
IR_VERSION_2017_10_30 = 0x00000002; IR_VERSION_2017_10_30 = 0x0000000000000002;
// IR VERSION 0.0.3 published on Nov 3, 2017 // IR VERSION 3 published on Nov 3, 2017
// - For operator versioning: // - For operator versioning:
// - Added new message OperatorSetIdProto // - Added new message OperatorSetIdProto
// - Added opset_import in ModelProto // - Added opset_import in ModelProto
// - For vendor extensions, added domain in NodeProto // - For vendor extensions, added domain in NodeProto
IR_VERSION = 0x00000003; IR_VERSION_2017_11_3 = 0x0000000000000003;
// IR VERSION 4 published on Jan 22, 2019
// - Relax constraint that initializers should be a subset of graph inputs
// - Add type BFLOAT16
IR_VERSION_2019_1_22 = 0x0000000000000004;
// IR VERSION 5 published on March 18, 2019
// - Add message TensorAnnotation.
// - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
IR_VERSION_2019_3_18 = 0x0000000000000005;
// IR VERSION 6 published on Sep 19, 2019
// - Add support for sparse tensor constants stored in model.
// - Add message SparseTensorProto
// - Add sparse initializers
IR_VERSION_2019_9_19 = 0x0000000000000006;
// IR VERSION 7 published on <TBD>
// - Add a list to promote inference graph's initializers to global and
// mutable variables. Global variables are visible in all graphs of the
// stored models.
// - Add message TrainingInfoProto to store initialization
// method and training algorithm. The execution of TrainingInfoProto
// can modify the values of mutable variables.
// - Make inference graph callable from TrainingInfoProto via GraphCall operator.
IR_VERSION = 0x0000000000000007;
} }
// A named attribute containing either singular float, integer, string // Attributes
// and tensor values, or repeated float, integer, string and tensor values. //
// A named attribute containing either singular float, integer, string, graph,
// and tensor values, or repeated float, integer, string, graph, and tensor values.
// An AttributeProto MUST contain the name field, and *only one* of the // An AttributeProto MUST contain the name field, and *only one* of the
// following content fields, effectively enforcing a C/C++ union equivalent. // following content fields, effectively enforcing a C/C++ union equivalent.
message AttributeProto { message AttributeProto {
...@@ -94,26 +120,34 @@ message AttributeProto { ...@@ -94,26 +120,34 @@ message AttributeProto {
STRING = 3; STRING = 3;
TENSOR = 4; TENSOR = 4;
GRAPH = 5; GRAPH = 5;
SPARSE_TENSOR = 11;
FLOATS = 6; FLOATS = 6;
INTS = 7; INTS = 7;
STRINGS = 8; STRINGS = 8;
TENSORS = 9; TENSORS = 9;
GRAPHS = 10; GRAPHS = 10;
SPARSE_TENSORS = 12;
} }
// The name field MUST be present for this version of the IR. // The name field MUST be present for this version of the IR.
optional string name = 1; // namespace Attribute optional string name = 1; // namespace Attribute
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope.
// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
optional string ref_attr_name = 21;
// A human-readable documentation for this attribute. Markdown is allowed. // A human-readable documentation for this attribute. Markdown is allowed.
optional string doc_string = 13; optional string doc_string = 13;
// The type field MUST be present for this version of the IR. // The type field MUST be present for this version of the IR.
// For 0.0.1 versions of the IR, this field was not defined, and // For 0.0.1 versions of the IR, this field was not defined, and
// implementations needed to use has_field hueristics to determine // implementations needed to use has_field heuristics to determine
// which value field was in use. For IR_VERSION 0.0.2 or later, this // which value field was in use. For IR_VERSION 0.0.2 or later, this
// field MUST be set and match the f|i|s|t|... field in use. This // field MUST be set and match the f|i|s|t|... field in use. This
// change was made to accomodate proto3 implementations. // change was made to accommodate proto3 implementations.
optional AttributeType type = 20; // discriminator that indicates which field below is in use optional AttributeType type = 20; // discriminator that indicates which field below is in use
// Exactly ONE of the following fields must be present for this version of the IR // Exactly ONE of the following fields must be present for this version of the IR
...@@ -122,6 +156,7 @@ message AttributeProto { ...@@ -122,6 +156,7 @@ message AttributeProto {
optional bytes s = 4; // UTF-8 string optional bytes s = 4; // UTF-8 string
optional TensorProto t = 5; // tensor value optional TensorProto t = 5; // tensor value
optional GraphProto g = 6; // graph optional GraphProto g = 6; // graph
optional SparseTensorProto sparse_tensor = 22; // sparse tensor value
// Do not use field below, it's deprecated. // Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph // optional ValueProto v = 12; // value - subsumes everything but graph
...@@ -130,6 +165,7 @@ message AttributeProto { ...@@ -130,6 +165,7 @@ message AttributeProto {
repeated bytes strings = 9; // list of UTF-8 strings repeated bytes strings = 9; // list of UTF-8 strings
repeated TensorProto tensors = 10; // list of tensors repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph repeated GraphProto graphs = 11; // list of graph
repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
} }
// Defines information on value, including the name, the type, and // Defines information on value, including the name, the type, and
...@@ -137,16 +173,20 @@ message AttributeProto { ...@@ -137,16 +173,20 @@ message AttributeProto {
message ValueInfoProto { message ValueInfoProto {
// This field MUST be present in this version of the IR. // This field MUST be present in this version of the IR.
optional string name = 1; // namespace Value optional string name = 1; // namespace Value
// This field MUST be present in this version of the IR. // This field MUST be present in this version of the IR for
// inputs and outputs of the top-level graph.
optional TypeProto type = 2; optional TypeProto type = 2;
// A human-readable documentation for this value. Markdown is allowed. // A human-readable documentation for this value. Markdown is allowed.
optional string doc_string = 3; optional string doc_string = 3;
} }
// NodeProto stores a node that is similar to the notion of "layer" // Nodes
// or "operator" in many deep learning frameworks. For example, it can be a //
// node of type "Conv" that takes in an image, a filter tensor and a bias // Computation graphs are made up of a DAG of nodes, which represent what is
// tensor, and produces the convolved output. // commonly called a "layer" or "pipeline stage" in machine learning frameworks.
//
// For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output.
message NodeProto { message NodeProto {
repeated string input = 1; // namespace Value repeated string input = 1; // namespace Value
repeated string output = 2; // namespace Value repeated string output = 2; // namespace Value
...@@ -161,18 +201,125 @@ message NodeProto { ...@@ -161,18 +201,125 @@ message NodeProto {
optional string domain = 7; // namespace Domain optional string domain = 7; // namespace Domain
// Additional named attributes. // Additional named attributes.
// NOTE: Simply using ValueProto.NameValuePairProto is the most general
// solution. I kept AttributeProto to minimize churn on CI results.
repeated AttributeProto attribute = 5; repeated AttributeProto attribute = 5;
// A human-readable documentation for this node. Markdown is allowed. // A human-readable documentation for this node. Markdown is allowed.
optional string doc_string = 6; optional string doc_string = 6;
} }
// ModelProto is a top-level file/container format for bundling a ML model. // Training information
// The semantics of the model are described by the GraphProto that represents // TrainingInfoProto stores information for training a model.
// a parameterized computation graph against a set of named operators that are // In particular, this defines two functionalities: an initialization-step
// defined independently from the graph. // and a training-algorithm-step. Initialization resets the model
// back to its original state as if no training has been consumed.
// Training algorithm improves the model based on input data.
//
// The semantics of the initialization-step is that the initializers
// in ModelProto.graph and in TrainingInfoProto.algorithm are first
// initialized as specified by the initializers in the graph, and then
// updated by the "initialization_binding" in every instance in
// ModelProto.training_info.
//
// The field "algorithm" defines a computation graph which represents a
// training algorithm's step. After the execution of a
// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
// may be immediately updated. If the targeted training algorithm contains
// consecutive update stages (such as block coordinate descent methods),
// the user needs to create a TrainingInfoProto for each stage.
message TrainingInfoProto {
// This field describes a graph to compute the initial tensors
// upon starting the training process. Initialization graph has no input
// and can have multiple outputs. Usually, trainable tensors in neural
// networks are randomly initialized. To achieve that, for each tensor,
// the user can put a random number operator such as RandomNormal or
// RandomUniform in TrainingInfoProto.initialization.node and assign its
// random output to the specific tensor using "initialization_binding".
// This graph can also set the initializers in "algorithm" in the same
// TrainingInfoProto; a use case is resetting the number of training
// iteration to zero.
//
// By default, this field is an empty graph and its evaluation does not
// produce any output.
optional GraphProto initialization = 1;
// This field represents a training algorithm step. Given required inputs,
// it computes outputs to update initializers in its own or inference graph's
// initializer lists. In general, this graph contains loss node, gradient node,
// optimizer node, increment of iteration count, and some calls to the inference
// graph.
//
// The field algorithm.node is the only place the user can use GraphCall
// operator. The only callable graph is the one stored in ModelProto.graph.
//
// By default, this field is an empty graph and its evaluation does not
// produce any output.
optional GraphProto algorithm = 2;
// This field specifies the bindings from the outputs of "initialization" to
// some initializers in "ModelProto.graph.initializer" and
// the "algorithm.initializer" in the same TrainingInfoProto.
// See "update_binding" below for details.
//
// By default, this field is empty and no initializer would be changed
// by the execution of "initialization".
repeated StringStringEntryProto initialization_binding = 3;
// Gradient-based training is usually an iterative procedure. In one gradient
// descent iteration, we apply
//
// x = x - r * g
//
// where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
// gradient of "x" with respect to a chosen loss. To avoid adding assignments
// into the training graph, we split the update equation into
//
// y = x - r * g
// x = y
//
// The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
// tell that "y" should be assigned to "x", the field "update_binding" may
// contain a key-value pair of strings, "x" (key of StringStringEntryProto)
// and "y" (value of StringStringEntryProto).
// For a neural network with multiple trainable (mutable) tensors, there can
// be multiple key-value pairs in "update_binding".
//
// The initializers appears as keys in "update_binding" are considered
// mutable and globally-visible variables. This implies some behaviors
// as described below.
//
// 1. We have only unique keys in all "update_binding"s so that two global
// variables may not have the same name. This ensures that one
// global variable is assigned up to once.
// 2. The keys must appear in names of "ModelProto.graph.initializer" or
// "TrainingInfoProto.algorithm.initializer".
// 3. The values must be output names of "algorithm".
// 4. If an optional input of a graph is omitted when using GraphCall, the
// global variable with the same name may be used.
// 5. When using GraphCall, the users always can pass values to optional
// inputs of the called graph even if the associated initializers appears
// as keys in "update_binding"s.
// 6. The graphs in TrainingInfoProto's can use global variables as
// their operator inputs.
// 7. Mutable variables are initialized to the value specified by the
// corresponding initializer, and then potentially updated by
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
//
// This field usually contains names of trainable tensors
// (in ModelProto.graph), optimizer states such as momentums in advanced
// stochastic gradient methods (in TrainingInfoProto.graph),
// and number of training iterations (in TrainingInfoProto.graph).
//
// By default, this field is empty and no initializer would be changed
// by the execution of "algorithm".
repeated StringStringEntryProto update_binding = 4;
}
// Models
//
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
//
// The semantics of the model are described by the associated GraphProto's.
message ModelProto { message ModelProto {
// The version of the IR this model targets. See Version enum above. // The version of the IR this model targets. See Version enum above.
// This field MUST be present. // This field MUST be present.
...@@ -217,6 +364,17 @@ message ModelProto { ...@@ -217,6 +364,17 @@ message ModelProto {
// Named metadata values; keys should be distinct. // Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14; repeated StringStringEntryProto metadata_props = 14;
// Training-specific information. Sequentially executing all stored
// `TrainingInfoProto.algorithm`s and assigning their outputs following
// the corresponding `TrainingInfoProto.update_binding`s is one training
// iteration. Similarly, to initialize the model
// (as if training hasn't happened), the user should sequentially execute
// all stored `TrainingInfoProto.initialization`s and assigns their outputs
// using `TrainingInfoProto.initialization_binding`s.
//
// If this field is empty, the training behavior of the model is undefined.
repeated TrainingInfoProto training_info = 20;
}; };
// StringStringEntryProto follows the pattern for cross-proto-version maps. // StringStringEntryProto follows the pattern for cross-proto-version maps.
...@@ -226,25 +384,38 @@ message StringStringEntryProto { ...@@ -226,25 +384,38 @@ message StringStringEntryProto {
optional string value= 2; optional string value= 2;
}; };
// GraphProto defines a parameterized series of nodes to form a directed acyclic graph. message TensorAnnotation {
// This is the equivalent of the "network" and "graph" in many deep learning optional string tensor_name = 1;
// <key, value> pairs to annotate tensor specified by <tensor_name> above.
// The keys used in the mapping below must be pre-defined in ONNX spec.
// For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
// quantization parameter keys.
repeated StringStringEntryProto quant_parameter_tensor_names = 2;
}
// Graphs
//
// A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks. // frameworks.
message GraphProto { message GraphProto {
// The nodes in the graph. // The nodes in the graph, sorted topologically.
repeated NodeProto node = 1; repeated NodeProto node = 1;
// The name of the graph. // The name of the graph.
optional string name = 2; // namespace Graph optional string name = 2; // namespace Graph
// A list of named tensor values (constants), used to specify default // A list of named tensor values, used to specify constant inputs of the graph.
// values for some of the inputs of the graph.
// Each TensorProto entry must have a distinct name (within the list) that // Each TensorProto entry must have a distinct name (within the list) that
// also appears in the input list. // MAY also appear in the input list.
// In an evaluation, the default value specified here is used if and only if
// user specifies no value for the corresponding input parameter.
// May be used to pass serialized parameters for networks.
repeated TensorProto initializer = 5; repeated TensorProto initializer = 5;
// Initializers (see above) stored in sparse format.
repeated SparseTensorProto sparse_initializer = 15;
// A human-readable documentation for this graph. Markdown is allowed. // A human-readable documentation for this graph. Markdown is allowed.
optional string doc_string = 10; optional string doc_string = 10;
...@@ -256,7 +427,13 @@ message GraphProto { ...@@ -256,7 +427,13 @@ message GraphProto {
// must be distinct. It is optional for a value to appear in value_info list. // must be distinct. It is optional for a value to appear in value_info list.
repeated ValueInfoProto value_info = 13; repeated ValueInfoProto value_info = 13;
// DO NOT USE the following fields, they were deprecated before // This field carries information to indicate the mapping among a tensor and its
// quantization parameter tensors. For example:
// For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated TensorAnnotation quantization_annotation = 14;
// DO NOT USE the following fields, they were deprecated from earlier versions.
// repeated string input = 3; // repeated string input = 3;
// repeated string output = 4; // repeated string output = 4;
// optional int64 ir_version = 6; // optional int64 ir_version = 6;
...@@ -265,7 +442,9 @@ message GraphProto { ...@@ -265,7 +442,9 @@ message GraphProto {
// optional string domain = 9; // optional string domain = 9;
} }
// A message defined to store a tensor in its serialized format. // Tensors
//
// A serialized tensor value.
message TensorProto { message TensorProto {
enum DataType { enum DataType {
UNDEFINED = 0; UNDEFINED = 0;
...@@ -280,13 +459,21 @@ message TensorProto { ...@@ -280,13 +459,21 @@ message TensorProto {
STRING = 8; // string STRING = 8; // string
BOOL = 9; // bool BOOL = 9; // bool
// Advanced types // IEEE754 half-precision floating-point format (16 bits wide).
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10; FLOAT16 = 10;
DOUBLE = 11; DOUBLE = 11;
UINT32 = 12; UINT32 = 12;
UINT64 = 13; UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components COMPLEX128 = 15; // complex with float64 real and imaginary components
// Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16;
// Future extensions go here. // Future extensions go here.
} }
...@@ -294,7 +481,8 @@ message TensorProto { ...@@ -294,7 +481,8 @@ message TensorProto {
repeated int64 dims = 1; repeated int64 dims = 1;
// The data type of the tensor. // The data type of the tensor.
optional DataType data_type = 2; // This field MUST have a valid TensorProto.DataType value
optional int32 data_type = 2;
// For very large tensors, we may want to store them in chunks, in which // For very large tensors, we may want to store them in chunks, in which
// case the following fields will specify the segment that is stored in // case the following fields will specify the segment that is stored in
...@@ -305,7 +493,7 @@ message TensorProto { ...@@ -305,7 +493,7 @@ message TensorProto {
} }
optional Segment segment = 3; optional Segment segment = 3;
// Tensor content must be in the row major order. // Tensor content must be organized in row-major order.
// //
// Depending on the data_type field, exactly one of the fields below with // Depending on the data_type field, exactly one of the fields below with
// name ending in _data is used to store the elements of the tensor. // name ending in _data is used to store the elements of the tensor.
...@@ -313,7 +501,7 @@ message TensorProto { ...@@ -313,7 +501,7 @@ message TensorProto {
// For float and complex64 values // For float and complex64 values
// Complex64 tensors are encoded as a single array of floats, // Complex64 tensors are encoded as a single array of floats,
// with the real components appearing in odd numbered positions, // with the real components appearing in odd numbered positions,
// and the corresponding imaginary component apparing in the // and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0] // is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64. // When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
...@@ -323,7 +511,7 @@ message TensorProto { ...@@ -323,7 +511,7 @@ message TensorProto {
// float16 values must be bit-wise converted to an uint16_t prior // float16 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer. // to writing to the buffer.
// When this field is present, the data_type field MUST be // When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32 // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16
repeated int32 int32_data = 5 [packed = true]; repeated int32 int32_data = 5 [packed = true];
// For strings. // For strings.
...@@ -360,10 +548,32 @@ message TensorProto { ...@@ -360,10 +548,32 @@ message TensorProto {
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
optional bytes raw_data = 9; optional bytes raw_data = 9;
// Data can be stored inside the protobuf file using type-specific fields or raw_data.
// Alternatively, raw bytes data can be stored in an external file, using the external_data field.
// external_data stores key-value pairs describing data location. Recognized keys are:
// - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
// protobuf model was stored
// - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
// Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
// - "length" (optional) - number of bytes containing data. Integer stored as string.
// - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
repeated StringStringEntryProto external_data = 13;
// Location of the data for this tensor. MUST be one of:
// - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
// - EXTERNAL - data stored in an external location as described by external_data field.
enum DataLocation {
DEFAULT = 0;
EXTERNAL = 1;
}
// If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
optional DataLocation data_location = 14;
// For double // For double
// Complex64 tensors are encoded as a single array of doubles, // Complex128 tensors are encoded as a single array of doubles,
// with the real components appearing in odd numbered positions, // with the real components appearing in odd numbered positions,
// and the corresponding imaginary component apparing in the // and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0] // is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
...@@ -375,6 +585,28 @@ message TensorProto { ...@@ -375,6 +585,28 @@ message TensorProto {
repeated uint64 uint64_data = 11 [packed = true]; repeated uint64 uint64_data = 11 [packed = true];
} }
// A serialized sparse-tensor value
message SparseTensorProto {
// The sequence of non-default values are encoded as a tensor of shape [NNZ].
// The default-value is zero for numeric tensors, and empty-string for string tensors.
optional TensorProto values = 1;
// The indices of the non-default values, which may be stored in one of two formats.
// (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
// corresponding to the j-th index of the i-th value (in the values tensor).
// (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
// must be the linearized-index of the i-th value (in the values tensor).
// The linearized-index can be converted into an index tuple (k_1,...,k_rank)
// using the shape provided below.
// The indices must appear in ascending order without duplication.
// In the first format, the ordering is lexicographic-ordering:
// e.g., index-value [1,4] must appear before [2,1]
optional TensorProto indices = 2;
// The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
repeated int64 dims = 3;
}
// Defines a tensor shape. A dimension can be either an integer value // Defines a tensor shape. A dimension can be either an integer value
// or a symbolic variable. A symbolic variable represents an unknown // or a symbolic variable. A symbolic variable represents an unknown
// dimension. // dimension.
...@@ -384,28 +616,73 @@ message TensorShapeProto { ...@@ -384,28 +616,73 @@ message TensorShapeProto {
int64 dim_value = 1; int64 dim_value = 1;
string dim_param = 2; // namespace Shape string dim_param = 2; // namespace Shape
}; };
// Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor.
// Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition
// for pre-defined dimension denotations.
optional string denotation = 3;
}; };
repeated Dimension dim = 1; repeated Dimension dim = 1;
} }
// Define the types. // Types
//
// The standard ONNX data types.
message TypeProto { message TypeProto {
message Tensor { message Tensor {
// This field MUST NOT have the value of UNDEFINED // This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR. // This field MUST be present for this version of the IR.
optional TensorProto.DataType elem_type = 1; optional int32 elem_type = 1;
optional TensorShapeProto shape = 2; optional TensorShapeProto shape = 2;
} }
// repeated T
message Sequence {
// The type and optional shape of each element of the sequence.
// This field MUST be present for this version of the IR.
optional TypeProto elem_type = 1;
};
// map<K,V>
message Map {
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
// This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
optional int32 key_type = 1;
// This field MUST be present for this version of the IR.
optional TypeProto value_type = 2;
};
oneof value { oneof value {
// The type of a tensor. // The type of a tensor.
Tensor tensor_type = 1; Tensor tensor_type = 1;
// NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
// as input and output to graphs and nodes. These types are needed to naturally
// support classical ML operators. DNN operators SHOULD restrict their input
// and output types to tensors.
// The type of a sequence.
Sequence sequence_type = 4;
// The type of a map.
Map map_type = 5;
} }
// An optional denotation can be used to denote the whole
// type with a standard semantic description as to what is
// stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition
// for pre-defined type denotations.
optional string denotation = 6;
} }
// Operator Sets
//
// OperatorSets are uniquely identified by a (domain, opset_version) pair. // OperatorSets are uniquely identified by a (domain, opset_version) pair.
message OperatorSetIdProto { message OperatorSetIdProto {
// The domain of the operator set being identified. // The domain of the operator set being identified.
...@@ -418,3 +695,8 @@ message OperatorSetIdProto { ...@@ -418,3 +695,8 @@ message OperatorSetIdProto {
// This field MUST be present in this version of the IR. // This field MUST be present in this version of the IR.
optional int64 version = 2; optional int64 version = 2;
} }
// For using protobuf-lite
option optimize_for = LITE_RUNTIME;
...@@ -52,7 +52,9 @@ static void print_instruction(std::ostream& os, ...@@ -52,7 +52,9 @@ static void print_instruction(std::ostream& os,
os << ")"; os << ")";
} }
os << " -> " << ins->get_shape(); // skip return instruction shape
if(ins->name() != "@return")
os << " -> " << ins->get_shape();
} }
template <class F> template <class F>
...@@ -147,7 +149,14 @@ void program::assign(const program& p) ...@@ -147,7 +149,14 @@ void program::assign(const program& p)
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) { std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return ins_map[i]; return ins_map[i];
}); });
copy_ins = add_instruction(ins->get_operator(), copy_inputs); if(ins->name() == "@return")
{
copy_ins = add_return(copy_inputs);
}
else
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
}
} }
ins_map[ins] = copy_ins; ins_map[ins] = copy_ins;
...@@ -270,6 +279,18 @@ instruction_ref program::add_parameter(std::string name, shape s) ...@@ -270,6 +279,18 @@ instruction_ref program::add_parameter(std::string name, shape s)
return impl->instructions.begin(); return impl->instructions.begin();
} }
instruction_ref program::add_return(std::vector<instruction_ref> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
impl->instructions.push_back({builtin::returns{}, {}, args});
auto result = std::prev(impl->instructions.end());
instruction::backreference(result);
assert(result->valid(begin()));
return result;
}
shape program::get_parameter_shape(std::string name) const shape program::get_parameter_shape(std::string name) const
{ {
auto ins = std::find_if( auto ins = std::find_if(
...@@ -334,7 +355,26 @@ std::size_t program::size() const { return impl->instructions.size(); } ...@@ -334,7 +355,26 @@ std::size_t program::size() const { return impl->instructions.size(); }
instruction_ref program::begin() const { return impl->instructions.begin(); } instruction_ref program::begin() const { return impl->instructions.begin(); }
instruction_ref program::end() const { return impl->instructions.end(); } instruction_ref program::end() const { return impl->instructions.end(); }
shape program::get_shape() const { return impl->instructions.back().get_shape(); } std::vector<shape> program::get_output_shapes() const
{
auto last_ins = impl->instructions.back();
if(last_ins.name() == "@return")
{
auto& output_ins = last_ins.inputs();
std::vector<shape> output_shapes;
std::transform(output_ins.begin(),
output_ins.end(),
std::back_inserter(output_shapes),
[](auto& ins) { return ins->get_shape(); });
return output_shapes;
}
// The else branch is to provide backward compatibility
else
{
return {last_ins.get_shape()};
}
}
context& program::get_context() const { return impl->ctx; } context& program::get_context() const { return impl->ctx; }
...@@ -372,10 +412,10 @@ void program::finalize() ...@@ -372,10 +412,10 @@ void program::finalize()
} }
template <class F> template <class F>
argument generic_eval(const program& p, std::vector<argument> generic_eval(const program& p,
context& ctx, context& ctx,
std::unordered_map<std::string, argument> params, std::unordered_map<std::string, argument> params,
F trace) F trace)
{ {
assert(p.validate() == p.end()); assert(p.validate() == p.end());
std::unordered_map<instruction_ref, argument> results; std::unordered_map<instruction_ref, argument> results;
...@@ -407,6 +447,19 @@ argument generic_eval(const program& p, ...@@ -407,6 +447,19 @@ argument generic_eval(const program& p,
{ {
results.emplace(ins, trace(ins, [&] { return argument{ins->get_shape(), nullptr}; })); results.emplace(ins, trace(ins, [&] { return argument{ins->get_shape(), nullptr}; }));
} }
else if(name == "@return")
{
std::vector<argument> prog_outputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(prog_outputs),
[&](instruction_ref i) {
assert(results.find(i) != results.end());
return results[i];
});
return prog_outputs;
}
else else
{ {
values.resize(ins->inputs().size()); values.resize(ins->inputs().size());
...@@ -421,10 +474,11 @@ argument generic_eval(const program& p, ...@@ -421,10 +474,11 @@ argument generic_eval(const program& p,
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
} }
return results.at(std::prev(p.end()));
return {results.at(std::prev(p.end()))};
} }
argument program::eval(std::unordered_map<std::string, argument> params) const std::vector<argument> program::eval(parameter_map params) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
#ifndef NDEBUG #ifndef NDEBUG
...@@ -531,6 +585,11 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -531,6 +585,11 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
print_program(*this, [&](auto ins, const auto& names) { print_program(*this, [&](auto ins, const auto& names) {
print_instruction(std::cout, ins, names); print_instruction(std::cout, ins, names);
// skip return instruction
if(ins->name() == "@return")
return;
double avg = common_average(ins_vec[ins]); double avg = common_average(ins_vec[ins]);
double percent = std::ceil(100.0 * avg / total_instruction_time); double percent = std::ceil(100.0 * avg / total_instruction_time);
os << ": " << avg << "ms, " << percent << "%"; os << ": " << avg << "ms, " << percent << "%";
......
...@@ -158,7 +158,7 @@ PYBIND11_MODULE(migraphx, m) ...@@ -158,7 +158,7 @@ PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); }) .def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_shape", &migraphx::program::get_shape) .def("get_output_shapes", &migraphx::program::get_output_shapes)
.def("compile", .def("compile",
[](migraphx::program& p, const migraphx::target& t, bool offload_copy) { [](migraphx::program& p, const migraphx::target& t, bool offload_copy) {
migraphx::compile_options options; migraphx::compile_options options;
...@@ -173,11 +173,20 @@ PYBIND11_MODULE(migraphx, m) ...@@ -173,11 +173,20 @@ PYBIND11_MODULE(migraphx, m)
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); }); .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
m.def("parse_tf", m.def("parse_tf",
&migraphx::parse_tf, [](const std::string& filename, bool is_nhwc, unsigned int batch_size) {
return migraphx::parse_tf(filename, migraphx::tf_options{is_nhwc, batch_size});
},
"Parse tf protobuf (default format is nhwc)", "Parse tf protobuf (default format is nhwc)",
py::arg("filename"), py::arg("filename"),
py::arg("is_nhwc") = true); py::arg("is_nhwc") = true,
m.def("parse_onnx", &migraphx::parse_onnx); py::arg("batch_size") = 1);
m.def("parse_onnx",
[](const std::string& filename, unsigned int batch_size) {
return migraphx::parse_onnx(filename, migraphx::onnx_options{batch_size});
},
"Parse onnx file",
py::arg("filename"),
py::arg("batch_size") = 1);
m.def("get_target", [](const std::string& name) -> migraphx::target { m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu") if(name == "cpu")
......
...@@ -105,6 +105,9 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) ...@@ -105,6 +105,9 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
std::unordered_map<instruction_ref, instruction_ref> map_fp16; std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
if(ins->name() == "@return")
break;
// all indicates every instruction is converted // all indicates every instruction is converted
if((not contains(ins_names, "all")) and (not contains(ins_names, ins->name()))) if((not contains(ins_names, "all")) and (not contains(ins_names, ins->name())))
{ {
...@@ -335,6 +338,9 @@ void quantize_int8_impl(program& prog, ...@@ -335,6 +338,9 @@ void quantize_int8_impl(program& prog,
std::unordered_map<instruction_ref, std::size_t> map_ins_index; std::unordered_map<instruction_ref, std::size_t> map_ins_index;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
if(ins->name() == "@return")
break;
if(not contains(ins_names, ins->name())) if(not contains(ins_names, ins->name()))
{ {
continue; continue;
......
...@@ -27,6 +27,15 @@ auto conv_const_weights() ...@@ -27,6 +27,15 @@ auto conv_const_weights()
match::args(match::any(), match::is_constant().bind("w"))); match::args(match::any(), match::is_constant().bind("w")));
} }
MIGRAPHX_PRED_MATCHER(args_has_same_ops, instruction_ref ins)
{
if(ins->inputs().empty())
return true;
return std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto j) {
return j->get_operator() == ins->inputs().front()->get_operator();
});
}
struct find_mul_conv struct find_mul_conv
{ {
auto matcher() const auto matcher() const
...@@ -167,6 +176,73 @@ struct find_inner_broadcast ...@@ -167,6 +176,73 @@ struct find_inner_broadcast
} }
}; };
struct find_concat_unary
{
auto matcher() const
{
return match::name("concat")(args_has_same_ops(),
match::arg(0)(match::nargs(1),
match::name("relu", "broadcast").bind("x"),
match::used_once()));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x = r.instructions["x"];
auto op = x->get_operator();
auto axis = any_cast<op::concat>(ins->get_operator()).axis;
// Adjust broadcast lens
if(op.name() == "broadcast")
{
auto b = any_cast<op::broadcast>(op);
if(b.axis != axis)
return;
b.broadcast_lens = ins->get_shape().lens();
op = b;
axis = 0;
}
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
return i->inputs().front();
});
auto concat = p.insert_instruction(ins, op::concat{axis}, inputs);
p.replace_instruction(ins, op, concat);
}
};
struct find_concat_binary
{
auto matcher() const
{
return match::name("concat")(args_has_same_ops(),
match::arg(0)(match::nargs(2),
match::name("add", "multiply").bind("x"),
match::used_once()));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x = r.instructions["x"];
auto op = x->get_operator();
auto concat_op = ins->get_operator();
auto xinputs = ins->inputs();
std::transform(xinputs.begin(), xinputs.end(), xinputs.begin(), [&](auto i) {
return i->inputs().front();
});
auto yinputs = ins->inputs();
std::transform(yinputs.begin(), yinputs.end(), yinputs.begin(), [&](auto i) {
return i->inputs().back();
});
auto xconcat = p.insert_instruction(ins, concat_op, xinputs);
auto yconcat = p.insert_instruction(ins, concat_op, yinputs);
p.replace_instruction(ins, op, xconcat, yconcat);
}
};
bool axis_equal(const std::vector<std::size_t>& x, bool axis_equal(const std::vector<std::size_t>& x,
const std::vector<std::size_t>& y, const std::vector<std::size_t>& y,
std::size_t axis) std::size_t axis)
...@@ -281,7 +357,9 @@ void simplify_algebra::apply(program& p) const ...@@ -281,7 +357,9 @@ void simplify_algebra::apply(program& p) const
find_add_lit_broadcast{}, find_add_lit_broadcast{},
find_add_convs{}, find_add_convs{},
find_mul_conv{}, find_mul_conv{},
find_mul_add{}); find_mul_add{},
find_concat_unary{},
find_concat_binary{});
dead_code_elimination{}.apply(p); dead_code_elimination{}.apply(p);
} }
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
...@@ -144,13 +145,14 @@ struct cpu_lrn ...@@ -144,13 +145,14 @@ struct cpu_lrn
int height = output_shape.lens()[2]; int height = output_shape.lens()[2];
int width = output_shape.lens()[3]; int width = output_shape.lens()[3];
float alphaoverarea = op.alpha / float(op.size); float alphaoverarea = op.alpha / float(op.size);
int radius = (op.size - 1) / 2; int radius_lower = (op.size - 1) / 2;
int radius_upper = op.size / 2 + 1;
par_dfor(n_batch, height, width)([&](int b, int h, int w) { par_dfor(n_batch, height, width)([&](int b, int h, int w) {
float scale = 0; float scale = 0;
dfor(channels)([&](int c) { dfor(channels)([&](int c) {
auto start = (c - radius) < 0 ? 0 : (c - radius); auto start = (c - radius_lower) < 0 ? 0 : (c - radius_lower);
auto end = (c + radius) > channels ? channels : (c + radius); auto end = (c + radius_upper) > channels ? channels : (c + radius_upper);
for(auto k = start; k < end; ++k) for(auto k = start; k < end; ++k)
{ {
scale += std::pow(input(b, k, h, w), 2); scale += std::pow(input(b, k, h, w), 2);
...@@ -220,6 +222,67 @@ struct cpu_convolution ...@@ -220,6 +222,67 @@ struct cpu_convolution
} }
}; };
template <class Op>
struct cpu_deconvolution
{
Op op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "cpu::" + op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
using type = typename decltype(output)::value_type;
std::fill(output.begin(), output.end(), type{0});
auto out_lens = output_shape.lens();
auto out_h = out_lens[2];
auto out_w = out_lens[3];
auto in = input.get_shape().lens();
auto in_n = in[0];
auto in_c = in[1];
auto in_h = in[2];
auto in_w = in[3];
auto wei = weights.get_shape().lens();
auto wei_n = wei[0];
auto wei_c = wei[1];
auto wei_h = wei[2];
auto wei_w = wei[3];
par_dfor(in_n, wei_c)([&](std::size_t o, std::size_t k) {
dfor(in_c, in_h, in_w, wei_h, wei_w)(
[&](std::size_t w, std::size_t i, std::size_t j, std::size_t x, std::size_t y) {
const int start_x = i * op.stride[0] - op.padding[0];
const int start_y = j * op.stride[1] - op.padding[1];
const int out_x = start_x + x * op.dilation[0];
const int out_y = start_y + y * op.dilation[1];
const auto group_id = w / (wei_n / op.group);
const auto in_ch = group_id * wei_c + k;
if(out_x >= 0 && out_x < out_h && out_y >= 0 && out_y < out_w)
{
output(o, in_ch, out_x, out_y) +=
input(o, w, i, j) * weights(w, k, x, y);
}
});
});
});
return result;
}
};
struct cpu_im2col struct cpu_im2col
{ {
op::im2col op; op::im2col op;
...@@ -598,9 +661,10 @@ struct cpu_softmax ...@@ -598,9 +661,10 @@ struct cpu_softmax
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = output_shape.lens();
std::size_t n_dims = batch_lens[op.axis]; int64_t tuned_axis = (op.axis < 0) ? op.axis + args[0].get_shape().lens().size() : op.axis;
batch_lens[op.axis] = 1; std::size_t n_dims = batch_lens[tuned_axis];
batch_lens[tuned_axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
...@@ -612,26 +676,26 @@ struct cpu_softmax ...@@ -612,26 +676,26 @@ struct cpu_softmax
auto idx = batch_shape.multi(i); auto idx = batch_shape.multi(i);
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[tuned_axis] = j;
batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end())); batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end()));
} }
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[tuned_axis] = j;
std::size_t index = output_shape.index(idx); std::size_t index = output_shape.index(idx);
output[index] = std::exp(input[index] - batch_max[i]); output[index] = std::exp(input[index] - batch_max[i]);
} }
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[tuned_axis] = j;
batch_sum[i] += output(idx.begin(), idx.end()); batch_sum[i] += output(idx.begin(), idx.end());
} }
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[tuned_axis] = j;
output(idx.begin(), idx.end()) = output(idx.begin(), idx.end()) =
op.output()(output(idx.begin(), idx.end()), batch_sum[i]); op.output()(output(idx.begin(), idx.end()), batch_sum[i]);
} }
...@@ -664,8 +728,10 @@ struct cpu_apply ...@@ -664,8 +728,10 @@ struct cpu_apply
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["convolution"] = extend_op<cpu_convolution<op::convolution>, op::convolution>(); apply_map["convolution"] = extend_op<cpu_convolution<op::convolution>, op::convolution>();
apply_map["dot"] = extend_op<cpu_gemm, op::dot>(); apply_map["deconvolution"] =
apply_map["quant_dot"] = extend_op<cpu_quant_gemm, op::quant_dot>(); extend_op<cpu_deconvolution<op::deconvolution>, op::deconvolution>();
apply_map["dot"] = extend_op<cpu_gemm, op::dot>();
apply_map["quant_dot"] = extend_op<cpu_quant_gemm, op::quant_dot>();
apply_map["quant_convolution"] = apply_map["quant_convolution"] =
extend_op<cpu_convolution<op::quant_convolution>, op::quant_convolution>(); extend_op<cpu_convolution<op::quant_convolution>, op::quant_convolution>();
apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>(); apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>();
......
...@@ -12,6 +12,7 @@ endif() ...@@ -12,6 +12,7 @@ endif()
add_library(migraphx_device add_library(migraphx_device
device/acos.cpp device/acos.cpp
device/acosh.cpp
device/add.cpp device/add.cpp
device/add_clip.cpp device/add_clip.cpp
device/add_relu.cpp device/add_relu.cpp
...@@ -20,7 +21,9 @@ add_library(migraphx_device ...@@ -20,7 +21,9 @@ add_library(migraphx_device
device/argmax.cpp device/argmax.cpp
device/argmin.cpp device/argmin.cpp
device/asin.cpp device/asin.cpp
device/asinh.cpp
device/atan.cpp device/atan.cpp
device/atanh.cpp
device/ceil.cpp device/ceil.cpp
device/clip.cpp device/clip.cpp
device/concat.cpp device/concat.cpp
...@@ -43,10 +46,12 @@ add_library(migraphx_device ...@@ -43,10 +46,12 @@ add_library(migraphx_device
device/mul_add_relu.cpp device/mul_add_relu.cpp
device/pad.cpp device/pad.cpp
device/pow.cpp device/pow.cpp
device/prelu.cpp
device/reduce_max.cpp device/reduce_max.cpp
device/reduce_mean.cpp device/reduce_mean.cpp
device/reduce_min.cpp device/reduce_min.cpp
device/reduce_sum.cpp device/reduce_sum.cpp
device/reduce_prod.cpp
device/relu.cpp device/relu.cpp
device/round.cpp device/round.cpp
device/rsqrt.cpp device/rsqrt.cpp
...@@ -79,6 +84,7 @@ add_library(migraphx_gpu ...@@ -79,6 +84,7 @@ add_library(migraphx_gpu
lowering.cpp lowering.cpp
pooling.cpp pooling.cpp
convolution.cpp convolution.cpp
deconvolution.cpp
quant_convolution.cpp quant_convolution.cpp
softmax.cpp softmax.cpp
logsoftmax.cpp logsoftmax.cpp
......
...@@ -14,7 +14,9 @@ shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const ...@@ -14,7 +14,9 @@ shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const
argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::argmax(ctx.get_stream().get(), args.back(), args.front(), op.axis); auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = (op.axis < 0) ? op.axis + n_dim : op.axis;
device::argmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis);
return args.back(); return args.back();
} }
......
...@@ -14,7 +14,9 @@ shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const ...@@ -14,7 +14,9 @@ shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const
argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argument>& args) const argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::argmin(ctx.get_stream().get(), args.back(), args.front(), op.axis); auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = (op.axis < 0) ? op.axis + n_dim : op.axis;
device::argmin(ctx.get_stream().get(), args.back(), args.front(), tuned_axis);
return args.back(); return args.back();
} }
......
#include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_deconvolution::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).standard();
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument miopen_deconvolution::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
auto x_desc = make_tensor(args[0].get_shape());
auto w_desc = make_tensor(args[1].get_shape());
auto y_desc = make_tensor(output_shape);
float alpha = 1;
float beta = 0;
auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("Running deconvolution failed");
return args[3];
}
shape miopen_deconvolution::compile(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
shape workspace_shape{};
auto x_desc = make_tensor(inputs[0]);
auto w_desc = make_tensor(inputs[1]);
auto y_desc = make_tensor(output_shape);
std::size_t workspace_size = 0;
miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}};
auto x = to_gpu(generate_argument(inputs[0]));
auto w = to_gpu(generate_argument(inputs[1]));
auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape);
int algo_count = 1;
miopenConvAlgoPerf_t perf;
auto status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
x_desc.get(),
x.implicit(),
w_desc.get(),
w.implicit(),
cd.get(),
y_desc.get(),
y.implicit(),
1,
&algo_count,
&perf,
workspace.implicit(),
workspace_size,
false);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("Find deconvolution failed");
handle = ctx.get_stream().get_miopen();
algo = perf.fwd_algo;
return shape{shape::int8_type, {perf.memory}};
}
void miopen_deconvolution::finalize(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
if(handle == ctx.get_stream().get_miopen())
return;
// Check that workspace hasn't changed
auto size = inputs.at(2).bytes();
auto ws = compile(ctx, output_shape, std::move(inputs));
if(ws.bytes() > size)
MIGRAPHX_THROW("Workspace has changed during finalization.");
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -9,7 +9,7 @@ namespace device { ...@@ -9,7 +9,7 @@ namespace device {
void acos(hipStream_t stream, const argument& result, const argument& arg) void acos(hipStream_t stream, const argument& result, const argument& arg)
{ {
nary(stream, result, arg)([](auto x) { return ::acos(to_hip_type(x)); }); nary(stream, result, arg)([](auto x) __device__ { return ::acos(to_hip_type(x)); });
} }
} // namespace device } // namespace device
......
#include <migraphx/gpu/device/acosh.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void acosh(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) { return ::acosh(to_hip_type(x)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment