"vscode:/vscode.git/clone" did not exist on "7248a810d97ca8ceb999cc0a9e2bf58adc68f263"
Commit 08950f35 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into int8_miopen_call

parents e862412f 1af75182
...@@ -72,7 +72,9 @@ struct instruction ...@@ -72,7 +72,9 @@ struct instruction
static void static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args); replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
argument eval() const; bool can_eval() const;
argument eval(bool check_eval = true) const;
void finalize(context& ctx); void finalize(context& ctx);
......
...@@ -30,8 +30,16 @@ const operation& get_operation(instruction_ref ins); ...@@ -30,8 +30,16 @@ const operation& get_operation(instruction_ref ins);
struct program struct program
{ {
program(); program();
// move constructor
program(program&&) noexcept; program(program&&) noexcept;
program& operator=(program&&) noexcept;
// copy constructor
program(const program&);
// copy assignment operator
program& operator=(program);
~program() noexcept; ~program() noexcept;
using parameter_map = std::unordered_map<std::string, argument>; using parameter_map = std::unordered_map<std::string, argument>;
...@@ -118,6 +126,9 @@ struct program ...@@ -118,6 +126,9 @@ struct program
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
private:
void assign(const program& p);
private: private:
std::unique_ptr<program_impl> impl; std::unique_ptr<program_impl> impl;
}; };
......
...@@ -162,7 +162,24 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) ...@@ -162,7 +162,24 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old->remove_output(*this); old->remove_output(*this);
} }
argument instruction::eval() const bool instruction::can_eval() const
{
if(op.name() == "@literal")
{
return true;
}
else if(is_context_free(op))
{
return std::all_of(
this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); });
}
else
{
return false;
}
}
argument instruction::eval(bool check_eval) const
{ {
if(op.name() == "@literal") if(op.name() == "@literal")
{ {
...@@ -170,14 +187,13 @@ argument instruction::eval() const ...@@ -170,14 +187,13 @@ argument instruction::eval() const
} }
if(is_context_free(op)) if(is_context_free(op))
{ {
if(check_eval and not this->can_eval())
return {};
std::vector<argument> args; std::vector<argument> args;
for(auto&& arg : this->inputs()) std::transform(this->inputs().begin(),
{ this->inputs().end(),
argument a = arg->eval(); std::back_inserter(args),
if(a.empty()) [](auto arg) { return arg->eval(false); });
return {};
args.push_back(a);
}
return op.compute(result, args); return op.compute(result, args);
} }
return {}; return {};
......
...@@ -86,8 +86,70 @@ static void print_program(const program& p, F print_func) ...@@ -86,8 +86,70 @@ static void print_program(const program& p, F print_func)
program::program() : impl(std::make_unique<program_impl>()) {} program::program() : impl(std::make_unique<program_impl>()) {}
program::program(program&&) noexcept = default; program::program(program&&) noexcept = default;
program& program::operator=(program&&) noexcept = default; program::~program() noexcept = default;
program::~program() noexcept = default;
// copy constructor
program::program(const program& p) { assign(p); }
// copy assignment operator
program& program::operator=(program p)
{
std::swap(p.impl, this->impl);
return *this;
}
void program::assign(const program& p)
{
// clean the current program
if(!impl)
{
impl = std::make_unique<program_impl>();
}
else if(!impl->instructions.empty())
{
impl->instructions.clear();
}
impl->ctx = p.impl->ctx;
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(p))
{
instruction_ref copy_ins{};
if(ins->name() == "@literal")
{
auto l = ins->get_literal();
copy_ins = impl->instructions.insert(impl->instructions.end(), instruction{l});
}
else if(ins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter;
auto s = ins->get_shape();
copy_ins = impl->instructions.insert(impl->instructions.end(),
{builtin::param{name}, std::move(s), {}});
}
else if(ins->name() == "@outline")
{
auto s = ins->get_shape();
copy_ins =
impl->instructions.insert(impl->instructions.end(), {builtin::outline{s}, s, {}});
}
else
{
// retrieve its mapped input
auto inputs = ins->inputs();
// ensure all inputs have its corresponding copy instructions
assert(std::all_of(
inputs.begin(), inputs.end(), [&](auto i) { return ins_map.count(i) > 0; }));
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return ins_map[i];
});
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
}
ins_map[ins] = copy_ins;
}
}
instruction_ref program::add_instruction(const operation& op, std::vector<instruction_ref> args) instruction_ref program::add_instruction(const operation& op, std::vector<instruction_ref> args)
{ {
......
...@@ -22,22 +22,32 @@ bool skip_propogate(instruction_ref ins) ...@@ -22,22 +22,32 @@ bool skip_propogate(instruction_ref ins)
void propagate_constant::apply(program& p) const void propagate_constant::apply(program& p) const
{ {
fix([&](auto self, auto ins) { for(auto i : iterator_for(p))
if(not skip_propogate(ins)) {
{ if(i->name() != "@literal")
auto r = ins->eval(); continue;
if(not r.empty()) if(i->outputs().empty())
continue;
fix([&](auto self, auto ins) {
std::unordered_set<instruction_ref> children(ins->outputs().begin(),
ins->outputs().end());
for(auto child : children)
{ {
assert(r.get_shape() == ins->get_shape()); if(skip_propogate(child))
auto l = p.add_literal(r.get_shape(), r.data()); {
p.replace_instruction(ins, l); self(child);
return; continue;
}
auto r = child->eval();
if(not r.empty())
{
assert(r.get_shape() == child->get_shape());
auto l = p.add_literal(r.get_shape(), r.data());
self(p.replace_instruction(child, l));
}
} }
} })(i);
std::unordered_set<instruction_ref> children(ins->inputs().begin(), ins->inputs().end()); }
for(auto child : children)
self(child);
})(std::prev(p.end()));
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -878,7 +878,7 @@ template <typename Op> ...@@ -878,7 +878,7 @@ template <typename Op>
struct cpu_binary struct cpu_binary
{ {
Op op; Op op;
std::string name() const { return op.name(); } std::string name() const { return "cpu::" + op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); } shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/cpu/target.hpp>
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -27,4 +31,78 @@ TEST_CASE(program_equality) ...@@ -27,4 +31,78 @@ TEST_CASE(program_equality)
EXPECT(x == y); EXPECT(x == y);
} }
TEST_CASE(program_copy)
{
auto create_program_1 = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5}};
std::vector<float> data(3 * 4 * 5);
std::iota(data.begin(), data.end(), 1.0f);
auto l2 = p.add_literal(migraphx::literal(s, data));
auto p1 = p.add_parameter("x", s);
auto po = p.add_outline(s);
auto sum = p.add_instruction(migraphx::op::add{}, l2, po);
p.add_instruction(migraphx::op::mul{}, sum, p1);
return p;
};
{
auto p1 = create_program_1();
migraphx::program p2{};
p2 = p1;
p2.compile(migraphx::cpu::target{});
EXPECT(p1 != p2);
p1.compile(migraphx::cpu::target{});
EXPECT(p1 == p2);
}
{
auto p1 = create_program_1();
auto p2(p1);
EXPECT(p1 == p2);
p1.compile(migraphx::cpu::target{});
EXPECT(p1 != p2);
p2 = p1;
EXPECT(p1 == p2);
}
{
auto p1 = create_program_1();
auto p2 = create_program();
EXPECT(p1 != p2);
p2 = p1;
EXPECT(p1 == p2);
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
EXPECT(p1 == p2);
}
{
migraphx::program p1;
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
migraphx::shape s3{migraphx::shape::float_type, {2, 6}};
auto para1 = p1.add_parameter("m1", s1);
auto para2 = p1.add_parameter("m2", s2);
auto para3 = p1.add_parameter("m3", s3);
p1.add_instruction(migraphx::op::dot{0.31f, 0.28f}, para1, para2, para3);
migraphx::program p2{};
p2 = p1;
EXPECT(p2 == p1);
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
EXPECT(p2 == p1);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment