"docs/vscode:/vscode.git/clone" did not exist on "eb39749fa235446ef7960f400bbf4c5de903000f"
Commit 15385fb1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents f7f02979 b606ed4f
...@@ -30,8 +30,16 @@ const operation& get_operation(instruction_ref ins); ...@@ -30,8 +30,16 @@ const operation& get_operation(instruction_ref ins);
struct program struct program
{ {
program(); program();
// move constructor
program(program&&) noexcept; program(program&&) noexcept;
program& operator=(program&&) noexcept;
// copy constructor
program(const program&);
// copy assignment operator
program& operator=(program);
~program() noexcept; ~program() noexcept;
using parameter_map = std::unordered_map<std::string, argument>; using parameter_map = std::unordered_map<std::string, argument>;
...@@ -108,6 +116,7 @@ struct program ...@@ -108,6 +116,7 @@ struct program
void debug_print() const; void debug_print() const;
void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins) const;
void debug_print(const std::vector<instruction_ref>& inss) const; void debug_print(const std::vector<instruction_ref>& inss) const;
void print_graph(std::ostream& os) const;
void dry_run(parameter_map params) const; void dry_run(parameter_map params) const;
...@@ -117,6 +126,9 @@ struct program ...@@ -117,6 +126,9 @@ struct program
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
private:
void assign(const program& p);
private: private:
std::unique_ptr<program_impl> impl; std::unique_ptr<program_impl> impl;
}; };
......
#ifndef MIGRAPHX_GUARD_RTGLIB_CONSTANT_PROPAGATE_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_PROPAGATE_CONSTANT_HPP
#define MIGRAPHX_GUARD_RTGLIB_CONSTANT_PROPAGATE_HPP #define MIGRAPHX_GUARD_RTGLIB_PROPAGATE_CONSTANT_HPP
#include <string> #include <string>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -12,9 +12,9 @@ struct program; ...@@ -12,9 +12,9 @@ struct program;
/** /**
* Replace instructions which take all literals with a literal of the computation. * Replace instructions which take all literals with a literal of the computation.
*/ */
struct constant_propagate struct propagate_constant
{ {
std::string name() const { return "constant_propagate"; } std::string name() const { return "propagate_constant"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
...@@ -12,7 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -12,7 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace detail { namespace detail {
template <class String, class T> template <class String, class T>
auto generic_find_impl(rank<2>, String&& s, const T& x) -> decltype(s.begin() + s.find(x), s.npos) auto generic_find_impl(rank<2>, String&& s, const T& x) -> decltype(s.npos, s.begin() + s.find(x))
{ {
auto index = s.find(x); auto index = s.find(x);
if(index == s.npos) if(index == s.npos)
......
...@@ -38,8 +38,9 @@ inline std::string join_strings(Strings strings, const std::string& delim) ...@@ -38,8 +38,9 @@ inline std::string join_strings(Strings strings, const std::string& delim)
return ""; return "";
auto nit = std::next(it); auto nit = std::next(it);
return std::accumulate( return std::accumulate(nit, strings.end(), *it, [&](std::string x, std::string y) {
nit, strings.end(), *it, [&](std::string x, std::string y) { return x + delim + y; }); return std::move(x) + delim + std::move(y);
});
} }
template <class F> template <class F>
......
...@@ -162,7 +162,24 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) ...@@ -162,7 +162,24 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old->remove_output(*this); old->remove_output(*this);
} }
argument instruction::eval() const bool instruction::can_eval() const
{
if(op.name() == "@literal")
{
return true;
}
else if(is_context_free(op))
{
return std::all_of(
this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); });
}
else
{
return false;
}
}
argument instruction::eval(bool check_eval) const
{ {
if(op.name() == "@literal") if(op.name() == "@literal")
{ {
...@@ -170,14 +187,13 @@ argument instruction::eval() const ...@@ -170,14 +187,13 @@ argument instruction::eval() const
} }
if(is_context_free(op)) if(is_context_free(op))
{ {
if(check_eval and not this->can_eval())
return {};
std::vector<argument> args; std::vector<argument> args;
for(auto&& arg : this->inputs()) std::transform(this->inputs().begin(),
{ this->inputs().end(),
argument a = arg->eval(); std::back_inserter(args),
if(a.empty()) [](auto arg) { return arg->eval(false); });
return {};
args.push_back(a);
}
return op.compute(result, args); return op.compute(result, args);
} }
return {}; return {};
......
...@@ -32,7 +32,7 @@ auto read_cifar10_images(const std::string& full_path) ...@@ -32,7 +32,7 @@ auto read_cifar10_images(const std::string& full_path)
labels[i] = *pimage++; labels[i] = *pimage++;
for(size_t j = 0; j < nbytes_per_image; j++) for(size_t j = 0; j < nbytes_per_image; j++)
{ {
float v = *(pimage + j) / 255.0f; float v = float(*(pimage + j)) / 255.0f;
data[i * nbytes_per_image + j] = v; data[i * nbytes_per_image + j] = v;
} }
} }
......
...@@ -141,8 +141,8 @@ struct onnx_parser ...@@ -141,8 +141,8 @@ struct onnx_parser
if(broadcasted != 0) if(broadcasted != 0)
{ {
uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>(); uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
auto l = auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]); args[1]);
return prog.add_instruction(x, args[0], l); return prog.add_instruction(x, args[0], l);
} }
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
...@@ -207,7 +207,7 @@ struct onnx_parser ...@@ -207,7 +207,7 @@ struct onnx_parser
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x)
{ {
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) { add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
}); });
} }
...@@ -215,7 +215,7 @@ struct onnx_parser ...@@ -215,7 +215,7 @@ struct onnx_parser
template <class T> template <class T>
void add_variadic_op(std::string name, T x) void add_variadic_op(std::string name, T x)
{ {
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) { add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
return std::accumulate(std::next(args.begin()), return std::accumulate(std::next(args.begin()),
args.end(), args.end(),
args.front(), args.front(),
...@@ -306,7 +306,7 @@ struct onnx_parser ...@@ -306,7 +306,7 @@ struct onnx_parser
{ {
uint64_t axis = 1; uint64_t axis = 1;
auto l1 = prog.add_instruction(op, args[0], args[1]); auto l1 = prog.add_instruction(op, args[0], args[1]);
auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape()}, args[2]); auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape().lens()}, args[2]);
return prog.add_instruction(op::add{}, l1, l2); return prog.add_instruction(op::add{}, l1, l2);
} }
return prog.add_instruction(op, l0, args[1]); return prog.add_instruction(op, l0, args[1]);
...@@ -670,15 +670,15 @@ struct onnx_parser ...@@ -670,15 +670,15 @@ struct onnx_parser
auto&& bias_floats = attributes["bias"].floats(); auto&& bias_floats = attributes["bias"].floats();
bias = std::vector<float>(bias_floats.begin(), bias_floats.end()); bias = std::vector<float>(bias_floats.begin(), bias_floats.end());
} }
auto input_shape = args.front()->get_shape(); auto input_lens = args.front()->get_shape().lens();
auto scale_val = prog.add_literal(scale); auto scale_val = prog.add_literal(scale);
auto bias_vals = prog.add_literal( auto bias_vals = prog.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias});
auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_shape}, scale_val); auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor); auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_shape}, bias_vals); auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
} }
...@@ -1360,28 +1360,26 @@ struct onnx_parser ...@@ -1360,28 +1360,26 @@ struct onnx_parser
static literal parse_tensor(const onnx::TensorProto& t) static literal parse_tensor(const onnx::TensorProto& t)
{ {
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end()); std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty())
{
dims = {1};
}
if(t.has_raw_data()) if(t.has_raw_data())
{ {
const std::string& s = t.raw_data(); const std::string& s = t.raw_data();
switch(t.data_type()) switch(t.data_type())
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error(""); case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: return literal{{shape::float_type, dims}, s.data()}; case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
case onnx::TensorProto::UINT8: throw std::runtime_error(""); case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::INT8: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::UINT16: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::UINT16:
case onnx::TensorProto::INT16: return literal{{shape::int32_type, dims}, s.data()}; return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT32: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::INT16: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT64: return literal{{shape::int64_type, dims}, s.data()}; case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::FLOAT16: return literal{{shape::half_type, dims}, s.data()}; case onnx::TensorProto::FLOAT16:
case onnx::TensorProto::DOUBLE: return literal{{shape::double_type, dims}, s.data()}; return create_literal(shape::half_type, dims, s.data());
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, s.data());
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
...@@ -1393,21 +1391,21 @@ struct onnx_parser ...@@ -1393,21 +1391,21 @@ struct onnx_parser
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error(""); case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: case onnx::TensorProto::FLOAT:
return literal{{shape::float_type, dims}, t.float_data().begin(), t.float_data().end()}; return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::UINT8: throw std::runtime_error(""); case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: case onnx::TensorProto::INT8:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::UINT16: case onnx::TensorProto::UINT16:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT16: case onnx::TensorProto::INT16:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT32: case onnx::TensorProto::INT32:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT64: case onnx::TensorProto::INT64:
return literal{{shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()}; return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16:
{ {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end()); std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
...@@ -1416,11 +1414,10 @@ struct onnx_parser ...@@ -1416,11 +1414,10 @@ struct onnx_parser
data_uint16.end(), data_uint16.end(),
std::back_inserter(data_half), std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); }); [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return literal{{shape::half_type, dims}, data_half.begin(), data_half.end()}; return create_literal(shape::half_type, dims, data_half);
} }
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return literal{ return create_literal(shape::double_type, dims, t.double_data());
{shape::double_type, dims}, t.double_data().begin(), t.double_data().end()};
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
...@@ -1429,6 +1426,23 @@ struct onnx_parser ...@@ -1429,6 +1426,23 @@ struct onnx_parser
MIGRAPHX_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
static literal
create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
{
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty())
return literal{{shape_type}, data};
return literal{{shape_type, dims}, data};
}
template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
{
if(dims.empty())
return literal{{shape_type}, data.begin(), data.end()};
return literal{{shape_type, dims}, data.begin(), data.end()};
}
static shape parse_type(const onnx::TypeProto& t) static shape parse_type(const onnx::TypeProto& t)
{ {
shape::type_t shape_type{}; shape::type_t shape_type{};
......
...@@ -63,11 +63,11 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -63,11 +63,11 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
} }
} }
long long offset = 0; std::size_t offset = 0;
while(!conflict_queue.empty()) while(!conflict_queue.empty())
{ {
live_range* range = conflict_queue.top(); live_range* range = conflict_queue.top();
long long iter_offset = range->offset; std::size_t iter_offset = range->offset;
if(offset > iter_offset) if(offset > iter_offset)
{ {
offset = std::max(offset, iter_offset + range->size); offset = std::max(offset, iter_offset + range->size);
...@@ -97,7 +97,7 @@ void memory_coloring_impl::build() ...@@ -97,7 +97,7 @@ void memory_coloring_impl::build()
if(num_of_instrs == 0) if(num_of_instrs == 0)
return; return;
int cur_points = num_of_instrs * 2; auto cur_points = num_of_instrs * 2;
instruction_ref iter = p_program->end(); instruction_ref iter = p_program->end();
instruction_ref begin = p_program->begin(); instruction_ref begin = p_program->begin();
std::vector<instruction_ref> dead_instrs; std::vector<instruction_ref> dead_instrs;
...@@ -193,13 +193,13 @@ void memory_coloring_impl::rewrite() ...@@ -193,13 +193,13 @@ void memory_coloring_impl::rewrite()
continue; continue;
std::size_t offset = 0; std::size_t offset = 0;
if(interval->get_offset() == invalid_offset) if(interval->get_offset() != invalid_offset)
{ {
assert(interval->result.bytes() == 0); offset = interval->get_offset();
} }
else else
{ {
offset = interval->get_offset(); assert(interval->result.bytes() == 0);
} }
if(is_allocate(ins)) if(is_allocate(ins))
...@@ -207,15 +207,6 @@ void memory_coloring_impl::rewrite() ...@@ -207,15 +207,6 @@ void memory_coloring_impl::rewrite()
p_program->replace_instruction( p_program->replace_instruction(
ins, op::load{ins->get_shape(), offset}, scratch_param); ins, op::load{ins->get_shape(), offset}, scratch_param);
} }
else if(is_literal(ins))
{
#if 0
auto pre = p_program->add_literal(ins->lit);
bool pre_copy = (interval->get_begin() < earliest_end_point);
p_program->replace_instruction(
ins, write_literal{offset, pre_copy}, scratch_param, pre);
#endif
}
} }
} }
MIGRAPHX_DEBUG(dump("---After rewrite---")); MIGRAPHX_DEBUG(dump("---After rewrite---"));
......
...@@ -21,15 +21,15 @@ ...@@ -21,15 +21,15 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
static const int invalid_offset = -1; static const std::size_t invalid_offset = std::numeric_limits<std::size_t>::max();
struct live_range struct live_range
{ {
int begin; // begin point in the instruction stream. std::size_t begin; // begin point in the instruction stream.
int end; // end point in the instruction stream. std::size_t end; // end point in the instruction stream.
long long offset; // offset to base pointer of allocated memory trunk. std::size_t offset; // offset to base pointer of allocated memory trunk.
int vn; // value number that identifies this live_range. std::size_t vn; // value number that identifies this live_range.
long long size; // size of required memory in bytes std::size_t size; // size of required memory in bytes
#ifdef MIGRAPHX_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
void dump(); void dump();
#endif #endif
...@@ -45,9 +45,9 @@ struct live_interval ...@@ -45,9 +45,9 @@ struct live_interval
is_live_on_entry = false; is_live_on_entry = false;
} }
void add_use(int use) { use_points.push_front(use); } void add_use(std::size_t use) { use_points.push_front(use); }
int get_begin() const { return segment.begin; } std::size_t get_begin() const { return segment.begin; }
int get_end() const { return segment.end; } std::size_t get_end() const { return segment.end; }
long long get_offset() const { return segment.offset; } long long get_offset() const { return segment.offset; }
#ifdef MIGRAPHX_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
...@@ -55,9 +55,9 @@ struct live_interval ...@@ -55,9 +55,9 @@ struct live_interval
#endif #endif
live_range segment; live_range segment;
int id; std::size_t id;
std::list<int> use_points; std::list<std::size_t> use_points;
int def_point; std::size_t def_point;
shape result; shape result;
bool is_literal; bool is_literal;
bool is_live_on_entry; bool is_live_on_entry;
...@@ -111,8 +111,8 @@ struct memory_coloring_impl ...@@ -111,8 +111,8 @@ struct memory_coloring_impl
{ {
if((range1.size == 0) || (range2.size == 0)) if((range1.size == 0) || (range2.size == 0))
return false; return false;
long long end1 = range1.offset + range1.size - 1; auto end1 = range1.offset + range1.size - 1;
long long end2 = range2.offset + range2.size - 1; auto end2 = range2.offset + range2.size - 1;
return ((end1 < range2.offset) || (end2 < range1.offset)); return ((end1 < range2.offset) || (end2 < range1.offset));
} }
void verify(); void verify();
...@@ -125,8 +125,8 @@ struct memory_coloring_impl ...@@ -125,8 +125,8 @@ struct memory_coloring_impl
{ {
bool operator()(const interval_ptr i1, const interval_ptr i2) const bool operator()(const interval_ptr i1, const interval_ptr i2) const
{ {
int len1 = i1->get_end() - i1->get_begin(); auto len1 = i1->get_end() - i1->get_begin();
int len2 = i2->get_end() - i2->get_begin(); auto len2 = i2->get_end() - i2->get_begin();
if(len1 != len2) if(len1 != len2)
{ {
return (len1 < len2); return (len1 < len2);
...@@ -158,7 +158,7 @@ struct memory_coloring_impl ...@@ -158,7 +158,7 @@ struct memory_coloring_impl
int num_of_lives; int num_of_lives;
int max_value_number; int max_value_number;
long long required_bytes; std::size_t required_bytes;
// The earliest program point where an live interval ends. // The earliest program point where an live interval ends.
int earliest_end_point; int earliest_end_point;
// The latest program point where an live interval ends. // The latest program point where an live interval ends.
......
#include <migraphx/program.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{
for(auto& p : passes)
{
trace("Pass: ", p.name());
p.apply(prog);
trace(prog);
#ifndef NDEBUG
trace("Validate ...");
auto invalid = prog.validate();
if(invalid != prog.end())
{
auto index = std::distance(prog.begin(), invalid);
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name());
}
trace();
#endif
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/time.hpp> #include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/pass_manager.hpp>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
...@@ -55,7 +56,7 @@ static void print_instruction(std::ostream& os, ...@@ -55,7 +56,7 @@ static void print_instruction(std::ostream& os,
} }
template <class F> template <class F>
static void print_program(std::ostream& os, const program& p, F annonate) static void print_program(const program& p, F print_func)
{ {
std::unordered_map<instruction_ref, std::string> names; std::unordered_map<instruction_ref, std::string> names;
int count = 0; int count = 0;
...@@ -76,11 +77,7 @@ static void print_program(std::ostream& os, const program& p, F annonate) ...@@ -76,11 +77,7 @@ static void print_program(std::ostream& os, const program& p, F annonate)
(void)arg; (void)arg;
} }
print_instruction(os, ins, names); print_func(ins, names);
annonate(ins, names);
os << std::endl;
count++; count++;
} }
...@@ -89,8 +86,70 @@ static void print_program(std::ostream& os, const program& p, F annonate) ...@@ -89,8 +86,70 @@ static void print_program(std::ostream& os, const program& p, F annonate)
program::program() : impl(std::make_unique<program_impl>()) {} program::program() : impl(std::make_unique<program_impl>()) {}
program::program(program&&) noexcept = default; program::program(program&&) noexcept = default;
program& program::operator=(program&&) noexcept = default; program::~program() noexcept = default;
program::~program() noexcept = default;
// copy constructor
program::program(const program& p) { assign(p); }
// copy assignment operator
program& program::operator=(program p)
{
std::swap(p.impl, this->impl);
return *this;
}
void program::assign(const program& p)
{
// clean the current program
if(!impl)
{
impl = std::make_unique<program_impl>();
}
else if(!impl->instructions.empty())
{
impl->instructions.clear();
}
impl->ctx = p.impl->ctx;
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(p))
{
instruction_ref copy_ins{};
if(ins->name() == "@literal")
{
auto l = ins->get_literal();
copy_ins = impl->instructions.insert(impl->instructions.end(), instruction{l});
}
else if(ins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter;
auto s = ins->get_shape();
copy_ins = impl->instructions.insert(impl->instructions.end(),
{builtin::param{name}, std::move(s), {}});
}
else if(ins->name() == "@outline")
{
auto s = ins->get_shape();
copy_ins =
impl->instructions.insert(impl->instructions.end(), {builtin::outline{s}, s, {}});
}
else
{
// retrieve its mapped input
auto inputs = ins->inputs();
// ensure all inputs have its corresponding copy instructions
assert(std::all_of(
inputs.begin(), inputs.end(), [&](auto i) { return ins_map.count(i) > 0; }));
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return ins_map[i];
});
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
}
ins_map[ins] = copy_ins;
}
}
instruction_ref program::add_instruction(const operation& op, std::vector<instruction_ref> args) instruction_ref program::add_instruction(const operation& op, std::vector<instruction_ref> args)
{ {
...@@ -291,23 +350,7 @@ void program::compile(const target& t, tracer trace) ...@@ -291,23 +350,7 @@ void program::compile(const target& t, tracer trace)
trace = tracer{std::cout}; trace = tracer{std::cout};
trace(*this); trace(*this);
trace(); trace();
for(auto&& p : t.get_passes(this->impl->ctx)) run_passes(*this, t.get_passes(this->impl->ctx), trace);
{
trace("Pass: ", p.name());
p.apply(*this);
trace(*this);
#ifndef NDEBUG
trace("Validate ...");
auto invalid = this->validate();
if(invalid != impl->instructions.end())
{
auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name());
}
trace();
#endif
}
auto invalid = this->validate(); auto invalid = this->validate();
if(invalid != impl->instructions.end()) if(invalid != impl->instructions.end())
{ {
...@@ -475,10 +518,12 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -475,10 +518,12 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
double calculate_overhead_time = total_time - total_instruction_time; double calculate_overhead_time = total_time - total_instruction_time;
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time; double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
print_program(os, *this, [&](auto ins, auto&&) { print_program(*this, [&](auto ins, const auto& names) {
print_instruction(std::cout, ins, names);
double avg = common_average(ins_vec[ins]); double avg = common_average(ins_vec[ins]);
double percent = std::ceil(100.0 * avg / total_instruction_time); double percent = std::ceil(100.0 * avg / total_instruction_time);
os << ": " << avg << "ms, " << percent << "%"; os << ": " << avg << "ms, " << percent << "%";
os << std::endl;
}); });
os << std::endl; os << std::endl;
...@@ -516,7 +561,7 @@ void program::debug_print(instruction_ref ins) const ...@@ -516,7 +561,7 @@ void program::debug_print(instruction_ref ins) const
return; return;
} }
std::stringstream ss; std::stringstream ss;
print_program(ss, *this, [&](auto x, auto&& names) { print_program(*this, [&](auto x, const auto& names) {
if(x == ins) if(x == ins)
{ {
print_instruction(std::cout, x, names); print_instruction(std::cout, x, names);
...@@ -531,6 +576,32 @@ void program::debug_print(const std::vector<instruction_ref>& inss) const ...@@ -531,6 +576,32 @@ void program::debug_print(const std::vector<instruction_ref>& inss) const
std::cout << std::endl; std::cout << std::endl;
} }
static std::string enclose_name(const std::string& name)
{
return '"' + replace_string(name, "\"", "\\\"") + '"';
}
void program::print_graph(std::ostream& os) const
{
os << "digraph {" << std::endl;
os << "\trankdir=LR;" << std::endl;
print_program(*this, [&](auto ins, const auto& names) {
os << "\t" << enclose_name(names.at(ins))
<< "[label=" << enclose_name(to_string(ins->get_operator())) << "];";
os << std::endl;
if(!ins->inputs().empty())
{
for(auto&& arg : ins->inputs())
{
os << "\t" << enclose_name(names.at(arg)) << " -> " << enclose_name(names.at(ins));
os << "[label=" << enclose_name(to_string(ins->get_shape())) << "];";
os << std::endl;
}
}
});
os << "}" << std::endl;
}
void program::dry_run(std::unordered_map<std::string, argument> params) const void program::dry_run(std::unordered_map<std::string, argument> params) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
...@@ -539,14 +610,21 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const ...@@ -539,14 +610,21 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const
void program::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const void program::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
{ {
print_program(os, *this, [&](auto ins, auto&&) { a(ins); }); print_program(*this, [&](auto ins, const auto& names) {
print_instruction(os, ins, names);
a(ins);
os << std::endl;
});
} }
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); } bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const program& p) std::ostream& operator<<(std::ostream& os, const program& p)
{ {
print_program(os, p, [](auto&&...) {}); print_program(p, [&](auto ins, const auto& names) {
print_instruction(os, ins, names);
os << std::endl;
});
return os; return os;
} }
......
#include <migraphx/propagate_constant.hpp>
#include <migraphx/program.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/functional.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool skip_propogate(instruction_ref ins)
{
if(ins->name() == "@literal")
return true;
auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar())
return true;
if(s.scalar() and s.elements() != 1)
return true;
return false;
}
void propagate_constant::apply(program& p) const
{
for(auto i : iterator_for(p))
{
if(i->name() != "@literal")
continue;
if(i->outputs().empty())
continue;
fix([&](auto self, auto ins) {
std::unordered_set<instruction_ref> children(ins->outputs().begin(),
ins->outputs().end());
for(auto child : children)
{
if(skip_propogate(child))
{
self(child);
continue;
}
auto r = child->eval();
if(not r.empty())
{
assert(r.get_shape() == child->get_shape());
auto l = p.add_literal(r.get_shape(), r.data());
self(p.replace_instruction(child, l));
}
}
})(i);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -204,6 +205,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -204,6 +205,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
auto sih_lens = sih->get_shape().lens();
// bias // bias
instruction_ref bwb{}; instruction_ref bwb{};
...@@ -214,8 +216,8 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -214,8 +216,8 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
bwb = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wb); bwb = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wb);
brb = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rb); brb = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, rb);
} }
instruction_ref hidden_out = prog.end(); instruction_ref hidden_out = prog.end();
...@@ -514,6 +516,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -514,6 +516,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
// initial states // initial states
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
auto sih_lens = sih->get_shape().lens();
// bias // bias
instruction_ref bwbz{}; instruction_ref bwbz{};
...@@ -528,16 +531,16 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -528,16 +531,16 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias); auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
bwbz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbz); bwbz = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wbz);
bwbr = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbr); bwbr = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wbr);
bwbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh); bwbh = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias); auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias); auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias); auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brbz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbz); brbz = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, rbz);
brbr = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbr); brbr = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, rbr);
brbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh); brbh = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, rbh);
} }
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
...@@ -960,8 +963,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -960,8 +963,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// initial cell state // initial cell state
auto sic = prog.insert_instruction(ins, op::squeeze{{0}}, ic); auto sic = prog.insert_instruction(ins, op::squeeze{{0}}, ic);
auto ic_shape = sic->get_shape(); auto ic_lens = sic->get_shape().lens();
// bias // bias
instruction_ref wbi_brcst{}; instruction_ref wbi_brcst{};
...@@ -974,26 +977,27 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -974,26 +977,27 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref rbc_brcst{}; instruction_ref rbc_brcst{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wbi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rbi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias); auto rbi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
wbi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbi); wbi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, wbi);
rbi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbi); rbi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, rbi);
auto wbo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto wbo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto rbo = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias); auto rbo = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
wbo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbo); wbo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, wbo);
rbo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbo); rbo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, rbo);
auto wbf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias); auto wbf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
auto rbf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias); auto rbf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias);
wbf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbf); wbf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, wbf);
rbf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbf); rbf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, rbf);
auto wbc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias); auto wbc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias); auto rbc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias);
wbc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbc); wbc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, wbc);
rbc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbc); rbc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, rbc);
} }
// peep hole // peep hole
...@@ -1004,13 +1008,13 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1004,13 +1008,13 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
{ {
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph); auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph); auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi); pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphi);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph); auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho); ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, ppho);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph); auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf); pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphf);
} }
for(long i = 0; i < seq_len; ++i) for(long i = 0; i < seq_len; ++i)
...@@ -1201,5 +1205,14 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -1201,5 +1205,14 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
} }
} }
namespace op {
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -117,7 +117,7 @@ struct cpu_lrn ...@@ -117,7 +117,7 @@ struct cpu_lrn
int channels = output_shape.lens()[1]; int channels = output_shape.lens()[1];
int height = output_shape.lens()[2]; int height = output_shape.lens()[2];
int width = output_shape.lens()[3]; int width = output_shape.lens()[3];
float alphaoverarea = op.alpha / op.size; float alphaoverarea = op.alpha / float(op.size);
int radius = (op.size - 1) / 2; int radius = (op.size - 1) / 2;
par_dfor(n_batch, height, width)([&](int b, int h, int w) { par_dfor(n_batch, height, width)([&](int b, int h, int w) {
...@@ -165,15 +165,15 @@ struct cpu_convolution ...@@ -165,15 +165,15 @@ struct cpu_convolution
output_shape.lens()[2], output_shape.lens()[2],
output_shape.lens()[3])( output_shape.lens()[3])(
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) { [&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x = i * op.stride[0] - op.padding[0]; const auto start_x = i * op.stride[0] - op.padding[0];
const int start_y = j * op.stride[1] - op.padding[1]; const auto start_y = j * op.stride[1] - op.padding[1];
const int group_id = w / (wei_n / op.group); const auto group_id = w / (wei_n / op.group);
double acc = 0; double acc = 0;
dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) { dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) {
const int in_x = start_x + x; const auto in_x = start_x + x;
const int in_y = start_y + y; const auto in_y = start_y + y;
const int in_ch = group_id * wei_c + k; const auto in_ch = group_id * wei_c + k;
if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w) if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w)
{ {
acc += input(o, in_ch, in_x, in_y) * weights(w, k, x, y); acc += input(o, in_ch, in_x, in_y) * weights(w, k, x, y);
...@@ -209,10 +209,8 @@ struct cpu_im2col ...@@ -209,10 +209,8 @@ struct cpu_im2col
const std::size_t& stride_h = op.stride[0]; const std::size_t& stride_h = op.stride[0];
const std::size_t& stride_w = op.stride[1]; const std::size_t& stride_w = op.stride[1];
int kdiv2_h; auto kdiv2_h = kernel_h / 2;
int kdiv2_w; auto kdiv2_w = kernel_w / 2;
kdiv2_h = kernel_h / 2;
kdiv2_w = kernel_w / 2;
// calculate output sizes // calculate output sizes
const std::size_t col_height = (height - kernel_h + 2 * pad_h) / stride_h + 1; const std::size_t col_height = (height - kernel_h + 2 * pad_h) / stride_h + 1;
const std::size_t col_width = (width - kernel_w + 2 * pad_w) / stride_w + 1; const std::size_t col_width = (width - kernel_w + 2 * pad_w) / stride_w + 1;
...@@ -230,8 +228,8 @@ struct cpu_im2col ...@@ -230,8 +228,8 @@ struct cpu_im2col
dfor(channels, dfor(channels,
kernel_h, kernel_h,
kernel_w)([&](std::size_t c, std::size_t koffset, std::size_t loffset) { kernel_w)([&](std::size_t c, std::size_t koffset, std::size_t loffset) {
int idx = iinput + koffset - kdiv2_h; auto idx = iinput + koffset - kdiv2_h;
int jdx = jinput + loffset - kdiv2_w; auto jdx = jinput + loffset - kdiv2_w;
col(ldx, p) = ((idx >= 0) && (idx < height) && (jdx >= 0) && (jdx < width)) col(ldx, p) = ((idx >= 0) && (idx < height) && (jdx >= 0) && (jdx < width))
? input(0, c, idx, jdx) ? input(0, c, idx, jdx)
: 0; : 0;
...@@ -637,15 +635,38 @@ struct cpu_unary ...@@ -637,15 +635,38 @@ struct cpu_unary
{ {
Op op; Op op;
std::string name() const { return op.name(); } std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); } shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(1);
auto s = inputs.at(0);
if(s.packed())
{
return s;
}
else
{
return {s.type(), s.lens()};
}
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), op.fcn()); if(input.get_shape().standard())
{
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = op.fcn()(input(idx.begin(), idx.end()));
});
}
}); });
}); });
return result; return result;
} }
}; };
...@@ -665,20 +686,20 @@ struct softmax2d ...@@ -665,20 +686,20 @@ struct softmax2d
auto nw = input.get_shape().lens()[3]; auto nw = input.get_shape().lens()[3];
dfor(nb, nh, nw)([&](std::size_t b, std::size_t i, std::size_t j) { dfor(nb, nh, nw)([&](std::size_t b, std::size_t i, std::size_t j) {
value_type cmax = std::numeric_limits<value_type>::lowest(); value_type cmax = std::numeric_limits<value_type>::lowest();
for(int c = 0; c < nc; c++) for(std::size_t c = 0; c < nc; c++)
{ {
cmax = std::max(cmax, input(b, c, i, j)); cmax = std::max(cmax, input(b, c, i, j));
} }
for(int c = 0; c < nc; c++) for(std::size_t c = 0; c < nc; c++)
{ {
output(b, c, i, j) = std::exp(input(b, c, i, j) - cmax); output(b, c, i, j) = std::exp(input(b, c, i, j) - cmax);
} }
value_type sum = value_type(0); value_type sum = value_type(0);
for(int c = 0; c < nc; c++) for(std::size_t c = 0; c < nc; c++)
{ {
sum += output(b, c, i, j); sum += output(b, c, i, j);
} }
for(int c = 0; c < nc; c++) for(std::size_t c = 0; c < nc; c++)
{ {
output(b, c, i, j) = output(b, c, i, j) / sum; output(b, c, i, j) = output(b, c, i, j) / sum;
} }
...@@ -815,13 +836,29 @@ template <typename Op> ...@@ -815,13 +836,29 @@ template <typename Op>
struct cpu_binary struct cpu_binary
{ {
Op op; Op op;
std::string name() const { return op.name(); } std::string name() const { return "cpu::" + op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); } shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{
return s0;
}
else
{
return {s0.type(), s0.lens()};
}
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().packed() and input2.get_shape().packed()) auto s1 = input1.get_shape();
auto s2 = input2.get_shape();
if(s1 == s2 and s1.standard())
{ {
std::transform( std::transform(
input1.begin(), input1.end(), input2.begin(), output.begin(), op.fcn()); input1.begin(), input1.end(), input2.begin(), output.begin(), op.fcn());
...@@ -834,6 +871,7 @@ struct cpu_binary ...@@ -834,6 +871,7 @@ struct cpu_binary
}); });
} }
}); });
return result; return result;
} }
}; };
......
...@@ -65,6 +65,7 @@ add_library(migraphx_gpu ...@@ -65,6 +65,7 @@ add_library(migraphx_gpu
gather.cpp gather.cpp
lrn.cpp lrn.cpp
schedule_model.cpp schedule_model.cpp
adjust_allocation.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
......
...@@ -7,8 +7,8 @@ namespace gpu { ...@@ -7,8 +7,8 @@ namespace gpu {
shape miopen_abs::compute_shape(const std::vector<shape>& inputs) const shape miopen_abs::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).not_broadcasted(); check_shapes{inputs, *this}.has(2).packed();
return inputs.at(1); return inputs.at(0);
} }
argument miopen_abs::compute(context& ctx, argument miopen_abs::compute(context& ctx,
......
#include <migraphx/gpu/adjust_allocation.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <algorithm>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void adjust_allocation::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
// skip instruction with no input
if(ins->inputs().empty())
continue;
if(ins->name() == "load")
continue;
auto alias_ins = instruction::get_output_alias(ins, true);
if(alias_ins->name() == "hip::allocate")
{
// shape allocated is different from actual shape
// of the instruction, reallocate and replace the previous one
if(alias_ins->get_shape() != ins->get_shape())
{
auto alloc_ins = p.insert_instruction(ins, hip_allocate{ins->get_shape()});
p.replace_instruction(alias_ins, alloc_ins);
}
}
}
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -16,7 +16,7 @@ argument gather(hipStream_t stream, ...@@ -16,7 +16,7 @@ argument gather(hipStream_t stream,
std::vector<migraphx::argument> args, std::vector<migraphx::argument> args,
int axis) int axis)
{ {
int axis_index = (axis < 0) ? (axis + args[0].get_shape().lens().size()) : axis; auto axis_index = (axis < 0) ? (axis + args[0].get_shape().lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) { visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements(); std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
......
...@@ -162,7 +162,10 @@ struct hip_triadd ...@@ -162,7 +162,10 @@ struct hip_triadd
device::add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2)); device::add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3); return args.at(3);
} }
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
}; };
struct hip_triadd_relu struct hip_triadd_relu
...@@ -178,7 +181,10 @@ struct hip_triadd_relu ...@@ -178,7 +181,10 @@ struct hip_triadd_relu
device::add_relu(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2)); device::add_relu(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3); return args.at(3);
} }
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
}; };
struct hip_add_relu struct hip_add_relu
...@@ -194,7 +200,10 @@ struct hip_add_relu ...@@ -194,7 +200,10 @@ struct hip_add_relu
device::add_relu(ctx.get_stream().get(), args.at(2), args.at(0), args.at(1)); device::add_relu(ctx.get_stream().get(), args.at(2), args.at(0), args.at(1));
return args.at(2); return args.at(2);
} }
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
}; };
struct find_add_relu struct find_add_relu
...@@ -285,7 +294,10 @@ struct miopen_conv_bias ...@@ -285,7 +294,10 @@ struct miopen_conv_bias
void finalize(context& ctx, const shape&, const std::vector<shape>&) { f.compile(ctx); } void finalize(context& ctx, const shape&, const std::vector<shape>&) { f.compile(ctx); }
shape get_workspace(context& ctx) { return f.get_workspace(ctx); } shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
}; };
struct miopen_conv_bias_relu struct miopen_conv_bias_relu
...@@ -332,7 +344,10 @@ struct miopen_conv_bias_relu ...@@ -332,7 +344,10 @@ struct miopen_conv_bias_relu
} }
void finalize(context& ctx, const shape&, const std::vector<shape>&) { f.compile(ctx); } void finalize(context& ctx, const shape&, const std::vector<shape>&) { f.compile(ctx); }
shape get_workspace(context& ctx) { return f.get_workspace(ctx); } shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
}; };
template <class... Ms> template <class... Ms>
......
...@@ -17,7 +17,10 @@ struct miopen_abs ...@@ -17,7 +17,10 @@ struct miopen_abs
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
}; };
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment