Commit 5dfeb457 authored by Scott Thornton's avatar Scott Thornton
Browse files

Merge branch 'master' into squeeze_unsqeeze

parents 523a78c7 f9f4f713
sphinx sphinx==1.6.2
breathe==4.9.1 breathe==4.9.1
# git+https://github.com/arximboldi/breathe@fix-node-parent # git+https://github.com/arximboldi/breathe@fix-node-parent
...@@ -7,6 +7,7 @@ add_library(migraph ...@@ -7,6 +7,7 @@ add_library(migraph
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
env.cpp env.cpp
generate.cpp generate.cpp
instruction.cpp
program.cpp program.cpp
shape.cpp shape.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
......
...@@ -10,7 +10,7 @@ void auto_contiguous::apply(program& p) const ...@@ -10,7 +10,7 @@ void auto_contiguous::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
shape s = ins->result; shape s = ins->get_shape();
if(not s.standard()) if(not s.standard())
{ {
auto c = p.insert_instruction(std::next(ins), contiguous{}, ins); auto c = p.insert_instruction(std::next(ins), contiguous{}, ins);
......
...@@ -17,16 +17,16 @@ void dead_code_elimination::apply(program& p) const ...@@ -17,16 +17,16 @@ void dead_code_elimination::apply(program& p) const
continue; continue;
const auto i = std::prev(ins); const auto i = std::prev(ins);
// Skip instruction with empty shape as output unless its a builtin // Skip instruction with empty shape as output unless its a builtin
if(i->result.elements() == 0 and not(i->op.name().front() == '@')) if(i->get_shape().elements() == 0 and not(i->name().front() == '@'))
continue; continue;
// Skip the last instruction // Skip the last instruction
if(i == last) if(i == last)
break; break;
fix([&](auto self, auto leaf) { fix([&](auto self, auto leaf) {
assert(p.has_instruction(leaf)); assert(p.has_instruction(leaf));
if(leaf->output.empty()) if(leaf->outputs().empty())
{ {
auto args = leaf->arguments; auto args = leaf->inputs();
leaf->clear_arguments(); leaf->clear_arguments();
p.move_instruction(leaf, p.end()); p.move_instruction(leaf, p.end());
for(auto arg : args) for(auto arg : args)
......
...@@ -14,7 +14,7 @@ void eliminate_allocation::apply(program& p) const ...@@ -14,7 +14,7 @@ void eliminate_allocation::apply(program& p) const
std::vector<std::pair<instruction_ref, std::size_t>> allocs; std::vector<std::pair<instruction_ref, std::size_t>> allocs;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->op.name() != allocation_op) if(ins->name() != allocation_op)
continue; continue;
allocs.emplace_back(ins, n); allocs.emplace_back(ins, n);
std::size_t size = ins->get_shape().bytes(); std::size_t size = ins->get_shape().bytes();
......
...@@ -27,19 +27,19 @@ void eliminate_contiguous::apply(program& p) const ...@@ -27,19 +27,19 @@ void eliminate_contiguous::apply(program& p) const
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
// Make a copy so we can modify it while we iterate // Make a copy so we can modify it while we iterate
auto args = ins->arguments; auto args = ins->inputs();
for(auto arg : ins->arguments) for(auto arg : ins->inputs())
{ {
// TODO: Pass in names for the operator in the constructor instead // TODO: Pass in names for the operator in the constructor instead
// of using ends_with // of using ends_with
if(ends_with(arg->op.name(), "contiguous")) if(ends_with(arg->name(), "contiguous"))
{ {
auto new_args = args; auto new_args = args;
auto prev = arg->arguments.front(); auto prev = arg->inputs().front();
replace(new_args, arg, prev); replace(new_args, arg, prev);
if(try_compute_shape(ins->op, new_args)) if(try_compute_shape(ins->get_operator(), new_args))
{ {
replace_argument(ins, arg, prev); instruction::replace_argument(ins, arg, prev);
} }
} }
} }
......
...@@ -10,30 +10,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -10,30 +10,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->op.name() != "batch_norm_inference") if(ins->name() != "batch_norm_inference")
continue; continue;
if(not std::all_of(ins->arguments.begin() + 1, ins->arguments.end(), [](auto arg) { if(not std::all_of(ins->inputs().begin() + 1, ins->inputs().end(), [](auto arg) {
return arg->op.name() == "@literal"; return arg->name() == "@literal";
})) }))
continue; continue;
auto conv_ins = ins->arguments[0]; auto conv_ins = ins->inputs()[0];
if(conv_ins->op.name() != "convolution") if(conv_ins->name() != "convolution")
continue; continue;
if(conv_ins->arguments[1]->op.name() != "@literal") if(conv_ins->inputs()[1]->name() != "@literal")
continue; continue;
// Get scale, bias, mean, variance from instruction_ref // Get scale, bias, mean, variance from instruction_ref
const auto& gamma = ins->arguments[1]->get_literal(); const auto& gamma = ins->inputs()[1]->get_literal();
const auto& bias = ins->arguments[2]->get_literal(); const auto& bias = ins->inputs()[2]->get_literal();
const auto& mean = ins->arguments[3]->get_literal(); const auto& mean = ins->inputs()[3]->get_literal();
const auto& variance = ins->arguments[4]->get_literal(); const auto& variance = ins->inputs()[4]->get_literal();
// Get epsilon // Get epsilon
auto bn_op = any_cast<batch_norm_inference>(ins->op); auto bn_op = any_cast<batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon; auto epsilon = bn_op.epsilon;
// Get convolution weights // Get convolution weights
const auto& weights = conv_ins->arguments[1]->get_literal(); const auto& weights = conv_ins->inputs()[1]->get_literal();
// Get convolution op // Get convolution op
auto conv_op = conv_ins->op; auto conv_op = conv_ins->get_operator();
auto weights_lens = weights.get_shape().lens(); auto weights_lens = weights.get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens(); auto conv_lens = conv_ins->get_shape().lens();
argument new_weights{weights.get_shape()}; argument new_weights{weights.get_shape()};
...@@ -58,7 +58,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -58,7 +58,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
// Replace convolution instruction with updated weights // Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()}); auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->arguments[0], l_weights}); auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights});
auto b = p.insert_instruction(ins, broadcast{1}, c, l_bias); auto b = p.insert_instruction(ins, broadcast{1}, c, l_bias);
p.replace_instruction(ins, add{}, {c, b}); p.replace_instruction(ins, add{}, {c, b});
} }
......
...@@ -12,7 +12,7 @@ constexpr T normalize(unsigned long z) ...@@ -12,7 +12,7 @@ constexpr T normalize(unsigned long z)
{ {
if(z == 0) if(z == 0)
return 0; return 0;
const auto max = 2048; const auto max = 32;
const double range = max / 2; // NOLINT const double range = max / 2; // NOLINT
double result = (z % max) / range; double result = (z % max) / range;
result -= 1; result -= 1;
......
...@@ -3,10 +3,8 @@ ...@@ -3,10 +3,8 @@
#include <migraph/literal.hpp> #include <migraph/literal.hpp>
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/builtin.hpp>
#include <migraph/instruction_ref.hpp> #include <migraph/instruction_ref.hpp>
#include <migraph/operation.hpp> #include <migraph/operation.hpp>
#include <migraph/erase.hpp>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -18,126 +16,60 @@ struct instruction ...@@ -18,126 +16,60 @@ struct instruction
{ {
instruction() {} instruction() {}
instruction(operation o, shape r, std::vector<instruction_ref> args) instruction(operation o, shape r, std::vector<instruction_ref> args);
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
}
instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) {} instruction(literal l);
// internal void replace(const shape& r);
void replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
op = std::move(o);
replace(r);
replace(std::move(args));
}
void replace(const shape& r) void recompute_shape();
{
if(r != result)
{
result = r;
for(auto&& ins : output)
{
assert(ins->op.name().front() != '@');
ins->recompute_shape();
}
}
}
void recompute_shape() { replace(compute_shape(op, arguments)); } void clear_arguments();
// internal friend bool operator==(const instruction& i, instruction_ref ref);
void replace(std::vector<instruction_ref> args)
{
clear_arguments();
arguments = std::move(args);
}
// internal bool valid(instruction_ref start) const;
void replace_argument(instruction_ref old, instruction_ref new_ins)
{
std::replace(arguments.begin(), arguments.end(), old, new_ins);
old->remove_output(*this);
}
void clear_arguments() bool valid() const;
{
for(auto&& arg : arguments)
{
arg->remove_output(*this);
}
arguments.clear();
}
friend bool operator==(const instruction& i, instruction_ref ref) shape get_shape() const;
{ const literal& get_literal() const;
return std::addressof(i) == std::addressof(*ref);
}
bool valid(instruction_ref start) const const operation& get_operator() const;
{
return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
auto self = std::find(i->output.begin(), i->output.end(), *this);
return self != i->output.end() &&
std::distance(start, i) < std::distance(start, *self);
});
}
bool valid() const std::string name() const;
{
shape computed;
if(op.name() == "@literal")
{
computed = lit.get_shape();
}
else if(op.name() == "@param")
{
computed = result;
}
else
{
try
{
computed = compute_shape(op, arguments);
}
catch(migraph::exception&)
{
return false;
}
}
return result == computed &&
std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->arguments.begin(), i->arguments.end(), *this) !=
i->arguments.end();
});
}
shape get_shape() const { return result; } const std::vector<instruction_ref>& inputs() const;
const literal& get_literal() const
{
assert(op.name() == "@literal");
return lit;
}
friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; } const std::vector<instruction_ref>& outputs() const;
friend bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); } friend bool operator==(instruction_ref ref, const instruction& i);
friend bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); } friend bool operator!=(const instruction& i, instruction_ref ref);
void add_output(instruction_ref ins) friend bool operator!=(instruction_ref ref, const instruction& i);
{
if(std::find(output.begin(), output.end(), ins) == output.end()) void add_output(instruction_ref ins);
output.push_back(ins);
}
template <class T> template <class T>
void remove_output(const T& ins) void remove_output(const T& ins);
{
migraph::erase(output, ins); static void backreference(instruction_ref ref);
}
static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins);
static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
private:
// internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args);
// internal
void replace(std::vector<instruction_ref> args);
// internal
void replace_argument(instruction_ref old, instruction_ref new_ins);
operation op; operation op;
shape result; shape result;
...@@ -146,29 +78,6 @@ struct instruction ...@@ -146,29 +78,6 @@ struct instruction
literal lit; literal lit;
}; };
inline void backreference(instruction_ref ref)
{
for(auto&& arg : ref->arguments)
arg->add_output(ref);
}
inline void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins)
{
ins->replace_argument(old, new_ins);
backreference(ins);
ins->recompute_shape();
}
// TODO: Move to a cpp file
// TODO: Use const ref for vector
inline shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->result; });
return op.compute_shape(shapes);
}
} // namespace migraph } // namespace migraph
namespace std { namespace std {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp> #include <migraph/auto_any_cast.hpp>
...@@ -27,13 +28,16 @@ struct operation ...@@ -27,13 +28,16 @@ struct operation
/// exception. /// exception.
shape compute_shape(const std::vector<shape>& input) const; shape compute_shape(const std::vector<shape>& input) const;
/** /**
* @brief This performs the operation's computation * @brief This performs the operation's computation.
*
* This method can be optional when the operation is only used as a placeholder to be lowered
* later on.
* *
* @param ctx This is the context created by the `target` during compilation. Implementations * @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class. * can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each * @param output This is the output shape. It is equivalent to running `compute_shape` with each
* `shape` of the `argument`. * `shape` of the `argument`.
* @param input This is the `argument` result from the previous instuction's computation. * @param input This is the `argument` result from the previous instruction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be * @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape. * the same the `output` shape.
*/ */
...@@ -55,11 +59,29 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -55,11 +59,29 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream } // namespace operation_stream
template <class T>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
{
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPH_THROW("Not computable: " + name);
}
template <class T> template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return x.compute(auto_any_cast(ctx), output_shape, input); return compute_op(rank<1>{}, x, ctx, output_shape, input);
} }
/* /*
......
...@@ -41,11 +41,6 @@ struct batch_norm_inference ...@@ -41,11 +41,6 @@ struct batch_norm_inference
check_shapes{inputs, *this}.has(5); check_shapes{inputs, *this}.has(5);
return inputs.front(); return inputs.front();
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct convolution struct convolution
...@@ -115,11 +110,6 @@ struct convolution ...@@ -115,11 +110,6 @@ struct convolution
} }
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const convolution& op) friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{ {
os << op.name() << "["; os << op.name() << "[";
...@@ -169,11 +159,6 @@ struct im2col ...@@ -169,11 +159,6 @@ struct im2col
auto channels_col = kernel_height * kernel_width * input_channels; auto channels_col = kernel_height * kernel_width * input_channels;
return {input.type(), {output_height * output_width, channels_col}}; return {input.type(), {output_height * output_width, channels_col}};
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct pooling struct pooling
...@@ -211,11 +196,6 @@ struct pooling ...@@ -211,11 +196,6 @@ struct pooling
}}; }};
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const pooling& op) friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{ {
os << op.name() << "["; os << op.name() << "[";
...@@ -236,11 +216,6 @@ struct activation ...@@ -236,11 +216,6 @@ struct activation
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
return inputs.front(); return inputs.front();
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const activation& op) friend std::ostream& operator<<(std::ostream& os, const activation& op)
{ {
os << op.name() << ":" << op.mode; os << op.name() << ":" << op.mode;
...@@ -305,10 +280,6 @@ struct contiguous ...@@ -305,10 +280,6 @@ struct contiguous
} }
return {t, lens}; return {t, lens};
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct slice struct slice
...@@ -497,12 +468,10 @@ struct reshape ...@@ -497,12 +468,10 @@ struct reshape
MIGRAPH_THROW("Wrong number of elements for reshape"); MIGRAPH_THROW("Wrong number of elements for reshape");
return s; return s;
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
friend std::ostream& operator<<(std::ostream& os, const reshape& op) friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{ {
os << op.name() << "["; os << op.name() << "[";
...@@ -530,11 +499,6 @@ struct gemm ...@@ -530,11 +499,6 @@ struct gemm
return {t, {a.lens()[0], b.lens()[1]}}; return {t, {a.lens()[0], b.lens()[1]}};
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const gemm& op) friend std::ostream& operator<<(std::ostream& os, const gemm& op)
{ {
os << op.name() << "["; os << op.name() << "[";
...@@ -550,10 +514,6 @@ struct unary ...@@ -550,10 +514,6 @@ struct unary
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.at(0); return inputs.at(0);
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct identity : unary struct identity : unary
...@@ -601,11 +561,6 @@ struct atan : unary ...@@ -601,11 +561,6 @@ struct atan : unary
std::string name() const { return "atan"; } std::string name() const { return "atan"; }
}; };
struct softmax : unary
{
std::string name() const { return "softmax"; }
};
struct tanh : unary struct tanh : unary
{ {
std::string name() const { return "tanh"; } std::string name() const { return "tanh"; }
...@@ -621,6 +576,16 @@ struct neg : unary ...@@ -621,6 +576,16 @@ struct neg : unary
std::string name() const { return "neg"; } std::string name() const { return "neg"; }
}; };
struct softmax
{
std::string name() const { return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1).only_dims(4);
return inputs.at(0);
}
};
struct flatten struct flatten
{ {
uint64_t axis = 0; uint64_t axis = 0;
...@@ -701,10 +666,6 @@ struct binary ...@@ -701,10 +666,6 @@ struct binary
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0); return inputs.at(0);
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct add : binary struct add : binary
......
...@@ -3,19 +3,10 @@ ...@@ -3,19 +3,10 @@
#include <algorithm> #include <algorithm>
#include <initializer_list> #include <initializer_list>
#include <migraph/rank.hpp>
namespace migraph { namespace migraph {
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
namespace detail { namespace detail {
template <class String, class T> template <class String, class T>
......
#ifndef MIGRAPH_GUARD_RTGLIB_RANK_HPP
#define MIGRAPH_GUARD_RTGLIB_RANK_HPP
namespace migraph {
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
} // namespace migraph
#endif
...@@ -6,14 +6,16 @@ ...@@ -6,14 +6,16 @@
namespace migraph { namespace migraph {
inline void verify_args(const std::string& name, inline bool verify_args(const std::string& name,
const argument& cpu_arg, const argument& cpu_arg,
const argument& gpu_arg, const argument& gpu_arg,
double tolerance = 80) double tolerance = 80)
{ {
bool passed = true;
visit_all(cpu_arg, gpu_arg)([&](auto cpu, auto gpu) { visit_all(cpu_arg, gpu_arg)([&](auto cpu, auto gpu) {
double error; double error;
if(not verify_range(cpu, gpu, tolerance, &error)) passed = verify_range(cpu, gpu, tolerance, &error);
if(not passed)
{ {
// TODO: Check for nans // TODO: Check for nans
std::cout << "FAILED: " << name << std::endl; std::cout << "FAILED: " << name << std::endl;
...@@ -27,6 +29,9 @@ inline void verify_args(const std::string& name, ...@@ -27,6 +29,9 @@ inline void verify_args(const std::string& name,
if(range_zero(gpu)) if(range_zero(gpu))
std::cout << "Gpu data is all zeros" << std::endl; std::cout << "Gpu data is all zeros" << std::endl;
auto mxdiff = max_diff(cpu, gpu);
std::cout << "Max diff: " << mxdiff << std::endl;
auto idx = mismatch_idx(cpu, gpu, float_equal); auto idx = mismatch_idx(cpu, gpu, float_equal);
if(idx < range_distance(cpu)) if(idx < range_distance(cpu))
{ {
...@@ -45,7 +50,36 @@ inline void verify_args(const std::string& name, ...@@ -45,7 +50,36 @@ inline void verify_args(const std::string& name,
<< gpu[gpu_nan_idx] << std::endl; << gpu[gpu_nan_idx] << std::endl;
std::cout << std::endl; std::cout << std::endl;
} }
else
{
if(range_zero(cpu))
std::cout << "Cpu data is all zeros" << std::endl;
if(range_zero(gpu))
std::cout << "Gpu data is all zeros" << std::endl;
// auto mxdiff = max_diff(cpu, gpu);
// std::cout << "Max diff: " << mxdiff << std::endl;
// auto idx = mismatch_idx(cpu, gpu, float_equal);
// if(idx < range_distance(cpu))
// {
// std::cout << "Mismatch at " << idx << ": " << cpu[idx] << " != " << gpu[idx]
// << std::endl;
// }
auto cpu_nan_idx = find_idx(cpu, not_finite);
if(cpu_nan_idx >= 0)
std::cout << "Non finite number found in cpu at " << cpu_nan_idx << ": "
<< cpu[cpu_nan_idx] << std::endl;
auto gpu_nan_idx = find_idx(gpu, not_finite);
if(gpu_nan_idx >= 0)
std::cout << "Non finite number found in gpu at " << gpu_nan_idx << ": "
<< gpu[gpu_nan_idx] << std::endl;
// std::cout << std::endl;
}
}); });
return passed;
} }
} // namespace migraph } // namespace migraph
......
#include <migraph/instruction.hpp>
#include <migraph/builtin.hpp>
#include <migraph/erase.hpp>
namespace migraph {
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
}
instruction::instruction(literal l)
: op(builtin::literal{}), result(l.get_shape()), lit(std::move(l))
{
}
void instruction::replace(const shape& r)
{
if(r != result)
{
result = r;
for(auto&& ins : output)
{
assert(ins->name().front() != '@');
ins->recompute_shape();
}
}
}
void instruction::recompute_shape() { replace(compute_shape(op, arguments)); }
void instruction::clear_arguments()
{
for(auto&& arg : arguments)
{
arg->remove_output(*this);
}
arguments.clear();
}
bool operator==(const instruction& i, instruction_ref ref)
{
return std::addressof(i) == std::addressof(*ref);
}
bool instruction::valid(instruction_ref start) const
{
return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
return self != i->outputs().end() &&
std::distance(start, i) < std::distance(start, *self);
});
}
bool instruction::valid() const
{
shape computed;
if(op.name() == "@literal")
{
computed = lit.get_shape();
}
else if(op.name() == "@param")
{
computed = result;
}
else
{
try
{
computed = compute_shape(op, arguments);
}
catch(migraph::exception&)
{
return false;
}
}
return result == computed && std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end();
});
}
shape instruction::get_shape() const { return result; }
const literal& instruction::get_literal() const
{
assert(op.name() == "@literal");
return lit;
}
const operation& instruction::get_operator() const { return op; }
std::string instruction::name() const { return op.name(); }
const std::vector<instruction_ref>& instruction::inputs() const { return arguments; }
const std::vector<instruction_ref>& instruction::outputs() const { return output; }
bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
void instruction::add_output(instruction_ref ins)
{
if(std::find(output.begin(), output.end(), ins) == output.end())
output.push_back(ins);
}
template <class T>
void instruction::remove_output(const T& ins)
{
migraph::erase(output, ins);
}
void instruction::backreference(instruction_ref ref)
{
for(auto&& arg : ref->inputs())
arg->add_output(ref);
}
void instruction::replace_argument(instruction_ref ins,
instruction_ref old,
instruction_ref new_ins)
{
ins->replace_argument(old, new_ins);
backreference(ins);
ins->recompute_shape();
}
void instruction::replace(instruction_ref ins,
operation o,
const shape& r,
std::vector<instruction_ref> args)
{
ins->replace(std::move(o), r, std::move(args));
backreference(ins);
}
void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
op = std::move(o);
replace(r);
replace(std::move(args));
}
void instruction::replace(std::vector<instruction_ref> args)
{
clear_arguments();
arguments = std::move(args);
}
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{
std::replace(arguments.begin(), arguments.end(), old, new_ins);
old->remove_output(*this);
}
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return op.compute_shape(shapes);
}
} // namespace migraph
...@@ -28,10 +28,6 @@ struct unknown ...@@ -28,10 +28,6 @@ struct unknown
else else
return input.front(); return input.front();
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x) friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{ {
os << x.name(); os << x.name();
...@@ -58,6 +54,7 @@ struct onnx_parser ...@@ -58,6 +54,7 @@ struct onnx_parser
add_generic_op("Mul", mul{}); add_generic_op("Mul", mul{});
add_generic_op("Relu", activation{"relu"}); add_generic_op("Relu", activation{"relu"});
add_generic_op("Sub", sub{}); add_generic_op("Sub", sub{});
add_generic_op("Sum", add{});
add_mem_op("Constant", &onnx_parser::parse_constant); add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv); add_mem_op("Conv", &onnx_parser::parse_conv);
...@@ -67,6 +64,7 @@ struct onnx_parser ...@@ -67,6 +64,7 @@ struct onnx_parser
add_mem_op("Flatten", &onnx_parser::parse_flatten); add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm); add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax);
} }
template <class F> template <class F>
...@@ -103,6 +101,15 @@ struct onnx_parser ...@@ -103,6 +101,15 @@ struct onnx_parser
}); });
} }
instruction_ref
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto dims = args.front()->get_shape().lens();
auto r = prog.add_instruction(reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
auto s = prog.add_instruction(softmax{}, r);
return prog.add_instruction(reshape{{long(dims[0]), long(dims[1])}}, s);
}
instruction_ref instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -160,7 +167,7 @@ struct onnx_parser ...@@ -160,7 +167,7 @@ struct onnx_parser
} }
if(args.size() == 2) if(args.size() == 2)
{ {
literal s = args[1]->lit; literal s = args[1]->get_literal();
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
} }
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, args[0]);
...@@ -344,11 +351,10 @@ struct onnx_parser ...@@ -344,11 +351,10 @@ struct onnx_parser
if(node.name().empty()) if(node.name().empty())
{ {
std::string generated = "migraph_unnamed_node"; std::string generated = "migraph_unnamed_node";
for(auto&& output : node.output()) return std::accumulate(node.output().begin(),
{ node.output().end(),
generated += "_" + output; generated,
} [](auto x, auto y) { return x + "_" + y; });
return generated;
} }
return node.name(); return node.name();
} }
...@@ -481,11 +487,11 @@ struct onnx_parser ...@@ -481,11 +487,11 @@ struct onnx_parser
break; // throw std::runtime_error("Unsupported type COMPLEX128"); break; // throw std::runtime_error("Unsupported type COMPLEX128");
} }
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
// TODO: USe std::transform auto&& tensor_dims = t.tensor_type().shape().dim();
for(auto&& d : t.tensor_type().shape().dim()) std::transform(tensor_dims.begin(),
{ tensor_dims.end(),
dims.push_back(d.dim_value()); std::back_inserter(dims),
} [](auto&& d) { return d.dim_value(); });
return {shape_type, dims}; return {shape_type, dims};
} }
}; };
......
...@@ -51,48 +51,76 @@ void verify_program(const std::string& name, F f, double tolerance = 100) ...@@ -51,48 +51,76 @@ void verify_program(const std::string& name, F f, double tolerance = 100)
auto x = run_cpu(f); auto x = run_cpu(f);
auto y = run_gpu(f); auto y = run_gpu(f);
migraph::verify_args(name, x, y, tolerance); migraph::verify_args(name, x, y, tolerance);
// std::cout << "cpu: " << x << std::endl;
// std::cout << "gpu: " << y << std::endl;
} }
void verify_instructions(const migraph::program& prog, double tolerance = 80) void verify_instructions(const migraph::program& prog, double tolerance = 80)
{ {
for(auto&& ins : prog) for(auto&& ins : prog)
{ {
if(ins.op.name().front() == '@') if(ins.name().front() == '@')
continue; continue;
if(ins.op.name() == "broadcast") if(ins.name() == "broadcast")
continue; continue;
if(ins.op.name() == "transpose") if(ins.name() == "transpose")
continue; continue;
if(ins.op.name() == "reshape") if(ins.name() == "reshape")
continue; continue;
auto create_program = [&] { auto create_program = [&] {
migraph::program p; migraph::program p;
std::vector<migraph::instruction_ref> inputs; std::vector<migraph::instruction_ref> inputs;
for(auto&& arg : ins.arguments) for(auto&& arg : ins.inputs())
{ {
if(arg->op.name() == "@literal") if(arg->name() == "@literal")
inputs.push_back(p.add_literal(arg->lit)); inputs.push_back(p.add_literal(arg->get_literal()));
else else
inputs.push_back( inputs.push_back(
p.add_parameter(std::to_string(inputs.size()), arg->get_shape())); p.add_parameter(std::to_string(inputs.size()), arg->get_shape()));
} }
p.add_instruction(ins.op, inputs); p.add_instruction(ins.get_operator(), inputs);
return p; return p;
}; };
try try
{ {
std::cout << "Verify: " << ins.op.name() << std::endl; std::cout << "Verify: " << ins.name() << std::endl;
std::cout << create_program() << std::endl; std::cout << create_program() << std::endl;
verify_program(ins.op.name(), create_program, tolerance); verify_program(ins.name(), create_program, tolerance);
} }
catch(...) catch(...)
{ {
std::cout << "Instruction " << ins.op.name() << " threw an exception." << std::endl; std::cout << "Instruction " << ins.name() << " threw an exception." << std::endl;
throw; throw;
} }
} }
} }
template <class F>
void verify_reduced(F f, int n, double tolerance = 80)
{
auto create_program = [&] {
migraph::program p = f();
auto last = std::prev(p.end(), n + 1);
p.remove_instructions(last, p.end());
return p;
};
std::cout << "Verify: " << std::endl;
std::cout << create_program() << std::endl;
verify_program(std::to_string(n), create_program, tolerance);
}
template <class F>
void verify_reduced_program(F f, double tolerance = 80)
{
migraph::program p = f();
auto n = std::distance(p.begin(), p.end());
for(int i = 0; i < n; i++)
{
verify_reduced(f, i, tolerance);
}
}
int main(int argc, char const* argv[]) int main(int argc, char const* argv[])
{ {
std::vector<std::string> args(argv + 1, argv + argc); std::vector<std::string> args(argv + 1, argv + argc);
...@@ -106,6 +134,10 @@ int main(int argc, char const* argv[]) ...@@ -106,6 +134,10 @@ int main(int argc, char const* argv[])
{ {
verify_instructions(p); verify_instructions(p);
} }
else if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-r"; }))
{
verify_reduced_program([&] { return migraph::parse_onnx(file); });
}
else else
{ {
verify_program(file, [&] { return migraph::parse_onnx(file); }); verify_program(file, [&] { return migraph::parse_onnx(file); });
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
namespace migraph { namespace migraph {
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_COMPILE) MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_COMPILE)
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_EVAL)
struct program_impl struct program_impl
{ {
...@@ -20,7 +21,7 @@ struct program_impl ...@@ -20,7 +21,7 @@ struct program_impl
context ctx; context ctx;
}; };
const operation& get_operation(instruction_ref ins) { return ins->op; } const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
template <class F> template <class F>
static void print_program(std::ostream& os, const program& p, F annonate) static void print_program(std::ostream& os, const program& p, F annonate)
...@@ -31,27 +32,27 @@ static void print_program(std::ostream& os, const program& p, F annonate) ...@@ -31,27 +32,27 @@ static void print_program(std::ostream& os, const program& p, F annonate)
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
std::string var_name = "@" + std::to_string(count); std::string var_name = "@" + std::to_string(count);
if(ins->op.name() == "@param") if(ins->name() == "@param")
{ {
var_name = any_cast<builtin::param>(ins->op).parameter; var_name = any_cast<builtin::param>(ins->get_operator()).parameter;
} }
os << var_name << " = "; os << var_name << " = ";
os << ins->op; os << ins->get_operator();
if(ins->op.name() == "@literal") if(ins->name() == "@literal")
{ {
if(ins->lit.get_shape().elements() > 10) if(ins->get_literal().get_shape().elements() > 10)
os << "{ ... }"; os << "{ ... }";
else else
os << "{" << ins->lit << "}"; os << "{" << ins->get_literal() << "}";
} }
if(!ins->arguments.empty()) if(!ins->inputs().empty())
{ {
char delim = '('; char delim = '(';
for(auto&& arg : ins->arguments) for(auto&& arg : ins->inputs())
{ {
assert(p.has_instruction(arg) && "Instruction not found"); assert(p.has_instruction(arg) && "Instruction not found");
os << delim << names.at(arg); os << delim << names.at(arg);
...@@ -60,7 +61,7 @@ static void print_program(std::ostream& os, const program& p, F annonate) ...@@ -60,7 +61,7 @@ static void print_program(std::ostream& os, const program& p, F annonate)
os << ")"; os << ")";
} }
os << " -> " << ins->result; os << " -> " << ins->get_shape();
annonate(ins, names); annonate(ins, names);
...@@ -92,8 +93,8 @@ instruction_ref program::insert_instruction(instruction_ref ins, ...@@ -92,8 +93,8 @@ instruction_ref program::insert_instruction(instruction_ref ins,
// TODO: Use move // TODO: Use move
shape r = compute_shape(op, args); shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, std::move(args)}); auto result = impl->instructions.insert(ins, {op, r, std::move(args)});
backreference(result); instruction::backreference(result);
// assert(result->arguments == args); // assert(result->inputs() == args);
assert(result->valid(begin())); assert(result->valid(begin()));
return result; return result;
} }
...@@ -108,8 +109,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, ...@@ -108,8 +109,7 @@ instruction_ref program::replace_instruction(instruction_ref ins,
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args); shape r = compute_shape(op, args);
ins->replace(op, r, std::move(args)); instruction::replace(ins, op, r, std::move(args));
backreference(ins);
assert(ins->valid(begin())); assert(ins->valid(begin()));
return ins; return ins;
} }
...@@ -120,21 +120,21 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re ...@@ -120,21 +120,21 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert(has_instruction(rep)); assert(has_instruction(rep));
assert(ins != rep); assert(ins != rep);
// TODO: Should it be an error if the output is empty? // TODO: Should it be an error if the output is empty?
if(ins->output.empty()) if(ins->outputs().empty())
{ {
return rep; return rep;
} }
for(auto&& out : ins->output) for(auto&& out : ins->outputs())
{ {
// TODO: Check for possible cycles // TODO: Check for possible cycles
if(out != rep) if(out != rep)
{ {
replace_argument(out, ins, rep); instruction::replace_argument(out, ins, rep);
} }
assert(out->valid(begin())); assert(out->valid(begin()));
} }
// Replacement should not be dead code unless its the last instruction // Replacement should not be dead code unless its the last instruction
assert(!rep->output.empty() or rep == std::prev(end())); assert(!rep->outputs().empty() or rep == std::prev(end()));
assert(ins->valid(begin())); assert(ins->valid(begin()));
assert(rep->valid(begin())); assert(rep->valid(begin()));
return rep; return rep;
...@@ -143,7 +143,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re ...@@ -143,7 +143,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
instruction_ref program::remove_instruction(instruction_ref ins) instruction_ref program::remove_instruction(instruction_ref ins)
{ {
assert(has_instruction(ins)); assert(has_instruction(ins));
assert(ins->output.empty()); assert(ins->outputs().empty());
ins->clear_arguments(); ins->clear_arguments();
return impl->instructions.erase(ins); return impl->instructions.erase(ins);
} }
...@@ -155,7 +155,7 @@ instruction_ref program::remove_instructions(instruction_ref first, instruction_ ...@@ -155,7 +155,7 @@ instruction_ref program::remove_instructions(instruction_ref first, instruction_
// TODO: Check every element // TODO: Check every element
assert(has_instruction(first)); assert(has_instruction(first));
std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); }); std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); });
assert(std::all_of(first, last, [&](instruction& ins) { return ins.output.empty(); })); assert(std::all_of(first, last, [&](instruction& ins) { return ins.outputs().empty(); }));
return impl->instructions.erase(first, last); return impl->instructions.erase(first, last);
} }
...@@ -188,9 +188,9 @@ shape program::get_parameter_shape(std::string name) const ...@@ -188,9 +188,9 @@ shape program::get_parameter_shape(std::string name) const
{ {
auto ins = std::find_if( auto ins = std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) { impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
if(x.op.name() == "@param") if(x.name() == "@param")
{ {
return any_cast<builtin::param>(x.op).parameter == name; return any_cast<builtin::param>(x.get_operator()).parameter == name;
} }
else else
{ {
...@@ -198,7 +198,7 @@ shape program::get_parameter_shape(std::string name) const ...@@ -198,7 +198,7 @@ shape program::get_parameter_shape(std::string name) const
} }
}); });
if(ins != this->end()) if(ins != this->end())
return ins->result; return ins->get_shape();
else else
return {}; return {};
} }
...@@ -208,10 +208,10 @@ std::unordered_map<std::string, shape> program::get_parameter_shapes() const ...@@ -208,10 +208,10 @@ std::unordered_map<std::string, shape> program::get_parameter_shapes() const
std::unordered_map<std::string, shape> result; std::unordered_map<std::string, shape> result;
for(auto&& ins : impl->instructions) for(auto&& ins : impl->instructions)
{ {
if(ins.op.name() == "@param") if(ins.name() == "@param")
{ {
auto&& name = any_cast<builtin::param>(ins.op).parameter; auto&& name = any_cast<builtin::param>(ins.get_operator()).parameter;
result[name] = ins.result; result[name] = ins.get_shape();
} }
} }
return result; return result;
...@@ -229,7 +229,7 @@ std::size_t program::size() const { return impl->instructions.size(); } ...@@ -229,7 +229,7 @@ std::size_t program::size() const { return impl->instructions.size(); }
instruction_ref program::begin() const { return impl->instructions.begin(); } instruction_ref program::begin() const { return impl->instructions.begin(); }
instruction_ref program::end() const { return impl->instructions.end(); } instruction_ref program::end() const { return impl->instructions.end(); }
shape program::get_shape() const { return impl->instructions.back().result; } shape program::get_shape() const { return impl->instructions.back().get_shape(); }
instruction_ref program::validate() const instruction_ref program::validate() const
{ {
...@@ -258,7 +258,7 @@ void program::compile(const target& t, tracer trace) ...@@ -258,7 +258,7 @@ void program::compile(const target& t, tracer trace)
{ {
auto index = std::distance(impl->instructions.begin(), invalid); auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPH_THROW(p.name() + " pass produces invalid program at instruction " + MIGRAPH_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->op.name()); std::to_string(index) + ": " + invalid->name());
} }
trace(); trace();
#endif #endif
...@@ -284,32 +284,32 @@ argument generic_eval(const program& p, ...@@ -284,32 +284,32 @@ argument generic_eval(const program& p,
values.reserve(16); values.reserve(16);
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->op.name() == "@literal") if(ins->name() == "@literal")
{ {
results.emplace(ins, trace(ins, [&] { return ins->lit.get_argument(); })); results.emplace(ins, trace(ins, [&] { return ins->get_literal().get_argument(); }));
} }
else if(ins->op.name() == "@param") else if(ins->name() == "@param")
{ {
results.emplace(ins, trace(ins, [&] { results.emplace(ins, trace(ins, [&] {
return params.at(any_cast<builtin::param>(ins->op).parameter); return params.at(
any_cast<builtin::param>(ins->get_operator()).parameter);
})); }));
} }
else if(ins->op.name() == "@outline") else if(ins->name() == "@outline")
{ {
results.emplace(ins, trace(ins, [&] { return argument{ins->result, nullptr}; })); results.emplace(ins, trace(ins, [&] { return argument{ins->get_shape(), nullptr}; }));
} }
else else
{ {
values.resize(ins->arguments.size()); values.resize(ins->inputs().size());
std::transform(ins->arguments.begin(), std::transform(
ins->arguments.end(), ins->inputs().begin(), ins->inputs().end(), values.begin(), [&](instruction_ref i) {
values.begin(), assert(results.find(i) != results.end());
[&](instruction_ref i) { return results[i];
assert(results.find(i) != results.end()); });
return results[i]; results.emplace(ins, trace(ins, [&] {
}); return ins->get_operator().compute(ctx, ins->get_shape(), values);
results.emplace(ins, }));
trace(ins, [&] { return ins->op.compute(ctx, ins->result, values); }));
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
} }
...@@ -318,8 +318,20 @@ argument generic_eval(const program& p, ...@@ -318,8 +318,20 @@ argument generic_eval(const program& p,
argument program::eval(std::unordered_map<std::string, argument> params) const argument program::eval(std::unordered_map<std::string, argument> params) const
{ {
return generic_eval( if(enabled(MIGRAPH_TRACE_EVAL{}))
*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); }); {
auto& ctx = this->impl->ctx;
return generic_eval(*this, this->impl->ctx, std::move(params), [&](auto& ins, auto f) {
ctx.finish();
std::cout << "Run instruction: " << ins->name() << std::endl;
return f();
});
}
else
{
return generic_eval(
*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); });
}
} }
double common_average(const std::vector<double>& v) double common_average(const std::vector<double>& v)
...@@ -385,7 +397,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -385,7 +397,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
{ {
double avg = common_average(p.second); double avg = common_average(p.second);
op_times[p.first->op.name()] += avg; op_times[p.first->name()] += avg;
total_instruction_time += avg; total_instruction_time += avg;
} }
double calculate_overhead_time = total_time - total_instruction_time; double calculate_overhead_time = total_time - total_instruction_time;
......
...@@ -25,26 +25,26 @@ void simplify_reshapes::apply(program& p) const ...@@ -25,26 +25,26 @@ void simplify_reshapes::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(not is_reshaper(ins->op.name())) if(not is_reshaper(ins->name()))
continue; continue;
if(ins->output.size() != 1) if(ins->outputs().size() != 1)
continue; continue;
if(is_reshaper(ins->output.front()->op.name())) if(is_reshaper(ins->outputs().front()->name()))
continue; continue;
// Gather reshapes // Gather reshapes
std::vector<instruction_ref> reshapes{ins}; std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()->op.name())) while(is_reshaper(reshapes.back()->name()))
{ {
assert(!reshapes.back()->arguments.empty()); assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->arguments.front())); assert(p.has_instruction(reshapes.back()->inputs().front()));
reshapes.push_back(reshapes.back()->arguments.front()); reshapes.push_back(reshapes.back()->inputs().front());
} }
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()}; std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
for(auto start : iterator_for(reshapes)) for(auto start : iterator_for(reshapes))
{ {
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) { auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->result == (*start)->result and i != (*start); return i->get_shape() == (*start)->get_shape() and i != (*start);
}); });
if(last != reshapes.rend()) if(last != reshapes.rend())
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment