Commit eb0d8fee authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into driver

parents 65ef35cd 0d796941
...@@ -3,6 +3,8 @@ CheckOptions: ...@@ -3,6 +3,8 @@ CheckOptions:
value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product' value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product'
- key: cppcoreguidelines-macro-usage.AllowedRegexp - key: cppcoreguidelines-macro-usage.AllowedRegexp
value: 'DEBUG|FALLTHROUGH|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_GENERATE_|_DETAIL_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED' value: 'DEBUG|FALLTHROUGH|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_GENERATE_|_DETAIL_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED'
- key: cppcoreguidelines-narrowing-conversions.WarnOnFloatingPointNarrowingConversion
value: 0
- key: modernize-loop-convert.MinConfidence - key: modernize-loop-convert.MinConfidence
value: risky value: risky
- key: modernize-loop-convert.NamingStyle - key: modernize-loop-convert.NamingStyle
......
...@@ -22,7 +22,7 @@ find_package(ROCM REQUIRED) ...@@ -22,7 +22,7 @@ find_package(ROCM REQUIRED)
include(ROCMSetupVersion) include(ROCMSetupVersion)
rocm_setup_version(VERSION 0.1) rocm_setup_version(VERSION 0.2)
option( BUILD_SHARED_LIBS "Build as a shared library" ON ) option( BUILD_SHARED_LIBS "Build as a shared library" ON )
...@@ -39,6 +39,8 @@ else() ...@@ -39,6 +39,8 @@ else()
set(MIGRAPHX_ENABLE_GPU Off CACHE BOOL "") set(MIGRAPHX_ENABLE_GPU Off CACHE BOOL "")
endif() endif()
set(MIGRAPHX_ENABLE_TF Off CACHE BOOL "")
add_compile_options(-std=c++14) add_compile_options(-std=c++14)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
...@@ -54,6 +56,7 @@ rocm_enable_clang_tidy( ...@@ -54,6 +56,7 @@ rocm_enable_clang_tidy(
-clang-diagnostic-extern-c-compat -clang-diagnostic-extern-c-compat
-clang-diagnostic-disabled-macro-expansion -clang-diagnostic-disabled-macro-expansion
-clang-diagnostic-unused-command-line-argument -clang-diagnostic-unused-command-line-argument
-cppcoreguidelines-explicit-virtual-functions
-cppcoreguidelines-pro-bounds-array-to-pointer-decay -cppcoreguidelines-pro-bounds-array-to-pointer-decay
-cppcoreguidelines-pro-bounds-constant-array-index -cppcoreguidelines-pro-bounds-constant-array-index
-cppcoreguidelines-pro-bounds-pointer-arithmetic -cppcoreguidelines-pro-bounds-pointer-arithmetic
...@@ -81,6 +84,7 @@ rocm_enable_clang_tidy( ...@@ -81,6 +84,7 @@ rocm_enable_clang_tidy(
-modernize-pass-by-value -modernize-pass-by-value
-modernize-use-default-member-init -modernize-use-default-member-init
-modernize-use-transparent-functors -modernize-use-transparent-functors
-performance-type-promotion-in-math-fn
-readability-braces-around-statements -readability-braces-around-statements
-readability-else-after-return -readability-else-after-return
-readability-named-parameter -readability-named-parameter
...@@ -145,6 +149,7 @@ rocm_create_package( ...@@ -145,6 +149,7 @@ rocm_create_package(
DESCRIPTION "AMD's graph optimizer" DESCRIPTION "AMD's graph optimizer"
MAINTAINER "Paul Fultz II <paul.fultz@amd.com>" MAINTAINER "Paul Fultz II <paul.fultz@amd.com>"
LDCONFIG LDCONFIG
PTH
DEPENDS miopen-hip rocblas hip_hcc half DEPENDS miopen-hip rocblas hip_hcc half
) )
......
...@@ -50,7 +50,7 @@ RUN pip install cget ...@@ -50,7 +50,7 @@ RUN pip install cget
RUN pip install https://github.com/pfultz2/rclone/archive/master.tar.gz RUN pip install https://github.com/pfultz2/rclone/archive/master.tar.gz
# Install hcc # Install hcc
RUN rclone -b roc-2.0.x -c 757fb492517b80e7c86338af5fc1a43d63cb25a9 https://github.com/RadeonOpenCompute/hcc.git /hcc RUN rclone -b roc-2.3.x -c fd93baed7dcc4fe8019b5fdc90213bfe7c298245 https://github.com/RadeonOpenCompute/hcc.git /hcc
RUN cget -p $PREFIX install hcc,/hcc RUN cget -p $PREFIX install hcc,/hcc
# Use hcc # Use hcc
...@@ -73,3 +73,8 @@ ENV LD_LIBRARY_PATH=$PREFIX/lib ...@@ -73,3 +73,8 @@ ENV LD_LIBRARY_PATH=$PREFIX/lib
# Install doc requirements # Install doc requirements
ADD doc/requirements.txt /doc-requirements.txt ADD doc/requirements.txt /doc-requirements.txt
RUN pip install -r /doc-requirements.txt RUN pip install -r /doc-requirements.txt
# Setup ubsan environment to printstacktrace
RUN ln -s /usr/bin/llvm-symbolizer-5.0 /usr/local/bin/llvm-symbolizer
ENV UBSAN_OPTIONS=print_stacktrace=1
ENV ASAN_OPTIONS=detect_stack_use_after_return=1:check_initialization_order=1:strict_init_order=1
...@@ -3,6 +3,7 @@ def rocmtestnode(variant, name, body) { ...@@ -3,6 +3,7 @@ def rocmtestnode(variant, name, body) {
def image = 'migraphxlib' def image = 'migraphxlib'
def cmake_build = { compiler, flags -> def cmake_build = { compiler, flags ->
def cmd = """ def cmd = """
env
ulimit -c unlimited ulimit -c unlimited
rm -rf build rm -rf build
mkdir build mkdir build
...@@ -20,8 +21,8 @@ def rocmtestnode(variant, name, body) { ...@@ -20,8 +21,8 @@ def rocmtestnode(variant, name, body) {
} }
} }
node(name) { node(name) {
withEnv(['HSA_ENABLE_SDMA=0', 'MIOPEN_DEBUG_GCN_ASM_KERNELS=0']) {
stage("checkout ${variant}") { stage("checkout ${variant}") {
env.HSA_ENABLE_SDMA=0
checkout scm checkout scm
} }
stage("image ${variant}") { stage("image ${variant}") {
...@@ -38,6 +39,7 @@ def rocmtestnode(variant, name, body) { ...@@ -38,6 +39,7 @@ def rocmtestnode(variant, name, body) {
} }
} }
} }
}
} }
@NonCPS @NonCPS
def rocmtest(m) { def rocmtest(m) {
...@@ -107,6 +109,10 @@ rocmtest tidy: rocmnode('rocmtest') { cmake_build -> ...@@ -107,6 +109,10 @@ rocmtest tidy: rocmnode('rocmtest') { cmake_build ->
stage('Clang Release') { stage('Clang Release') {
cmake_build("hcc", "-DCMAKE_BUILD_TYPE=release") cmake_build("hcc", "-DCMAKE_BUILD_TYPE=release")
} }
stage('Clang Release Python 3') {
cmake_build("hcc", "-DCMAKE_BUILD_TYPE=release -DPYTHON_EXECUTABLE=/usr/local/bin/python3")
}
}, gcc5: rocmnode('rocmtest') { cmake_build -> }, gcc5: rocmnode('rocmtest') { cmake_build ->
stage('GCC 5 Debug') { stage('GCC 5 Debug') {
cmake_build("g++-5", "-DCMAKE_BUILD_TYPE=debug") cmake_build("g++-5", "-DCMAKE_BUILD_TYPE=debug")
......
...@@ -12,14 +12,18 @@ AMD's graph optimization engine. ...@@ -12,14 +12,18 @@ AMD's graph optimization engine.
## Installing the dependencies ## Installing the dependencies
The dependencies can be installed with the `install_deps.cmake`, script: `cmake -P install_deps.cmake`. Dependencies can be installed using the ROCm build tool [rbuild](https://github.com/RadeonOpenCompute/rbuild).
This will install by default to `/usr/local` but it can be installed in another location with `--prefix` argument:
To install rbuild:
``` ```
cmake -P install_deps.cmake --prefix /some/local/dir pip install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz
``` ```
To build dependencies along with MIGraphX
```
rbuild build -d depend --cxx=/opt/rocm/bin/hcc
```
This builds dependencies in the subdirectory named depend and then builds MIGraphX using these dependencies.
## Building MIGraphX from source ## Building MIGraphX from source
......
...@@ -90,6 +90,7 @@ else() ...@@ -90,6 +90,7 @@ else()
-Wno-double-promotion -Wno-double-promotion
-Wno-exit-time-destructors -Wno-exit-time-destructors
-Wno-extra-semi -Wno-extra-semi
-Wno-extra-semi-stmt
-Wno-float-conversion -Wno-float-conversion
-Wno-gnu-anonymous-struct -Wno-gnu-anonymous-struct
-Wno-gnu-zero-variadic-macro-arguments -Wno-gnu-zero-variadic-macro-arguments
...@@ -104,7 +105,7 @@ else() ...@@ -104,7 +105,7 @@ else()
else() else()
list(APPEND CMAKE_COMPILER_WARNINGS list(APPEND CMAKE_COMPILER_WARNINGS
-Wno-missing-field-initializers -Wno-missing-field-initializers
-Wno-deprecated-declarations # -Wno-deprecated-declarations
) )
endif() endif()
add_definitions(${CMAKE_COMPILER_WARNINGS}) add_definitions(${CMAKE_COMPILER_WARNINGS})
......
...@@ -89,49 +89,76 @@ ...@@ -89,49 +89,76 @@
<summary>Use manage pointer for resource management</summary> <summary>Use manage pointer for resource management</summary>
</message> </message>
</rule> </rule>
<rule>
<tokenlist>raw</tokenlist>
<pattern><![CDATA[hipLaunchKernelGGL \( (?!\( \w+ < \w+ > \))]]></pattern>
<message>
<id>UseDeviceLaunch</id>
<severity>style</severity>
<summary>Use device::launch instead</summary>
</message>
</rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern>! !</pattern> <pattern><![CDATA[if (\([^()]*(?-1)*[^()]*\)) { [^{}]* (return|throw|break|continue) [^;]* ; } else {]]></pattern>
<message> <message>
<id>doubleNegative</id> <id>UnnecessaryElseStatement</id>
<severity>style</severity> <severity>style</severity>
<summary>Double negative is always positive</summary> <summary>Else statement is not necessary.</summary>
</message> </message>
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( \w+ (\||&) \w+ \)]]></pattern> <pattern><![CDATA[\? (true|false) : (true|false)]]></pattern>
<message> <message>
<id>BitwiseOperatorInConditional</id> <id>RedundantConditionalOperator</id>
<severity>style</severity> <severity>style</severity>
<summary>Bitwise operator found in if statement.</summary> <summary>Conditional operator is redundant.</summary>
</message> </message>
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( [^)]+ \) { if \( [^)]+ \) ({[^{}]*(?1)*[^{}]*}) }]]></pattern> <pattern><![CDATA[switch (\([^()]*(?-1)*[^()]*\)) { }]]></pattern>
<message> <message>
<id>CollapsibleIfStatements</id> <id>EmptySwitchStatement</id>
<severity>style</severity> <severity>style</severity>
<summary>These two if statements can be collapsed into one.</summary> <summary>Empty switch statement.</summary>
</message> </message>
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[catch \( [^())]+ \) { }]]></pattern> <pattern><![CDATA[(?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w) ; \1 = [^;]+ ; return \1 ;]]></pattern>
<message> <message>
<id>EmptyCatchStatement</id> <id>RedundantLocalVariable</id>
<severity>style</severity> <severity>style</severity>
<summary>An empty catch statement.</summary> <summary>Variable is returned immediately after its declaration, can be simplified to just return expression.</summary>
</message> </message>
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[do { } while \(]]></pattern> <pattern><![CDATA[for \( ; [^;]+ ; \)]]></pattern>
<message> <message>
<id>EmptyDoWhileStatement</id> <id>ForLoopShouldBeWhileLoop</id>
<severity>style</severity> <severity>style</severity>
<summary>Empty do-while.</summary> <summary>For loop should be written as a while loop.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[while (\([^()]*(?-1)*[^()]*\)) { }]]></pattern>
<message>
<id>EmptyWhileStatement</id>
<severity>style</severity>
<summary>Empty while statement.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( \w+ (\||&) \w+ \)]]></pattern>
<message>
<id>BitwiseOperatorInConditional</id>
<severity>style</severity>
<summary>Bitwise operator found in if statement.</summary>
</message> </message>
</rule> </rule>
<rule> <rule>
...@@ -145,7 +172,7 @@ ...@@ -145,7 +172,7 @@
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[for \( [^()]+ \) { }]]></pattern> <pattern><![CDATA[for (\([^()]*(?-1)*[^()]*\)) { }]]></pattern>
<message> <message>
<id>EmptyForStatement</id> <id>EmptyForStatement</id>
<severity>style</severity> <severity>style</severity>
...@@ -154,7 +181,7 @@ ...@@ -154,7 +181,7 @@
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( [^()]+ \) { }]]></pattern> <pattern><![CDATA[if (\([^()]*(?-1)*[^()]*\)) { }]]></pattern>
<message> <message>
<id>EmptyIfStatement</id> <id>EmptyIfStatement</id>
<severity>style</severity> <severity>style</severity>
...@@ -163,43 +190,52 @@ ...@@ -163,43 +190,52 @@
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[switch \( [^()]+ \) { }]]></pattern> <pattern><![CDATA[if (\([^()]*(?-1)*[^()]*\)) { return (true|false) ; } else { return (true|false) ; }]]></pattern>
<message> <message>
<id>EmptySwitchStatement</id> <id>RedundantIfStatement</id>
<severity>style</severity> <severity>style</severity>
<summary>Empty switch statement.</summary> <summary>The if statement is redundant.</summary>
</message> </message>
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[while \( [^()]+ \) { }]]></pattern> <pattern><![CDATA[! !]]></pattern>
<message> <message>
<id>EmptyWhileStatement</id> <id>DoubleNegative</id>
<severity>style</severity> <severity>style</severity>
<summary>Empty while statement.</summary> <summary>Double negative is always positive.</summary>
</message> </message>
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[ for \( ; [^;]+ ; \)]]></pattern> <pattern><![CDATA[~ ~]]></pattern>
<message> <message>
<id>ForLoopShouldBeWhileLoop</id> <id>DoubleNegative</id>
<severity>style</severity> <severity>style</severity>
<summary>For loop should be written as a while loop.</summary> <summary>Double negative is always positive.</summary>
</message> </message>
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern>goto</pattern> <pattern><![CDATA[! \( !]]></pattern>
<message> <message>
<id>GotoStatement</id> <id>DoubleNegative</id>
<severity>style</severity> <severity>style</severity>
<summary>Goto considered harmful.</summary> <summary>Double negative is always positive.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[~ \( ~]]></pattern>
<message>
<id>DoubleNegative</id>
<severity>style</severity>
<summary>Double negative is always positive.</summary>
</message> </message>
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( \w+ != \w+ \) ({[^{}]*(?1)*[^{}]*}) else { (?!if)]]></pattern> <pattern><![CDATA[if \( \w+ != \w+ \) ({[^{}]*(?-1)*[^{}]*}) else { (?!if)]]></pattern>
<message> <message>
<id>InvertedLogic</id> <id>InvertedLogic</id>
<severity>style</severity> <severity>style</severity>
...@@ -208,7 +244,7 @@ ...@@ -208,7 +244,7 @@
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( ! \w+ \) ({[^{}]*(?1)*[^{}]*}) else { (?!if)]]></pattern> <pattern><![CDATA[if \( ! \w+ \) ({[^{}]*(?-1)*[^{}]*}) else { (?!if)]]></pattern>
<message> <message>
<id>InvertedLogic</id> <id>InvertedLogic</id>
<severity>style</severity> <severity>style</severity>
...@@ -235,34 +271,43 @@ ...@@ -235,34 +271,43 @@
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[\? (true|false) : (true|false)]]></pattern> <pattern><![CDATA[catch (\([^()]*(?-1)*[^()]*\)) { }]]></pattern>
<message> <message>
<id>RedundantConditionalOperator</id> <id>EmptyCatchStatement</id>
<severity>style</severity> <severity>style</severity>
<summary>Conditional operator is redundant.</summary> <summary>An empty catch statement.</summary>
</message> </message>
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( [^()]+ \) { return (true|false) ; } else { return (true|false) ; }]]></pattern> <pattern><![CDATA[if (\([^()]*(?-1)*[^()]*\)) { assert (\([^()]*(?-1)*[^()]*\)) ; }]]></pattern>
<message> <message>
<id>RedundantIfStatement</id> <id>ConditionalAssert</id>
<severity>style</severity> <severity>style</severity>
<summary>The if statement is redundant.</summary> <summary>The if condition should be included in assert.</summary>
</message> </message>
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[if \( [^()]+ \) { [^{}]* (return|throw|break|continue) [^;]* ; } else {]]></pattern> <pattern><![CDATA[if \( (\w) . empty \( \) \) { for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* \w : \1 \) ({[^{}]*(?-1)*[^{}]*}) }]]></pattern>
<message> <message>
<id>UnnecessaryElseStatement</id> <id>UnnecessaryEmptyCondition</id>
<severity>style</severity> <severity>style</severity>
<summary>Else statement is not necessary.</summary> <summary>Unnecessary check for empty before for range loop.</summary>
</message> </message>
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>normal</tokenlist>
<pattern><![CDATA[for \( \w+ (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = \w+ \[ \1 \] ; }]]></pattern> <pattern><![CDATA[if \( ! (\w) . empty \( \) \) { for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* \w : \1 \) ({[^{}]*(?-1)*[^{}]*}) }]]></pattern>
<message>
<id>UnnecessaryEmptyCondition</id>
<severity>style</severity>
<summary>Unnecessary check for empty before for range loop.</summary>
</message>
</rule>
<rule>
<tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = \w+ \[ \1 \] ; }]]></pattern>
<message> <message>
<id>useStlAlgorithm</id> <id>useStlAlgorithm</id>
<severity>style</severity> <severity>style</severity>
...@@ -270,8 +315,8 @@ ...@@ -270,8 +315,8 @@
</message> </message>
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( \w+ (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = \w+ ; }]]></pattern> <pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = \w+ ; }]]></pattern>
<message> <message>
<id>useStlAlgorithm</id> <id>useStlAlgorithm</id>
<severity>style</severity> <severity>style</severity>
...@@ -279,8 +324,8 @@ ...@@ -279,8 +324,8 @@
</message> </message>
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( \w+ (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (\w+ :: )*\w+ \( \) ; }]]></pattern> <pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (?:\w+ :: )*\w+ \( \) ; }]]></pattern>
<message> <message>
<id>useStlAlgorithm</id> <id>useStlAlgorithm</id>
<severity>style</severity> <severity>style</severity>
...@@ -288,8 +333,8 @@ ...@@ -288,8 +333,8 @@
</message> </message>
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( \w+ (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (\w+ :: )*\w+ \( \w+ \[ \1 \] \) ; }]]></pattern> <pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (?:\w+ :: )*\w+ \( \w+ \[ \1 \] \) ; }]]></pattern>
<message> <message>
<id>useStlAlgorithm</id> <id>useStlAlgorithm</id>
<severity>style</severity> <severity>style</severity>
...@@ -297,11 +342,65 @@ ...@@ -297,11 +342,65 @@
</message> </message>
</rule> </rule>
<rule> <rule>
<tokenlist>normal</tokenlist> <tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( \w+ (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (\w+ :: )*\w+ \( \w+ \[ \1 \] , \w+ \[ \1 \] \) ; }]]></pattern> <pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (?:\w+ :: )*\w+ \( \w+ \[ \1 \] , \w+ \[ \1 \] \) ; }]]></pattern>
<message> <message>
<id>useStlAlgorithm</id> <id>useStlAlgorithm</id>
<severity>style</severity> <severity>style</severity>
<summary>Considering using std::transform instead.</summary> <summary>Considering using std::transform instead.</summary>
</message> </message>
</rule> </rule>
<rule>
<tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) : (?:[^()]*(\([^()]*(?-1)*[^()]*\)))*[^)]*\) { (?:(?<idx1>\w+) \+\+|\+\+ (?<idx2>\w+)) ; if (\([^()]*(?-1)*[^()]*\)) { \w+ = \g{idx1}|\g{idx2} ; (?:break ; )?(?:return [^;]*; )?} }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
<summary>Considering using std::find or std::find_if instead.</summary>
</message>
</rule>
<rule>
<tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) : (?:[^()]*(\([^()]*(?-1)*[^()]*\)))*[^)]*\) { if (\([^()]*(?-1)*[^()]*\)) { \w+ = (?<idx>\w) ; (?:break ; )?(?:return [^;]*; )?} (?:(\g{idx}) \+\+|\+\+ (\g{idx})) ; }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
<summary>Considering using std::find or std::find_if instead.</summary>
</message>
</rule>
<rule>
<tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) : (?:[^()]*(\([^()]*(?-1)*[^()]*\)))*[^)]*\) { (?:(?<idx1>\w+) \+\+|\+\+ (?<idx2>\w+)) ; if (\([^()]*(?-1)*[^()]*\)) { return \g{idx1}|\g{idx2} ; } }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
<summary>Considering using std::find or std::find_if instead.</summary>
</message>
</rule>
<rule>
<tokenlist>simple</tokenlist>
<pattern><![CDATA[for \( (?:(?:\w+|<|>|::) )*(?:\w+|>)(?: &|\*)* (\w+) : (?:[^()]*(\([^()]*(?-1)*[^()]*\)))*[^)]*\) { if (\([^()]*(?-1)*[^()]*\)) { return (?<idx>\w+) ; } (?:(\g{idx}) \+\+|\+\+ (\g{idx})) ; }]]></pattern>
<message>
<id>useStlAlgorithm</id>
<severity>style</severity>
<summary>Considering using std::find or std::find_if instead.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern><![CDATA[do { } while \(]]></pattern>
<message>
<id>EmptyDoWhileStatement</id>
<severity>style</severity>
<summary>Empty do-while.</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern>goto</pattern>
<message>
<id>GotoStatement</id>
<severity>style</severity>
<summary>Goto considered harmful.</summary>
</message>
</rule>
pfultz2/rocm-recipes pfultz2/rocm-recipes
danmar/cppcheck@681cb7dd909d1bfe41796b7616e43265177b9464 -DHAVE_RULES=1 danmar/cppcheck@8aa68ee297c2d9ebadf5bcfd00c66ea8d9291e35 -DHAVE_RULES=1
ROCm-Developer-Tools/HIP@fc22ef991ce7cb15821c8ccb4f03cdfc3e7e43cf ROCm-Developer-Tools/HIP@e21df3058728ad8e73708bc99b82e0bdd3509c97
python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36 python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36
-f requirements.txt -f requirements.txt
...@@ -5,18 +5,23 @@ include(ROCMPackageConfigHelpers) ...@@ -5,18 +5,23 @@ include(ROCMPackageConfigHelpers)
add_library(migraphx add_library(migraphx
auto_contiguous.cpp auto_contiguous.cpp
common_subexpression_elimination.cpp common_subexpression_elimination.cpp
constant_propagate.cpp propagate_constant.cpp
dead_code_elimination.cpp dead_code_elimination.cpp
eliminate_allocation.cpp eliminate_allocation.cpp
eliminate_contiguous.cpp eliminate_contiguous.cpp
eliminate_concat.cpp eliminate_concat.cpp
eliminate_identity.cpp
eliminate_pad.cpp
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
env.cpp env.cpp
generate.cpp generate.cpp
instruction.cpp instruction.cpp
program.cpp program.cpp
quantization.cpp
shape.cpp shape.cpp
schedule.cpp
pass_manager.cpp
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
opt/memory_coloring.cpp opt/memory_coloring.cpp
...@@ -36,6 +41,8 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU ...@@ -36,6 +41,8 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU
set(PACKAGE_DEPENDS) set(PACKAGE_DEPENDS)
add_subdirectory(onnx) add_subdirectory(onnx)
add_subdirectory(tf)
add_subdirectory(py) add_subdirectory(py)
add_subdirectory(targets/cpu) add_subdirectory(targets/cpu)
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
......
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/contiguous.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
namespace migraphx { namespace migraphx {
......
#include <migraphx/constant_propagate.hpp>
#include <migraphx/program.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct match_const_add
{
auto matcher() const
{
return match::name("add")(match::args(match::name("@literal"), match::name("@literal")));
}
void apply(program& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto arg1 = ins->inputs().at(0)->get_literal();
auto arg2 = ins->inputs().at(1)->get_literal();
auto sum = p.add_literal(transform(arg1, arg2, [](auto x, auto y) { return x + y; }));
p.replace_instruction(ins, sum);
}
};
void constant_propagate::apply(program& p) const { match::find_matches(p, match_const_add{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <unordered_set>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -41,16 +42,17 @@ void dead_code_elimination::apply(program& p) const ...@@ -41,16 +42,17 @@ void dead_code_elimination::apply(program& p) const
// Skip the last instruction // Skip the last instruction
if(i == last) if(i == last)
break; break;
// Skip instruction with empty shape as output unless its a builtin or undefined // Skip instruction with empty shape as output unless its a builtin or undefined or identity
if(i->get_shape().elements() == 0 and not(i->name().front() == '@') and if(i->get_shape().elements() == 0 and i->name().front() != '@' and
not(i->name() == "undefined")) i->name() != "undefined" and i->name() != "identity")
continue; continue;
assert(bidistance(p, i, last) > 0); assert(bidistance(p, i, last) > 0);
fix([&](auto self, auto leaf) { fix([&](auto self, auto leaf) {
assert(p.has_instruction(leaf)); assert(p.has_instruction(leaf));
if(leaf->outputs().empty()) if(leaf->outputs().empty())
{ {
auto args = leaf->inputs(); std::unordered_set<instruction_ref> args(leaf->inputs().begin(),
leaf->inputs().end());
leaf->clear_arguments(); leaf->clear_arguments();
assert(bidistance(p, last, leaf) < 0); assert(bidistance(p, last, leaf) < 0);
assert(leaf != ins); assert(leaf != ins);
......
#include <migraphx/eliminate_allocation.hpp> #include <migraphx/eliminate_allocation.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/load.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/load.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
......
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -10,19 +9,60 @@ ...@@ -10,19 +9,60 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool try_compute_shape(const operation& op, const std::vector<instruction_ref>& args) static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs)
{ {
try try
{ {
compute_shape(op, args); shape new_shape = ins->get_operator().compute_shape(inputs);
// If the output shape is a standard shape, no need to try its output
if(new_shape.standard())
{
return true;
}
// if no changes for the shape, the contiguous can also be removed
if(new_shape == ins->get_shape())
{
return true;
}
auto outputs = ins->outputs();
// If the current instruction has no output, it means it is the last
// instruction and generates a non-standard output shape, and the last
// output shape is different from the case with the contiguous operator
if(outputs.empty())
{
return false;
}
for(auto output : outputs)
{
auto args = output->inputs();
std::vector<shape> input_shapes(args.size());
std::transform(args.begin(), args.end(), input_shapes.begin(), [&](auto& arg) {
return (arg == ins) ? new_shape : arg->get_shape();
});
if(!try_compute_shape(output, input_shapes))
{
return false;
}
}
} }
catch(...) catch(...)
{ {
return false; return false;
} }
return true; return true;
} }
static bool try_compute_shape(instruction_ref ins, const std::vector<instruction_ref>& args)
{
auto inputs = to_shapes(args);
return try_compute_shape(ins, inputs);
}
void eliminate_contiguous::apply(program& p) const void eliminate_contiguous::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
...@@ -38,7 +78,7 @@ void eliminate_contiguous::apply(program& p) const ...@@ -38,7 +78,7 @@ void eliminate_contiguous::apply(program& p) const
auto new_args = args; auto new_args = args;
auto prev = arg->inputs().front(); auto prev = arg->inputs().front();
replace(new_args, arg, prev); replace(new_args, arg, prev);
if(try_compute_shape(ins->get_operator(), new_args)) if(try_compute_shape(ins, new_args))
{ {
instruction::replace_argument(ins, arg, prev); instruction::replace_argument(ins, arg, prev);
} }
......
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_identity::apply(program& p) const
{
auto last = std::prev(p.end());
for(auto ins : iterator_for(p))
{
// Skip the first instruction, since we always process the previous
// instruction
if(ins == p.begin())
continue;
const auto i = std::prev(ins);
if(i->name() == "identity")
{
p.replace_instruction(i, i->inputs().front());
p.move_instruction(i, p.end());
}
if(ins == last)
{
if(ins->name() == "identity")
{
const instruction_ref& identity_input = ins->inputs().front();
if(identity_input->outputs().size() == 1)
{
p.move_instruction(identity_input, i);
// since this is the last instruction, removing it only
// requires changing "last" and calling remove below
last = std::prev(last);
}
}
break;
}
}
p.remove_instructions(std::next(last), p.end());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_pad::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
const std::string& op_name = ins->name();
if(op_name != "convolution" and op_name != "im2col" and op_name != "pooling")
continue;
auto input = ins->inputs().front();
if(input->name() != "pad")
continue;
if(op_name == "convolution")
update_op(op::convolution{}, input, ins, p);
else if(op_name == "im2col")
update_op(op::im2col{}, input, ins, p);
else if(op_name == "pooling")
update_op(op::pooling{}, input, ins, p);
}
}
template <class T>
void eliminate_pad::update_op(T,
const instruction_ref& input,
const instruction_ref& ins,
program& p) const
{
auto pad_op = any_cast<op::pad>(input->get_operator());
if(!pad_op.symmetric())
return;
std::vector<int64_t> pads = pad_op.pads;
std::array<size_t, 2> new_pads{static_cast<size_t>(pads[2]), static_cast<size_t>(pads[3])};
T op = any_cast<T>(ins->get_operator());
if(op.padding_mode != op::padding_mode_t::default_)
return;
op.padding = new_pads;
std::vector<instruction_ref> new_inputs{ins->inputs()};
new_inputs.front() = input->inputs().front();
p.replace_instruction(ins, op, new_inputs);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -21,6 +21,14 @@ bool disabled(const char* name) ...@@ -21,6 +21,14 @@ bool disabled(const char* name)
return contains({"0", "disable", "disabled", "no", "false"}, e.front()); return contains({"0", "disable", "disabled", "no", "false"}, e.front());
} }
std::size_t value_of(const char* name)
{
auto e = env(name);
if(e.empty())
return 0;
return std::stoul(e.front());
}
std::vector<std::string> env(const char* name) std::vector<std::string> env(const char* name)
{ {
auto p = std::getenv(name); auto p = std::getenv(name);
......
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
namespace migraphx { namespace migraphx {
...@@ -14,32 +17,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -14,32 +17,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{ {
if(ins->name() != "batch_norm_inference") if(ins->name() != "batch_norm_inference")
continue; continue;
if(not std::all_of(ins->inputs().begin() + 1, ins->inputs().end(), [](auto arg) { // Get scale, bias, mean, variance from inputs
return arg->name() == "@literal"; auto gamma = ins->inputs()[1]->eval();
})) auto bias = ins->inputs()[2]->eval();
auto mean = ins->inputs()[3]->eval();
auto variance = ins->inputs()[4]->eval();
if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
continue; continue;
auto conv_ins = ins->inputs()[0]; auto conv_ins = ins->inputs()[0];
if(conv_ins->name() != "convolution") if(conv_ins->name() != "convolution")
continue; continue;
if(conv_ins->inputs()[1]->name() != "@literal") // Get convolution weights
auto weights = conv_ins->inputs()[1]->eval();
if(weights.empty())
continue; continue;
// Get scale, bias, mean, variance from instruction_ref
const auto& gamma = ins->inputs()[1]->get_literal();
const auto& bias = ins->inputs()[2]->get_literal();
const auto& mean = ins->inputs()[3]->get_literal();
const auto& variance = ins->inputs()[4]->get_literal();
// Get epsilon // Get epsilon
auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator()); auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon; auto epsilon = bn_op.epsilon;
// Get convolution weights
const auto& weights = conv_ins->inputs()[1]->get_literal();
// Get convolution op // Get convolution op
auto conv_op = conv_ins->get_operator(); auto conv_op = conv_ins->get_operator();
auto weights_lens = weights.get_shape().lens(); auto weights_lens = weights.get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens(); auto conv_lens = conv_ins->get_shape().lens();
argument new_weights{weights.get_shape()}; argument new_weights{weights.get_shape()};
argument new_bias{bias.get_shape()}; argument new_bias{{bias.get_shape().type(), {bias.get_shape().elements()}}};
visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)( visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)(
[&](auto weights2, [&](auto weights2,
auto gamma2, auto gamma2,
...@@ -51,18 +52,18 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -51,18 +52,18 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
dfor(weights_lens[0], weights_lens[1], weights_lens[2], weights_lens[3])( dfor(weights_lens[0], weights_lens[1], weights_lens[2], weights_lens[3])(
[&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) {
new_weights2(k, c, h, w) = new_weights2(k, c, h, w) =
gamma2(k) / std::sqrt(variance2(k) + epsilon) * weights2(k, c, h, w); gamma2[k] / std::sqrt(variance2[k] + epsilon) * weights2(k, c, h, w);
}); });
dfor(new_bias.get_shape().elements())([&](std::size_t c) { dfor(new_bias.get_shape().elements())([&](std::size_t c) {
new_bias2(c) = new_bias2[c] =
bias2(c) - (gamma2(c) * mean2(c) / std::sqrt(variance2(c) + epsilon)); bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
}); });
}); });
// Replace convolution instruction with updated weights // Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()}); auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights}); auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights});
auto b = p.insert_instruction(ins, op::broadcast{1, c->get_shape()}, l_bias); auto b = p.insert_instruction(ins, op::broadcast{1, c->get_shape().lens()}, l_bias);
p.replace_instruction(ins, op::add{}, {c, b}); p.replace_instruction(ins, op::add{}, {c, b});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment