Unverified Commit b98308b8 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Merge branch 'develop' into dyn_onnx_matmul

parents b48c4cf6 56c43445
......@@ -87,7 +87,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@c0723a7e50043d973cb73ae51dc30d36679ee7e5 -DBUILD_MIXR_TARGET=On
RUN cget -p /usr/local install ROCmSoftwarePlatform/rocMLIR@0f38fb33f518b53b94b541feb9b079668c5518e8 -DBUILD_MIXR_TARGET=On -DLLVM_ENABLE_ZSTD=Off -DLLVM_ENABLE_THREADS=Off
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
......@@ -77,7 +77,6 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
}
auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
shape::dynamic_dimension one_dyn_dim{1, 1, 0};
std::transform(
s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(),
......@@ -88,7 +87,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
{
return a;
}
else if(a == one_dyn_dim or b == one_dyn_dim)
else if(a == 1 or b == 1)
{
// setting opt to 0, may need to be changed
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0};
......
......@@ -51,8 +51,8 @@ void dead_code_elimination::apply(module& m) const
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate]
if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and
i->name().front() != '@' and
not contains({"undefined", "identity", "allocate"}, i->name()))
not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and
not i->is_undefined())
continue;
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited;
......
......@@ -109,8 +109,12 @@ struct loader
ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true));
ap(output_type,
{"--cpp"},
ap.help("Print out the program as cpp program."),
ap.help("Print out the program as C++ program."),
ap.set_value("cpp"));
ap(output_type,
{"--python", "--py"},
ap.help("Print out the program as python program."),
ap.set_value("py"));
ap(output_type, {"--json"}, ap.help("Print out program as json."), ap.set_value("json"));
ap(output_type,
{"--text"},
......@@ -259,7 +263,9 @@ struct loader
type = "binary";
}
if(type == "cpp")
if(type == "py")
p.print_py(*os);
else if(type == "cpp")
p.print_cpp(*os);
else if(type == "graphviz")
p.print_graph(*os, brief);
......
......@@ -30,23 +30,31 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
T generic_read_file(const std::string& filename)
T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0)
{
std::ifstream is(filename, std::ios::binary | std::ios::ate);
std::streamsize size = is.tellg();
if(size < 1)
if(nbytes == 0)
{
// if there is a non-zero offset and nbytes is not set,
// calculate size of remaining bytes to read
nbytes = is.tellg();
if(offset > nbytes)
MIGRAPHX_THROW("offset is larger than file size");
nbytes -= offset;
}
if(nbytes < 1)
MIGRAPHX_THROW("Invalid size for: " + filename);
is.seekg(0, std::ios::beg);
is.seekg(offset, std::ios::beg);
T buffer(size, 0);
if(not is.read(&buffer[0], size))
T buffer(nbytes, 0);
if(not is.read(&buffer[0], nbytes))
MIGRAPHX_THROW("Error reading file: " + filename);
return buffer;
}
std::vector<char> read_buffer(const std::string& filename)
std::vector<char> read_buffer(const std::string& filename, size_t offset, size_t nbytes)
{
return generic_read_file<std::vector<char>>(filename);
return generic_read_file<std::vector<char>>(filename, offset, nbytes);
}
std::string read_string(const std::string& filename)
......
......@@ -31,7 +31,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::vector<char> read_buffer(const std::string& filename);
std::vector<char> read_buffer(const std::string& filename, size_t offset = 0, size_t nbytes = 0);
std::string read_string(const std::string& filename);
void write_buffer(const std::string& filename, const char* buffer, std::size_t size);
......
......@@ -121,6 +121,8 @@ struct instruction
bool can_eval() const;
bool is_undefined() const;
argument eval(bool check_eval = true) const;
void finalize(context& ctx);
......
......@@ -80,6 +80,7 @@ struct literal : raw_data<literal>
fill(start, end);
}
// Directly copies buffer of x
template <class T, MIGRAPHX_REQUIRES(sizeof(T) == 1)>
literal(const shape& s, T* x) : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
{
......@@ -107,25 +108,15 @@ struct literal : raw_data<literal>
std::shared_ptr<char> buffer;
shape m_shape;
// Keeps the same data ordering as the given container
template <class Iterator>
void fill(Iterator start, Iterator end)
{
assert(std::distance(start, end) == m_shape.elements());
if(m_shape.standard())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
}
else
{
auto it = start;
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = *it; // NOLINT(bugprone-signed-char-misuse)
it++;
});
});
}
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get()));
std::copy(start, end, output.begin());
});
}
};
......
......@@ -205,6 +205,12 @@ struct module
void print_graph(std::ostream& os, bool brief = false) const;
void print_py(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string>
print_py(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const;
void print_cpp(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string>
print_cpp(std::ostream& os,
......
......@@ -30,6 +30,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -56,12 +57,20 @@ struct argmax
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens();
lens[axis] = 1;
return {shape::int64_type, lens};
check_shapes{inputs, *this, true}.has(1);
const auto& s0 = inputs[0];
if(s0.dynamic())
{
auto dyn_dims = s0.dyn_dims();
dyn_dims[axis] = {1, 1, 0};
return {shape::int64_type, dyn_dims};
}
else
{
auto lens = s0.lens();
lens[axis] = 1;
return {shape::int64_type, lens};
}
}
template <class T>
......@@ -79,19 +88,18 @@ struct argmax
max_index = i;
}
}
return max_index;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
argument result{dyn_out.computed_shape};
auto batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
par_for(dyn_out.computed_shape.elements(), [&](auto i) {
auto data_idx = dyn_out.computed_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num);
});
});
......
......@@ -55,17 +55,47 @@ struct flatten
std::string name() const { return "flatten"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto&& lens = inputs.front().lens();
auto x =
std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
auto y =
std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {inputs.at(0).type(), {x, y}};
check_shapes{inputs, *this, true}.has(1);
auto s = inputs[0];
if(s.dynamic())
{
auto min_lens = s.min_lens();
auto max_lens = s.max_lens();
auto opt_lens = s.opt_lens();
// If any of the opt values is 0, output opt will be 0
shape::dynamic_dimension x = {
std::accumulate(
min_lens.begin(), min_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}),
std::accumulate(
max_lens.begin(), max_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}),
std::accumulate(opt_lens.begin(),
opt_lens.begin() + axis,
std::size_t{1},
std::multiplies<>{})};
shape::dynamic_dimension y = {
std::accumulate(
min_lens.begin() + axis, min_lens.end(), std::size_t{1}, std::multiplies<>{}),
std::accumulate(
max_lens.begin() + axis, max_lens.end(), std::size_t{1}, std::multiplies<>{}),
std::accumulate(
opt_lens.begin() + axis, opt_lens.end(), std::size_t{1}, std::multiplies<>{}),
};
return {s.type(), {x, y}};
}
else
{
check_shapes{inputs, *this}.standard();
auto&& lens = s.lens();
auto x = std::accumulate(
lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
auto y = std::accumulate(
lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {s.type(), {x, y}};
}
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
return args[0].reshape(dyn_out.computed_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -31,7 +31,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/int_divide.hpp>
#include <migraphx/dyn_output.hpp>
#include <cmath>
#include <utility>
......@@ -49,6 +49,9 @@ struct pooling
bool ceil_mode = false;
int lp_order = 2;
// Global pooling with dynamic shape input
bool dyn_global = false;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
......@@ -57,7 +60,8 @@ struct pooling
f(self.stride, "stride"),
f(self.lengths, "lengths"),
f(self.ceil_mode, "ceil_mode"),
f(self.lp_order, "lp_order"));
f(self.lp_order, "lp_order"),
f(self.dyn_global, "dyn_global"));
}
std::string name() const { return "pooling"; }
......@@ -65,51 +69,111 @@ struct pooling
void check_attribute_size() const
{
if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or
stride.size() != lengths.size())
(not dyn_global and stride.size() != lengths.size()))
{
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
}
}
size_t kdims() const
{
check_attribute_size();
return stride.size();
}
value attributes() const { return {{"normalize_padding", "padding"}}; }
std::vector<std::size_t> calc_spatial_dim_out(const std::vector<std::size_t>& input_lens,
std::size_t kdims) const
{
std::vector<std::size_t> output_lens{};
for(size_t i = 0; i < kdims; ++i)
{
if(input_lens[i + 2] == 0)
{
// handle opt = 0
output_lens.push_back(0);
}
else
{
std::size_t padding_factor = 2 * padding[i];
if(padding.size() == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
assert(input_lens[i + 2] + padding_factor >= lengths[i]);
std::size_t dim_size = input_lens[i + 2] + padding_factor - lengths[i];
std::size_t len =
(ceil_mode)
? dim_size / stride[i] + static_cast<std::size_t>((dim_size % stride[i] !=
0)) // ceil uint divide
: dim_size / stride[i]; // floor divide
output_lens.push_back(len + 1);
}
}
return output_lens;
}
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
check_shapes{inputs, *this, true}.has(1);
check_attribute_size();
const shape& input = inputs.at(0);
auto input_lens = input.lens();
size_t kdims = input_lens.size() - 2;
auto input_size = inputs[0].lens().size();
auto padding_size = padding.size();
if(input_size != padding_size / 2 + 2 and input_size != padding_size + 2)
auto padding_size = padding.size();
size_t kdims = input.ndim() - 2;
if(input.ndim() != padding_size / 2 + 2 and input.ndim() != padding_size + 2)
{
MIGRAPHX_THROW("POOLING: input and attribute size mismatch!");
}
std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2);
for(size_t i = 0; i < kdims; i++)
if(input.dynamic())
{
std::ptrdiff_t dim_size;
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
dim_size = input_lens[i + 2] + padding_factor - lengths[i];
assert(dim_size >= 0);
std::size_t len = (ceil_mode) ? ceil_divide<std::ptrdiff_t>(dim_size, stride[i])
: floor_divide<std::ptrdiff_t>(dim_size, stride[i]);
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(1, len + 1)));
auto input_dyn_dims = input.dyn_dims();
std::vector<shape::dynamic_dimension> output_dyn_dims(input_dyn_dims.begin(),
input_dyn_dims.begin() + 2);
if(dyn_global)
{
for(size_t i = 0; i < kdims; ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{1, 1, 1});
}
return {input.type(), output_dyn_dims};
}
else
{
auto min_spatial_dims = calc_spatial_dim_out(input.min_lens(), kdims);
auto max_spatial_dims = calc_spatial_dim_out(input.max_lens(), kdims);
auto opt_spatial_dims = calc_spatial_dim_out(input.opt_lens(), kdims);
for(size_t i = 0; i < kdims; ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{
min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]});
}
return {input.type(), output_dyn_dims};
}
}
return inputs[0].with_lens(output_lens);
}
else
{
auto input_lens = input.lens();
size_t kdims() const
{
check_attribute_size();
return stride.size();
std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2);
// Used for when normalize_compute_shape() is called again at model eval time
// for an originally dynamic shape. Since kernel shape is not used with dyn_global.
if(dyn_global)
{
for(size_t i = 0; i < kdims; ++i)
{
output_lens.push_back(1);
}
return {input.type(), output_lens};
}
else
{
auto output_spatial_lens = calc_spatial_dim_out(input_lens, kdims);
output_lens.insert(
output_lens.end(), output_spatial_lens.begin(), output_spatial_lens.end());
return inputs[0].with_lens(output_lens);
}
}
}
struct lpnorm_pool
......@@ -158,7 +222,11 @@ struct pooling
};
template <class Type, class Out, class In, class Op>
void calc_pooling(const shape& output_shape, Out& output, const In& input, Op op) const
void calc_pooling(const shape& output_shape,
Out& output,
const In& input,
const std::vector<std::size_t>& kernel_dims,
Op op) const
{
auto in_s = input.get_shape();
auto in_lens = in_s.lens();
......@@ -172,7 +240,7 @@ struct pooling
auto d_2 = dim - 2;
int start =
static_cast<int>(idx_o[dim] * stride[d_2]) - static_cast<int>(padding[d_2]);
int end = std::min(start + lengths[d_2], in_lens[dim]);
int end = std::min(start + kernel_dims[d_2], in_lens[dim]);
start = std::max(start, 0);
win_start.push_back(start);
win_size.push_back(end - start);
......@@ -198,21 +266,32 @@ struct pooling
});
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
argument result{dyn_out.computed_shape};
auto input_lens = args[0].get_shape().lens();
std::vector<std::size_t> kernel_dims;
if(dyn_global)
{
kernel_dims.insert(kernel_dims.end(), input_lens.begin() + 2, input_lens.end());
}
else
{
kernel_dims = this->lengths;
}
visit_all(result, args[0])([&](auto output, auto input) {
using type = typename decltype(output)::value_type;
switch(mode)
{
case migraphx::op::pooling_mode::average:
calc_pooling<type>(output_shape, output, input, avg_pool{});
calc_pooling<type>(dyn_out.computed_shape, output, input, kernel_dims, avg_pool{});
break;
case migraphx::op::pooling_mode::max:
calc_pooling<type>(output_shape, output, input, max_pool{});
calc_pooling<type>(dyn_out.computed_shape, output, input, kernel_dims, max_pool{});
break;
case migraphx::op::pooling_mode::lpnorm:
calc_pooling<type>(output_shape, output, input, lpnorm_pool{lp_order});
calc_pooling<type>(
dyn_out.computed_shape, output, input, kernel_dims, lpnorm_pool{lp_order});
break;
}
});
......
......@@ -53,15 +53,15 @@ struct softmax
std::string name() const { return "softmax"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
if(inputs.at(0).packed())
check_shapes{inputs, *this, true}.has(1);
auto s0 = inputs[0];
if(s0.dynamic() or s0.packed())
{
return inputs.at(0);
return s0;
}
else
{
auto lens = inputs.at(0).lens();
return {inputs.at(0).type(), lens};
return {s0.type(), s0.lens()};
}
}
......
......@@ -59,9 +59,8 @@ struct squeeze
auto input_shape = inputs[0];
if(input_shape.dynamic())
{
shape::dynamic_dimension one_dyn_dim{1, 1, 0};
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return input_shape.dyn_dims()[axis] != one_dyn_dim;
return input_shape.dyn_dims()[axis] != 1;
}))
{
MIGRAPHX_THROW(
......@@ -70,14 +69,10 @@ struct squeeze
std::vector<shape::dynamic_dimension> dyn_dims = {};
if(axes.empty())
{
for(auto i : range(input_shape.ndim()))
{
auto dd = input_shape.dyn_dims()[i];
if(dd != one_dyn_dim)
{
dyn_dims.push_back(dd);
}
}
std::copy_if(input_shape.dyn_dims().cbegin(),
input_shape.dyn_dims().cend(),
std::back_inserter(dyn_dims),
[&](auto dd) { return dd != 1; });
}
else
{
......
......@@ -115,6 +115,7 @@ struct program
print_func) const;
void print_graph(std::ostream& os, bool brief = false) const;
void print_py(std::ostream& os) const;
void print_cpp(std::ostream& os) const;
void dry_run(parameter_map params) const;
......
......@@ -101,6 +101,12 @@ struct shape
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y);
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x);
// compare to fixed std::size_t dimension
friend bool operator==(const dynamic_dimension& x, const std::size_t& y);
friend bool operator==(const std::size_t& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const std::size_t& y);
friend bool operator!=(const std::size_t& x, const dynamic_dimension& y);
};
static const std::vector<type_t>& types();
......
......@@ -31,6 +31,9 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Iterates the given function over the indices from the shape in order.
*/
template <class F>
void shape_for_each(const migraphx::shape& s, F f)
{
......@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f)
call(indices);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref&
{
return;
}
auto kdims = input->get_shape().lens().size() - 2;
auto kdims = input->get_shape().ndim() - 2;
if(std::equal(op.padding.begin(),
op.padding.begin() + kdims,
op.padding.begin() + kdims,
op.padding.end()))
return;
std::vector<int64_t> padding(input->get_shape().lens().size() * 2, 0);
std::vector<int64_t> padding(input->get_shape().ndim() * 2, 0);
std::vector<size_t> pads_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> pads_r(op.padding.begin() + kdims, op.padding.end());
op.padding = std::vector<size_t>(kdims * 2, 0);
......
......@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
std::replace(module_args.begin(), module_args.end(), old, new_mod);
}
bool instruction::is_undefined() const
{
if(op.name() == "undefined")
{
return true;
}
else if(this->inputs().empty())
{
return false;
}
else
{
return std::all_of(this->inputs().begin(), this->inputs().end(), [](auto arg) {
return arg->is_undefined();
});
}
}
bool instruction::can_eval() const
{
if(op.name() == "@literal")
......
......@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
return to_c_id("x_" + replace_string(name, ":", "_module_"));
}
static void print_py_op(std::ostream& os, const operation& op)
{
auto v = op.to_value();
os << "migraphx.op(" << enclose_name(op.name());
auto default_values = make_op(op.name()).to_value();
for(auto&& x : v)
{
auto name = x.get_key();
if(default_values[name] == x)
continue;
os << ", " << name << "=" << to_json_string(x.without_key());
}
os << ")";
}
static void print_make_op(std::ostream& os, const operation& op)
{
auto v = op.to_value();
......@@ -804,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op)
os << ")";
}
static void print_py_shape(std::ostream& os, const migraphx::shape& s)
{
os << "migraphx.shape(" << s.type_string() << ", lens=" << to_json_string(s.lens());
if(not s.standard())
os << ", strides=" << to_json_string(s.strides());
os << ")";
}
static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
{
os << "migraphx::shape{migraphx::shape::" << s.type_string();
......@@ -813,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
os << "}";
}
std::unordered_map<instruction_ref, std::string>
module::print_py(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const
{
// cppcheck-suppress variableScope
unsigned long seed = names.size();
auto last = std::prev(this->end());
names = this->print(
[&](auto ins, auto ins_names) {
std::vector<std::string> input_vars;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(input_vars),
[&](auto input) { return cpp_var_name(ins_names.at(input)); });
if(ins != last)
os << cpp_var_name(ins_names.at(ins)) << " = ";
if(ins->name() == "@literal")
{
os << mname << ".add_literal(";
bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
});
// Disable abs for now
use_abs = false;
if(use_abs)
os << "migraphx.abs_literal(";
os << "migraphx.generate_literal(";
print_py_shape(os, ins->get_shape());
os << ", " << seed << ")";
if(use_abs)
os << ")";
os << ")" << std::endl;
seed++;
}
else if(ins->name() == "@param")
{
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << mname << ".add_parameter(" << enclose_name(name) << ",";
print_py_shape(os, ins->get_shape());
os << ")" << std::endl;
}
else if(ins->name() == "@return")
{
os << mname << ".add_return([" << join_strings(input_vars, ", ") << "])"
<< std::endl;
}
else
{
assert(ins->name().front() != '@');
os << mname << ".add_instruction(";
print_py_op(os, ins->get_operator());
os << ", [" << join_strings(input_vars, ", ") << "]";
os << ")" << std::endl;
}
},
names);
return names;
}
std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os,
const std::string& mname,
......@@ -874,6 +960,8 @@ module::print_cpp(std::ostream& os,
return names;
}
void module::print_py(std::ostream& os) const { this->print_py(os, this->name(), {}); }
void module::print_cpp(std::ostream& os) const { this->print_cpp(os, this->name(), {}); }
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment