Commit bbbf98a4 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add copy constructor and copy assignment operator of program

parent 849f7d92
......@@ -30,8 +30,19 @@ const operation& get_operation(instruction_ref ins);
struct program
{
program();
// move constructor
program(program&&) noexcept;
// copy constructor
program(const program&) noexcept;
// move assignment operator
program& operator=(program&&) noexcept;
// copy assignment operator
program& operator=(const program&) noexcept;
~program() noexcept;
using parameter_map = std::unordered_map<std::string, argument>;
......@@ -118,6 +129,9 @@ struct program
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
private:
void copy(const program& prog);
private:
std::unique_ptr<program_impl> impl;
};
......
......@@ -89,6 +89,71 @@ program::program(program&&) noexcept = default;
program& program::operator=(program&&) noexcept = default;
program::~program() noexcept = default;
// copy constructor
program::program(const program& p) noexcept
{
copy(p);
}
// copy assignment operator
program& program::operator=(const program& p) noexcept
{
if (this != &p)
{
copy(p);
}
return *this;
}
void program::copy(const program& p)
{
// clean the current program
if (!impl)
{
impl = std::make_unique<program_impl>();
}
else if (!impl->instructions.empty())
{
remove_instructions(begin(), end());
}
impl->ctx = p.impl->ctx;
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for (auto ins : iterator_for(p))
{
instruction_ref copy_ins{};
if (ins->name() == "@literal")
{
auto l = ins->get_literal();
copy_ins = impl->instructions.insert(impl->instructions.end(), instruction{l});
}
else if (ins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter;
auto s = ins->get_shape();
copy_ins = impl->instructions.insert(impl->instructions.end(), {builtin::param{std::move(name)}, std::move(s), {}});
}
else if (ins->name() == "@outline")
{
auto s = ins->get_shape();
copy_ins = impl->instructions.insert(impl->instructions.end(), {builtin::outline{s}, s, {}});
}
else
{
// retrieve its mapped input
auto inputs = ins->inputs();
// ensure all inputs have its corresponding copy instructions
assert(std::all_of(inputs.begin(), inputs.end(), [&](auto i) { return ins_map.count(i) > 0; }));
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&] (auto i) { return ins_map[i]; });
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
}
ins_map[ins] = copy_ins;
}
}
instruction_ref program::add_instruction(const operation& op, std::vector<instruction_ref> args)
{
return insert_instruction(impl->instructions.end(), op, std::move(args));
......
......@@ -2,6 +2,8 @@
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
......@@ -27,4 +29,41 @@ TEST_CASE(program_equality)
EXPECT(x == y);
}
TEST_CASE(program_copy)
{
auto create_program_1 = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5}};
std::vector<float> data(3 * 4 * 5);
std::iota(data.begin(), data.end(), 1.0f);
auto l2 = p.add_literal(migraphx::literal(s, data));
auto p1 = p.add_parameter("x", s);
auto po = p.add_outline(s);
auto sum = p.add_instruction(migraphx::op::add{}, l2, p1);
p.add_instruction(migraphx::op::mul{}, sum, po);
return p;
};
{
auto p1 = create_program_1();
auto p2 = p1;
EXPECT(p1 == p2);
}
{
auto p1 = create_program_1();
auto p2(p1);
EXPECT(p1 == p2);
}
{
auto p1 = create_program_1();
auto p2 = create_program();
p2 = p1;
EXPECT(p1 == p2);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment