Commit f9a06df3 authored by Cagri Eryilmaz's avatar Cagri Eryilmaz
Browse files

Merge branch 'develop' into unet

parents 07189c21 0b04fc80
...@@ -39,9 +39,7 @@ struct reverse ...@@ -39,9 +39,7 @@ struct reverse
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
auto lens = inputs[0].lens(); return inputs[0].with_lens(inputs[0].lens());
auto type = inputs[0].type();
return shape{type, lens};
} }
argument compute(const shape& s, std::vector<argument> args) const argument compute(const shape& s, std::vector<argument> args) const
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -39,7 +40,7 @@ struct scalar ...@@ -39,7 +40,7 @@ struct scalar
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -77,7 +78,7 @@ struct squeeze ...@@ -77,7 +78,7 @@ struct squeeze
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -25,8 +27,15 @@ struct step ...@@ -25,8 +27,15 @@ struct step
return pack(f(self.axes, "axes"), f(self.steps, "steps")); return pack(f(self.axes, "axes"), f(self.steps, "steps"));
} }
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "step"; } std::string name() const { return "step"; }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
...@@ -63,7 +72,7 @@ struct step ...@@ -63,7 +72,7 @@ struct step
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -63,7 +64,7 @@ struct transpose ...@@ -63,7 +64,7 @@ struct transpose
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -70,7 +71,7 @@ struct unsqueeze ...@@ -70,7 +71,7 @@ struct unsqueeze
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <migraphx/module_ref.hpp> #include <migraphx/module_ref.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp> #include <migraphx/auto_any_cast.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -435,9 +436,9 @@ void from_value_op(T& x, const value& v) ...@@ -435,9 +436,9 @@ void from_value_op(T& x, const value& v)
} }
template <class T> template <class T>
bool is_borrowed_op(const T&) lifetime get_lifetime_op(const T&)
{ {
return false; return lifetime::local;
} }
} // namespace detail } // namespace detail
...@@ -451,7 +452,7 @@ bool is_borrowed_op(const T&) ...@@ -451,7 +452,7 @@ bool is_borrowed_op(const T&)
* bool is_context_free() const; * bool is_context_free() const;
* bool need_normalization() const; * bool need_normalization() const;
* bool has_finalize() const; * bool has_finalize() const;
* bool is_borrowed() const; * lifetime get_lifetime() const;
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const; * std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* value compile(context& ctx,const shape& output,const std::vector<shape>& input) ; * value compile(context& ctx,const shape& output,const std::vector<shape>& input) ;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ; * void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
...@@ -559,10 +560,10 @@ struct operation ...@@ -559,10 +560,10 @@ struct operation
return (*this).private_detail_te_get_handle().has_finalize(); return (*this).private_detail_te_get_handle().has_finalize();
} }
bool is_borrowed() const lifetime get_lifetime() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_borrowed(); return (*this).private_detail_te_get_handle().get_lifetime();
} }
std::ptrdiff_t output_alias(const std::vector<shape>& input) const std::ptrdiff_t output_alias(const std::vector<shape>& input) const
...@@ -678,7 +679,7 @@ struct operation ...@@ -678,7 +679,7 @@ struct operation
virtual bool is_context_free() const = 0; virtual bool is_context_free() const = 0;
virtual bool need_normalization() const = 0; virtual bool need_normalization() const = 0;
virtual bool has_finalize() const = 0; virtual bool has_finalize() const = 0;
virtual bool is_borrowed() const = 0; virtual lifetime get_lifetime() const = 0;
virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0; virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
virtual value virtual value
compile(context& ctx, const shape& output, const std::vector<shape>& input) = 0; compile(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
...@@ -750,16 +751,16 @@ struct operation ...@@ -750,16 +751,16 @@ struct operation
} }
template <class T> template <class T>
static auto private_detail_te_default_is_borrowed(char, T&& private_detail_te_self) static auto private_detail_te_default_get_lifetime(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.is_borrowed()) -> decltype(private_detail_te_self.get_lifetime())
{ {
return private_detail_te_self.is_borrowed(); return private_detail_te_self.get_lifetime();
} }
template <class T> template <class T>
static bool private_detail_te_default_is_borrowed(float, T&& private_detail_te_self) static lifetime private_detail_te_default_get_lifetime(float, T&& private_detail_te_self)
{ {
return detail::is_borrowed_op(private_detail_te_self); return detail::get_lifetime_op(private_detail_te_self);
} }
template <class T> template <class T>
...@@ -1044,10 +1045,10 @@ struct operation ...@@ -1044,10 +1045,10 @@ struct operation
return private_detail_te_default_has_finalize(char(0), private_detail_te_value); return private_detail_te_default_has_finalize(char(0), private_detail_te_value);
} }
bool is_borrowed() const override lifetime get_lifetime() const override
{ {
return private_detail_te_default_is_borrowed(char(0), private_detail_te_value); return private_detail_te_default_get_lifetime(char(0), private_detail_te_value);
} }
std::ptrdiff_t output_alias(const std::vector<shape>& input) const override std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
......
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_PREALLOCATE_PARAM_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHX_PREALLOCATE_PARAM_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_PREALLOCATE_PARAM_HPP #define MIGRAPHX_GUARD_MIGRAPHX_PREALLOCATE_PARAM_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/allocation_model.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu { struct module;
struct preallocate_param struct preallocate_param
{ {
std::string param{}; std::string param;
context* ctx = nullptr; allocation_model model;
std::string name() const { return "preallocate_param"; } std::string name() const { return "preallocate_param"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PREALLOCATE_PARAM_HPP
#endif
...@@ -440,17 +440,19 @@ bool is_borrowed(instruction_ref ins) ...@@ -440,17 +440,19 @@ bool is_borrowed(instruction_ref ins)
auto alias = instruction::get_output_alias(ins, true); auto alias = instruction::get_output_alias(ins, true);
if(alias == ins) if(alias == ins)
return false; return false;
if(alias->get_operator().is_borrowed()) lifetime l = alias->get_operator().get_lifetime();
if(l == lifetime::borrow)
return true; return true;
return is_borrowed(alias); return is_borrowed(alias);
} }
bool is_param_alias(instruction_ref ins) bool is_global(instruction_ref ins)
{ {
return instruction::get_output_alias(ins)->name() == "@param"; const auto& op = instruction::get_output_alias(ins)->get_operator();
return op.name() == "@param" or op.get_lifetime() == lifetime::global;
} }
bool is_dangling(instruction_ref ins) { return not is_param_alias(ins) and is_borrowed(ins); } bool is_dangling(instruction_ref ins) { return not is_global(ins) and is_borrowed(ins); }
instruction_ref module::find_dangling_reference() const instruction_ref module::find_dangling_reference() const
{ {
......
...@@ -84,9 +84,6 @@ struct onnx_parser ...@@ -84,9 +84,6 @@ struct onnx_parser
shape::type_t get_type(int dtype); shape::type_t get_type(int dtype);
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1);
} // namespace onnx } // namespace onnx
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp> #include <migraphx/pad_calc.hpp>
#include <migraphx/common.hpp>
#include <migraphx/type_traits.hpp> #include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
...@@ -91,66 +92,11 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r ...@@ -91,66 +92,11 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
return curr_ins; return curr_ins;
} }
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
if(s0.size() > s1.size())
{
s0.swap(s1);
}
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
}
instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::string& op_name, instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0, instruction_ref arg0,
instruction_ref arg1) const instruction_ref arg1) const
{ {
if(arg0->get_shape().lens() != arg1->get_shape().lens()) return add_common_op(*mod, make_op(op_name), {arg0, arg1});
{
// Get lengths for both arguments
auto s0 = arg0->get_shape().lens();
auto s1 = arg1->get_shape().lens();
auto out_lens = compute_broadcasted_lens(s0, s1);
auto l0 = arg0;
if(arg0->get_shape().lens() != out_lens)
l0 = add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), arg0);
auto l1 = arg1;
if(arg1->get_shape().lens() != out_lens)
l1 = add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), arg1);
return add_instruction(make_op(op_name), l0, l1);
}
else
{
return add_instruction(make_op(op_name), {arg0, arg1});
}
} }
instruction_ref instruction_ref
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
namespace migraphx { namespace migraphx {
......
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -29,10 +29,6 @@ struct parse_slice : op_parser<parse_slice> ...@@ -29,10 +29,6 @@ struct parse_slice : op_parser<parse_slice>
migraphx::argument step_arg = args.back()->eval(); migraphx::argument step_arg = args.back()->eval();
check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice"); check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); }); step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return abs(s) == 1; }))
{
MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1 or -1");
}
} }
if(args.size() >= 4) if(args.size() >= 4)
...@@ -98,7 +94,16 @@ struct parse_slice : op_parser<parse_slice> ...@@ -98,7 +94,16 @@ struct parse_slice : op_parser<parse_slice>
auto ins = info.add_instruction(op, args[0]); auto ins = info.add_instruction(op, args[0]);
if(not raxes.empty()) if(not raxes.empty())
return info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins); ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins);
if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; }))
{
std::vector<int64_t> nsteps;
std::transform(steps.begin(), steps.end(), std::back_inserter(nsteps), [](auto s) {
return std::abs(s);
});
return ins = info.add_instruction(
make_op("step", {{"axes", op.axes}, {"steps", nsteps}}), ins);
}
else else
return ins; return ins;
} }
......
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
namespace migraphx { namespace migraphx {
......
#include <migraphx/gpu/preallocate_param.hpp> #include <migraphx/preallocate_param.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void preallocate_param::apply(module& p) const void preallocate_param::apply(module& m) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "@param") if(ins->name() != "@param")
continue; continue;
if(param != any_cast<builtin::param>(ins->get_operator()).parameter) if(param != any_cast<builtin::param>(ins->get_operator()).parameter)
continue; continue;
std::string id = p.name() + ":" + param; std::string id = m.name() + ":" + param;
auto r = p.insert_instruction(ins, hip_allocate_memory{ins->get_shape(), id}); auto r = m.insert_instruction(ins, model.preallocate(ins->get_shape(), id));
p.replace_instruction(ins, r); m.replace_instruction(ins, r);
} }
} }
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -19,6 +19,7 @@ add_library(migraphx_cpu ...@@ -19,6 +19,7 @@ add_library(migraphx_cpu
logsoftmax.cpp logsoftmax.cpp
lowering.cpp lowering.cpp
lrn.cpp lrn.cpp
preallocate.cpp
pooling.cpp pooling.cpp
reduction.cpp reduction.cpp
reorder.cpp reorder.cpp
......
...@@ -11,6 +11,11 @@ operation cpu_allocation_model::allocate(const shape& s) const ...@@ -11,6 +11,11 @@ operation cpu_allocation_model::allocate(const shape& s) const
return make_op(name(), {{"shape", to_value(s)}}); return make_op(name(), {{"shape", to_value(s)}});
} }
operation cpu_allocation_model::preallocate(const shape& s, const std::string& id) const
{
return make_op("cpu::preallocate", {{"shape", to_value(s)}, {"id", id}});
}
std::string cpu_allocation_model::copy() const { return "cpu::copy"; } std::string cpu_allocation_model::copy() const { return "cpu::copy"; }
} // namespace cpu } // namespace cpu
......
...@@ -14,6 +14,7 @@ struct cpu_allocation_model ...@@ -14,6 +14,7 @@ struct cpu_allocation_model
std::string name() const; std::string name() const;
std::string copy() const; std::string copy() const;
operation allocate(const shape& s) const; operation allocate(const shape& s) const;
operation preallocate(const shape& s, const std::string& id) const;
}; };
} // namespace cpu } // namespace cpu
......
...@@ -440,7 +440,7 @@ struct cpu_apply ...@@ -440,7 +440,7 @@ struct cpu_apply
} }
} }
instruction_ref apply_pow(instruction_ref ins) instruction_ref apply_pow(instruction_ref ins) const
{ {
auto beta = read_scalar<float>(ins->inputs()[1]); auto beta = read_scalar<float>(ins->inputs()[1]);
if(beta.empty()) if(beta.empty())
...@@ -451,7 +451,7 @@ struct cpu_apply ...@@ -451,7 +451,7 @@ struct cpu_apply
{ins->inputs().front()}); {ins->inputs().front()});
} }
instruction_ref apply_pooling(instruction_ref ins) instruction_ref apply_pooling(instruction_ref ins) const
{ {
auto&& op = ins->get_operator(); auto&& op = ins->get_operator();
auto v = op.to_value(); auto v = op.to_value();
...@@ -479,30 +479,20 @@ struct cpu_apply ...@@ -479,30 +479,20 @@ struct cpu_apply
return {r.at<T>()}; return {r.at<T>()};
} }
instruction_ref replace(instruction_ref ins, const operation& op) instruction_ref replace(instruction_ref ins, const operation& op) const
{ {
return replace(ins, op, ins->inputs()); return replace(ins, op, ins->inputs());
} }
instruction_ref instruction_ref
replace(instruction_ref ins, const operation& op, std::vector<instruction_ref> inputs) replace(instruction_ref ins, const operation& op, std::vector<instruction_ref> inputs) const
{ {
inputs.push_back(insert_allocation(ins, ins->get_shape())); inputs.push_back(insert_allocation(ins, ins->get_shape()));
return modl->replace_instruction(ins, op, inputs); return modl->replace_instruction(ins, op, inputs);
} }
instruction_ref insert_allocation(instruction_ref ins, const shape& s) instruction_ref insert_allocation(instruction_ref ins, const shape& s) const
{ {
auto ins_alias = instruction::get_output_alias(ins);
if(last->name() == "@return" and prog_output_names.count(ins_alias) > 0)
{
return modl->add_parameter(prog_output_names[ins_alias], s);
}
else if(ins == last)
{
return modl->add_parameter("output", s);
}
return modl->insert_instruction(ins, make_op("cpu::allocate", {{"shape", to_value(s)}})); return modl->insert_instruction(ins, make_op("cpu::allocate", {{"shape", to_value(s)}}));
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment