Commit 20b1d690 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into tests

parents 17aaaa1e ba729cfc
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include <migraphx/pass.hpp> #include <migraphx/pass.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -34,10 +36,86 @@ struct target ...@@ -34,10 +36,86 @@ struct target
* @return The context to be used during compilation and execution. * @return The context to be used during compilation and execution.
*/ */
context get_context() const; context get_context() const;
/**
* @brief copy an argument to the current target.
*
* @param arg Input argument to be copied to the target
* @return Argument in the target.
*/
argument copy_to(const argument& arg) const;
/**
* @brief copy an argument from the current target.
*
* @param arg Input argument to be copied from the target
* @return Argument in the host.
*/
argument copy_from(const argument& arg) const;
/**
* @brief Allocate an argument based on the input shape
*
* @param s Shape of the argument to be allocated in the target
* @return Allocated argument in the target.
*/
argument allocate(const shape& s) const;
}; };
#else #else
template <class T>
auto target_allocate(rank<1>, T& x, const shape& s) -> decltype(x.allocate(s))
{
return x.allocate(s);
}
template <class T>
argument target_allocate(rank<0>, T& x, const shape&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument target_allocate(T& x, const shape& s)
{
return target_allocate(rank<1>{}, x, s);
}
template <class T>
auto copy_to_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_to(arg))
{
return x.copy_to(arg);
}
template <class T>
argument copy_to_target(rank<0>, T&, const argument& arg)
{
return arg;
}
template <class T>
argument copy_to_target(T& x, const argument& arg)
{
return copy_to_target(rank<1>{}, x, arg);
}
template <class T>
auto copy_from_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_from(arg))
{
return x.copy_from(arg);
}
template <class T>
argument copy_from_target(rank<0>, T&, const argument& arg)
{
return arg;
}
template <class T>
argument copy_from_target(T& x, const argument& arg)
{
return copy_from_target(rank<1>{}, x, arg);
}
/* /*
* Type-erased interface for: * Type-erased interface for:
* *
...@@ -46,6 +124,9 @@ struct target ...@@ -46,6 +124,9 @@ struct target
* std::string name() const; * std::string name() const;
* std::vector<pass> get_passes(context& ctx) const; * std::vector<pass> get_passes(context& ctx) const;
* context get_context() const; * context get_context() const;
* argument copy_to(const argument& input) const;
* argument copy_from(const argument& input) const;
* argument allocate(const shape& s) const;
* }; * };
* *
*/ */
...@@ -125,6 +206,24 @@ struct target ...@@ -125,6 +206,24 @@ struct target
return (*this).private_detail_te_get_handle().get_context(); return (*this).private_detail_te_get_handle().get_context();
} }
argument copy_to(const argument& input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().copy_to(input);
}
argument copy_from(const argument& input) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().copy_from(input);
}
argument allocate(const shape& s) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().allocate(s);
}
friend bool is_shared(const target& private_detail_x, const target& private_detail_y) friend bool is_shared(const target& private_detail_x, const target& private_detail_y)
{ {
return private_detail_x.private_detail_te_handle_mem_var == return private_detail_x.private_detail_te_handle_mem_var ==
...@@ -141,6 +240,9 @@ struct target ...@@ -141,6 +240,9 @@ struct target
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual std::vector<pass> get_passes(context& ctx) const = 0; virtual std::vector<pass> get_passes(context& ctx) const = 0;
virtual context get_context() const = 0; virtual context get_context() const = 0;
virtual argument copy_to(const argument& input) const = 0;
virtual argument copy_from(const argument& input) const = 0;
virtual argument allocate(const shape& s) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -181,6 +283,24 @@ struct target ...@@ -181,6 +283,24 @@ struct target
context get_context() const override { return private_detail_te_value.get_context(); } context get_context() const override { return private_detail_te_value.get_context(); }
argument copy_to(const argument& input) const override
{
return copy_to_target(private_detail_te_value, input);
}
argument copy_from(const argument& input) const override
{
return copy_from_target(private_detail_te_value, input);
}
argument allocate(const shape& s) const override
{
return target_allocate(private_detail_te_value, s);
}
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -132,7 +132,11 @@ struct tensor_view ...@@ -132,7 +132,11 @@ struct tensor_view
return m_data + this->size(); return m_data + this->size();
} }
std::vector<T> to_vector() const { return std::vector<T>(this->begin(), this->end()); } template <class U = T>
std::vector<U> to_vector() const
{
return std::vector<U>(this->begin(), this->end());
}
friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x) friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
{ {
......
...@@ -168,6 +168,7 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n ...@@ -168,6 +168,7 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n
{ {
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance; double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance;
auto error = rms_range(r1, r2); auto error = rms_range(r1, r2);
// cppcheck-suppress uninitvar
if(out_error != nullptr) if(out_error != nullptr)
*out_error = error; *out_error = error;
return error <= threshold; return error <= threshold;
......
...@@ -9,6 +9,7 @@ set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On) ...@@ -9,6 +9,7 @@ set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
add_library(migraphx_onnx onnx.cpp) add_library(migraphx_onnx onnx.cpp)
set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx) set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
rocm_set_soversion(migraphx_onnx ${PROJECT_VERSION})
rocm_clang_tidy_check(migraphx_onnx) rocm_clang_tidy_check(migraphx_onnx)
target_link_libraries(migraphx_onnx PRIVATE onnx-proto) target_link_libraries(migraphx_onnx PRIVATE onnx-proto)
target_link_libraries(migraphx_onnx PUBLIC migraphx) target_link_libraries(migraphx_onnx PUBLIC migraphx)
...@@ -19,7 +20,7 @@ rocm_install_targets( ...@@ -19,7 +20,7 @@ rocm_install_targets(
add_executable(read_onnx read_onnx.cpp) add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check(read_onnx) rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx migraphx_onnx) target_link_libraries(read_onnx migraphx_cpu migraphx_onnx)
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
......
This diff is collapsed.
...@@ -85,6 +85,9 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -85,6 +85,9 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
offset += (element_size - (offset % element_size)); offset += (element_size - (offset % element_size));
conflict_queue.pop(); conflict_queue.pop();
} }
// when int8 type is used, the offset could be any number
// if not 4-byte aligned, miopen int8 convolution can crash
offset = (offset + 3) / 4 * 4;
segment.offset = offset; segment.offset = offset;
MIGRAPHX_DEBUG(segment.dump()); MIGRAPHX_DEBUG(segment.dump());
required_bytes = std::max(required_bytes, offset + segment.size); required_bytes = std::max(required_bytes, offset + segment.size);
......
...@@ -107,7 +107,7 @@ struct memory_coloring_impl ...@@ -107,7 +107,7 @@ struct memory_coloring_impl
return ins->name() == "check_context"; return ins->name() == "check_context";
} }
static bool is_disjoin(live_range& range1, live_range& range2) static bool is_disjoin(const live_range& range1, const live_range& range2)
{ {
if((range1.size == 0) || (range2.size == 0)) if((range1.size == 0) || (range2.size == 0))
return false; return false;
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/target.hpp> #include <migraphx/target.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
......
...@@ -241,7 +241,7 @@ instruction_ref program::remove_instructions(instruction_ref first, instruction_ ...@@ -241,7 +241,7 @@ instruction_ref program::remove_instructions(instruction_ref first, instruction_
// TODO: Check every element // TODO: Check every element
assert(has_instruction(first)); assert(has_instruction(first));
std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); }); std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); });
assert(std::all_of(first, last, [&](instruction& ins) { return ins.outputs().empty(); })); assert(std::all_of(first, last, [&](const instruction& ins) { return ins.outputs().empty(); }));
return impl->instructions.erase(first, last); return impl->instructions.erase(first, last);
} }
......
...@@ -10,8 +10,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,8 +10,8 @@ inline namespace MIGRAPHX_INLINE_NS {
bool skip_propogate(instruction_ref ins) bool skip_propogate(instruction_ref ins)
{ {
if(ins->name() == "@literal") if(ins->name() == "contiguous")
return true; return skip_propogate(ins->inputs().front());
auto&& s = ins->get_shape(); auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar()) if(s.broadcasted() and not s.scalar())
return true; return true;
...@@ -33,7 +33,7 @@ void propagate_constant::apply(program& p) const ...@@ -33,7 +33,7 @@ void propagate_constant::apply(program& p) const
ins->outputs().end()); ins->outputs().end());
for(auto child : children) for(auto child : children)
{ {
if(skip_propogate(child)) if(child->name() == "@literal" or skip_propogate(child))
{ {
self(child); self(child);
continue; continue;
......
...@@ -12,12 +12,7 @@ if(MIGRAPHX_ENABLE_PYTHON) ...@@ -12,12 +12,7 @@ if(MIGRAPHX_ENABLE_PYTHON)
C_VISIBILITY_PRESET hidden C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden CXX_VISIBILITY_PRESET hidden
) )
if(MIGRAPHX_ENABLE_TF) target_link_libraries(migraphx_py PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_cpu)
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_tf migraphx_cpu)
target_compile_definitions(migraphx_py PRIVATE -DENABLE_TF)
else()
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_onnx migraphx_cpu)
endif()
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(migraphx_py PRIVATE migraphx_gpu) target_link_libraries(migraphx_py PRIVATE migraphx_gpu)
target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU) target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU)
......
...@@ -6,11 +6,9 @@ ...@@ -6,11 +6,9 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#ifdef ENABLE_TF
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#else
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#endif #include <migraphx/type_name.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
...@@ -104,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -104,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info)
t = as.type_enum(); t = as.type_enum();
n = sizeof(as()); n = sizeof(as());
} }
}); });
if(n == 0)
{
MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type" + info.format);
}
auto strides = info.strides; auto strides = info.strides;
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t { std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
return n > 0 ? i / n : 0; return n > 0 ? i / n : 0;
...@@ -153,6 +156,7 @@ PYBIND11_MODULE(migraphx, m) ...@@ -153,6 +156,7 @@ PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target"); py::class_<migraphx::target>(m, "target");
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_shape", &migraphx::program::get_shape) .def("get_shape", &migraphx::program::get_shape)
.def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); }) .def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); })
...@@ -161,16 +165,13 @@ PYBIND11_MODULE(migraphx, m) ...@@ -161,16 +165,13 @@ PYBIND11_MODULE(migraphx, m)
.def("__ne__", std::not_equal_to<migraphx::program>{}) .def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); }); .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
#ifdef ENABLE_TF
m.def("parse_tf", m.def("parse_tf",
&migraphx::parse_tf, &migraphx::parse_tf,
"Parse tf protobuf (default format is nhwc)", "Parse tf protobuf (default format is nhwc)",
py::arg("filename"), py::arg("filename"),
py::arg("is_nhwc") = true); py::arg("is_nhwc") = true);
#else
m.def("parse_onnx", &migraphx::parse_onnx); m.def("parse_onnx", &migraphx::parse_onnx);
#endif
m.def("get_target", [](const std::string& name) -> migraphx::target { m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu") if(name == "cpu")
return migraphx::cpu::target{}; return migraphx::cpu::target{};
...@@ -182,10 +183,16 @@ PYBIND11_MODULE(migraphx, m) ...@@ -182,10 +183,16 @@ PYBIND11_MODULE(migraphx, m)
}); });
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0); m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("quantize", [](migraphx::program& p, std::vector<std::string>& ins_names) { m.def("quantize_fp16",
migraphx::quantize(p, ins_names); &migraphx::quantize_fp16,
}); py::arg("prog"),
m.def("quantize", [](migraphx::program& p) { migraphx::quantize(p, {"all"}); }); py::arg("ins_names") = std::vector<std::string>{"all"});
m.def("quantize_int8",
&migraphx::quantize_int8,
py::arg("prog"),
py::arg("t"),
py::arg("calibration") = std::vector<migraphx::program::parameter_map>{},
py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"});
#ifdef HAVE_GPU #ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
This diff is collapsed.
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
...@@ -11,7 +12,7 @@ ...@@ -11,7 +12,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void fwd_conv_batchnorm_rewrite::apply(program& p) const void rewrite_batchnorm::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
...@@ -25,46 +26,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -25,46 +26,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); })) if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
continue; continue;
auto conv_ins = ins->inputs()[0]; auto s = shape{ins->get_shape().type(), {ins->get_shape().lens()[1]}};
if(conv_ins->name() != "convolution")
continue;
// Get convolution weights
auto weights = conv_ins->inputs()[1]->eval();
if(weights.empty())
continue;
// Get epsilon // Get epsilon
auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator()); auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon; auto epsilon = bn_op.epsilon;
// Get convolution op
auto conv_op = conv_ins->get_operator(); argument a{s};
auto weights_lens = weights.get_shape().lens(); argument b{s};
auto conv_lens = conv_ins->get_shape().lens(); visit_all(gamma, bias, mean, variance, a, b)(
argument new_weights{weights.get_shape()}; [&](auto gamma2, auto bias2, auto mean2, auto variance2, auto a2, auto b2) {
argument new_bias{{bias.get_shape().type(), {bias.get_shape().elements()}}}; dfor(a.get_shape().elements())(
visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)( [&](std::size_t c) { a2[c] = gamma2[c] / std::sqrt(variance2[c] + epsilon); });
[&](auto weights2, dfor(b.get_shape().elements())([&](std::size_t c) {
auto gamma2, b2[c] = bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
auto bias2,
auto mean2,
auto variance2,
auto new_weights2,
auto new_bias2) {
dfor(weights_lens[0], weights_lens[1], weights_lens[2], weights_lens[3])(
[&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) {
new_weights2(k, c, h, w) =
gamma2[k] / std::sqrt(variance2[k] + epsilon) * weights2(k, c, h, w);
});
dfor(new_bias.get_shape().elements())([&](std::size_t c) {
new_bias2[c] =
bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
}); });
}); });
// Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()}); auto a_ins = p.add_literal({a.get_shape(), a.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights}); auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins);
auto b = p.insert_instruction(ins, op::broadcast{1, c->get_shape().lens()}, l_bias); auto mul = p.insert_instruction(ins, op::mul{}, ins->inputs().front(), a_broadcast);
p.replace_instruction(ins, op::add{}, {c, b}); auto b_ins = p.add_literal({b.get_shape(), b.data()});
auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins);
auto add = p.insert_instruction(ins, op::add{}, mul, b_broadcast);
p.replace_instruction(ins, add);
} }
} }
......
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(program& prog) const
{
for(auto ins : iterator_for(prog))
{
if(ins->name() != "pooling")
continue;
if(ins->get_shape().lens().size() != 4)
continue;
if(ins->inputs().empty())
continue;
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
if(op.mode != "average")
continue;
if(op.padding[0] != 0 and op.padding[1] != 0)
continue;
if(op.stride[0] != 1 and op.stride[1] != 1)
continue;
if(s.lens()[2] != op.lengths[0] and s.lens()[3] != op.lengths[1])
continue;
std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1];
auto reshape =
prog.insert_instruction(ins, op::reshape{{n * c, -1}}, ins->inputs().front());
auto pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape);
prog.replace_instruction(ins, op::reshape{{n, c, 1, 1}}, pooling);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
This diff is collapsed.
...@@ -138,6 +138,24 @@ std::size_t shape::index(std::size_t i) const ...@@ -138,6 +138,24 @@ std::size_t shape::index(std::size_t i) const
return result; return result;
} }
} }
std::vector<std::size_t> shape::multi(std::size_t i) const
{
assert(this->standard());
std::vector<std::size_t> indices(lens().size());
std::transform(strides().begin(),
strides().end(),
lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (i / stride) % len;
});
return indices;
}
bool shape::packed() const { return this->elements() == this->element_space(); } bool shape::packed() const { return this->elements() == this->element_space(); }
bool shape::transposed() const bool shape::transposed() const
......
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct find_add_lit_broadcast auto lit_broadcast() { return match::any_of(match::is_constant(), match::name("broadcast")); }
auto not_lit_broadcast() { return match::none_of(match::is_constant(), match::name("broadcast")); }
auto op_lit_broadcast(std::string op, std::string x, std::string y)
{
return match::name(std::move(op))(match::either_arg(0, 1)(
lit_broadcast().bind(std::move(x)), not_lit_broadcast().bind(std::move(y))));
}
auto conv_const_weights()
{
return match::name("convolution")(match::used_once(),
match::args(match::any(), match::is_constant().bind("w")));
}
struct find_mul_conv
{ {
auto lit_broadcast() const auto matcher() const
{ {
return match::any_of(match::name("@literal"), match::name("broadcast")); return match::name("mul")(match::either_arg(0, 1)(conv_const_weights().bind("conv"),
match::name("broadcast").bind("a")));
} }
auto not_lit_broadcast() const
void apply(program& p, match::matcher_result r) const
{ {
return match::none_of(match::name("@literal"), match::name("broadcast")); auto ins = r.result;
auto conv_ins = r.instructions["conv"];
auto a_ins = r.instructions["a"];
auto w_ins = r.instructions["w"];
auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator());
if(broadcast_op.axis != 1)
return;
auto new_a = p.insert_instruction(
ins, op::broadcast{0, w_ins->get_shape().lens()}, a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, op::mul{}, new_a, w_ins);
auto new_conv = p.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
p.replace_instruction(ins, new_conv);
} }
auto add_lit_broadcast(std::string x, std::string y) const };
// a * (x + b) => a * x + a * b
struct find_mul_add
{
auto matcher() const
{ {
return match::name("add")(match::either_arg(0, 1)(lit_broadcast().bind(std::move(x)), return match::name("mul")(match::either_arg(0, 1)(
not_lit_broadcast().bind(std::move(y)))); match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
} }
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
auto ax_ins = p.insert_instruction(ins, op::mul{}, a_ins, x_ins);
auto ab_ins = p.insert_instruction(ins, op::mul{}, a_ins, b_ins);
p.replace_instruction(ins, op::add{}, ax_ins, ab_ins);
}
};
struct find_add_lit_broadcast
{
auto matcher() const auto matcher() const
{ {
return match::name("add")( return match::name("add")(
match::args(add_lit_broadcast("a", "x"), add_lit_broadcast("b", "y"))); match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
p.replace_instruction(ins, op::add{}, x_ins, sumab);
}
};
struct find_double_add_lit_broadcast
{
auto matcher() const
{
return match::name("add")(
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -36,11 +117,9 @@ struct find_add_lit_broadcast ...@@ -36,11 +117,9 @@ struct find_add_lit_broadcast
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"]; auto b_ins = r.instructions["b"];
if(a_ins->name() != b_ins->name())
return;
instruction_ref sumab; instruction_ref sumab;
if(a_ins->name() == "broadcast") if(a_ins->name() == "broadcast" and b_ins->name() == "broadcast")
{ {
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape()) if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return; return;
...@@ -59,7 +138,46 @@ struct find_add_lit_broadcast ...@@ -59,7 +138,46 @@ struct find_add_lit_broadcast
} }
}; };
void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}); } struct find_inner_broadcast
{
auto matcher() const
{
return match::name("mul", "add")(
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto y_ins = r.instructions["y"];
auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator());
auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator());
if(xbroadcast.axis != ybroadcast.axis)
return;
auto op = p.insert_instruction(
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front());
p.replace_instruction(ins, xbroadcast, op);
}
};
void simplify_algebra::apply(program& p) const
{
// Run simplifications multiple times
for(int i = 0; i < 4; i++)
{
match::find_matches(p,
find_inner_broadcast{},
find_double_add_lit_broadcast{},
find_add_lit_broadcast{},
find_mul_conv{},
find_mul_add{});
dead_code_elimination{}.apply(p);
}
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
This diff is collapsed.
...@@ -5,6 +5,7 @@ add_library(migraphx_cpu ...@@ -5,6 +5,7 @@ add_library(migraphx_cpu
gemm.cpp gemm.cpp
) )
set_target_properties(migraphx_cpu PROPERTIES EXPORT_NAME cpu) set_target_properties(migraphx_cpu PROPERTIES EXPORT_NAME cpu)
rocm_set_soversion(migraphx_cpu ${PROJECT_VERSION})
find_path(BLAZE_INCLUDE blaze/Blaze.h) find_path(BLAZE_INCLUDE blaze/Blaze.h)
find_package(Threads) find_package(Threads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment