Commit 5dfeb457 authored by Scott Thornton's avatar Scott Thornton
Browse files

Merge branch 'master' into squeeze_unsqeeze

parents 523a78c7 f9f4f713
sphinx
sphinx==1.6.2
breathe==4.9.1
# git+https://github.com/arximboldi/breathe@fix-node-parent
......@@ -7,6 +7,7 @@ add_library(migraph
fwd_conv_batchnorm_rewrite.cpp
env.cpp
generate.cpp
instruction.cpp
program.cpp
shape.cpp
simplify_reshapes.cpp
......
......@@ -10,7 +10,7 @@ void auto_contiguous::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
shape s = ins->result;
shape s = ins->get_shape();
if(not s.standard())
{
auto c = p.insert_instruction(std::next(ins), contiguous{}, ins);
......
......@@ -17,16 +17,16 @@ void dead_code_elimination::apply(program& p) const
continue;
const auto i = std::prev(ins);
// Skip instruction with empty shape as output unless its a builtin
if(i->result.elements() == 0 and not(i->op.name().front() == '@'))
if(i->get_shape().elements() == 0 and not(i->name().front() == '@'))
continue;
// Skip the last instruction
if(i == last)
break;
fix([&](auto self, auto leaf) {
assert(p.has_instruction(leaf));
if(leaf->output.empty())
if(leaf->outputs().empty())
{
auto args = leaf->arguments;
auto args = leaf->inputs();
leaf->clear_arguments();
p.move_instruction(leaf, p.end());
for(auto arg : args)
......
......@@ -14,7 +14,7 @@ void eliminate_allocation::apply(program& p) const
std::vector<std::pair<instruction_ref, std::size_t>> allocs;
for(auto ins : iterator_for(p))
{
if(ins->op.name() != allocation_op)
if(ins->name() != allocation_op)
continue;
allocs.emplace_back(ins, n);
std::size_t size = ins->get_shape().bytes();
......
......@@ -27,19 +27,19 @@ void eliminate_contiguous::apply(program& p) const
for(auto ins : iterator_for(p))
{
// Make a copy so we can modify it while we iterate
auto args = ins->arguments;
for(auto arg : ins->arguments)
auto args = ins->inputs();
for(auto arg : ins->inputs())
{
// TODO: Pass in names for the operator in the constructor instead
// of using ends_with
if(ends_with(arg->op.name(), "contiguous"))
if(ends_with(arg->name(), "contiguous"))
{
auto new_args = args;
auto prev = arg->arguments.front();
auto prev = arg->inputs().front();
replace(new_args, arg, prev);
if(try_compute_shape(ins->op, new_args))
if(try_compute_shape(ins->get_operator(), new_args))
{
replace_argument(ins, arg, prev);
instruction::replace_argument(ins, arg, prev);
}
}
}
......
......@@ -10,30 +10,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
if(ins->op.name() != "batch_norm_inference")
if(ins->name() != "batch_norm_inference")
continue;
if(not std::all_of(ins->arguments.begin() + 1, ins->arguments.end(), [](auto arg) {
return arg->op.name() == "@literal";
if(not std::all_of(ins->inputs().begin() + 1, ins->inputs().end(), [](auto arg) {
return arg->name() == "@literal";
}))
continue;
auto conv_ins = ins->arguments[0];
if(conv_ins->op.name() != "convolution")
auto conv_ins = ins->inputs()[0];
if(conv_ins->name() != "convolution")
continue;
if(conv_ins->arguments[1]->op.name() != "@literal")
if(conv_ins->inputs()[1]->name() != "@literal")
continue;
// Get scale, bias, mean, variance from instruction_ref
const auto& gamma = ins->arguments[1]->get_literal();
const auto& bias = ins->arguments[2]->get_literal();
const auto& mean = ins->arguments[3]->get_literal();
const auto& variance = ins->arguments[4]->get_literal();
const auto& gamma = ins->inputs()[1]->get_literal();
const auto& bias = ins->inputs()[2]->get_literal();
const auto& mean = ins->inputs()[3]->get_literal();
const auto& variance = ins->inputs()[4]->get_literal();
// Get epsilon
auto bn_op = any_cast<batch_norm_inference>(ins->op);
auto bn_op = any_cast<batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon;
// Get convolution weights
const auto& weights = conv_ins->arguments[1]->get_literal();
const auto& weights = conv_ins->inputs()[1]->get_literal();
// Get convolution op
auto conv_op = conv_ins->op;
auto conv_op = conv_ins->get_operator();
auto weights_lens = weights.get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens();
argument new_weights{weights.get_shape()};
......@@ -58,7 +58,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
// Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->arguments[0], l_weights});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights});
auto b = p.insert_instruction(ins, broadcast{1}, c, l_bias);
p.replace_instruction(ins, add{}, {c, b});
}
......
......@@ -12,7 +12,7 @@ constexpr T normalize(unsigned long z)
{
if(z == 0)
return 0;
const auto max = 2048;
const auto max = 32;
const double range = max / 2; // NOLINT
double result = (z % max) / range;
result -= 1;
......
......@@ -3,10 +3,8 @@
#include <migraph/literal.hpp>
#include <migraph/shape.hpp>
#include <migraph/builtin.hpp>
#include <migraph/instruction_ref.hpp>
#include <migraph/operation.hpp>
#include <migraph/erase.hpp>
#include <string>
#include <utility>
......@@ -18,126 +16,60 @@ struct instruction
{
instruction() {}
instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
}
instruction(operation o, shape r, std::vector<instruction_ref> args);
instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) {}
instruction(literal l);
// internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
op = std::move(o);
replace(r);
replace(std::move(args));
}
void replace(const shape& r);
void replace(const shape& r)
{
if(r != result)
{
result = r;
for(auto&& ins : output)
{
assert(ins->op.name().front() != '@');
ins->recompute_shape();
}
}
}
void recompute_shape();
void recompute_shape() { replace(compute_shape(op, arguments)); }
void clear_arguments();
// internal
void replace(std::vector<instruction_ref> args)
{
clear_arguments();
arguments = std::move(args);
}
friend bool operator==(const instruction& i, instruction_ref ref);
// internal
void replace_argument(instruction_ref old, instruction_ref new_ins)
{
std::replace(arguments.begin(), arguments.end(), old, new_ins);
old->remove_output(*this);
}
bool valid(instruction_ref start) const;
void clear_arguments()
{
for(auto&& arg : arguments)
{
arg->remove_output(*this);
}
arguments.clear();
}
bool valid() const;
friend bool operator==(const instruction& i, instruction_ref ref)
{
return std::addressof(i) == std::addressof(*ref);
}
shape get_shape() const;
const literal& get_literal() const;
bool valid(instruction_ref start) const
{
return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
auto self = std::find(i->output.begin(), i->output.end(), *this);
return self != i->output.end() &&
std::distance(start, i) < std::distance(start, *self);
});
}
const operation& get_operator() const;
bool valid() const
{
shape computed;
if(op.name() == "@literal")
{
computed = lit.get_shape();
}
else if(op.name() == "@param")
{
computed = result;
}
else
{
try
{
computed = compute_shape(op, arguments);
}
catch(migraph::exception&)
{
return false;
}
}
return result == computed &&
std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->arguments.begin(), i->arguments.end(), *this) !=
i->arguments.end();
});
}
std::string name() const;
shape get_shape() const { return result; }
const literal& get_literal() const
{
assert(op.name() == "@literal");
return lit;
}
const std::vector<instruction_ref>& inputs() const;
friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
const std::vector<instruction_ref>& outputs() const;
friend bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
friend bool operator==(instruction_ref ref, const instruction& i);
friend bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
friend bool operator!=(const instruction& i, instruction_ref ref);
void add_output(instruction_ref ins)
{
if(std::find(output.begin(), output.end(), ins) == output.end())
output.push_back(ins);
}
friend bool operator!=(instruction_ref ref, const instruction& i);
void add_output(instruction_ref ins);
template <class T>
void remove_output(const T& ins)
{
migraph::erase(output, ins);
}
void remove_output(const T& ins);
static void backreference(instruction_ref ref);
static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins);
static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
private:
// internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args);
// internal
void replace(std::vector<instruction_ref> args);
// internal
void replace_argument(instruction_ref old, instruction_ref new_ins);
operation op;
shape result;
......@@ -146,29 +78,6 @@ struct instruction
literal lit;
};
inline void backreference(instruction_ref ref)
{
for(auto&& arg : ref->arguments)
arg->add_output(ref);
}
inline void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins)
{
ins->replace_argument(old, new_ins);
backreference(ins);
ins->recompute_shape();
}
// TODO: Move to a cpp file
// TODO: Use const ref for vector
inline shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->result; });
return op.compute_shape(shapes);
}
} // namespace migraph
namespace std {
......
......@@ -8,6 +8,7 @@
#include <type_traits>
#include <utility>
#include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
......@@ -27,13 +28,16 @@ struct operation
/// exception.
shape compute_shape(const std::vector<shape>& input) const;
/**
* @brief This performs the operation's computation
* @brief This performs the operation's computation.
*
* This method can be optional when the operation is only used as a placeholder to be lowered
* later on.
*
* @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each
* `shape` of the `argument`.
* @param input This is the `argument` result from the previous instuction's computation.
* @param input This is the `argument` result from the previous instruction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape.
*/
......@@ -55,11 +59,29 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream
template <class T>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
{
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPH_THROW("Not computable: " + name);
}
template <class T>
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{
return x.compute(auto_any_cast(ctx), output_shape, input);
return compute_op(rank<1>{}, x, ctx, output_shape, input);
}
/*
......
......@@ -41,11 +41,6 @@ struct batch_norm_inference
check_shapes{inputs, *this}.has(5);
return inputs.front();
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
};
struct convolution
......@@ -115,11 +110,6 @@ struct convolution
}
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{
os << op.name() << "[";
......@@ -169,11 +159,6 @@ struct im2col
auto channels_col = kernel_height * kernel_width * input_channels;
return {input.type(), {output_height * output_width, channels_col}};
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
};
struct pooling
......@@ -211,11 +196,6 @@ struct pooling
}};
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{
os << op.name() << "[";
......@@ -236,11 +216,6 @@ struct activation
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const activation& op)
{
os << op.name() << ":" << op.mode;
......@@ -305,10 +280,6 @@ struct contiguous
}
return {t, lens};
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
};
struct slice
......@@ -497,12 +468,10 @@ struct reshape
MIGRAPH_THROW("Wrong number of elements for reshape");
return s;
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
}
friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{
os << op.name() << "[";
......@@ -530,11 +499,6 @@ struct gemm
return {t, {a.lens()[0], b.lens()[1]}};
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const gemm& op)
{
os << op.name() << "[";
......@@ -550,10 +514,6 @@ struct unary
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
};
struct identity : unary
......@@ -601,11 +561,6 @@ struct atan : unary
std::string name() const { return "atan"; }
};
struct softmax : unary
{
std::string name() const { return "softmax"; }
};
struct tanh : unary
{
std::string name() const { return "tanh"; }
......@@ -621,6 +576,16 @@ struct neg : unary
std::string name() const { return "neg"; }
};
struct softmax
{
std::string name() const { return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1).only_dims(4);
return inputs.at(0);
}
};
struct flatten
{
uint64_t axis = 0;
......@@ -701,10 +666,6 @@ struct binary
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
};
struct add : binary
......
......@@ -3,19 +3,10 @@
#include <algorithm>
#include <initializer_list>
#include <migraph/rank.hpp>
namespace migraph {
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
namespace detail {
template <class String, class T>
......
#ifndef MIGRAPH_GUARD_RTGLIB_RANK_HPP
#define MIGRAPH_GUARD_RTGLIB_RANK_HPP
namespace migraph {
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
} // namespace migraph
#endif
......@@ -6,14 +6,16 @@
namespace migraph {
inline void verify_args(const std::string& name,
inline bool verify_args(const std::string& name,
const argument& cpu_arg,
const argument& gpu_arg,
double tolerance = 80)
{
bool passed = true;
visit_all(cpu_arg, gpu_arg)([&](auto cpu, auto gpu) {
double error;
if(not verify_range(cpu, gpu, tolerance, &error))
passed = verify_range(cpu, gpu, tolerance, &error);
if(not passed)
{
// TODO: Check for nans
std::cout << "FAILED: " << name << std::endl;
......@@ -27,6 +29,9 @@ inline void verify_args(const std::string& name,
if(range_zero(gpu))
std::cout << "Gpu data is all zeros" << std::endl;
auto mxdiff = max_diff(cpu, gpu);
std::cout << "Max diff: " << mxdiff << std::endl;
auto idx = mismatch_idx(cpu, gpu, float_equal);
if(idx < range_distance(cpu))
{
......@@ -45,7 +50,36 @@ inline void verify_args(const std::string& name,
<< gpu[gpu_nan_idx] << std::endl;
std::cout << std::endl;
}
else
{
if(range_zero(cpu))
std::cout << "Cpu data is all zeros" << std::endl;
if(range_zero(gpu))
std::cout << "Gpu data is all zeros" << std::endl;
// auto mxdiff = max_diff(cpu, gpu);
// std::cout << "Max diff: " << mxdiff << std::endl;
// auto idx = mismatch_idx(cpu, gpu, float_equal);
// if(idx < range_distance(cpu))
// {
// std::cout << "Mismatch at " << idx << ": " << cpu[idx] << " != " << gpu[idx]
// << std::endl;
// }
auto cpu_nan_idx = find_idx(cpu, not_finite);
if(cpu_nan_idx >= 0)
std::cout << "Non finite number found in cpu at " << cpu_nan_idx << ": "
<< cpu[cpu_nan_idx] << std::endl;
auto gpu_nan_idx = find_idx(gpu, not_finite);
if(gpu_nan_idx >= 0)
std::cout << "Non finite number found in gpu at " << gpu_nan_idx << ": "
<< gpu[gpu_nan_idx] << std::endl;
// std::cout << std::endl;
}
});
return passed;
}
} // namespace migraph
......
#include <migraph/instruction.hpp>
#include <migraph/builtin.hpp>
#include <migraph/erase.hpp>
namespace migraph {
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
}
instruction::instruction(literal l)
: op(builtin::literal{}), result(l.get_shape()), lit(std::move(l))
{
}
void instruction::replace(const shape& r)
{
if(r != result)
{
result = r;
for(auto&& ins : output)
{
assert(ins->name().front() != '@');
ins->recompute_shape();
}
}
}
void instruction::recompute_shape() { replace(compute_shape(op, arguments)); }
void instruction::clear_arguments()
{
for(auto&& arg : arguments)
{
arg->remove_output(*this);
}
arguments.clear();
}
bool operator==(const instruction& i, instruction_ref ref)
{
return std::addressof(i) == std::addressof(*ref);
}
bool instruction::valid(instruction_ref start) const
{
return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
return self != i->outputs().end() &&
std::distance(start, i) < std::distance(start, *self);
});
}
bool instruction::valid() const
{
shape computed;
if(op.name() == "@literal")
{
computed = lit.get_shape();
}
else if(op.name() == "@param")
{
computed = result;
}
else
{
try
{
computed = compute_shape(op, arguments);
}
catch(migraph::exception&)
{
return false;
}
}
return result == computed && std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end();
});
}
shape instruction::get_shape() const { return result; }
const literal& instruction::get_literal() const
{
assert(op.name() == "@literal");
return lit;
}
const operation& instruction::get_operator() const { return op; }
std::string instruction::name() const { return op.name(); }
const std::vector<instruction_ref>& instruction::inputs() const { return arguments; }
const std::vector<instruction_ref>& instruction::outputs() const { return output; }
bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
void instruction::add_output(instruction_ref ins)
{
if(std::find(output.begin(), output.end(), ins) == output.end())
output.push_back(ins);
}
template <class T>
void instruction::remove_output(const T& ins)
{
migraph::erase(output, ins);
}
void instruction::backreference(instruction_ref ref)
{
for(auto&& arg : ref->inputs())
arg->add_output(ref);
}
void instruction::replace_argument(instruction_ref ins,
instruction_ref old,
instruction_ref new_ins)
{
ins->replace_argument(old, new_ins);
backreference(ins);
ins->recompute_shape();
}
void instruction::replace(instruction_ref ins,
operation o,
const shape& r,
std::vector<instruction_ref> args)
{
ins->replace(std::move(o), r, std::move(args));
backreference(ins);
}
void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
op = std::move(o);
replace(r);
replace(std::move(args));
}
void instruction::replace(std::vector<instruction_ref> args)
{
clear_arguments();
arguments = std::move(args);
}
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{
std::replace(arguments.begin(), arguments.end(), old, new_ins);
old->remove_output(*this);
}
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return op.compute_shape(shapes);
}
} // namespace migraph
......@@ -28,10 +28,6 @@ struct unknown
else
return input.front();
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
......@@ -58,6 +54,7 @@ struct onnx_parser
add_generic_op("Mul", mul{});
add_generic_op("Relu", activation{"relu"});
add_generic_op("Sub", sub{});
add_generic_op("Sum", add{});
add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv);
......@@ -67,6 +64,7 @@ struct onnx_parser
add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax);
}
template <class F>
......@@ -103,6 +101,15 @@ struct onnx_parser
});
}
instruction_ref
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto dims = args.front()->get_shape().lens();
auto r = prog.add_instruction(reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
auto s = prog.add_instruction(softmax{}, r);
return prog.add_instruction(reshape{{long(dims[0]), long(dims[1])}}, s);
}
instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
......@@ -160,7 +167,7 @@ struct onnx_parser
}
if(args.size() == 2)
{
literal s = args[1]->lit;
literal s = args[1]->get_literal();
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
return prog.add_instruction(op, args[0]);
......@@ -344,11 +351,10 @@ struct onnx_parser
if(node.name().empty())
{
std::string generated = "migraph_unnamed_node";
for(auto&& output : node.output())
{
generated += "_" + output;
}
return generated;
return std::accumulate(node.output().begin(),
node.output().end(),
generated,
[](auto x, auto y) { return x + "_" + y; });
}
return node.name();
}
......@@ -481,11 +487,11 @@ struct onnx_parser
break; // throw std::runtime_error("Unsupported type COMPLEX128");
}
std::vector<std::size_t> dims;
// TODO: USe std::transform
for(auto&& d : t.tensor_type().shape().dim())
{
dims.push_back(d.dim_value());
}
auto&& tensor_dims = t.tensor_type().shape().dim();
std::transform(tensor_dims.begin(),
tensor_dims.end(),
std::back_inserter(dims),
[](auto&& d) { return d.dim_value(); });
return {shape_type, dims};
}
};
......
......@@ -51,48 +51,76 @@ void verify_program(const std::string& name, F f, double tolerance = 100)
auto x = run_cpu(f);
auto y = run_gpu(f);
migraph::verify_args(name, x, y, tolerance);
// std::cout << "cpu: " << x << std::endl;
// std::cout << "gpu: " << y << std::endl;
}
void verify_instructions(const migraph::program& prog, double tolerance = 80)
{
for(auto&& ins : prog)
{
if(ins.op.name().front() == '@')
if(ins.name().front() == '@')
continue;
if(ins.op.name() == "broadcast")
if(ins.name() == "broadcast")
continue;
if(ins.op.name() == "transpose")
if(ins.name() == "transpose")
continue;
if(ins.op.name() == "reshape")
if(ins.name() == "reshape")
continue;
auto create_program = [&] {
migraph::program p;
std::vector<migraph::instruction_ref> inputs;
for(auto&& arg : ins.arguments)
for(auto&& arg : ins.inputs())
{
if(arg->op.name() == "@literal")
inputs.push_back(p.add_literal(arg->lit));
if(arg->name() == "@literal")
inputs.push_back(p.add_literal(arg->get_literal()));
else
inputs.push_back(
p.add_parameter(std::to_string(inputs.size()), arg->get_shape()));
}
p.add_instruction(ins.op, inputs);
p.add_instruction(ins.get_operator(), inputs);
return p;
};
try
{
std::cout << "Verify: " << ins.op.name() << std::endl;
std::cout << "Verify: " << ins.name() << std::endl;
std::cout << create_program() << std::endl;
verify_program(ins.op.name(), create_program, tolerance);
verify_program(ins.name(), create_program, tolerance);
}
catch(...)
{
std::cout << "Instruction " << ins.op.name() << " threw an exception." << std::endl;
std::cout << "Instruction " << ins.name() << " threw an exception." << std::endl;
throw;
}
}
}
template <class F>
void verify_reduced(F f, int n, double tolerance = 80)
{
auto create_program = [&] {
migraph::program p = f();
auto last = std::prev(p.end(), n + 1);
p.remove_instructions(last, p.end());
return p;
};
std::cout << "Verify: " << std::endl;
std::cout << create_program() << std::endl;
verify_program(std::to_string(n), create_program, tolerance);
}
template <class F>
void verify_reduced_program(F f, double tolerance = 80)
{
migraph::program p = f();
auto n = std::distance(p.begin(), p.end());
for(int i = 0; i < n; i++)
{
verify_reduced(f, i, tolerance);
}
}
int main(int argc, char const* argv[])
{
std::vector<std::string> args(argv + 1, argv + argc);
......@@ -106,6 +134,10 @@ int main(int argc, char const* argv[])
{
verify_instructions(p);
}
else if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-r"; }))
{
verify_reduced_program([&] { return migraph::parse_onnx(file); });
}
else
{
verify_program(file, [&] { return migraph::parse_onnx(file); });
......
......@@ -12,6 +12,7 @@
namespace migraph {
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_COMPILE)
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_EVAL)
struct program_impl
{
......@@ -20,7 +21,7 @@ struct program_impl
context ctx;
};
const operation& get_operation(instruction_ref ins) { return ins->op; }
const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
template <class F>
static void print_program(std::ostream& os, const program& p, F annonate)
......@@ -31,27 +32,27 @@ static void print_program(std::ostream& os, const program& p, F annonate)
for(auto ins : iterator_for(p))
{
std::string var_name = "@" + std::to_string(count);
if(ins->op.name() == "@param")
if(ins->name() == "@param")
{
var_name = any_cast<builtin::param>(ins->op).parameter;
var_name = any_cast<builtin::param>(ins->get_operator()).parameter;
}
os << var_name << " = ";
os << ins->op;
os << ins->get_operator();
if(ins->op.name() == "@literal")
if(ins->name() == "@literal")
{
if(ins->lit.get_shape().elements() > 10)
if(ins->get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins->lit << "}";
os << "{" << ins->get_literal() << "}";
}
if(!ins->arguments.empty())
if(!ins->inputs().empty())
{
char delim = '(';
for(auto&& arg : ins->arguments)
for(auto&& arg : ins->inputs())
{
assert(p.has_instruction(arg) && "Instruction not found");
os << delim << names.at(arg);
......@@ -60,7 +61,7 @@ static void print_program(std::ostream& os, const program& p, F annonate)
os << ")";
}
os << " -> " << ins->result;
os << " -> " << ins->get_shape();
annonate(ins, names);
......@@ -92,8 +93,8 @@ instruction_ref program::insert_instruction(instruction_ref ins,
// TODO: Use move
shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, std::move(args)});
backreference(result);
// assert(result->arguments == args);
instruction::backreference(result);
// assert(result->inputs() == args);
assert(result->valid(begin()));
return result;
}
......@@ -108,8 +109,7 @@ instruction_ref program::replace_instruction(instruction_ref ins,
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
ins->replace(op, r, std::move(args));
backreference(ins);
instruction::replace(ins, op, r, std::move(args));
assert(ins->valid(begin()));
return ins;
}
......@@ -120,21 +120,21 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert(has_instruction(rep));
assert(ins != rep);
// TODO: Should it be an error if the output is empty?
if(ins->output.empty())
if(ins->outputs().empty())
{
return rep;
}
for(auto&& out : ins->output)
for(auto&& out : ins->outputs())
{
// TODO: Check for possible cycles
if(out != rep)
{
replace_argument(out, ins, rep);
instruction::replace_argument(out, ins, rep);
}
assert(out->valid(begin()));
}
// Replacement should not be dead code unless its the last instruction
assert(!rep->output.empty() or rep == std::prev(end()));
assert(!rep->outputs().empty() or rep == std::prev(end()));
assert(ins->valid(begin()));
assert(rep->valid(begin()));
return rep;
......@@ -143,7 +143,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
instruction_ref program::remove_instruction(instruction_ref ins)
{
assert(has_instruction(ins));
assert(ins->output.empty());
assert(ins->outputs().empty());
ins->clear_arguments();
return impl->instructions.erase(ins);
}
......@@ -155,7 +155,7 @@ instruction_ref program::remove_instructions(instruction_ref first, instruction_
// TODO: Check every element
assert(has_instruction(first));
std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); });
assert(std::all_of(first, last, [&](instruction& ins) { return ins.output.empty(); }));
assert(std::all_of(first, last, [&](instruction& ins) { return ins.outputs().empty(); }));
return impl->instructions.erase(first, last);
}
......@@ -188,9 +188,9 @@ shape program::get_parameter_shape(std::string name) const
{
auto ins = std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
if(x.op.name() == "@param")
if(x.name() == "@param")
{
return any_cast<builtin::param>(x.op).parameter == name;
return any_cast<builtin::param>(x.get_operator()).parameter == name;
}
else
{
......@@ -198,7 +198,7 @@ shape program::get_parameter_shape(std::string name) const
}
});
if(ins != this->end())
return ins->result;
return ins->get_shape();
else
return {};
}
......@@ -208,10 +208,10 @@ std::unordered_map<std::string, shape> program::get_parameter_shapes() const
std::unordered_map<std::string, shape> result;
for(auto&& ins : impl->instructions)
{
if(ins.op.name() == "@param")
if(ins.name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins.op).parameter;
result[name] = ins.result;
auto&& name = any_cast<builtin::param>(ins.get_operator()).parameter;
result[name] = ins.get_shape();
}
}
return result;
......@@ -229,7 +229,7 @@ std::size_t program::size() const { return impl->instructions.size(); }
instruction_ref program::begin() const { return impl->instructions.begin(); }
instruction_ref program::end() const { return impl->instructions.end(); }
shape program::get_shape() const { return impl->instructions.back().result; }
shape program::get_shape() const { return impl->instructions.back().get_shape(); }
instruction_ref program::validate() const
{
......@@ -258,7 +258,7 @@ void program::compile(const target& t, tracer trace)
{
auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPH_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->op.name());
std::to_string(index) + ": " + invalid->name());
}
trace();
#endif
......@@ -284,32 +284,32 @@ argument generic_eval(const program& p,
values.reserve(16);
for(auto ins : iterator_for(p))
{
if(ins->op.name() == "@literal")
if(ins->name() == "@literal")
{
results.emplace(ins, trace(ins, [&] { return ins->lit.get_argument(); }));
results.emplace(ins, trace(ins, [&] { return ins->get_literal().get_argument(); }));
}
else if(ins->op.name() == "@param")
else if(ins->name() == "@param")
{
results.emplace(ins, trace(ins, [&] {
return params.at(any_cast<builtin::param>(ins->op).parameter);
return params.at(
any_cast<builtin::param>(ins->get_operator()).parameter);
}));
}
else if(ins->op.name() == "@outline")
else if(ins->name() == "@outline")
{
results.emplace(ins, trace(ins, [&] { return argument{ins->result, nullptr}; }));
results.emplace(ins, trace(ins, [&] { return argument{ins->get_shape(), nullptr}; }));
}
else
{
values.resize(ins->arguments.size());
std::transform(ins->arguments.begin(),
ins->arguments.end(),
values.begin(),
[&](instruction_ref i) {
assert(results.find(i) != results.end());
return results[i];
});
results.emplace(ins,
trace(ins, [&] { return ins->op.compute(ctx, ins->result, values); }));
values.resize(ins->inputs().size());
std::transform(
ins->inputs().begin(), ins->inputs().end(), values.begin(), [&](instruction_ref i) {
assert(results.find(i) != results.end());
return results[i];
});
results.emplace(ins, trace(ins, [&] {
return ins->get_operator().compute(ctx, ins->get_shape(), values);
}));
}
assert(results.find(ins) != results.end());
}
......@@ -318,8 +318,20 @@ argument generic_eval(const program& p,
argument program::eval(std::unordered_map<std::string, argument> params) const
{
return generic_eval(
*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); });
if(enabled(MIGRAPH_TRACE_EVAL{}))
{
auto& ctx = this->impl->ctx;
return generic_eval(*this, this->impl->ctx, std::move(params), [&](auto& ins, auto f) {
ctx.finish();
std::cout << "Run instruction: " << ins->name() << std::endl;
return f();
});
}
else
{
return generic_eval(
*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); });
}
}
double common_average(const std::vector<double>& v)
......@@ -385,7 +397,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
for(auto&& p : ins_vec)
{
double avg = common_average(p.second);
op_times[p.first->op.name()] += avg;
op_times[p.first->name()] += avg;
total_instruction_time += avg;
}
double calculate_overhead_time = total_time - total_instruction_time;
......
......@@ -25,26 +25,26 @@ void simplify_reshapes::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
if(not is_reshaper(ins->op.name()))
if(not is_reshaper(ins->name()))
continue;
if(ins->output.size() != 1)
if(ins->outputs().size() != 1)
continue;
if(is_reshaper(ins->output.front()->op.name()))
if(is_reshaper(ins->outputs().front()->name()))
continue;
// Gather reshapes
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()->op.name()))
while(is_reshaper(reshapes.back()->name()))
{
assert(!reshapes.back()->arguments.empty());
assert(p.has_instruction(reshapes.back()->arguments.front()));
reshapes.push_back(reshapes.back()->arguments.front());
assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front()));
reshapes.push_back(reshapes.back()->inputs().front());
}
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
for(auto start : iterator_for(reshapes))
{
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->result == (*start)->result and i != (*start);
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment