Commit d22bab64 authored by wsttiger's avatar wsttiger
Browse files

Added op namespace on operators

parents ad0ab357 3d264140
cmake_minimum_required(VERSION 3.5) cmake_minimum_required(VERSION 3.5)
if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_BINARY_DIR}")
message(FATAL_ERROR "The binary and source directroy cannot be the same")
endif()
project(migraphlib) project(migraphlib)
find_package(ROCM REQUIRED) find_package(ROCM REQUIRED)
......
...@@ -6,9 +6,12 @@ ARG PREFIX=/usr/local ...@@ -6,9 +6,12 @@ ARG PREFIX=/usr/local
RUN dpkg --add-architecture i386 RUN dpkg --add-architecture i386
# Add rocm repository # Add rocm repository
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y curl apt-utils wget RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y curl apt-utils wget software-properties-common
RUN curl https://raw.githubusercontent.com/RadeonOpenCompute/ROCm-docker/master/add-rocm.sh | bash RUN curl https://raw.githubusercontent.com/RadeonOpenCompute/ROCm-docker/master/add-rocm.sh | bash
# Add ubuntu toolchain
RUN apt-get update && add-apt-repository ppa:ubuntu-toolchain-r/test -y
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
apt-utils \ apt-utils \
...@@ -19,6 +22,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -19,6 +22,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
cmake \ cmake \
curl \ curl \
doxygen \ doxygen \
g++-7 \
gdb \ gdb \
git \ git \
hsa-rocr-dev \ hsa-rocr-dev \
...@@ -26,14 +30,14 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -26,14 +30,14 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
lcov \ lcov \
libelf-dev \ libelf-dev \
libncurses5-dev \ libncurses5-dev \
libpthread-stubs0-dev \
libnuma-dev \ libnuma-dev \
libpthread-stubs0-dev \
python \ python \
python-dev \ python-dev \
python-pip \ python-pip \
rocminfo \
rocm-opencl \ rocm-opencl \
rocm-opencl-dev \ rocm-opencl-dev \
rocminfo \
software-properties-common \ software-properties-common \
wget && \ wget && \
apt-get clean && \ apt-get clean && \
......
...@@ -3,10 +3,11 @@ def rocmtestnode(variant, name, body) { ...@@ -3,10 +3,11 @@ def rocmtestnode(variant, name, body) {
def image = 'migraphlib' def image = 'migraphlib'
def cmake_build = { compiler, flags -> def cmake_build = { compiler, flags ->
def cmd = """ def cmd = """
ulimit -c unlimited
rm -rf build rm -rf build
mkdir build mkdir build
cd build cd build
CXX=${compiler} CXXFLAGS='-Werror -Wno-fallback' cmake -DCMAKE_CXX_FLAGS_DEBUG='-g -fno-omit-frame-pointer -fsanitize=undefined -fno-sanitize-recover=undefined' ${flags} .. CXX=${compiler} CXXFLAGS='-Werror -Wno-fallback' cmake ${flags} ..
CTEST_PARALLEL_LEVEL=32 make -j32 all doc check CTEST_PARALLEL_LEVEL=32 make -j32 all doc check
""" """
echo cmd echo cmd
...@@ -92,16 +93,26 @@ rocmtest tidy: rocmnode('rocmtest') { cmake_build -> ...@@ -92,16 +93,26 @@ rocmtest tidy: rocmnode('rocmtest') { cmake_build ->
} }
}, clang: rocmnode('rocmtest') { cmake_build -> }, clang: rocmnode('rocmtest') { cmake_build ->
stage('Clang Debug') { stage('Clang Debug') {
cmake_build('hcc', '-DCMAKE_BUILD_TYPE=debug') // TODO: Enanle integer
def sanitizers = "undefined"
cmake_build("hcc", "-DCMAKE_BUILD_TYPE=debug -DCMAKE_CXX_FLAGS_DEBUG='-g -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}'")
} }
stage('Clang Release') { stage('Clang Release') {
cmake_build('hcc', '-DCMAKE_BUILD_TYPE=release') cmake_build("hcc", "-DCMAKE_BUILD_TYPE=release")
} }
}, gcc: rocmnode('rocmtest') { cmake_build -> }, gcc5: rocmnode('rocmtest') { cmake_build ->
stage('GCC Debug') { stage('GCC 5 Debug') {
cmake_build('g++-5', '-DCMAKE_BUILD_TYPE=debug') cmake_build("g++-5", "-DCMAKE_BUILD_TYPE=debug")
} }
stage('GCC Release') { stage('GCC 5 Release') {
cmake_build('g++-5', '-DCMAKE_BUILD_TYPE=release') cmake_build("g++-5", "-DCMAKE_BUILD_TYPE=release")
}
}, gcc7: rocmnode('rocmtest') { cmake_build ->
stage('GCC 7 Debug') {
def linker_flags = '-fuse-ld=gold'
def cmake_linker_flags = "-DCMAKE_EXE_LINKER_FLAGS='${linker_flags}' -DCMAKE_SHARED_LINKER_FLAGS='${linker_flags}'"
// TODO: Add bounds-strict
def sanitizers = "undefined,address"
cmake_build("g++-7", "-DCMAKE_BUILD_TYPE=debug ${cmake_linker_flags} -DCMAKE_CXX_FLAGS_DEBUG='-g -fno-omit-frame-pointer -fsanitize-address-use-after-scope -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}'")
} }
} }
/*
The MIT License (MIT)
Copyright (c) 2015-Present Advanced Micro Devices, Inc. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
...@@ -68,6 +68,19 @@ else() ...@@ -68,6 +68,19 @@ else()
-Wno-sign-compare -Wno-sign-compare
) )
# Flags for gcc 7
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "7.0")
list(APPEND CMAKE_COMPILER_WARNINGS
-Wduplicated-branches
-Wduplicated-cond
-Wno-noexcept-type
-Wodr
-Wshift-negative-value
-Wshift-overflow=2
)
endif()
endif()
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang") if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang")
list(APPEND CMAKE_COMPILER_WARNINGS list(APPEND CMAKE_COMPILER_WARNINGS
-Weverything -Weverything
......
...@@ -90,3 +90,210 @@ ...@@ -90,3 +90,210 @@
<summary>Double negative is always positive</summary> <summary>Double negative is always positive</summary>
</message> </message>
</rule> </rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( \w+ (\||&) \w+ \)]]></pattern>
<message>
<id>BitwiseOperatorInConditional</id>
<severity>style</severity>
<summary>Bitwise operator found in if statement.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( [^)]+ \) { if \( [^)]+ \) ({[^{}]*(?1)*[^{}]*}) }]]></pattern>
<message>
<id>CollapsibleIfStatements</id>
<severity>style</severity>
<summary>These two if statements can be collapsed into one.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[catch \( [^())]+ \) { }]]></pattern>
<message>
<id>EmptyCatchStatement</id>
<severity>style</severity>
<summary>An empty catch statement.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[do { } while \(]]></pattern>
<message>
<id>EmptyDoWhileStatement</id>
<severity>style</severity>
<summary>Empty do-while.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[else { }]]></pattern>
<message>
<id>EmptyElseBlock</id>
<severity>style</severity>
<summary>Empty else statement can be safely removed.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[for \( [^()]+ \) { }]]></pattern>
<message>
<id>EmptyForStatement</id>
<severity>style</severity>
<summary>Empty for statement.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( [^()]+ \) { }]]></pattern>
<message>
<id>EmptyIfStatement</id>
<severity>style</severity>
<summary>Empty if statement.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[switch \( [^()]+ \) { }]]></pattern>
<message>
<id>EmptySwitchStatement</id>
<severity>style</severity>
<summary>Empty switch statement.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[while \( [^()]+ \) { }]]></pattern>
<message>
<id>EmptyWhileStatement</id>
<severity>style</severity>
<summary>Empty while statement.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[ for \( ; [^;]+ ; \)]]></pattern>
<message>
<id>ForLoopShouldBeWhileLoop</id>
<severity>style</severity>
<summary>For loop should be written as a while loop.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern>goto</pattern>
<message>
<id>GotoStatement</id>
<severity>style</severity>
<summary>Goto considered harmful.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( \w+ != \w+ \) ({[^{}]*(?1)*[^{}]*}) else { (?!if)]]></pattern>
<message>
<id>InvertedLogic</id>
<severity>style</severity>
<summary>It is cleaner to invert the logic.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( ! \w+ \) ({[^{}]*(?1)*[^{}]*}) else { (?!if)]]></pattern>
<message>
<id>InvertedLogic</id>
<severity>style</severity>
<summary>It is cleaner to invert the logic.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[\w+ != \w+ \?]]></pattern>
<message>
<id>InvertedLogic</id>
<severity>style</severity>
<summary>It is cleaner to invert the logic.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[! \w+ \?]]></pattern>
<message>
<id>InvertedLogic</id>
<severity>style</severity>
<summary>It is cleaner to invert the logic.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[\? (true|false) : (true|false)]]></pattern>
<message>
<id>RedundantConditionalOperator</id>
<severity>style</severity>
<summary>Conditional operator is redundant.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( [^()]+ \) { return (true|false) ; } else { return (true|false) ; }]]></pattern>
<message>
<id>RedundantIfStatement</id>
<severity>style</severity>
<summary>The if statement is redundant.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( [^()]+ \) { [^{}]* (return|throw|break|continue) [^;]* ; } else {]]></pattern>
<message>
<id>UnnecessaryElseStatement</id>
<severity>style</severity>
<summary>Else statement is not necessary.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[for \( \w+ (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = \w+ \[ \1 \] ; }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
<summary>Considering using std::copy instead.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[for \( \w+ (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = \w+ ; }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
<summary>Considering using std::fill instead.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[for \( \w+ (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (\w+ :: )*\w+ \( \) ; }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
<summary>Considering using std::generate instead.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[for \( \w+ (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (\w+ :: )*\w+ \( \w+ \[ \1 \] \) ; }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
<summary>Considering using std::transform instead.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[for \( \w+ (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (\w+ :: )*\w+ \( \w+ \[ \1 \] , \w+ \[ \1 \] \) ; }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
<summary>Considering using std::transform instead.</summary>
</message>
</rule>
Matchers
========
Introduction
------------
The matchers provide a way compose several predicates together. Many of the matchers can be composed so that ``m(m1, m2)`` will first check that ``m`` matches and then it will check that ``m1`` and ``m2`` will match.
The most commonly-used matcher is the ``name`` matcher. It will match the instruction that have the operator that is equal to the name specified::
auto match_sum = name("sum");
This will find ``sum`` operators. We can also find ``sum`` operators which the output is ``standard_shape``:
auto match_sum = name("sum")(standard_shape());
Arguments
---------
We also want to match arguments to the instructions as well. One way, is to match each argument using the ``arg`` matcher::
auto match_sum = name("sum")(arg(0)(name("@literal"), arg(1)(name("@literal"))));
This will match a ``sum`` operator with the two arguments that are literals. Of course, instead of writing ``arg(0)`` and ``arg(1)`` everytime, the ``args`` matcher can be used::
auto match_sum = name("sum")(args(name("@literal"), name("@literal")));
Binding
-------
As we traverse through the instructions we may want reference some of the instructions we find along the way. We can do this by calling ``.bind``::
auto match_sum = name("sum")(args(
name("@literal").bind("one"),
name("@literal").bind("two")
)).bind("sum");
This will associate the instruction to a name that can be read from the ``matcher_result`` when it matches.
Finding matches
---------------
Finally, when you want to use the matchers to find instructions a callback object can be written which has the matcher and an ``apply`` function which will take the ``matcher_result`` when the match is found::
struct match_find_sum
{
auto matcher() const { return name("sum"); }
void apply(program& p, matcher_result r) const
{
// Do something with the result
}
};
find_matches(prog, match_find_sum{});
Creating matchers
-----------------
There are several ways to create matchers. The macros ``MIGRAPH_BASIC_MATCHER`` and ``MIGRAPH_PRED_MATCHER`` help with creating matchers. For example, we can create a matcher for shapes that are broadcasted::
MIGRAPH_PRED_MATCHER(broadcasted_shape, instruction_ref ins)
{
return ins->get_shape().broadcasted();
}
If we want parameters to the predicate, then we will need to use the ``make_basic_pred_matcher`` to create the matcher. For example, here is how we would create a matcher to check the number of dimensions of the shape::
inline auto number_of_dims(std::size_t n)
{
return make_basic_pred_matcher([=](instruction_ref ins) {
return ins->get_shape().lens().size() == n;
});
}
Developer Guide
===============
.. toctree::
:maxdepth: 2
:caption: Contents:
dev/matchers
...@@ -7,15 +7,11 @@ Welcome to MIGraph's documentation! ...@@ -7,15 +7,11 @@ Welcome to MIGraph's documentation!
=================================== ===================================
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 3
:caption: Contents: :caption: Contents:
overview user_guide
reference/data developer_guide
reference/operators
reference/program
reference/targets
reference/pass
Indices and tables Indices and tables
......
User Guide
==========
.. toctree::
:maxdepth: 2
:caption: Contents:
overview
reference/data
reference/operators
reference/program
reference/targets
reference/pass
...@@ -11,6 +11,8 @@ add_library(migraph ...@@ -11,6 +11,8 @@ add_library(migraph
program.cpp program.cpp
shape.cpp shape.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
) )
rocm_clang_tidy_check(migraph) rocm_clang_tidy_check(migraph)
target_include_directories(migraph PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraph PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
...@@ -20,3 +22,10 @@ add_subdirectory(targets/cpu) ...@@ -20,3 +22,10 @@ add_subdirectory(targets/cpu)
if(MIGRAPH_ENABLE_GPU) if(MIGRAPH_ENABLE_GPU)
add_subdirectory(targets/gpu) add_subdirectory(targets/gpu)
endif() endif()
#install (TARGETS migraph
# LIBRARY DESTINATION /opt/rocm/lib)
#install (DIRECTORY include/migraph DESTINATION /opt/rocm/include)
...@@ -13,7 +13,7 @@ void auto_contiguous::apply(program& p) const ...@@ -13,7 +13,7 @@ void auto_contiguous::apply(program& p) const
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.standard()) if(not s.standard())
{ {
auto c = p.insert_instruction(std::next(ins), contiguous{}, ins); auto c = p.insert_instruction(std::next(ins), op::contiguous{}, ins);
p.replace_instruction(ins, c); p.replace_instruction(ins, c);
} }
} }
......
...@@ -4,12 +4,17 @@ ...@@ -4,12 +4,17 @@
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp> #include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp> #include <migraph/ranges.hpp>
#include <migraph/stringutils.hpp>
#include <migraph/pass_config.hpp>
namespace migraph { namespace migraph {
void eliminate_allocation::apply(program& p) const void eliminate_allocation::apply(program& p) const
{ {
assert(alignment > 0); assert(alignment > 0);
if(!enabled(MIGRAPH_DISABLE_MEMORY_COLORING{}))
return;
std::size_t n = 0; std::size_t n = 0;
std::vector<std::pair<instruction_ref, std::size_t>> allocs; std::vector<std::pair<instruction_ref, std::size_t>> allocs;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
...@@ -27,7 +32,7 @@ void eliminate_allocation::apply(program& p) const ...@@ -27,7 +32,7 @@ void eliminate_allocation::apply(program& p) const
auto ins = pp.first; auto ins = pp.first;
auto s = ins->get_shape(); auto s = ins->get_shape();
auto offset = pp.second; auto offset = pp.second;
p.replace_instruction(ins, load{s, offset}, mem); p.replace_instruction(ins, op::load{s, offset}, mem);
} }
} }
} // namespace migraph } // namespace migraph
...@@ -28,7 +28,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -28,7 +28,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
const auto& mean = ins->inputs()[3]->get_literal(); const auto& mean = ins->inputs()[3]->get_literal();
const auto& variance = ins->inputs()[4]->get_literal(); const auto& variance = ins->inputs()[4]->get_literal();
// Get epsilon // Get epsilon
auto bn_op = any_cast<batch_norm_inference>(ins->get_operator()); auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon; auto epsilon = bn_op.epsilon;
// Get convolution weights // Get convolution weights
const auto& weights = conv_ins->inputs()[1]->get_literal(); const auto& weights = conv_ins->inputs()[1]->get_literal();
...@@ -59,8 +59,8 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -59,8 +59,8 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()}); auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights}); auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights});
auto b = p.insert_instruction(ins, broadcast{1}, c, l_bias); auto b = p.insert_instruction(ins, op::broadcast{1}, c, l_bias);
p.replace_instruction(ins, add{}, {c, b}); p.replace_instruction(ins, op::add{}, {c, b});
} }
} }
} // namespace migraph } // namespace migraph
...@@ -61,6 +61,12 @@ constexpr void repeat_c_impl(F f, seq<Ns...>) ...@@ -61,6 +61,12 @@ constexpr void repeat_c_impl(F f, seq<Ns...>)
swallow{(f(std::integral_constant<std::size_t, Ns>{}), 0)...}; swallow{(f(std::integral_constant<std::size_t, Ns>{}), 0)...};
} }
template <class F, std::size_t... Ns>
constexpr auto sequence_c_impl(F&& f, seq<Ns...>)
{
return f(std::integral_constant<std::size_t, Ns>{}...);
}
} // namespace detail } // namespace detail
template <std::size_t N, class F> template <std::size_t N, class F>
...@@ -69,6 +75,18 @@ constexpr void repeat_c(F f) ...@@ -69,6 +75,18 @@ constexpr void repeat_c(F f)
detail::repeat_c_impl(f, detail::gens<N>{}); detail::repeat_c_impl(f, detail::gens<N>{});
} }
template <std::size_t N, class F>
constexpr auto sequence_c(F&& f)
{
return detail::sequence_c_impl(f, detail::gens<N>{});
}
template <class F, class... Ts>
constexpr void each_args(F f, Ts&&... xs)
{
swallow{(f(std::forward<Ts>(xs)), 0)...};
}
/// Implements a fix-point combinator /// Implements a fix-point combinator
template <class R, class F> template <class R, class F>
detail::fix_f<R, F> fix(F f) detail::fix_f<R, F> fix(F f)
...@@ -88,6 +106,24 @@ auto pack(Ts... xs) ...@@ -88,6 +106,24 @@ auto pack(Ts... xs)
return [=](auto f) { return f(xs...); }; return [=](auto f) { return f(xs...); };
} }
template <class F, class T>
auto fold_impl(F&&, T&& x)
{
return x;
}
template <class F, class T, class U, class... Ts>
auto fold_impl(F&& f, T&& x, U&& y, Ts&&... xs)
{
return fold_impl(f, f(std::forward<T>(x), std::forward<U>(y)), std::forward<Ts>(xs)...);
}
template <class F>
auto fold(F f)
{
return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); };
}
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -71,13 +71,13 @@ struct instruction ...@@ -71,13 +71,13 @@ struct instruction
// internal // internal
void replace_argument(instruction_ref old, instruction_ref new_ins); void replace_argument(instruction_ref old, instruction_ref new_ins);
private:
operation op; operation op;
shape result; shape result;
std::vector<instruction_ref> output; std::vector<instruction_ref> output;
std::vector<instruction_ref> arguments; std::vector<instruction_ref> arguments;
literal lit; literal lit;
}; };
} // namespace migraph } // namespace migraph
namespace std { namespace std {
......
...@@ -87,8 +87,8 @@ struct literal : raw_data<literal> ...@@ -87,8 +87,8 @@ struct literal : raw_data<literal>
m_shape.visit_type([&](auto as) { m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get())); auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
it++;
output(idx.begin(), idx.end()) = *it; output(idx.begin(), idx.end()) = *it;
it++;
}); });
}); });
} }
......
#ifndef MIGRAPH_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPH_GUARD_RTGLIB_MATCHER_HPP
#include <migraph/functional.hpp>
#include <migraph/ranges.hpp>
#include <migraph/instruction.hpp>
#include <migraph/program.hpp>
#include <migraph/iterator_for.hpp>
#include <unordered_map>
namespace migraph {
namespace matchers {
struct matcher_context
{
matcher_context(instruction_ref i) : last(i) {}
std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; }
private:
instruction_ref last;
};
/// Convert a predicate function into a matcher
template <class P>
struct predicate_matcher
{
P p;
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
assert(ins != ctx.not_found());
if(p(ins))
return ins;
return ctx.not_found();
}
};
/// Convert a function into a matcher
template <class F>
struct function_matcher
{
F f;
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
assert(ins != ctx.not_found());
return f(ctx, ins);
}
};
/// Convert a function into a matcher
template <class F>
function_matcher<F> make_function_matcher(F f)
{
return {f};
}
/// Converts a matcher to bind the instruction to name
template <class M>
auto bind_match(M m, std::string name)
{
return make_function_matcher(
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
auto result = m.match(ctx, ins);
if(result != ctx.not_found())
ctx.instructions.emplace(name, ins);
return result;
});
}
/// Convert a matcher to a bindable matcher
template <class M>
struct bindable_matcher
{
M m;
auto bind(std::string name) { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
return m.match(ctx, ins);
}
};
/// Create a bindable matcher
template <class M>
bindable_matcher<M> make_bindable_matcher(M m)
{
return {m};
}
/// Create a bindable matcher from a function
template <class F>
bindable_matcher<function_matcher<F>> make_bf_matcher(F f)
{
return {{f}};
}
/// Create a bindable matcher from a predicate function
template <class F>
bindable_matcher<predicate_matcher<F>> make_bp_matcher(F f)
{
return {{f}};
}
using bool_list = std::initializer_list<bool>;
struct id_matcher
{
instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; }
};
/// The basic matcher provides the all_of composability of the matcher
template <class M>
struct basic_matcher
{
M m;
template <class... Ts>
auto operator()(Ts... ms) const
{
// Copy m because we cant capture `this` by value
auto mm = m;
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
auto result = mm.match(ctx, ins);
if(result != ctx.not_found())
{
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, result) != ctx.not_found();
})(true, ms...);
if(matches)
return result;
}
return ctx.not_found();
});
}
auto bind(std::string name) { return bind_match(m, name); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
return m.match(ctx, ins);
}
};
/// Create a basic matcher from a matcher
template <class M>
basic_matcher<M> make_basic_matcher(M m)
{
return {m};
}
/// Create a basic matcher from a function
template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f)
{
return {{f}};
}
/// Create a basic matcher from a predicate function
template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
{
return {{p}};
}
/// This macro takes care of the boilerplate for defining a matcher
#define MIGRAPH_BASIC_MATCHER(name, ...) \
struct name##_m \
{ \
instruction_ref match(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraph::matchers::basic_matcher<name##_m>{{}}; \
inline instruction_ref name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPH_PRED_MATCHER(name, ...) \
struct name##_m \
{ \
bool operator()(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraph::matchers::basic_matcher<predicate_matcher<name##_m>>{{}}; \
inline bool name##_m::operator()(__VA_ARGS__) const
struct matcher_result
{
std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref result;
};
/// Match a single instruction
template <class M>
matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
{
assert(ins != p.end());
matcher_result result;
matcher_context ctx{p.end()};
result.result = m.match(ctx, ins);
result.instructions = ctx.instructions;
return result;
}
/// Find matches in a program
template <class... Ms>
void find_matches(program& p, Ms&&... ms)
{
for(auto ins : iterator_for(p))
{
bool match = false;
each_args(
[&](auto&& m) {
// cppcheck-suppress knownConditionTrueFalse
if(match)
return;
auto r = match_instruction(p, ins, m.matcher());
if(r.result == p.end())
return;
m.apply(p, r);
match = true;
},
ms...);
}
}
template <class... Ts>
auto all_of(Ts... ms)
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, ins) != ctx.not_found();
})(true, ms...);
if(matches)
return ins;
return ctx.not_found();
});
}
template <class... Ts>
auto none_of(Ts... ms)
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, ins) == ctx.not_found();
})(true, ms...);
if(matches)
return ins;
return ctx.not_found();
});
}
template <class... Ts>
auto any_of(Ts... ms)
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) {
return x or y.match(ctx, ins) != ctx.not_found();
})(false, ms...);
if(matches)
return ins;
return ctx.not_found();
});
}
MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
inline auto name(std::string name)
{
return make_basic_pred_matcher(
[ =, name = std::move(name) ](instruction_ref ins) { return ins->name() == name; });
}
inline auto nargs(std::size_t n)
{
return make_basic_pred_matcher([=](instruction_ref ins) { return ins->inputs().size() == n; });
}
inline auto arg(std::size_t i)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
if(i < ins->inputs().size())
return ins->inputs()[i];
return ctx.not_found();
});
}
// Workaround for bugs in clang
template <std::size_t...>
struct args_impl_ints
{
};
template <std::size_t... Ns, class... Ms>
auto args_impl(args_impl_ints<Ns...>, Ms... ms)
{
return matchers::all_of(nargs(sizeof...(Ns)), arg(Ns)(ms)...);
}
template <class... Ms>
auto args(Ms... ms)
{
return sequence_c<sizeof...(Ms)>([=](auto... is) {
// It needs to be written as `decltype(is)::value` for gcc 5
return args_impl(args_impl_ints<decltype(is)::value...>{}, ms...);
});
}
} // namespace matchers
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_HPP
#define MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace migraph {
struct program;
struct memory_coloring
{
std::string allocation_op{};
std::string name() const { return "memory coloring"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <utility> #include <utility>
namespace migraph { namespace migraph {
namespace op {
struct not_computable struct not_computable
{ {
...@@ -282,6 +283,151 @@ struct contiguous ...@@ -282,6 +283,151 @@ struct contiguous
} }
}; };
struct slice
{
std::vector<int64_t> axes;
std::vector<int64_t> starts;
std::vector<int64_t> ends;
std::string name() const { return "slice"; }
auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
{
int64_t r = std::min(index, static_cast<int64_t>(lens[axis]));
if(r < 0)
r += lens[axis];
return std::size_t(r);
}
auto compute_offset(const shape& s) const
{
const std::vector<std::size_t>& lens = s.lens();
const std::vector<std::size_t>& strides = s.strides();
auto offset = 0;
if(!axes.empty())
{
for(std::size_t i = 0; i < axes.size(); i++)
{
auto axis = axes[i];
offset += fix_index(lens, axis, starts[i]) * strides[axis];
}
}
else
{
for(std::size_t axis = 0; axis < lens.size(); axis++)
{
offset += fix_index(lens, axis, starts[axis]) * strides[axis];
}
}
return offset;
}
shape compute_shape(std::vector<shape> inputs) const
{
auto input_shape = inputs[0];
auto t = input_shape.type();
const auto& old_lens = input_shape.lens();
const auto& old_strides = input_shape.strides();
// std::vector<int64_t> t_axes(old_lens.size());
// if(axes.size() == 0)
// {
// std::iota(t_axes.begin(), t_axes.end(), 0);
// }
// else
// {
// std::copy(axes.begin(), axes.end(), t_axes.begin());
// }
if(starts.size() != axes.size() || axes.size() != ends.size())
{
MIGRAPH_THROW("inconsistent sizes");
}
std::vector<std::size_t> new_lens = old_lens;
for(std::size_t i = 0; i < axes.size(); i++)
{
auto axis = axes[i];
new_lens[axis] =
fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]);
}
return shape{t, new_lens, old_strides};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
auto input = args[0];
auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
return {std::move(output_shape), [=] { return input.data() + offset; }};
}
};
struct squeeze
{
std::vector<int64_t> axes;
std::string name() const { return "squeeze"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
if(std::any_of(
axes.begin(), axes.end(), [&](auto axis) { return input_shape.lens()[axis] != 1; }))
{
MIGRAPH_THROW("squeeze axis dimension should be equal to 1");
}
std::vector<std::size_t> new_lens;
if(axes.empty())
{
std::copy_if(old_lens.begin(),
old_lens.end(),
std::back_inserter(new_lens),
[](auto len) { return len != 1; });
}
else
{
for(std::size_t i = 0; i < old_lens.size(); i++)
{
if(std::find(axes.begin(), axes.end(), i) == axes.end())
{
new_lens.push_back(old_lens[i]);
}
}
}
return shape{type, new_lens};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
}
};
struct unsqueeze
{
std::vector<int64_t> axes;
std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size);
std::size_t p = 0;
for(std::size_t i = 0; i < new_size; i++)
{
if(std::find(axes.begin(), axes.end(), i) != axes.end())
{
new_lens[i] = 1;
}
else
{
new_lens[i] = old_lens[p++];
}
}
return shape{type, new_lens};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
}
};
struct reshape struct reshape
{ {
std::vector<int64_t> dims; std::vector<int64_t> dims;
...@@ -571,6 +717,7 @@ struct outline ...@@ -571,6 +717,7 @@ struct outline
} }
}; };
} // namespace op
} // namespace migraph } // namespace migraph
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment