Unverified Commit 7d972d2b authored by Scott Thornton's avatar Scott Thornton Committed by GitHub
Browse files

Merge pull request #105 from ROCmSoftwarePlatform/remove_concat

Remove concat
parents 1cdb49a6 4afdd0e9
......@@ -6,6 +6,7 @@ add_library(migraph
dead_code_elimination.cpp
eliminate_allocation.cpp
eliminate_contiguous.cpp
eliminate_concat.cpp
fwd_conv_batchnorm_rewrite.cpp
env.cpp
generate.cpp
......
#include <iterator>
#include <migraph/eliminate_concat.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/dfor.hpp>
namespace migraph {
void eliminate_concat::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
// Look for the concat operator
if(ins->name() != concat_opt.name())
continue;
// If any inputs are literals then abort
if(std::any_of(ins->inputs().begin() + 1, ins->inputs().end(), [](auto arg) {
return arg->name() == "@literal";
}))
continue;
// We can only do this optimization when concat axis is either the leftmost
// axis OR the sizes to the left of this axis are all equal to 1
// Since we've already checked that the non-axis dimensions are identical
// we only need to check the first input
auto lens = ins->inputs().front()->get_shape().lens();
auto concat_op = concat_opt.get_concat(ins->get_operator());
if(concat_op.axis == 0 ||
std::all_of(lens.begin(), lens.begin() + concat_op.axis, [](auto x) { return x == 1; }))
{
// Last input should be an allocation
auto last = ins->inputs().back();
if(last->name() != concat_opt.allocate())
continue;
// Where are the allocations for the tensors to be concatenated?
std::vector<instruction_ref> allocations;
for(auto ins2 = ins->inputs().begin(); ins2 != ins->inputs().end() - 1; ins2++)
{
auto last2 = (*ins2)->inputs().back();
if(last2->name() == concat_opt.allocate())
{
allocations.push_back(last2);
}
}
// Need to sort the allocations, so that we know where to
// insert the "super"-allocation
std::sort(
allocations.begin(), allocations.end(), [&](instruction_ref x, instruction_ref y) {
return std::distance(p.begin(), x) < std::distance(p.begin(), y);
});
// Move "super" allocation to the front
auto first = allocations.front();
auto super = p.move_instruction(last, first);
std::size_t offset = 0;
for(auto x : allocations)
{
migraph::op::load op{x->get_shape(), offset};
// migraph::op::load op{x->get_shape(), 0};
p.replace_instruction(x, op, {super});
offset += x->get_shape().bytes();
}
std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
p.replace_instruction(ins, migraph::op::identity{}, args);
}
}
}
} // namespace migraph
#ifndef MIGRAPH_GUARD_CONCAT_OPT_HPP
#define MIGRAPH_GUARD_CONCAT_OPT_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraph/operation.hpp>
#include <migraph/operators.hpp>
namespace migraph {
struct program;
#ifdef DOXYGEN
/// An interface for target-dependent optimization for the concat instruction
struct concat_optimization
{
/// The name of the target-dependent concat operator
std::string name() const;
/// A name of the target-dependent allocate operator
std::string allocate() const;
/// Return the target-independent concat operator
op::concat get_concat(const operation& op) const;
};
#else
/*
* Type-erased interface for:
*
* struct concat_optimization
* {
* std::string name() const;
* std::string allocate() const;
* op::concat get_concat(const operation& op) const;
* };
*
*/
struct concat_optimization
{
// Constructors
concat_optimization() = default;
template <typename PrivateDetailTypeErasedT>
concat_optimization(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
concat_optimization& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().name();
}
std::string allocate() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().allocate();
}
op::concat get_concat(const operation& op) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_concat(op);
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual std::string allocate() const = 0;
virtual op::concat get_concat(const operation& op) const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
std::string name() const override { return private_detail_te_value.name(); }
std::string allocate() const override { return private_detail_te_value.allocate(); }
op::concat get_concat(const operation& op) const override
{
return private_detail_te_value.get_concat(op);
}
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const concat_optimization* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(concat_optimization* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(concat_optimization& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const concat_optimization& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP
#define MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
#include <migraph/concat_opt.hpp>
namespace migraph {
struct program;
struct eliminate_concat
{
concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
......@@ -603,9 +603,14 @@ struct unary
}
};
struct identity : unary
struct identity
{
std::string name() const { return "identity"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
};
struct abs : unary
......
......@@ -193,6 +193,7 @@ void memory_coloring_impl::register_operand_alias()
operand_alias["transpose"] = 0;
operand_alias["flatten"] = 0;
operand_alias["broadcast"] = 0;
operand_alias["identity"] = 0;
operand_alias["reshape"] = 0;
operand_alias["pass"] = 0;
operand_alias["scalar"] = 0;
......
#ifndef MIGRAPH_GUARD_RTGLIB_CONCAT_GPU_OPT_HPP
#define MIGRAPH_GUARD_RTGLIB_CONCAT_GPU_OPT_HPP
#include <migraph/gpu/concat.hpp>
namespace migraph {
namespace gpu {
struct concat_gpu_optimization
{
std::string name() const { return "gpu::concat"; }
std::string allocate() const { return "hip::allocate"; }
migraph::op::concat get_concat(const migraph::operation& op) const
{
return migraph::any_cast<migraph::gpu::hip_concat>(op).op;
}
};
} // namespace gpu
} // namespace migraph
#endif
......@@ -15,6 +15,8 @@
#include <migraph/eliminate_contiguous.hpp>
#include <migraph/common_subexpression_elimination.hpp>
#include <migraph/fwd_conv_batchnorm_rewrite.hpp>
#include <migraph/eliminate_concat.hpp>
#include <migraph/gpu/concat_gpu_opt.hpp>
namespace migraph {
namespace gpu {
......@@ -38,6 +40,8 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
simplify_reshapes{},
dead_code_elimination{},
lowering{ctx},
eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{},
eliminate_contiguous{},
dead_code_elimination{},
fuse_ops{&ctx},
......
#include <migraph/eliminate_concat.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct concat
{
concat(std::size_t axis) { op.axis = axis; }
migraph::op::concat op;
std::string name() const { return "eliminate_concat::concat"; }
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
{
return op.compute_shape(std::move(inputs));
}
migraph::argument compute(migraph::context&,
const migraph::shape& output_shape,
const std::vector<migraph::argument>&) const
{
return {output_shape};
}
};
struct concat_test_optimization
{
/// A unique name used to identify the concat optimization
std::string name() const { return "eliminate_concat::concat"; }
/// A unique name used to identify the allocate operator
std::string allocate() const { return "allocate"; }
/// Return the lowered concat operator
migraph::op::concat get_concat(const migraph::operation& op) const
{
return migraph::any_cast<concat>(op).op;
}
};
struct eliminate_concat_target
{
std::size_t align = 32;
std::string name() const { return "eliminate_target"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {migraph::eliminate_concat{concat_test_optimization{}},
migraph::dead_code_elimination{}};
}
migraph::context get_context() const { return {}; }
};
struct allocate
{
migraph::shape s{};
std::string name() const { return "allocate"; }
migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const
{
migraph::check_shapes{inputs}.has(0);
return s;
}
migraph::argument compute(migraph::context&,
const migraph::shape& output_shape,
const std::vector<migraph::argument>&) const
{
return {output_shape};
}
};
struct fred_op
{
std::string name() const { return "fred_op"; }
migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const
{
migraph::check_shapes{inputs}.has(1);
return inputs.at(0);
}
migraph::argument compute(migraph::context&,
const migraph::shape&,
const std::vector<migraph::argument>& args) const
{
return args.at(0);
}
};
void basic()
{
auto create_test_program = []() {
migraph::program p;
auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1;
auto a4 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4);
return p;
};
auto create_control_program = []() {
migraph::program p;
auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 10, 8, 8}}});
auto l1 = p.add_instruction(
migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 2, 8, 8}}, 0}, {a1});
auto p1 = p.add_instruction(fred_op{}, l1);
auto l2 = p.add_instruction(
migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 3, 8, 8}}, 512}, {a1});
auto p2 = p.add_instruction(fred_op{}, l2);
auto l3 = p.add_instruction(
migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 5, 8, 8}}, 1280},
{a1});
auto p3 = p.add_instruction(fred_op{}, l3);
p.add_instruction(migraph::op::identity{}, {a1, p1, p2, p3});
return p;
};
auto p1 = create_test_program();
auto p2 = create_control_program();
p1.compile(eliminate_concat_target{});
EXPECT(p1 == p2);
}
void wont_work()
{
auto create_test_program = []() {
migraph::program p;
auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1;
auto a4 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4);
return p;
};
auto create_control_program = []() {
migraph::program p;
auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1;
auto a4 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4);
return p;
};
auto p1 = create_test_program();
auto p2 = create_control_program();
p1.compile(eliminate_concat_target{});
EXPECT(p1 == p2);
}
int main()
{
basic();
wont_work();
}
......@@ -697,6 +697,73 @@ struct test_concat2
}
};
struct test_concat_relu
{
migraph::program create_program() const
{
migraph::program p;
std::size_t axis = 0;
migraph::shape s0{migraph::shape::float_type, {2, 2}};
migraph::shape s1{migraph::shape::float_type, {3, 2}};
migraph::shape s2{migraph::shape::float_type, {1, 2}};
auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2);
auto r0 = p.add_instruction(migraph::op::relu{}, l0);
auto r1 = p.add_instruction(migraph::op::relu{}, l1);
auto r2 = p.add_instruction(migraph::op::relu{}, l2);
auto c0 = p.add_instruction(migraph::op::concat{axis}, r0, r1, r2);
p.add_instruction(migraph::op::relu{}, c0);
return p;
}
};
void manual_identity()
{
migraph::program p;
std::vector<float> data0 = {0, 1, 2, 3};
migraph::shape s0{migraph::shape::float_type, {2, 2}};
auto l0 = p.add_literal(migraph::literal{s0, data0});
p.add_instruction(migraph::op::identity{}, l0);
p.compile(migraph::gpu::target{});
migraph::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second));
}
auto result = migraph::gpu::from_gpu(p.eval(m));
std::cout << result << std::endl;
}
void manual_test_concat_relu()
{
migraph::program p;
std::size_t axis = 0;
std::vector<float> data0 = {0, 1, 2, 3};
std::vector<float> data1 = {4, 5, 6, 7, 8, 9};
std::vector<float> data2 = {10, 11};
migraph::shape s0{migraph::shape::float_type, {2, 2}};
migraph::shape s1{migraph::shape::float_type, {3, 2}};
migraph::shape s2{migraph::shape::float_type, {1, 2}};
auto l0 = p.add_literal(migraph::literal{s0, data0});
auto l1 = p.add_literal(migraph::literal{s1, data1});
auto l2 = p.add_literal(migraph::literal{s2, data2});
auto r0 = p.add_instruction(migraph::op::relu{}, l0);
auto r1 = p.add_instruction(migraph::op::relu{}, l1);
auto r2 = p.add_instruction(migraph::op::relu{}, l2);
auto c0 = p.add_instruction(migraph::op::concat{axis}, r0, r1, r2);
p.add_instruction(migraph::op::relu{}, c0);
p.compile(migraph::gpu::target{});
migraph::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second));
}
auto result = migraph::gpu::from_gpu(p.eval(m));
std::cout << result << std::endl;
}
struct test_conv_bn_relu_pooling2
{
static migraph::instruction_ref
......@@ -737,6 +804,7 @@ int main()
{
verify_program<test_concat>();
verify_program<test_concat2>();
verify_program<test_concat_relu>();
verify_program<test_add>();
verify_program<test_add_half>();
verify_program<test_mul>();
......
#ifndef MIGRAPH_GUARD_CONCAT_OPT_HPP
#define MIGRAPH_GUARD_CONCAT_OPT_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraph/operation.hpp>
#include <migraph/operators.hpp>
namespace migraph {
struct program;
#ifdef DOXYGEN
/// An interface for target-dependent optimization for the concat instruction
struct concat_optimization
{
/// The name of the target-dependent concat operator
std::string name() const;
/// A name of the target-dependent allocate operator
std::string allocate() const;
/// Return the target-independent concat operator
op::concat get_concat(const operation& op) const;
};
#else
<%
interface('concat_optimization',
virtual('name', returns='std::string', const=True),
virtual('allocate', returns='std::string', const=True),
virtual('get_concat', returns='op::concat', op='const operation&', const=True)
)
%>
#endif
} // namespace migraph
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment