Commit 08950f35 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into int8_miopen_call

parents e862412f 1af75182
......@@ -72,7 +72,9 @@ struct instruction
static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
argument eval() const;
bool can_eval() const;
argument eval(bool check_eval = true) const;
void finalize(context& ctx);
......
......@@ -30,8 +30,16 @@ const operation& get_operation(instruction_ref ins);
struct program
{
program();
// move constructor
program(program&&) noexcept;
program& operator=(program&&) noexcept;
// copy constructor
program(const program&);
// copy assignment operator
program& operator=(program);
~program() noexcept;
using parameter_map = std::unordered_map<std::string, argument>;
......@@ -118,6 +126,9 @@ struct program
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
private:
void assign(const program& p);
private:
std::unique_ptr<program_impl> impl;
};
......
......@@ -162,7 +162,24 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old->remove_output(*this);
}
argument instruction::eval() const
bool instruction::can_eval() const
{
if(op.name() == "@literal")
{
return true;
}
else if(is_context_free(op))
{
return std::all_of(
this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); });
}
else
{
return false;
}
}
argument instruction::eval(bool check_eval) const
{
if(op.name() == "@literal")
{
......@@ -170,14 +187,13 @@ argument instruction::eval() const
}
if(is_context_free(op))
{
std::vector<argument> args;
for(auto&& arg : this->inputs())
{
argument a = arg->eval();
if(a.empty())
if(check_eval and not this->can_eval())
return {};
args.push_back(a);
}
std::vector<argument> args;
std::transform(this->inputs().begin(),
this->inputs().end(),
std::back_inserter(args),
[](auto arg) { return arg->eval(false); });
return op.compute(result, args);
}
return {};
......
......@@ -86,9 +86,71 @@ static void print_program(const program& p, F print_func)
program::program() : impl(std::make_unique<program_impl>()) {}
program::program(program&&) noexcept = default;
program& program::operator=(program&&) noexcept = default;
program::~program() noexcept = default;
// copy constructor
program::program(const program& p) { assign(p); }
// copy assignment operator
program& program::operator=(program p)
{
std::swap(p.impl, this->impl);
return *this;
}
void program::assign(const program& p)
{
// clean the current program
if(!impl)
{
impl = std::make_unique<program_impl>();
}
else if(!impl->instructions.empty())
{
impl->instructions.clear();
}
impl->ctx = p.impl->ctx;
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(p))
{
instruction_ref copy_ins{};
if(ins->name() == "@literal")
{
auto l = ins->get_literal();
copy_ins = impl->instructions.insert(impl->instructions.end(), instruction{l});
}
else if(ins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter;
auto s = ins->get_shape();
copy_ins = impl->instructions.insert(impl->instructions.end(),
{builtin::param{name}, std::move(s), {}});
}
else if(ins->name() == "@outline")
{
auto s = ins->get_shape();
copy_ins =
impl->instructions.insert(impl->instructions.end(), {builtin::outline{s}, s, {}});
}
else
{
// retrieve its mapped input
auto inputs = ins->inputs();
// ensure all inputs have its corresponding copy instructions
assert(std::all_of(
inputs.begin(), inputs.end(), [&](auto i) { return ins_map.count(i) > 0; }));
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return ins_map[i];
});
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
}
ins_map[ins] = copy_ins;
}
}
instruction_ref program::add_instruction(const operation& op, std::vector<instruction_ref> args)
{
return insert_instruction(impl->instructions.end(), op, std::move(args));
......
......@@ -22,22 +22,32 @@ bool skip_propogate(instruction_ref ins)
void propagate_constant::apply(program& p) const
{
for(auto i : iterator_for(p))
{
if(i->name() != "@literal")
continue;
if(i->outputs().empty())
continue;
fix([&](auto self, auto ins) {
if(not skip_propogate(ins))
std::unordered_set<instruction_ref> children(ins->outputs().begin(),
ins->outputs().end());
for(auto child : children)
{
auto r = ins->eval();
if(skip_propogate(child))
{
self(child);
continue;
}
auto r = child->eval();
if(not r.empty())
{
assert(r.get_shape() == ins->get_shape());
assert(r.get_shape() == child->get_shape());
auto l = p.add_literal(r.get_shape(), r.data());
p.replace_instruction(ins, l);
return;
self(p.replace_instruction(child, l));
}
}
std::unordered_set<instruction_ref> children(ins->inputs().begin(), ins->inputs().end());
for(auto child : children)
self(child);
})(std::prev(p.end()));
})(i);
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -878,7 +878,7 @@ template <typename Op>
struct cpu_binary
{
Op op;
std::string name() const { return op.name(); }
std::string name() const { return "cpu::" + op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
......
......@@ -2,6 +2,10 @@
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/cpu/target.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
......@@ -27,4 +31,78 @@ TEST_CASE(program_equality)
EXPECT(x == y);
}
TEST_CASE(program_copy)
{
auto create_program_1 = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5}};
std::vector<float> data(3 * 4 * 5);
std::iota(data.begin(), data.end(), 1.0f);
auto l2 = p.add_literal(migraphx::literal(s, data));
auto p1 = p.add_parameter("x", s);
auto po = p.add_outline(s);
auto sum = p.add_instruction(migraphx::op::add{}, l2, po);
p.add_instruction(migraphx::op::mul{}, sum, p1);
return p;
};
{
auto p1 = create_program_1();
migraphx::program p2{};
p2 = p1;
p2.compile(migraphx::cpu::target{});
EXPECT(p1 != p2);
p1.compile(migraphx::cpu::target{});
EXPECT(p1 == p2);
}
{
auto p1 = create_program_1();
auto p2(p1);
EXPECT(p1 == p2);
p1.compile(migraphx::cpu::target{});
EXPECT(p1 != p2);
p2 = p1;
EXPECT(p1 == p2);
}
{
auto p1 = create_program_1();
auto p2 = create_program();
EXPECT(p1 != p2);
p2 = p1;
EXPECT(p1 == p2);
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
EXPECT(p1 == p2);
}
{
migraphx::program p1;
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
migraphx::shape s3{migraphx::shape::float_type, {2, 6}};
auto para1 = p1.add_parameter("m1", s1);
auto para2 = p1.add_parameter("m2", s2);
auto para3 = p1.add_parameter("m3", s3);
p1.add_instruction(migraphx::op::dot{0.31f, 0.28f}, para1, para2, para3);
migraphx::program p2{};
p2 = p1;
EXPECT(p2 == p1);
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
EXPECT(p2 == p1);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment