"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "392ed8e545f72d514269d821cd982f42414d88bc"
Commit 7702c20d authored by Paul's avatar Paul
Browse files

Merge

parents c362e7fa 9afce86d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SQLITE_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SQLITE_HPP
#include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp>
#include <memory>
#include <unordered_map>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct sqlite_impl;
struct sqlite
{
sqlite() = default;
static sqlite read(const fs::path& p);
static sqlite write(const fs::path& p);
std::vector<std::unordered_map<std::string, std::string>> execute(const std::string& s);
private:
std::shared_ptr<sqlite_impl> impl;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SQLITE_HPP
...@@ -174,27 +174,27 @@ inline std::string interpolate_string(const std::string& input, ...@@ -174,27 +174,27 @@ inline std::string interpolate_string(const std::string& input,
} }
template <class Iterator> template <class Iterator>
inline std::string to_string_range(Iterator start, Iterator last) inline std::string to_string_range(Iterator start, Iterator last, const char* delim = ", ")
{ {
std::stringstream ss; std::stringstream ss;
if(start != last) if(start != last)
{ {
ss << *start; ss << *start;
std::for_each(std::next(start), last, [&](auto&& x) { ss << ", " << x; }); std::for_each(std::next(start), last, [&](auto&& x) { ss << delim << x; });
} }
return ss.str(); return ss.str();
} }
template <class Range> template <class Range>
inline std::string to_string_range(const Range& r) inline std::string to_string_range(const Range& r, const char* delim = ", ")
{ {
return to_string_range(r.begin(), r.end()); return to_string_range(r.begin(), r.end(), delim);
} }
template <class T> template <class T>
inline std::string to_string_range(const std::initializer_list<T>& r) inline std::string to_string_range(const std::initializer_list<T>& r, const char* delim = ", ")
{ {
return to_string_range(r.begin(), r.end()); return to_string_range(r.begin(), r.end(), delim);
} }
template <class T> template <class T>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
enum class support_metric
{
latency,
throughput
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
...@@ -37,6 +37,8 @@ ...@@ -37,6 +37,8 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -61,6 +63,13 @@ struct target ...@@ -61,6 +63,13 @@ struct target
* @return The context to be used during compilation and execution. * @return The context to be used during compilation and execution.
*/ */
context get_context() const; context get_context() const;
/**
* @brief Check how well an instruction is supported on a target with the given metric
* @param ins Instruction to check if it's supported
* @param metric Used to define how the return value should be interpreted
* @return The value based on the chosen metric. Negative numbers mean unsupported
*/
float is_supported(T&, instruction_ref ins, support_metric m) const;
/** /**
* @brief copy an argument to the current target. * @brief copy an argument to the current target.
* *
...@@ -105,6 +114,12 @@ argument copy_from_target(T&, const argument& arg) ...@@ -105,6 +114,12 @@ argument copy_from_target(T&, const argument& arg)
return arg; return arg;
} }
template <class T>
float target_is_supported(T&, instruction_ref, support_metric)
{
return 0;
}
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for: // Type-erased interface for:
...@@ -117,6 +132,8 @@ struct target ...@@ -117,6 +132,8 @@ struct target
// //
context get_context() const; context get_context() const;
// (optional) // (optional)
float is_supported(instruction_ref ins, support_metric m) const;
// (optional)
argument copy_to(const argument& input) const; argument copy_to(const argument& input) const;
// (optional) // (optional)
argument copy_from(const argument& input) const; argument copy_from(const argument& input) const;
...@@ -207,6 +224,12 @@ struct target ...@@ -207,6 +224,12 @@ struct target
return (*this).private_detail_te_get_handle().get_context(); return (*this).private_detail_te_get_handle().get_context();
} }
float is_supported(instruction_ref ins, support_metric m) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_supported(ins, m);
}
argument copy_to(const argument& input) const argument copy_to(const argument& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -242,11 +265,31 @@ struct target ...@@ -242,11 +265,31 @@ struct target
virtual std::vector<pass> get_passes(context& ctx, virtual std::vector<pass> get_passes(context& ctx,
const compile_options& options) const = 0; const compile_options& options) const = 0;
virtual context get_context() const = 0; virtual context get_context() const = 0;
virtual float is_supported(instruction_ref ins, support_metric m) const = 0;
virtual argument copy_to(const argument& input) const = 0; virtual argument copy_to(const argument& input) const = 0;
virtual argument copy_from(const argument& input) const = 0; virtual argument copy_from(const argument& input) const = 0;
virtual argument allocate(const shape& s) const = 0; virtual argument allocate(const shape& s) const = 0;
}; };
template <class T>
static auto private_detail_te_default_is_supported(char,
T&& private_detail_te_self,
instruction_ref ins,
support_metric m)
-> decltype(private_detail_te_self.is_supported(ins, m))
{
return private_detail_te_self.is_supported(ins, m);
}
template <class T>
static float private_detail_te_default_is_supported(float,
T&& private_detail_te_self,
instruction_ref ins,
support_metric m)
{
return target_is_supported(private_detail_te_self, ins, m);
}
template <class T> template <class T>
static auto static auto
private_detail_te_default_copy_to(char, T&& private_detail_te_self, const argument& input) private_detail_te_default_copy_to(char, T&& private_detail_te_self, const argument& input)
...@@ -329,6 +372,12 @@ struct target ...@@ -329,6 +372,12 @@ struct target
context get_context() const override { return private_detail_te_value.get_context(); } context get_context() const override { return private_detail_te_value.get_context(); }
float is_supported(instruction_ref ins, support_metric m) const override
{
return private_detail_te_default_is_supported(char(0), private_detail_te_value, ins, m);
}
argument copy_to(const argument& input) const override argument copy_to(const argument& input) const override
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#include <unordered_map>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct target_assignments
{
void add_assignment(instruction_ref ins, const std::string& target);
auto begin() const { return assignments.cbegin(); }
auto end() const { return assignments.cend(); }
private:
std::unordered_map<instruction_ref, std::string> assignments;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
...@@ -40,6 +40,12 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins, ...@@ -40,6 +40,12 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
auto val = op.to_value(); auto val = op.to_value();
auto op_padding = val.at("padding").to_vector<size_t>(); auto op_padding = val.at("padding").to_vector<size_t>();
// skip if shape is dynamic
if(input->get_shape().dynamic())
{
return;
}
auto kdims = input->get_shape().lens().size() - 2; auto kdims = input->get_shape().lens().size() - 2;
if(std::equal(op_padding.begin(), if(std::equal(op_padding.begin(),
op_padding.begin() + kdims, op_padding.begin() + kdims,
......
...@@ -445,8 +445,8 @@ operation instruction::normalized_operator() const ...@@ -445,8 +445,8 @@ operation instruction::normalized_operator() const
operation o = this->get_operator(); operation o = this->get_operator();
if(this->need_normalization()) if(this->need_normalization())
{ {
auto lens = this->inputs().front()->get_shape().lens(); auto s = this->inputs().front()->get_shape();
if(!normalize_attributes(o, lens)) if(!normalize_attributes(o, s.max_lens()))
return this->get_operator(); return this->get_operator();
} }
return o; return o;
......
...@@ -399,7 +399,8 @@ module::add_instructions(const std::vector<instruction_ref>& instructions, ...@@ -399,7 +399,8 @@ module::add_instructions(const std::vector<instruction_ref>& instructions,
} }
std::vector<instruction_ref> std::vector<instruction_ref>
module::add_instructions(module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins) module::add_instructions(const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{ {
return this->insert_instructions(this->end(), m, std::move(map_ins)); return this->insert_instructions(this->end(), m, std::move(map_ins));
} }
...@@ -420,8 +421,10 @@ module::insert_instructions(instruction_ref ins, ...@@ -420,8 +421,10 @@ module::insert_instructions(instruction_ref ins,
return insert_generic_instructions(*this, ins, instructions, std::move(map_ins)); return insert_generic_instructions(*this, ins, instructions, std::move(map_ins));
} }
std::vector<instruction_ref> module::insert_instructions( std::vector<instruction_ref>
instruction_ref ins, module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins) module::insert_instructions(instruction_ref ins,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{ {
return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins)); return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins));
} }
...@@ -436,11 +439,7 @@ module::insert_instructions(instruction_ref ins, ...@@ -436,11 +439,7 @@ module::insert_instructions(instruction_ref ins,
return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins)); return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins));
} }
instruction_ref module::add_literal(literal l) instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); }
{
impl->emplace_front(std::move(l));
return impl->instructions.begin();
}
instruction_ref module::add_outline(const shape& s) instruction_ref module::add_outline(const shape& s)
{ {
...@@ -450,10 +449,7 @@ instruction_ref module::add_outline(const shape& s) ...@@ -450,10 +449,7 @@ instruction_ref module::add_outline(const shape& s)
instruction_ref module::add_parameter(std::string name, shape s) instruction_ref module::add_parameter(std::string name, shape s)
{ {
assert(get_parameter_shape(name) == shape{}); return insert_parameter(begin(), std::move(name), std::move(s));
impl->push_front({builtin::param{std::move(name), impl->nparams}, std::move(s), {}});
impl->nparams++;
return impl->instructions.begin();
} }
instruction_ref module::add_return(std::vector<instruction_ref> args) instruction_ref module::add_return(std::vector<instruction_ref> args)
...@@ -466,6 +462,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args) ...@@ -466,6 +462,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args)
return result; return result;
} }
instruction_ref module::insert_literal(instruction_ref ins, literal l)
{
impl->emplace(ins, std::move(l));
return std::prev(ins);
}
instruction_ref module::insert_parameter(instruction_ref ins, std::string name, shape s)
{
assert(get_parameter_shape(name) == shape{});
impl->insert(ins, {builtin::param{std::move(name), impl->nparams}, std::move(s), {}});
impl->nparams++;
return std::prev(ins);
}
instruction_ref module::replace_return(std::vector<instruction_ref> args) instruction_ref module::replace_return(std::vector<instruction_ref> args)
{ {
auto last = std::prev(this->end()); auto last = std::prev(this->end());
......
...@@ -43,9 +43,9 @@ void normalize_ops::apply(module& m) const ...@@ -43,9 +43,9 @@ void normalize_ops::apply(module& m) const
if(inputs.empty()) if(inputs.empty())
continue; continue;
auto lens = inputs[0]->get_shape().lens(); auto s = inputs[0]->get_shape();
migraphx::operation tuned_op = ins->get_operator(); migraphx::operation tuned_op = ins->get_operator();
if(normalize_attributes(tuned_op, lens)) if(normalize_attributes(tuned_op, s.max_lens()))
{ {
m.replace_instruction(ins, tuned_op, inputs); m.replace_instruction(ins, tuned_op, inputs);
ins->set_normalized(); ins->set_normalized();
......
...@@ -93,9 +93,10 @@ struct onnx_parser ...@@ -93,9 +93,10 @@ struct onnx_parser
onnx_parser&, const node_info&, std::vector<instruction_ref>)>; onnx_parser&, const node_info&, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
std::size_t default_dim_value = 1; shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10; int64_t max_loop_iterations = 10;
int64_t opset_version = 13; int64_t opset_version = 13;
...@@ -118,6 +119,7 @@ struct onnx_parser ...@@ -118,6 +119,7 @@ struct onnx_parser
}; };
shape::type_t get_type(int dtype); shape::type_t get_type(int dtype);
bool is_type_float(shape::type_t dtype);
} // namespace onnx } // namespace onnx
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -41,8 +41,25 @@ template <class... Ts> ...@@ -41,8 +41,25 @@ template <class... Ts>
program parse_onnx_from(const onnx_options& options, Ts&&... xs) program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{ {
onnx::onnx_parser parser; onnx::onnx_parser parser;
parser.map_input_dims = options.map_input_dims; parser.map_input_dims = options.map_input_dims;
parser.default_dim_value = options.default_dim_value; parser.map_dyn_input_dims = options.map_dyn_input_dims;
auto dim_val = options.default_dim_value;
if(dim_val != 0)
{
if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1, 0})
{
MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value"
"set to non-default value");
}
else
{
parser.default_dyn_dim_value = {dim_val, dim_val, 0};
}
}
else
{
parser.default_dyn_dim_value = options.default_dyn_dim_value;
}
parser.skip_unknown_operators = options.skip_unknown_operators; parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations; parser.max_loop_iterations = options.max_loop_iterations;
......
...@@ -28,16 +28,17 @@ ...@@ -28,16 +28,17 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/common.hpp> #include <migraphx/common.hpp>
#include <migraphx/type_traits.hpp> #include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp> #include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp> #include <migraphx/op/unknown.hpp>
#include <migraphx/env.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node) static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node)
...@@ -58,7 +59,7 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const ...@@ -58,7 +59,7 @@ create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const
std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>()); std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>());
if(elem_num == 0) if(elem_num == 0)
{ {
return {}; return literal{shape_type};
} }
// in case of scalar constants in onnx file, use dims=1 to fill initializer data // in case of scalar constants in onnx file, use dims=1 to fill initializer data
...@@ -75,7 +76,7 @@ static literal create_literal(shape::type_t shape_type, const std::vector<size_t ...@@ -75,7 +76,7 @@ static literal create_literal(shape::type_t shape_type, const std::vector<size_t
std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>()); std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies<std::size_t>());
if(elem_num == 0) if(elem_num == 0)
{ {
return {}; return literal{shape_type};
} }
// scalar input // scalar input
...@@ -255,6 +256,11 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) ...@@ -255,6 +256,11 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{ {
if(not map_input_dims.empty() and not map_dyn_input_dims.empty())
{
MIGRAPHX_THROW("PARSE_GRAPH: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used");
}
std::unordered_map<std::string, instruction_ref> mod_insts; std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
{ {
...@@ -268,7 +274,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -268,7 +274,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
// input not in initializer_data, so it is a real input // input not in initializer_data, so it is a real input
if(!contains(mod_insts, name)) if(!contains(mod_insts, name))
{ {
// ONNX specification does not specify hwo to deal with the // ONNX specification does not specify how to deal with the
// scenario that a nested subgraph contains a parameter with the // scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph. // name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that. // In the current implementation, MIGraphX throws an exception for that.
...@@ -278,13 +284,22 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -278,13 +284,22 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
"\" existing in parent graph!"); "\" existing in parent graph!");
} }
shape s;
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0) if(map_input_dims.count(name) > 0)
{ {
dims = map_input_dims.at(name); dims = map_input_dims.at(name);
s = parse_type(input.type(), dims);
}
else if(map_dyn_input_dims.count(name) > 0)
{
shape::type_t shape_type = get_type(input.type().tensor_type().elem_type());
s = {shape_type, map_dyn_input_dims.at(name)};
}
else
{
s = parse_type(input.type(), dims);
} }
shape s = parse_type(input.type(), dims);
mod_insts[name] = mod->add_parameter(name, s); mod_insts[name] = mod->add_parameter(name, s);
} }
} }
...@@ -439,30 +454,41 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -439,30 +454,41 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return {shape_type, input_dims}; return {shape_type, input_dims};
} }
std::vector<std::size_t> dims; std::vector<shape::dynamic_dimension> dynamic_dims;
auto&& tensor_dims = t.tensor_type().shape().dim(); auto&& tensor_dims = t.tensor_type().shape().dim();
std::transform(tensor_dims.begin(), std::transform(tensor_dims.begin(),
tensor_dims.end(), tensor_dims.end(),
std::back_inserter(dims), std::back_inserter(dynamic_dims),
[&](auto&& d) -> std::size_t { [&](auto&& d) -> shape::dynamic_dimension {
if(d.has_dim_value()) if(d.has_dim_value())
{ {
if(static_cast<int>(d.dim_value()) <= 0) if(static_cast<int>(d.dim_value()) <= 0)
{ {
return default_dim_value; return default_dyn_dim_value;
} }
return d.dim_value(); std::size_t tmp = d.dim_value();
return {tmp, tmp, 0};
} }
else else
{ {
return default_dim_value; return default_dyn_dim_value;
} }
}); });
if(dims.empty()) if(dynamic_dims.empty())
{
return {shape_type}; return {shape_type};
}
return {shape_type, dims}; if(std::all_of(dynamic_dims.begin(), dynamic_dims.end(), [](auto dd) { return dd.is_fixed(); }))
{
std::vector<std::size_t> dims;
std::transform(dynamic_dims.begin(),
dynamic_dims.end(),
std::back_inserter(dims),
[](auto d) { return d.max; });
return {shape_type, dims};
}
return {shape_type, dynamic_dims};
} }
shape::type_t get_type(int dtype) shape::type_t get_type(int dtype)
...@@ -487,6 +513,16 @@ shape::type_t get_type(int dtype) ...@@ -487,6 +513,16 @@ shape::type_t get_type(int dtype)
} }
} }
bool is_type_float(shape::type_t dtype)
{
bool r = false;
if(dtype == shape::float_type || dtype == shape::double_type || dtype == shape::half_type)
{
r = true;
}
return r;
}
} // namespace onnx } // namespace onnx
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -43,7 +43,7 @@ struct parse_constant : op_parser<parse_constant> ...@@ -43,7 +43,7 @@ struct parse_constant : op_parser<parse_constant>
// return empty literal // return empty literal
if(v.get_shape().elements() == 0) if(v.get_shape().elements() == 0)
{ {
return info.add_literal(literal{}); return info.add_literal(literal{v.get_shape().type()});
} }
auto dim_size = info.attributes.at("value").t().dims_size(); auto dim_size = info.attributes.at("value").t().dims_size();
......
...@@ -47,15 +47,17 @@ struct parse_convolution : op_parser<parse_convolution> ...@@ -47,15 +47,17 @@ struct parse_convolution : op_parser<parse_convolution>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto op = make_op(opd.op_name); auto op = make_op(opd.op_name);
auto values = op.to_value(); auto values = op.to_value();
auto l0 = args[0]; auto l0 = args[0];
auto weights = args[1]; auto weights = args[1];
auto in_lens = l0->get_shape().lens(); auto l0_shape = l0->get_shape();
auto w_shape = weights->get_shape();
auto in_lens = l0_shape.max_lens();
assert(in_lens.size() > 2); assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2; auto kdims = in_lens.size() - 2;
// ensure pads availabe only when auto_pad is "NOT_SET" // ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode(info, "CONV"); check_padding_mode(info, "CONV");
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
...@@ -79,21 +81,65 @@ struct parse_convolution : op_parser<parse_convolution> ...@@ -79,21 +81,65 @@ struct parse_convolution : op_parser<parse_convolution>
copy(info.attributes["pads"].ints(), std::back_inserter(padding)); copy(info.attributes["pads"].ints(), std::back_inserter(padding));
check_attr_sizes(kdims, padding.size() / 2, "PARSE_CONV: inconsistent paddings"); check_attr_sizes(kdims, padding.size() / 2, "PARSE_CONV: inconsistent paddings");
} }
if(contains(info.attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{ {
auto weight_lens = weights->get_shape().lens(); bool is_same_padding = false;
std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end()); auto auto_pad = info.attributes["auto_pad"].s();
cal_auto_padding_size(info,
values,
k_lens,
values["dilation"].to_vector<std::size_t>(),
in_lens,
padding);
auto auto_pad = info.attributes["auto_pad"].s();
if(auto_pad.find("SAME") != std::string::npos) if(auto_pad.find("SAME") != std::string::npos)
{ {
values["padding_mode"] = to_value(op::padding_mode_t::same); is_same_padding = true;
}
// check if image shape is dynamic
bool image_shape_dynamic = false;
if(l0_shape.dynamic())
{
auto dyn_dims = l0_shape.dyn_dims();
std::for_each(dyn_dims.begin() + 2, dyn_dims.end(), [&](auto dyn_dim) {
if(not dyn_dim.is_fixed())
{
image_shape_dynamic = true;
}
});
}
// check if kernel shape is dynamic
bool kernel_shape_dynamic = false;
if(w_shape.dynamic())
{
auto dyn_dims = w_shape.dyn_dims();
std::for_each(dyn_dims.begin() + 2, dyn_dims.end(), [&](auto dyn_dim) {
if(not dyn_dim.is_fixed())
{
kernel_shape_dynamic = true;
}
});
}
if(is_same_padding)
{
if(image_shape_dynamic or kernel_shape_dynamic)
{
// must calculate "same" padding with input shape data
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper
? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
values["use_dynamic_same_auto_pad"] = true;
}
else
{
values["padding_mode"] = to_value(op::padding_mode_t::same);
// kernel shape will be fixed, so max_lens() == min_len() for kernel lengths
auto weight_lens = weights->get_shape().max_lens();
std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end());
cal_auto_padding_size(info,
values,
k_lens,
values["dilation"].to_vector<std::size_t>(),
in_lens,
padding);
}
} }
} }
values["padding"] = std::vector<size_t>(padding.begin(), padding.end()); values["padding"] = std::vector<size_t>(padding.begin(), padding.end());
......
...@@ -47,7 +47,8 @@ struct parse_if : op_parser<parse_if> ...@@ -47,7 +47,8 @@ struct parse_if : op_parser<parse_if>
if(args.front()->get_shape().elements() != 1) if(args.front()->get_shape().elements() != 1)
{ {
MIGRAPHX_THROW("PARSE_IF: condition input can have only one element!"); MIGRAPHX_THROW("PARSE_IF: " + info.name +
" condition input can have only one element!");
} }
std::string then_name = info.name + "_if"; std::string then_name = info.name + "_if";
...@@ -69,7 +70,8 @@ struct parse_if : op_parser<parse_if> ...@@ -69,7 +70,8 @@ struct parse_if : op_parser<parse_if>
else_out_shapes.begin(), else_out_shapes.begin(),
else_out_shapes.end())) else_out_shapes.end()))
{ {
MIGRAPHX_THROW("PARSE_IF: then and else sub_grahps must have same output shapes!"); MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_grahps must have same output shapes!");
} }
auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl}); auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
......
...@@ -32,9 +32,12 @@ namespace onnx { ...@@ -32,9 +32,12 @@ namespace onnx {
struct parse_instancenorm : op_parser<parse_instancenorm> struct parse_instancenorm : op_parser<parse_instancenorm>
{ {
const std::set<shape::type_t> valid_types = {
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"InstanceNormalization"}}; } std::vector<op_desc> operators() const { return {{"InstanceNormalization"}}; }
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& opd,
const onnx_parser& parser, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
...@@ -52,6 +55,11 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -52,6 +55,11 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto scale = args[1]; auto scale = args[1];
auto bias = args[2]; auto bias = args[2];
auto dims = x->get_shape().lens(); auto dims = x->get_shape().lens();
auto dtype = x->get_shape().type();
if(not contains(valid_types, dtype))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double).");
auto ndims = dims.size(); auto ndims = dims.size();
assert(ndims >= 2); assert(ndims >= 2);
auto kdims = ndims - 2; auto kdims = ndims - 2;
...@@ -65,7 +73,7 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -65,7 +73,7 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast); auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0); auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast); auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = info.add_literal(epsilon); auto epsilon_literal = info.add_literal(literal{shape{dtype}, {epsilon}});
auto epsilon_bcast = auto epsilon_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal); info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast = auto variance_bcast =
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_mod : op_parser<parse_mod>
{
std::vector<op_desc> operators() const { return {{"Mod"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
std::string mod = "mod";
if(is_type_float(args[0]->get_shape().type()) || is_type_float(args[1]->get_shape().type()))
{
if(!contains(info.attributes, "fmod"))
{
MIGRAPHX_THROW("Mod operator with float args and fmod=0 invalid");
}
}
if(contains(info.attributes, "fmod"))
{
if(parser.parse_value(info.attributes.at("fmod")).at<int>() == 1)
{
mod = "fmod";
}
}
return info.add_common_op(mod, args[0], args[1]);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/pad_calc.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void calculate_padding(int64_t idx,
std::vector<int64_t>& pads,
int64_t input_dim,
int64_t stride,
int64_t dilation,
int64_t weight_dim,
bool is_same_upper)
{
int64_t output_dim = (input_dim + stride - 1) / stride; // round up result
int64_t new_weight_dim = weight_dim + (weight_dim - 1) * (dilation - 1);
int64_t pad =
std::max(static_cast<int64_t>(0), (output_dim - 1) * stride + new_weight_dim - input_dim);
auto pad_ndims = pads.size() / 2;
if(is_same_upper)
{
pads[idx] = pad / 2;
pads[idx + pad_ndims] = pad - pad / 2;
}
else
{
pads[idx + pad_ndims] = pad / 2;
pads[idx] = pad - pad / 2;
}
}
std::vector<std::size_t> calc_dyn_auto_pad(std::vector<std::size_t> tensor_lens,
std::vector<std::size_t> k_lens,
std::vector<std::size_t> strides,
std::vector<std::size_t> dilations,
bool use_upper)
{
std::vector<std::size_t> padding;
padding.resize(2 * k_lens.size());
for(size_t i = 0; i < padding.size() / 2; i++)
{
std::ptrdiff_t input_dim = tensor_lens[i];
std::ptrdiff_t stride = strides[i];
std::ptrdiff_t weight_dim = k_lens[i];
std::ptrdiff_t dilation = dilations[i];
std::ptrdiff_t output_dim = (input_dim + stride - 1) / stride; // round up result
std::ptrdiff_t new_weight_dim = weight_dim + (weight_dim - 1) * (dilation - 1);
std::size_t pad = std::max(static_cast<std::ptrdiff_t>(0),
(output_dim - 1) * stride + new_weight_dim - input_dim);
auto pad_ndims = padding.size() / 2;
if(use_upper)
{
padding[i] = pad / 2;
padding[i + pad_ndims] = pad - pad / 2;
}
else
{
padding[i + pad_ndims] = pad / 2;
padding[i] = pad - pad / 2;
}
}
return padding;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -159,6 +159,25 @@ instruction_ref program::validate() const ...@@ -159,6 +159,25 @@ instruction_ref program::validate() const
return mm->validate(); return mm->validate();
} }
target_assignments program::get_target_assignments(const std::vector<target>& targets,
assignment_options options)
{
const auto m = options.metric;
target_assignments p;
const auto* mod = get_main_module();
for(auto it : iterator_for(*mod))
{
auto t = std::max_element(
targets.begin(), targets.end(), [it, m](const target& lhs, const target& rhs) {
return lhs.is_supported(it, m) < rhs.is_supported(it, m);
});
p.add_assignment(it, t->name());
}
return p;
}
bool program::is_compiled() const { return not this->impl->target_name.empty(); } bool program::is_compiled() const { return not this->impl->target_name.empty(); }
void program::compile(const target& t, compile_options options) void program::compile(const target& t, compile_options options)
...@@ -288,9 +307,12 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -288,9 +307,12 @@ std::vector<argument> generic_eval(const module* mod,
if(not contains(params, param_name)) if(not contains(params, param_name))
MIGRAPHX_THROW("Parameter not found: " + param_name); MIGRAPHX_THROW("Parameter not found: " + param_name);
auto param = params[param_name]; auto param = params[param_name];
if(param.get_shape() != ins->get_shape()) // TODO: may want to check correct number of dimensions and/or was within bounds
if(not ins->get_shape().dynamic() and param.get_shape() != ins->get_shape())
{
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) + MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name); "} for parameter: " + param_name);
}
return param; return param;
})); }));
} }
...@@ -333,7 +355,10 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -333,7 +355,10 @@ std::vector<argument> generic_eval(const module* mod,
})); }));
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
assert(results.at(ins).get_shape() == ins->get_shape()); if(not ins->get_shape().dynamic())
{
assert(results.at(ins).get_shape() == ins->get_shape());
}
} }
return {results.at(std::prev(mod->end()))}; return {results.at(std::prev(mod->end()))};
} }
...@@ -504,12 +529,14 @@ static void mod_from_val(module_ref mod, ...@@ -504,12 +529,14 @@ static void mod_from_val(module_ref mod,
if(name == "@param") if(name == "@param")
{ {
output = mod->add_parameter(fields["parameter"].to<std::string>(), output = mod->insert_parameter(mod->end(),
migraphx::from_value<shape>(node.at("shape"))); fields["parameter"].to<std::string>(),
migraphx::from_value<shape>(node.at("shape")));
} }
else if(name == "@literal") else if(name == "@literal")
{ {
output = mod->add_literal(migraphx::from_value<literal>(node.at("literal"))); output =
mod->insert_literal(mod->end(), migraphx::from_value<literal>(node.at("literal")));
} }
else else
{ {
...@@ -544,11 +571,11 @@ static void mod_from_val(module_ref mod, ...@@ -544,11 +571,11 @@ static void mod_from_val(module_ref mod,
} }
else if(module_inputs.empty()) else if(module_inputs.empty())
{ {
output = mod->add_instruction(op, inputs); output = mod->insert_instruction(mod->end(), op, inputs);
} }
else else
{ {
output = mod->add_instruction(op, inputs, module_inputs); output = mod->insert_instruction(mod->end(), op, inputs, module_inputs);
} }
} }
output->set_normalized(normalized); output->set_normalized(normalized);
...@@ -681,11 +708,13 @@ void program::perf_report(std::ostream& os, ...@@ -681,11 +708,13 @@ void program::perf_report(std::ostream& os,
double overhead_percent = overhead_time * 100.0 / total_time; double overhead_percent = overhead_time * 100.0 / total_time;
double total_instruction_time = 0.0; double total_instruction_time = 0.0;
std::unordered_map<std::string, double> op_times; std::unordered_map<std::string, double> op_times;
std::unordered_map<std::string, std::size_t> op_n;
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
{ {
double avg = common_average(p.second); double avg = common_average(p.second);
op_times[perf_group(p.first->get_operator())] += avg; op_times[perf_group(p.first->get_operator())] += avg;
total_instruction_time += avg; total_instruction_time += avg;
op_n[perf_group(p.first->get_operator())]++;
} }
double calculate_overhead_time = total_time - total_instruction_time; double calculate_overhead_time = total_time - total_instruction_time;
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time; double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
...@@ -706,18 +735,19 @@ void program::perf_report(std::ostream& os, ...@@ -706,18 +735,19 @@ void program::perf_report(std::ostream& os,
os << std::endl; os << std::endl;
os << "Summary:" << std::endl; os << "Summary:" << std::endl;
std::vector<std::pair<double, std::string>> op_times_sorted; std::vector<std::tuple<double, std::size_t, std::string>> op_times_sorted;
std::transform(op_times.begin(), std::transform(
op_times.end(), op_times.begin(), op_times.end(), std::back_inserter(op_times_sorted), [&](auto p) {
std::back_inserter(op_times_sorted), auto&& name = p.first;
[](auto p) { return std::make_pair(p.second, p.first); }); return std::make_tuple(p.second, op_n.at(name), name);
});
std::sort(op_times_sorted.begin(), op_times_sorted.end(), std::greater<>{}); std::sort(op_times_sorted.begin(), op_times_sorted.end(), std::greater<>{});
for(auto&& p : op_times_sorted) for(auto&& [avg, nn, name] : op_times_sorted)
{ {
auto&& name = p.second;
double avg = p.first;
double percent = std::ceil(100.0 * avg / total_instruction_time); double percent = std::ceil(100.0 * avg / total_instruction_time);
os << name << ": " << avg << "ms, " << percent << "%" << std::endl; double per_ins = avg / nn;
os << name << ": " << avg << "ms / " << nn << " = " << per_ins << "ms, " << percent << "%"
<< std::endl;
} }
os << std::endl; os << std::endl;
......
...@@ -36,7 +36,7 @@ void raw_data_to_value(value& v, const RawData& rd) ...@@ -36,7 +36,7 @@ void raw_data_to_value(value& v, const RawData& rd)
result["shape"] = migraphx::to_value(rd.get_shape()); result["shape"] = migraphx::to_value(rd.get_shape());
if(rd.get_shape().type() == shape::tuple_type) if(rd.get_shape().type() == shape::tuple_type)
result["sub"] = migraphx::to_value(rd.get_sub_objects()); result["sub"] = migraphx::to_value(rd.get_sub_objects());
else else if(not rd.empty())
result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes()); result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes());
v = result; v = result;
} }
...@@ -56,7 +56,7 @@ void migraphx_from_value(const value& v, argument& a) ...@@ -56,7 +56,7 @@ void migraphx_from_value(const value& v, argument& a)
literal l = migraphx::from_value<literal>(v); literal l = migraphx::from_value<literal>(v);
a = l.get_argument(); a = l.get_argument();
} }
else else if(v.contains("sub"))
{ {
a = migraphx::from_value<std::vector<argument>>(v.at("sub")); a = migraphx::from_value<std::vector<argument>>(v.at("sub"));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment