diff --git a/.clang-tidy b/.clang-tidy old mode 100644 new mode 100755 index 0011e948f452e394a7670e7acdf148f7ffe13cc8..597e84527e03ab108e072ab7cac5b4976c387e12 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,10 +1,10 @@ CheckOptions: + - key: bugprone-reserved-identifier.AllowedIdentifiers + value: '__HIP_PLATFORM_HCC__;__HIP_ROCclr__' - key: bugprone-unused-return-value.CheckedFunctions value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product' - key: cppcoreguidelines-macro-usage.AllowedRegexp - value: 'DEBUG|FALLTHROUGH|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_GENERATE_|_DETAIL_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED' - - key: cppcoreguidelines-narrowing-conversions.WarnOnFloatingPointNarrowingConversion - value: 0 + value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|DEPRECATED|STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_' - key: modernize-loop-convert.MinConfidence value: risky - key: modernize-loop-convert.NamingStyle @@ -16,7 +16,7 @@ CheckOptions: - key: readability-function-size.BranchThreshold value: '15' - key: readability-function-size.LineThreshold - value: '300' + value: '350' - key: readability-function-size.NestingThreshold value: '5' - key: readability-function-size.ParameterThreshold @@ -109,7 +109,7 @@ CheckOptions: value: CamelCase - key: readability-identifier-naming.TypeAliasCase value: lower_case - # - key: readability-identifier-naming.MacroDefinitionCase - # value: UPPER_CASE - # - key: readability-identifier-naming.MacroDefinitionPrefix - # value: MIGRAPHX_ + - key: readability-identifier-naming.MacroDefinitionCase + value: UPPER_CASE + - key: readability-identifier-naming.MacroDefinitionPrefix + value: MIGRAPHX_ diff --git a/.githooks/pre-commit b/.githooks/pre-commit index f51683236b74038b8ec050d9b2343c41339b2905..a9a2a0fc666dcb815996177e7d5dd414323a9ccd 100755 --- a/.githooks/pre-commit +++ b/.githooks/pre-commit @@ -4,7 +4,7 @@ # are installed, and if so, uses the installed version to format # the staged changes. -base=clang-format-5.0 +base=clang-format-10 format="" yapf_base=yapf yapf_format="" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100755 index 0000000000000000000000000000000000000000..dea1f303969d0ca230cd0d6260fd94899f889a85 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,258 @@ +name: migraphx + +on: [push, pull_request] + +jobs: + cancel: + runs-on: ubuntu-latest + steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@0.6.0 + with: + access_token: ${{ github.token }} + tidy: + runs-on: ubuntu-18.04 + + steps: + - name: Free space + run: sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia* /usr/local/lib/android + + - uses: actions/checkout@v2 + + # In this step, this action saves a list of existing images, + # the cache is created without them in the post run. + # It also restores the cache if it exists. + - uses: satackey/action-docker-layer-caching@v0.0.11 + # Ignore the failure of a step and avoid terminating the job. + continue-on-error: true + + - name: Prepare timestamp + id: cache_timestamp + shell: cmake -P {0} + run: | + string(TIMESTAMP current_date "%Y-%m-%d-%H;%M;%S" UTC) + message("::set-output name=timestamp::${current_date}") + + - name: Cache files for tidy + uses: pat-s/always-upload-cache@v2.1.3 + with: + path: tidy-cache + key: tidy-cache-${{ steps.cache_timestamp.outputs.timestamp }} + restore-keys: | + tidy-cache-${{ steps.cache_timestamp.outputs.timestamp }} + tidy-cache- + + - name: Build the Docker image + run: docker build . --file hip-clang.docker --tag migraphx + + - name: Clang tidy + shell: bash -c "docker run -i -v=$GITHUB_WORKSPACE:/data -w /data migraphx bash < {0}" + run: | + mkdir build + cd build + CXX=/opt/rocm/llvm/bin/clang++ CC=/opt/rocm/llvm/bin/clang cmake \ + -DMIGRAPHX_ENABLE_GPU=On \ + -DMIGRAPHX_ENABLE_CPU=On \ + -DROCM_ENABLE_GH_ANNOTATIONS=On \ + -DCLANG_TIDY_DEPEND_ON_TARGET=Off \ + -DCLANG_TIDY_CACHE=/data/tidy-cache \ + .. + make -j2 -k onnx-proto tf-proto tidy + + cppcheck: + runs-on: ubuntu-18.04 + + steps: + - name: Free space + run: sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia* /usr/local/lib/android + - uses: actions/checkout@v2 + + # In this step, this action saves a list of existing images, + # the cache is created without them in the post run. + # It also restores the cache if it exists. + - uses: satackey/action-docker-layer-caching@v0.0.11 + # Ignore the failure of a step and avoid terminating the job. + continue-on-error: true + + - name: Prepare timestamp + id: cache_timestamp + shell: cmake -P {0} + run: | + string(TIMESTAMP current_date "%Y-%m-%d-%H;%M;%S" UTC) + message("::set-output name=timestamp::${current_date}") + + - name: Cache files for cppcheck + uses: pat-s/always-upload-cache@v2.1.3 + with: + path: cppcheck-cache + key: cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-${{ steps.cache_timestamp.outputs.timestamp }} + restore-keys: | + cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-${{ steps.cache_timestamp.outputs.timestamp }} + cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}- + + - name: Build the Docker image + run: docker build . --file hip-clang.docker --tag migraphx + + - name: Cppcheck + shell: bash -c "docker run -i -v=$GITHUB_WORKSPACE:/data -w /data migraphx bash < {0}" + run: | + mkdir build + cd build + CXX=/opt/rocm/llvm/bin/clang++ CC=/opt/rocm/llvm/bin/clang cmake \ + -DCPPCHECK_BUILD_DIR=/data/cppcheck-cache \ + -DROCM_ENABLE_GH_ANNOTATIONS=On \ + .. + make -j2 cppcheck + + format: + runs-on: ubuntu-18.04 + + steps: + - name: Free space + run: sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia* /usr/local/lib/android + - uses: actions/checkout@v2 + + # In this step, this action saves a list of existing images, + # the cache is created without them in the post run. + # It also restores the cache if it exists. + - uses: satackey/action-docker-layer-caching@v0.0.11 + # Ignore the failure of a step and avoid terminating the job. + continue-on-error: true + + - name: Build the Docker image + run: docker build . --file hip-clang.docker --tag migraphx + + - name: Check formatting + shell: bash -c "docker run -i -v=$GITHUB_WORKSPACE:/data -w /data migraphx bash < {0}" + run: | + set -e + find . -iname '*.h' \ + -o -iname '*.hpp' \ + -o -iname '*.cpp' \ + -o -iname '*.h.in' \ + -o -iname '*.hpp.in' \ + -o -iname '*.cpp.in' \ + -o -iname '*.cl' \ + | grep -v 'build/' \ + | xargs -n 1 -P 1 -I{} -t sh -c 'clang-format-10 -style=file {} | diff - {}' + find . -iname '*.py' \ + | grep -v 'build/' \ + | xargs -n 1 -P 1 -I{} -t sh -c 'yapf {} | diff - {}' + + pyflakes: + runs-on: ubuntu-18.04 + + steps: + - name: Free space + run: sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia* /usr/local/lib/android + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Install pyflakes + run: pip install pyflakes==2.4.0 mypy==0.931 + + - name: Run pyflakes + run: | + pyflakes --version + pyflakes examples/ tools/ src/ test/ doc/ + mypy --version + mypy tools/api.py + + linux: + + runs-on: ${{ matrix.os }} + + env: + CCACHE_COMPRESSLEVEL: 10 + CCACHE_DIR: ${{github.workspace}}/ccache + CCACHE_NOHASHDIR: true + CCACHE_BASEDIR: ${{github.workspace}} + CCACHE_MAXSIZE: 1 + + strategy: + matrix: + os: + - ubuntu-18.04 + - ubuntu-20.04 + configuration: + - debug + - release + - codecov + + steps: + - name: Free space + run: sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia* /usr/local/lib/android + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.6 + - name: Cache dependencies + # Ignore the failure of a step and avoid terminating the job. + continue-on-error: true + uses: actions/cache@v2 + with: + # This path is specific to Ubuntu + path: ${{ github.workspace }}/cget + # Look to see if there is a cache hit for the corresponding requirements file + key: + ${{ matrix.os }}-cget-4-${{ hashFiles('requirements.txt', 'dev-requirements.txt') }} + ${{ matrix.os }}-cget-4- + + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz + rbuild prepare -d cget -s gh + - name: Prepare timestamp + id: cache_timestamp + shell: cmake -P {0} + run: | + string(TIMESTAMP current_date "%Y-%m-%d-%H;%M;%S" UTC) + message("::set-output name=timestamp::${current_date}") + + - name: Cache files for ccache + # Ignore the failure of a step and avoid terminating the job. + continue-on-error: true + uses: pat-s/always-upload-cache@v2.1.3 + with: + path: ccache + key: ${{ matrix.os }}-${{ matrix.configuration }}-ccache-${{ steps.cache_timestamp.outputs.timestamp }} + restore-keys: | + ${{ matrix.os }}-${{ matrix.configuration }}-ccache-${{ steps.cache_timestamp.outputs.timestamp }} + ${{ matrix.os }}-${{ matrix.configuration }}-ccache- + + - name: Build and test + env: + CMAKE_PREFIX_PATH: ${{ github.workspace }}/cget + CCACHE_LOGFILE: /tmp/ccache.log + CXXFLAGS: -Werror -pthread --param ggc-min-expand=5 --param ggc-min-heapsize=8192 + run: | + echo "leak:dnnl::impl::malloc" > suppressions.txt + export LSAN_OPTIONS="suppressions=$(pwd)/suppressions.txt" + rbuild build -d cget -s gh -T check \ + -DCMAKE_BUILD_TYPE=${{matrix.configuration}} \ + -DMIGRAPHX_ENABLE_PYTHON=${{matrix.configuration == 'release' && 'On' || 'Off'}} \ + -DCMAKE_CXX_FLAGS_DEBUG="-g1 -Os -fdebug-prefix-map=$PWD=. -fdebug-types-section -fno-omit-frame-pointer -fsanitize=undefined -fno-sanitize-recover=undefined" \ + -DCMAKE_CXX_FLAGS_CODECOV="-g1 -Og -fdebug-prefix-map=$PWD=. -fdebug-types-section -fprofile-arcs -ftest-coverage -fno-omit-frame-pointer" \ + -DCMAKE_EXE_LINKER_FLAGS='-fuse-ld=gold' \ + -DCMAKE_SHARED_LINKER_FLAGS='-fuse-ld=gold' + ${{ github.workspace }}/cget/bin/ccache -s + + - name: Upload code coverage + if: "matrix.configuration == 'codecov'" + env: + CODECOV_TOKEN: "8545af1c-f90b-4345-92a5-0d075503ca56" + run: | + sudo apt-get install -y lcov + cd build + lcov --directory . --capture --output-file $(pwd)/coverage.info + lcov --remove $(pwd)/coverage.info '/usr/*' --output-file $(pwd)/coverage.info + lcov --list $(pwd)/coverage.info + curl -s https://codecov.io/bash | bash + echo "Uploaded" + + diff --git a/.gitignore b/.gitignore index 1377554ebea6f98a2c748183bc5a96852af12ac2..b617ef4de8365f1bc83f2f9844fcbf18006969e8 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,63 @@ -*.swp +#==============================================================================# +# File extensions to be ignored anywhere in the tree. +#==============================================================================# + +# Temp files created by most text editors +*~ + +# Merge files created by git +*.orig + +# Byte compiled python modules +*.pyc +*.pyd + +# Vim swap files +.*.sw? +.sw? + +# Visual Studio +.vs +/.vscode/* + +# Sublime Text settings +*.sublime-workspace +*.sublime-project + +# Eclipse Project settings +*.*project +.settings + +# OS X specific files +.DS_store + +#==============================================================================# +# Explicit files to ignore (only matches one). +#==============================================================================# + +# Various tags +/tags +/TAGS +/GPATH +/GRTAGS +/GSYMS +/GTAGS +/ID +.gitusers +/compile_commands.json + +/CMakeSettings.json + + +#==============================================================================# +# Directories to ignore (do not add trailing '/'s, they skip symlinks). +#==============================================================================# +# Nested build directory +/build* + +# Downloaded models +test/onnx/models + +# VS2017 and VSCode config files. +.vscode +.vs diff --git a/CMakeLists.txt b/CMakeLists.txt old mode 100644 new mode 100755 index 038ccd9b5cdf1e7ec1ad5b156083e7832ba01849..53bc376cbdca16a7e885e294d58b6a60f39345fa --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,6 +10,12 @@ if( NOT MSVC_IDE AND NOT CMAKE_BUILD_TYPE ) set( CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel." ) endif() +# Setup valid strings for build type +if (NOT CMAKE_CONFIGURATION_TYPES) + set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo;MinSizeRel" CACHE STRING "Configs") +endif() +set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS ${CMAKE_CONFIGURATION_TYPES}) + # Default installation path if(WIN32) set(CMAKE_INSTALL_PREFIX "/opt/rocm/x86_64-w64-mingw32" CACHE PATH "") @@ -20,34 +26,52 @@ endif() project(migraphx) find_package(ROCM REQUIRED) +find_path(HALF_INCLUDE_DIR half.hpp PATH_SUFFIXES half) +if (NOT HALF_INCLUDE_DIR) + message(FATAL_ERROR "Could not find half.hpp - Please check that the install path of half.hpp has been added to CMAKE_PREFIX_PATH") +endif() + +include(CheckTypeSize) +set(CMAKE_REQUIRED_INCLUDES ${HALF_INCLUDE_DIR}) +set(CMAKE_EXTRA_INCLUDE_FILES half.hpp) +check_type_size("half_float::detail::expr" HALF_EXPR LANGUAGE CXX) +set(CMAKE_REQUIRED_INCLUDES) +set(CMAKE_EXTRA_INCLUDE_FILES) + +find_package(nlohmann_json 3.8.0 REQUIRED) + include(ROCMSetupVersion) -rocm_setup_version(VERSION 0.5) +rocm_setup_version(VERSION 2.3) set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}) option( BUILD_SHARED_LIBS "Build as a shared library" ON ) -if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "5.4") - message(FATAL_ERROR "MIGraph requires at least gcc 5.4") - endif() -endif() - include(CheckCXXCompilerFlag) check_cxx_compiler_flag("--cuda-host-only -x hip" HAS_HIP) if(HAS_HIP) - message(STATUS "Enable miopen backend") + message(STATUS "Enable gpu backend") set(MIGRAPHX_ENABLE_GPU On CACHE BOOL "") else() set(MIGRAPHX_ENABLE_GPU Off CACHE BOOL "") endif() -add_compile_options(-std=c++14) +# Disable cpu backend by default +set(MIGRAPHX_ENABLE_CPU Off CACHE BOOL "") + +set(CMAKE_CXX_STANDARD_DEFAULT "") +add_compile_options(-std=c++17) + +if(${CMAKE_VERSION} VERSION_LESS "3.12.0") + set(CONFIGURE_DEPENDS) +else() + set(CONFIGURE_DEPENDS CONFIGURE_DEPENDS) +endif() list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) include(EnableCompilerWarnings) include(ROCMClangTidy) -if(CMAKE_CXX_COMPILER MATCHES ".*hcc") +if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+") set(MIGRAPHX_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name) # Enable tidy on hip elseif(MIGRAPHX_ENABLE_GPU) @@ -55,15 +79,49 @@ elseif(MIGRAPHX_ENABLE_GPU) endif() rocm_enable_clang_tidy( CHECKS - * - -android-cloexec-fopen - -clang-analyzer-alpha.core.CastToStruct + boost-* + bugprone-* + cert-* + clang-analyzer-* + clang-diagnostic-* + cppcoreguidelines-* + google-* + hicpp-multiway-paths-covered + hicpp-signed-bitwise + llvm-namespace-comment + misc-* + modernize-* + performance-* + readability-* + -bugprone-easily-swappable-parameters + -bugprone-implicit-widening-of-multiplication-result + -bugprone-macro-parentheses + -bugprone-signed-char-misuse + # Disable the aliased reserved identifiers + -cert-dcl37-c + -cert-dcl51-cpp + -cert-err33-c + -cert-str34-c + # Disable all alpha checks by default + -clang-analyzer-alpha* + # Enable some alpha checks + clang-analyzer-alpha.core.CallAndMessageUnInitRefArg + clang-analyzer-alpha.core.Conversion + clang-analyzer-alpha.core.IdenticalExpr + clang-analyzer-alpha.core.PointerArithm + clang-analyzer-alpha.core.PointerSub + clang-analyzer-alpha.core.TestAfterDivZero + clang-analyzer-alpha.cplusplus.InvalidIterator + clang-analyzer-alpha.cplusplus.IteratorRange + clang-analyzer-alpha.cplusplus.MismatchedIterator + clang-analyzer-alpha.cplusplus.MisusedMovedObject -clang-analyzer-optin.performance.Padding -clang-diagnostic-deprecated-declarations -clang-diagnostic-extern-c-compat -clang-diagnostic-disabled-macro-expansion -clang-diagnostic-unused-command-line-argument -cppcoreguidelines-explicit-virtual-functions + -cppcoreguidelines-init-variables -cppcoreguidelines-pro-bounds-array-to-pointer-decay -cppcoreguidelines-pro-bounds-constant-array-index -cppcoreguidelines-pro-bounds-pointer-arithmetic @@ -72,22 +130,12 @@ rocm_enable_clang_tidy( -cppcoreguidelines-pro-type-union-access -cppcoreguidelines-pro-type-vararg -cppcoreguidelines-special-member-functions - -fuchsia-* - -google-readability-braces-around-statements - -google-readability-todo + -cppcoreguidelines-virtual-class-destructor + -google-readability-* -google-runtime-int -google-runtime-references - -hicpp-braces-around-statements - -hicpp-explicit-conversions - -hicpp-member-init - -hicpp-no-array-decay - -hicpp-special-member-functions - -hicpp-uppercase-literal-suffix - -hicpp-use-override - # This check is broken - -llvm-header-guard - -llvm-include-order -misc-macro-parentheses + -misc-no-recursion -modernize-concat-nested-namespaces -modernize-pass-by-value -modernize-use-default-member-init @@ -97,12 +145,18 @@ rocm_enable_clang_tidy( -modernize-use-transparent-functors -performance-type-promotion-in-math-fn -readability-braces-around-statements + -readability-convert-member-functions-to-static -readability-else-after-return + -readability-function-cognitive-complexity + -readability-identifier-length -readability-named-parameter - -readability-uppercase-literal-suffix, + -readability-redundant-string-init + -readability-suspicious-call-argument + -readability-uppercase-literal-suffix -*-avoid-c-arrays -*-explicit-constructor -*-magic-numbers + -*-narrowing-conversions -*-non-private-member-variables-in-classes -*-use-auto -*-use-emplace @@ -114,12 +168,14 @@ rocm_enable_clang_tidy( -UNDEBUG -DMIGRAPHX_USE_CLANG_TIDY "-Dmain\\\\(...\\\\)=main\\\\(__VA_ARGS__\\\\) // NOLINT" - # CLANG_ARGS - # -analyzer-config optin.cplusplus.UninitializedObject:Pedantic=true - # -analyzer-config widen-loops=true - # -analyzer-config unroll-loops=true - # -analyzer-config cfg-lifetime=true - # -analyzer-config cfg-scopes=true + CLANG_ARGS + -analyzer-max-loop 10 + -analyzer-inline-max-stack-depth 10 + -analyzer-config optin.cplusplus.UninitializedObject:Pedantic=true + -analyzer-config widen-loops=true + -analyzer-config unroll-loops=true + -analyzer-config cfg-lifetime=true + -analyzer-config cfg-scopes=true ) include(ROCMCppCheck) rocm_enable_cppcheck( @@ -128,7 +184,7 @@ rocm_enable_cppcheck( style performance portability - SUPPRESS + SUPPRESS ConfigurationNotChecked unmatchedSuppression unusedFunction @@ -136,28 +192,39 @@ rocm_enable_cppcheck( passedByValue unusedStructMember functionStatic - functionConst:*program.* + functionConst shadowFunction shadowVar shadowVariable unsafeClassDivZero definePrefix:*test/include/test.hpp + ctuOneDefinitionRuleViolation:*test/* + useSmartPointer:*src/api/api.cpp + useSmartPointer:*make_shared_array.hpp + constParameter:*src/targets/gpu/*.cpp + constParameter:*src/targets/gpu/*.hpp + # Suppress mlir_conv.cpp since this file will be deleted + *:*src/targets/gpu/mlir_conv.cpp FORCE INCONCLUSIVE RULE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/cppcheck.rules SOURCES + examples/ src/ test/ INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/src/include ${CMAKE_CURRENT_SOURCE_DIR}/src/targets/cpu/include - ${CMAKE_CURRENT_SOURCE_DIR}/src/targets/miopen/include + ${CMAKE_CURRENT_SOURCE_DIR}/src/targets/gpu/include + ${CMAKE_CURRENT_SOURCE_DIR}/src/targets/gpu/device/include + ${CMAKE_CURRENT_SOURCE_DIR}/src/targets/gpu/kernels/include ${CMAKE_CURRENT_SOURCE_DIR}/test/include DEFINE CPPCHECK=1 __device__= __host__= + __global__= ) enable_testing() @@ -169,14 +236,21 @@ rocm_create_package( MAINTAINER "Paul Fultz II " LDCONFIG PTH - DEPENDS miopen-hip rocblas hip_hcc half + DEPENDS miopen-hip rocblas hip-rocclr hip-base half ) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) - add_subdirectory(src) add_subdirectory(doc) add_subdirectory(test) add_subdirectory(tools) + +set(DEST_DIR ${CMAKE_BINARY_DIR}) +file(GLOB backend_files ${CMAKE_SOURCE_DIR}/src/py/backend/*.py) +file(MAKE_DIRECTORY ${DEST_DIR}/lib/onnx_migraphx) +foreach(py_file ${backend_files}) + configure_file(${py_file} ${DEST_DIR}/lib/onnx_migraphx/. COPYONLY) +endforeach(py_file) +configure_file(${CMAKE_SOURCE_DIR}/test/py/onnx_backend_test.py ${DEST_DIR}/onnx_backend_test.py COPYONLY) diff --git a/Dockerfile b/Dockerfile old mode 100644 new mode 100755 index 59410167f9ced7997b28a402e96e44c6f62c87fc..a974f089c7d92025e2709be50247e8682f9d569c --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:xenial-20180417 +FROM ubuntu:20.04 ARG PREFIX=/usr/local @@ -6,81 +6,94 @@ ARG PREFIX=/usr/local RUN dpkg --add-architecture i386 # Add rocm repository -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y curl apt-utils wget software-properties-common -RUN curl https://raw.githubusercontent.com/RadeonOpenCompute/ROCm-docker/master/add-rocm.sh | bash - -# Add ubuntu toolchain -RUN apt-get update && add-apt-repository ppa:ubuntu-toolchain-r/test -y +RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.0.2/ ubuntu main > /etc/apt/sources.list.d/rocm.list' # Install dependencies RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ apt-utils \ build-essential \ - clang-5.0 \ - clang-format-5.0 \ - clang-tidy-5.0 \ + clang-format-10 \ cmake \ - comgr \ curl \ doxygen \ g++-7 \ gdb \ git \ - hsa-rocr-dev \ - hsakmt-roct-dev \ lcov \ - libelf-dev \ - libncurses5-dev \ - libnuma-dev \ - libpthread-stubs0-dev \ - libssl-dev \ - python \ - python-dev \ - python-pip \ - rocm-device-libs \ - rocm-opencl \ - rocm-opencl-dev \ + locales \ + pkg-config \ + python3 \ + python3-dev \ + python3-pip \ software-properties-common \ wget \ + rocm-device-libs \ + hip-base \ + libnuma-dev \ + miopen-hip \ + rocblas \ zlib1g-dev && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* -# Install cget -RUN pip install cget - -# Install rclone -RUN pip install https://github.com/pfultz2/rclone/archive/master.tar.gz +# Workaround broken rocm packages +RUN ln -s /opt/rocm-* /opt/rocm +RUN echo "/opt/rocm/lib" > /etc/ld.so.conf.d/rocm.conf +RUN echo "/opt/rocm/llvm/lib" > /etc/ld.so.conf.d/rocm-llvm.conf +RUN ldconfig -# Install yapf -RUN pip install yapf==0.28.0 - -# Install hcc -RUN rclone -b roc-2.6.x -c 0f4c96b7851af2663a7f3ac16ecfb76c7c78a5bf https://github.com/RadeonOpenCompute/hcc.git /hcc -RUN cget -p $PREFIX install hcc,/hcc - -# Use hcc -RUN cget -p $PREFIX init --cxx $PREFIX/bin/hcc +RUN locale-gen en_US.UTF-8 +RUN update-locale LANG=en_US.UTF-8 -# Workaround hip's broken cmake -RUN ln -s $PREFIX /opt/rocm/hip -RUN ln -s $PREFIX /opt/rocm/hcc +ENV LC_ALL=C.UTF-8 +ENV LANG=C.UTF-8 # Install dependencies ADD dev-requirements.txt /dev-requirements.txt ADD requirements.txt /requirements.txt -RUN cget -p $PREFIX install -f /dev-requirements.txt -DMIOPEN_CACHE_DIR="" +ADD rbuild.ini /rbuild.ini -ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db -ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db +COPY ./tools/install_prereqs.sh / +RUN /install_prereqs.sh /usr/local / && rm /install_prereqs.sh +RUN test -f /usr/local/hash || exit 1 -ENV LD_LIBRARY_PATH=$PREFIX/lib +# Install yapf +RUN pip3 install yapf==0.28.0 # Install doc requirements ADD doc/requirements.txt /doc-requirements.txt -RUN pip install -r /doc-requirements.txt +RUN pip3 install -r /doc-requirements.txt + +# Download real models to run onnx unit tests +ENV ONNX_HOME=$HOME +COPY ./tools/download_models.sh / +RUN /download_models.sh && rm /download_models.sh + +# Install latest ccache version +RUN cget -p $PREFIX install facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cmake +RUN cget -p $PREFIX install ccache@v4.1 + +# Install newer cmake for onnx runtime +RUN cget -p /opt/cmake install kitware/cmake@v3.13.4 + +ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime +ARG ONNXRUNTIME_BRANCH=master +ARG ONNXRUNTIME_COMMIT=24f1bd6156cf5968bbc76dfb0e801a9b9c56b9fc +RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime && \ + cd onnxruntime && \ + git checkout ${ONNXRUNTIME_COMMIT} && \ + /bin/sh dockerfiles/scripts/install_common_deps.sh + +ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh + +RUN PATH=/opt/cmake/bin:$PATH cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@02078ce236ad90e3aec04c0c770ef5bfc99e49c2 + +ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db +ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db +ENV LD_LIBRARY_PATH=$PREFIX/lib # Setup ubsan environment to printstacktrace -RUN ln -s /usr/bin/llvm-symbolizer-5.0 /usr/local/bin/llvm-symbolizer ENV UBSAN_OPTIONS=print_stacktrace=1 ENV ASAN_OPTIONS=detect_stack_use_after_return=1:check_initialization_order=1:strict_init_order=1 +RUN ln -s /opt/rocm/llvm/bin/llvm-symbolizer /usr/bin/llvm-symbolizer + diff --git a/Jenkinsfile b/Jenkinsfile old mode 100644 new mode 100755 index d901f885e0ca52e6b514c8c174f46e63bc5270c9..cbf194aa07c4ae936f1e9f1a6d9042f1d13284e7 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1,19 +1,31 @@ -def rocmtestnode(variant, name, body) { +// def rocmtestnode(variant, name, body, args, pre) { +def rocmtestnode(Map conf) { + def variant = conf.get("variant") + def name = conf.get("node") + def body = conf.get("body") + def docker_args = conf.get("docker_args", "") + def docker_build_args = conf.get("docker_build_args", "") + def pre = conf.get("pre", {}) + def ccache = "/var/jenkins/.cache/ccache" def image = 'migraphxlib' + env.CCACHE_COMPRESSLEVEL = 7 + env.CCACHE_DIR = ccache def cmake_build = { compiler, flags -> def cmd = """ env ulimit -c unlimited + echo "leak:dnnl::impl::malloc" > suppressions.txt + export LSAN_OPTIONS="suppressions=\$(pwd)/suppressions.txt" rm -rf build mkdir build cd build - CXX=${compiler} CXXFLAGS='-Werror -Wno-fallback' cmake ${flags} .. - CTEST_PARALLEL_LEVEL=32 make -j32 generate all doc package check + CXX=${compiler} CXXFLAGS='-Werror' cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ${flags} .. + make -j\$(nproc) generate all doc package check VERBOSE=1 """ echo cmd sh cmd - if (compiler == "hcc") { + if (compiler != "hcc") { // Only archive from master or develop if (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "master") { archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true @@ -25,26 +37,28 @@ def rocmtestnode(variant, name, body) { stage("checkout ${variant}") { checkout scm } - stage("image ${variant}") { - try { - docker.build("${image}", '.') - } catch(Exception ex) { - docker.build("${image}", '--no-cache .') + gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'AMDMIGraphX') { + pre() + stage("image ${variant}") { + try { + docker.build("${image}", "${docker_build_args} .") + } catch(Exception ex) { + docker.build("${image}", "${docker_build_args} --no-cache .") + } } - } - withDockerContainer(image: image, args: '--device=/dev/kfd --device=/dev/dri --group-add video --cap-add SYS_PTRACE') { - timeout(time: 1, unit: 'HOURS') { - body(cmake_build) + withDockerContainer(image: image, args: "--device=/dev/kfd --device=/dev/dri --group-add video --cap-add SYS_PTRACE -v=/var/jenkins/:/var/jenkins ${docker_args}") { + timeout(time: 2, unit: 'HOURS') { + body(cmake_build) + } } } } } } -@NonCPS def rocmtest(m) { def builders = [:] - for(e in m) { + m.each { e -> def label = e.key; def action = e.value; builders[label] = { @@ -54,95 +68,72 @@ def rocmtest(m) { parallel builders } -@NonCPS -def rocmnode(name, body) { - def node_name = 'rocmtest || rocm' - if(name == 'fiji') { - node_name = 'rocmtest && fiji'; - } else if(name == 'vega') { - node_name = 'rocmtest && vega'; - } else { - node_name = name - } - return { label -> - rocmtestnode(label, node_name, body) +def rocmnodename(name) { + def rocmtest_name = "(rocmtest || migraphx)" + def node_name = "${rocmtest_name}" + if(name == "fiji") { + node_name = "${rocmtest_name} && fiji"; + } else if(name == "vega") { + node_name = "${rocmtest_name} && vega"; + } else if(name == "navi21") { + node_name = "${rocmtest_name} && navi21"; + } else if(name == "nogpu") { + return rocmtest_name; } + return node_name } -@NonCPS -def rocmnode(body) { - rocmnode('rocmtest', body) +def rocmnode(name, body) { + return { label -> + rocmtestnode(variant: label, node: rocmnodename(name), body: body) + } } -// Static checks -rocmtest tidy: rocmnode('rocmtest') { cmake_build -> - stage('Clang Tidy') { - sh ''' - rm -rf build - mkdir build - cd build - CXX=hcc cmake .. - make -j8 -k analyze - ''' - } -}, format: rocmnode('rocmtest') { cmake_build -> - stage('Format') { - sh ''' - find . -iname \'*.h\' \ - -o -iname \'*.hpp\' \ - -o -iname \'*.cpp\' \ - -o -iname \'*.h.in\' \ - -o -iname \'*.hpp.in\' \ - -o -iname \'*.cpp.in\' \ - -o -iname \'*.cl\' \ - | grep -v 'build/' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-5.0 -style=file {} | diff - {}\' - find . -iname \'*.py\' \ - | grep -v 'build/' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'yapf {} | diff - {}\' - ''' - } -}, clang_debug: rocmnode('vega') { cmake_build -> - stage('Clang Debug') { - // TODO: Enable integer +rocmtest clang_debug: rocmnode('vega') { cmake_build -> + stage('Hip Clang Debug') { def sanitizers = "undefined" - def debug_flags = "-g -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" - cmake_build("hcc", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'") + def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" + cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'") } }, clang_release: rocmnode('vega') { cmake_build -> - stage('Clang Release') { - cmake_build("hcc", "-DCMAKE_BUILD_TYPE=release") - } -}, clang_release_py3: rocmnode('vega') { cmake_build -> - stage('Clang Release Python 3') { - cmake_build("hcc", "-DCMAKE_BUILD_TYPE=release -DPYTHON_EXECUTABLE=/usr/local/bin/python3") - } -}, gcc5: rocmnode('rocmtest') { cmake_build -> - stage('GCC 5 Debug') { - cmake_build("g++-5", "-DCMAKE_BUILD_TYPE=debug") + stage('Hip Clang Release') { + cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=release") + stash includes: 'build/*.deb', name: 'migraphx-package' } - stage('GCC 5 Release') { - cmake_build("g++-5", "-DCMAKE_BUILD_TYPE=release") +}, mlir_debug: rocmnode('vega') { cmake_build -> + stage('MLIR Debug') { + def sanitizers = "undefined" + def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" + cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'") } -}, gcc7: rocmnode('rocmtest') { cmake_build -> - stage('GCC 7 Debug') { - def linker_flags = '-fuse-ld=gold' - def cmake_linker_flags = "-DCMAKE_EXE_LINKER_FLAGS='${linker_flags}' -DCMAKE_SHARED_LINKER_FLAGS='${linker_flags}'" - // TODO: Add bounds-strict +}, clang_asan: rocmnode('nogpu') { cmake_build -> + stage('Clang ASAN') { def sanitizers = "undefined,address" - def debug_flags = "-g -fprofile-arcs -ftest-coverage -fno-omit-frame-pointer -fsanitize-address-use-after-scope -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" - cmake_build("g++-7", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off ${cmake_linker_flags} -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'") + def debug_flags = "-g -O2 -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" + cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=Off -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'") + } +}//, clang_release_navi: rocmnode('navi21') { cmake_build -> +// stage('HIP Clang Release Navi') { +// cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=release") +// } +//} +def onnxnode(name, body) { + return { label -> + rocmtestnode(variant: label, node: rocmnodename(name), docker_args: '-u root', body: body, pre: { + sh 'rm -rf ./build/*.deb' + unstash 'migraphx-package' + }) } - stage('Codecov') { - env.CODECOV_TOKEN="8545af1c-f90b-4345-92a5-0d075503ca56" +} + +rocmtest onnx: onnxnode('rocmtest') { cmake_build -> + stage("Onnx runtime") { sh ''' - cd build - lcov --directory . --capture --output-file coverage.info - lcov --remove coverage.info '/usr/*' --output-file coverage.info - lcov --list coverage.info - curl -s https://codecov.io/bash | bash - echo "Uploaded" + apt install half + ls -lR + dpkg -i ./build/*.deb + cd /onnxruntime && ./build_and_test_onnxrt.sh ''' } } diff --git a/README.md b/README.md old mode 100644 new mode 100755 index 8d5a3fa839744bbc3a65720d63432761df90ea68..560dea363665652dbc65f6f280d1f985d9e01e43 --- a/README.md +++ b/README.md @@ -1,82 +1,190 @@ # AMD MIGraphX -AMD's graph optimization engine. +AMD MIGraphX is AMD's graph inference engine that accelerates machine learning model inference. AMD MIGraphX can be used by +installing binaries directly or building from source code. + +In the following, instructions of how to build and install MIGraphX are described with Ubuntu as the OS +(Instructions of installation on other Linux OSes will come later). Note that all the following instructions assume +ROCm has been installed successfully. ROCm installation instructions are explained in the [ROCm installation +guide](https://rocmdocs.amd.com/en/latest/Installation_Guide/Installation-Guide.html). + +## Installing from binaries +With ROCm installed correctly, MIGraphX binaries can be installed on Ubuntu with the following command: +``` +sudo apt update && sudo apt install -y migraphx +``` +then the header files and libs are installed under `/opt/rocm-`, where `` is the ROCm version. + +## Building from source + +There are three ways to build the MIGraphX sources. +* [Use the ROCm build tool](#use-the-rocm-build-tool-rbuild) + + This approach uses [rbuild](https://github.com/RadeonOpenCompute/rbuild) to install the prerequisites and +build the libs with just one command. + +* [Use cmake](#use-cmake-to-build-migraphx) + + This approach uses a script to install the prerequisites, then use cmake to build the source. + +* [Use docker](#use-docker) + + This approach builds a docker image with all prerequisites installed, then build the MIGraphX sources inside a docker container. + +In the following, we will first list the prerequisites required to build MIGraphX source code, then describe +each of the three approaches. + +### List of prerequisites +The following is a list of prerequisites required to build MIGraphX source. -## Prerequisites * [ROCm cmake modules](https://github.com/RadeonOpenCompute/rocm-cmake) **required** * [MIOpen](https://github.com/ROCmSoftwarePlatform/MIOpen) for running on the GPU +* [rocBLAS](https://github.com/ROCmSoftwarePlatform/rocBLAS) for running on the GPU * [HIP](https://github.com/ROCm-Developer-Tools/HIP) for running on the GPU -* [Protobuf](https://github.com/google/protobuf) for reading [onxx](https://github.com/onnx/onnx) files +* [Protobuf](https://github.com/google/protobuf) for reading [onnx](https://github.com/onnx/onnx) files * [Half](http://half.sourceforge.net/) - IEEE 754-based half-precision floating point library * [pybind11](https://pybind11.readthedocs.io/en/stable/) - for python bindings +* [JSON](https://github.com/nlohmann/json) - for model serialization to json string format +* [MessagePack](https://msgpack.org/index.html) - for model serialization to binary format + +#### Use the ROCm build tool [rbuild](https://github.com/RadeonOpenCompute/rbuild). -## Installing the dependencies +In this approach, we use the [rbuild](https://github.com/RadeonOpenCompute/rbuild) build tool to +build MIGraphX. The specific steps are as follows: + +1) Install rocm-cmake, pip3, rocblas, and miopen-hip with the command + +``` +sudo apt update && sudo apt install -y rocm-cmake python3-pip rocblas miopen-hip +``` -Dependencies can be installed using the ROCm build tool [rbuild](https://github.com/RadeonOpenCompute/rbuild). +2) Install [rbuild](https://github.com/RadeonOpenCompute/rbuild) (sudo may be required here.) -To install rbuild: ``` -pip install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz +pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz ``` -To build dependencies along with MIGraphX +3) Build MIGraphX source code + ``` -rbuild build -d depend --cxx=/opt/rocm/bin/hcc +rbuild build -d depend -B build --cxx=/opt/rocm/llvm/bin/clang++ ``` -This builds dependencies in the subdirectory named depend and then builds MIGraphX using these dependencies. -## Building MIGraphX from source +then all the prerequisites are in the folder `depend`, and MIGraphX is built in the `build` directory. + +Note that for ROCm3.7 and later releases, Ubuntu 18.04 or later releases are needed. +Upgrade to Ubuntu 18.04 is available at [Upgrade Ubuntu to 18.04](https://github.com/ROCmSoftwarePlatform/AMDMIGraphX/wiki/Upgrade-to-Ubuntu-18.04-for-ROCM3.7-or-later-releases) + +Also note that you may meet the error of `rbuild: command not found`. It is because rbuild is installed +at `$HOME/.local/bin`, which is not in `PATH`. You can either export PATH as `export PATH=$HOME/.local/bin:$PATH` +to add the folder to `PATH` or add the option `--prefix /usr/local` in the pip3 command when installing rbuild. -## Configuring with cmake +#### Use cmake to build MIGraphX -First create a build directory: +If using this approach, we need to install the prerequisites, configure the cmake, and then build the source. +##### Installing the prerequisites + +For convenience, the prerequisites can be built automatically with rbuild as: ``` -mkdir build; -cd build; +rbuild build -d depend --cxx=/opt/rocm/llvm/bin/clang++ ``` -Next configure cmake. The hcc compiler is required to build the MIOpen backend: +then all the prerequisites are in the folder `depend`, and they can be used in the `cmake` configuration +as `-DCMAKE_PREFIX_PATH=depend`. + +If you have sudo access, as an alternative to the rbuild command, you can install the prerequisites just +like in the dockerfile by calling `./tools/install_prereqs.sh`. + +(Note that this script is for Ubuntu. By default, all prerequisites are installed at the default location `/usr/local` +and are accessible by all users. For the default location, `sudo` is required to run the script. +You can also specify a location at which the prerequisites are installed with `./tools/install_prereqs.sh $your_loc`.) + +##### Building MIGraphX source and install libs + +With the above prerequisites installed, we can build source as: + +1) Go to the project folder and create a `build` directory: ``` -CXX=/opt/rocm/bin/hcc cmake .. +mkdir build +cd build ``` -If the dependencies from `install_deps.cmake` was installed to another directory, the `CMAKE_PREFIX_PATH` needs to be set to what `--prefix` was set to from `install_deps.cmake`: +2) Configure the cmake. If the prerequisites are installed at the default location `/usr/local`, the command is: + +``` +CXX=/opt/rocm/llvm/bin/clang++ cmake .. +``` +Otherwise, you need to set `-DCMAKE_PREFIX_PATH=$your_loc` to configure the cmake. +3) Build MIGraphX source code ``` -CXX=/opt/rocm/bin/hcc cmake -DCMAKE_PREFIX_PATH=/some/dir .. +make -j$(nproc) ``` +Correctness can be verified as: -#### Changing the cmake configuration +``` +make -j$(nproc) check +``` -The configuration can be changed after running cmake by using `ccmake`: +MIGraphX libs can be installed as: -` ccmake .. ` **OR** `cmake-gui`: ` cmake-gui ..` +``` +make install +``` -## Building the library +#### Use docker -The library can be built, from the `build` directory using the 'Release' configuration: +The easiest way to setup the development environment is to use docker. With the dockerfile, you can build a docker image as: -` cmake --build . --config Release ` **OR** ` make ` + docker build -t migraphx . -And can be installed by using the 'install' target: +Then to enter the developement environment use `docker run`: -` cmake --build . --config Release --target install ` **OR** ` make install ` + docker run --device='/dev/kfd' --device='/dev/dri' -v=`pwd`:/code/AMDMIGraphX -w /code/AMDMIGraphX --group-add video -it migraphx -This will install the library to the `CMAKE_INSTALL_PREFIX` path that was set. +In the docker container, all the required prerequisites are already installed, so users can just go to the folder +`/code/AMDMIGraphX` and follow the steps in the above [Build MIGraphX source and install +libs](#building-migraphx-source-and-install-libs) +section to build MIGraphX source. -## Running the tests +### Using MIGraphX Python Module +To use MIGraphX's Python module, please either set `PYTHONPATH` or use `.deb` package as explained below: -The tests can be run by using the 'check' target: +- Setting `PYTHONPATH` : +``` +export PYTHONPATH=/opt/rocm/lib:$PYTHONPATH +``` +- Creating and installing the package: + +To create deb package: +``` +make package +``` +This will provide the path of .deb package. -` cmake --build . --config Release --target check ` **OR** ` make check ` +To install: +``` +dpkg -i +``` -## Building the documentation + +### Calling MIGraphX APIs +To use MIGraphX's C/C++ API in your cmake project, we need to set `CMAKE_PREFIX_PATH` to the MIGraphX +installation location and then do +``` +find_package(migraphx) +target_link_libraries(myApp migraphx::c) +``` +Where `myApp` is the cmake target in your project. + + +### Building the documentation HTML and PDF documentation can be built using: @@ -97,7 +205,7 @@ Depending on your setup `sudo` may be required for the pip install. All the code is formatted using clang-format. To format a file, use: ``` -clang-format-5.0 -style=file -i +clang-format-10 -style=file -i ``` Also, githooks can be installed to format the code per-commit: @@ -105,13 +213,3 @@ Also, githooks can be installed to format the code per-commit: ``` ./.githooks/install ``` - -## Using docker - -The easiest way to setup the development environment is to use docker. You can build the top-level docker file: - - docker build -t migraphx . - -Then to enter the developement environment use `docker run`: - - docker run --device='/dev/kfd' --device='/dev/dri' -v=`pwd`:/data -w /data --group-add video -it migraphx diff --git a/cmake/CheckCXXLinkerFlag.cmake b/cmake/CheckCXXLinkerFlag.cmake new file mode 100644 index 0000000000000000000000000000000000000000..9efdcb8cd0f5f6bd2e5ede6dd538ad94ddefe29a --- /dev/null +++ b/cmake/CheckCXXLinkerFlag.cmake @@ -0,0 +1,34 @@ + +set(check_cxx_linker_flag_patterns + FAIL_REGEX "[Uu]nrecogni[sz]ed .*option" # GNU, NAG + FAIL_REGEX "switch .* is no longer supported" # GNU + FAIL_REGEX "unknown .*option" # Clang + FAIL_REGEX "optimization flag .* not supported" # Clang + FAIL_REGEX "unknown argument ignored" # Clang (cl) + FAIL_REGEX "ignoring unknown option" # MSVC, Intel + FAIL_REGEX "warning D9002" # MSVC, any lang + FAIL_REGEX "option.*not supported" # Intel + FAIL_REGEX "invalid argument .*option" # Intel + FAIL_REGEX "ignoring option .*argument required" # Intel + FAIL_REGEX "ignoring option .*argument is of wrong type" # Intel + FAIL_REGEX "[Uu]nknown option" # HP + FAIL_REGEX "[Ww]arning: [Oo]ption" # SunPro + FAIL_REGEX "command option .* is not recognized" # XL + FAIL_REGEX "command option .* contains an incorrect subargument" # XL + FAIL_REGEX "Option .* is not recognized. Option will be ignored." # XL + FAIL_REGEX "not supported in this configuration. ignored" # AIX + FAIL_REGEX "File with unknown suffix passed to linker" # PGI + FAIL_REGEX "[Uu]nknown switch" # PGI + FAIL_REGEX "WARNING: unknown flag:" # Open64 + FAIL_REGEX "Incorrect command line option:" # Borland + FAIL_REGEX "Warning: illegal option" # SunStudio 12 + FAIL_REGEX "[Ww]arning: Invalid suboption" # Fujitsu + FAIL_REGEX "An invalid option .* appears on the command line" # Cray + ) + +function(check_cxx_linker_flag _flag _var) + set (_source "int main() { return 0; }") + include (CheckCXXSourceCompiles) + check_cxx_source_compiles("${_source}" _result ${check_cxx_linker_flag_patterns}) + set(${_var} "${_result}" PARENT_SCOPE) +endfunction() diff --git a/cmake/DoxygenDoc.cmake b/cmake/DoxygenDoc.cmake deleted file mode 100644 index 86bae05cb821b95133a62335ba4bec285efbb5f8..0000000000000000000000000000000000000000 --- a/cmake/DoxygenDoc.cmake +++ /dev/null @@ -1,359 +0,0 @@ -################################################################################ -# -# MIT License -# -# Copyright (c) 2017 Advanced Micro Devices, Inc. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# -################################################################################ -include(CMakeParseArguments) -include(MainDoc) - -find_program(DOXYGEN_EXECUTABLE NAMES doxygen - PATH_SUFFIXES bin - DOC "Doxygen documentation generator" -) -mark_as_advanced(DOXYGEN_EXECUTABLE) - -find_path(DOT_EXECUTABLE NAMES dot - PATH_SUFFIXES bin - DOC "Graphviz" -) -mark_as_advanced(DOT_EXECUTABLE) - -set(DOXYGEN_ARGS -ABBREVIATE_BRIEF -ALIASES -ALLEXTERNALS -ALLOW_UNICODE_NAMES -ALPHABETICAL_INDEX -ALWAYS_DETAILED_SEC -AUTOLINK_SUPPORT -BINARY_TOC -BRIEF_MEMBER_DESC -BUILTIN_STL_SUPPORT -CALLER_GRAPH -CALL_GRAPH -CASE_SENSE_NAMES -CHM_FILE -CHM_INDEX_ENCODING -CITE_BIB_FILES -CLANG_ASSISTED_PARSING -CLANG_OPTIONS -CLASS_DIAGRAMS -CLASS_GRAPH -COLLABORATION_GRAPH -COLS_IN_ALPHA_INDEX -COMPACT_LATEX -COMPACT_RTF -CPP_CLI_SUPPORT -CREATE_SUBDIRS -DIAFILE_DIRS -DIA_PATH -DIRECTORY_GRAPH -DISABLE_INDEX -DISTRIBUTE_GROUP_DOC -DOCBOOK_OUTPUT -DOCBOOK_PROGRAMLISTING -DOCSET_BUNDLE_ID -DOCSET_FEEDNAME -DOCSET_PUBLISHER_ID -DOCSET_PUBLISHER_NAME -DOTFILE_DIRS -DOT_CLEANUP -DOT_FONTNAME -DOT_FONTPATH -DOT_FONTSIZE -DOT_GRAPH_MAX_NODES -DOT_IMAGE_FORMAT -DOT_MULTI_TARGETS -DOT_NUM_THREADS -# DOT_PATH -DOT_TRANSPARENT -DOXYFILE_ENCODING -ECLIPSE_DOC_ID -ENABLED_SECTIONS -ENABLE_PREPROCESSING -ENUM_VALUES_PER_LINE -EXAMPLE_PATH -EXAMPLE_PATTERNS -EXAMPLE_RECURSIVE -EXCLUDE -EXCLUDE_PATTERNS -EXCLUDE_SYMBOLS -EXCLUDE_SYMLINKS -EXPAND_AS_DEFINED -EXPAND_ONLY_PREDEF -EXTENSION_MAPPING -EXTERNAL_GROUPS -EXTERNAL_PAGES -EXTERNAL_SEARCH -EXTERNAL_SEARCH_ID -EXTRACT_ALL -EXTRACT_ANON_NSPACES -EXTRACT_LOCAL_CLASSES -EXTRACT_LOCAL_METHODS -EXTRACT_PACKAGE -EXTRACT_PRIVATE -EXTRACT_STATIC -EXTRA_PACKAGES -EXTRA_SEARCH_MAPPINGS -EXT_LINKS_IN_WINDOW -FILE_PATTERNS -FILE_VERSION_FILTER -FILTER_PATTERNS -FILTER_SOURCE_FILES -FILTER_SOURCE_PATTERNS -FORCE_LOCAL_INCLUDES -FORMULA_FONTSIZE -FORMULA_TRANSPARENT -FULL_PATH_NAMES -GENERATE_AUTOGEN_DEF -GENERATE_BUGLIST -GENERATE_CHI -GENERATE_DEPRECATEDLIST -GENERATE_DOCBOOK -GENERATE_DOCSET -GENERATE_ECLIPSEHELP -GENERATE_HTML -GENERATE_HTMLHELP -GENERATE_LATEX -GENERATE_LEGEND -GENERATE_MAN -GENERATE_PERLMOD -GENERATE_QHP -GENERATE_RTF -GENERATE_TAGFILE -GENERATE_TESTLIST -GENERATE_TODOLIST -GENERATE_TREEVIEW -GENERATE_XML -GRAPHICAL_HIERARCHY -GROUP_GRAPHS -GROUP_NESTED_COMPOUNDS -# HAVE_DOT -HHC_LOCATION -HIDE_COMPOUND_REFERENCE -HIDE_FRIEND_COMPOUNDS -HIDE_IN_BODY_DOCS -HIDE_SCOPE_NAMES -HIDE_UNDOC_CLASSES -HIDE_UNDOC_MEMBERS -HIDE_UNDOC_RELATIONS -HTML_COLORSTYLE_GAMMA -HTML_COLORSTYLE_HUE -HTML_COLORSTYLE_SAT -HTML_DYNAMIC_SECTIONS -HTML_EXTRA_FILES -HTML_EXTRA_STYLESHEET -HTML_FILE_EXTENSION -HTML_FOOTER -HTML_HEADER -HTML_INDEX_NUM_ENTRIES -HTML_OUTPUT -HTML_STYLESHEET -HTML_TIMESTAMP -IDL_PROPERTY_SUPPORT -IGNORE_PREFIX -IMAGE_PATH -INCLUDED_BY_GRAPH -INCLUDE_FILE_PATTERNS -INCLUDE_GRAPH -INCLUDE_PATH -INHERIT_DOCS -INLINE_GROUPED_CLASSES -INLINE_INFO -INLINE_INHERITED_MEMB -INLINE_SIMPLE_STRUCTS -INLINE_SOURCES -INPUT -INPUT_ENCODING -INPUT_FILTER -INTERACTIVE_SVG -INTERNAL_DOCS -JAVADOC_AUTOBRIEF -LATEX_BATCHMODE -LATEX_BIB_STYLE -LATEX_CMD_NAME -LATEX_EXTRA_FILES -LATEX_EXTRA_STYLESHEET -LATEX_FOOTER -LATEX_HEADER -LATEX_HIDE_INDICES -LATEX_OUTPUT -LATEX_SOURCE_CODE -LATEX_TIMESTAMP -LAYOUT_FILE -LOOKUP_CACHE_SIZE -MACRO_EXPANSION -MAKEINDEX_CMD_NAME -MAN_EXTENSION -MAN_LINKS -MAN_OUTPUT -MAN_SUBDIR -MARKDOWN_SUPPORT -MATHJAX_CODEFILE -MATHJAX_EXTENSIONS -MATHJAX_FORMAT -MATHJAX_RELPATH -MAX_DOT_GRAPH_DEPTH -MAX_INITIALIZER_LINES -MSCFILE_DIRS -MSCGEN_PATH -MULTILINE_CPP_IS_BRIEF -OPTIMIZE_FOR_FORTRAN -OPTIMIZE_OUTPUT_FOR_C -OPTIMIZE_OUTPUT_JAVA -OPTIMIZE_OUTPUT_VHDL -OUTPUT_DIRECTORY -OUTPUT_LANGUAGE -PAPER_TYPE -PDF_HYPERLINKS -PERLMOD_LATEX -PERLMOD_MAKEVAR_PREFIX -PERLMOD_PRETTY -PERL_PATH -PLANTUML_CFG_FILE -PLANTUML_INCLUDE_PATH -PLANTUML_JAR_PATH -PREDEFINED -PROJECT_BRIEF -PROJECT_LOGO -PROJECT_NAME -PROJECT_NUMBER -QCH_FILE -QHG_LOCATION -QHP_CUST_FILTER_ATTRS -QHP_CUST_FILTER_NAME -QHP_NAMESPACE -QHP_SECT_FILTER_ATTRS -QHP_VIRTUAL_FOLDER -QT_AUTOBRIEF -QUIET -RECURSIVE -REFERENCED_BY_RELATION -REFERENCES_LINK_SOURCE -REFERENCES_RELATION -REPEAT_BRIEF -RTF_EXTENSIONS_FILE -RTF_HYPERLINKS -RTF_OUTPUT -RTF_SOURCE_CODE -RTF_STYLESHEET_FILE -SEARCHDATA_FILE -SEARCHENGINE -SEARCHENGINE_URL -SEARCH_INCLUDES -SEPARATE_MEMBER_PAGES -SERVER_BASED_SEARCH -SHORT_NAMES -SHOW_FILES -SHOW_GROUPED_MEMB_INC -SHOW_INCLUDE_FILES -SHOW_NAMESPACES -SHOW_USED_FILES -SIP_SUPPORT -SKIP_FUNCTION_MACROS -SORT_BRIEF_DOCS -SORT_BY_SCOPE_NAME -SORT_GROUP_NAMES -SORT_MEMBERS_CTORS_1ST -SORT_MEMBER_DOCS -SOURCE_BROWSER -SOURCE_TOOLTIPS -STRICT_PROTO_MATCHING -STRIP_CODE_COMMENTS -STRIP_FROM_INC_PATH -STRIP_FROM_PATH -SUBGROUPING -TAB_SIZE -TAGFILES -TCL_SUBST -TEMPLATE_RELATIONS -TOC_EXPAND -TOC_INCLUDE_HEADINGS -TREEVIEW_WIDTH -TYPEDEF_HIDES_STRUCT -UML_LIMIT_NUM_FIELDS -UML_LOOK -USE_HTAGS -USE_MATHJAX -USE_MDFILE_AS_MAINPAGE -USE_PDFLATEX -VERBATIM_HEADERS -WARNINGS -WARN_AS_ERROR -WARN_FORMAT -WARN_IF_DOC_ERROR -WARN_IF_UNDOCUMENTED -WARN_LOGFILE -WARN_NO_PARAMDOC -XML_OUTPUT -XML_PROGRAMLISTING -) - -set(DOXYGEN_CONFIG_FILE "${CMAKE_CURRENT_BINARY_DIR}/doxygen/doxygen.conf" CACHE PATH "Path to generated doxygen configuration file") - -function(add_doxygen_doc) - set(options) - set(oneValueArgs) - set(multiValueArgs DEPENDS ${DOXYGEN_ARGS}) - - cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - - file(WRITE ${DOXYGEN_CONFIG_FILE} "# Auto-generated doxygen configuration file\n") - - if(NOT PARSE_STRIP_FROM_PATH) - set(PARSE_STRIP_FROM_PATH ${CMAKE_SOURCE_DIR}) - endif() - - foreach(ARG ${DOXYGEN_ARGS}) - if(PARSE_${ARG}) - string(REPLACE ";" " " ARG_VALUE "${PARSE_${ARG}}") - file(APPEND ${DOXYGEN_CONFIG_FILE} "\n${ARG} = ${ARG_VALUE}\n") - endif() - endforeach() - - if(PARSE_OUTPUT_DIRECTORY) - if(NOT EXISTS ${PARSE_OUTPUT_DIRECTORY}) - file(MAKE_DIRECTORY ${PARSE_OUTPUT_DIRECTORY}) - endif() - endif() - - if(DOT_EXECUTABLE) - file(APPEND ${DOXYGEN_CONFIG_FILE} "\nDOT_PATH = \"${DOT_EXECUTABLE}\"\n") - file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = YES\n") - else() - file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = NO\n") - endif() - - add_custom_target(doxygen - ${DOXYGEN_EXECUTABLE} ${DOXYGEN_CONFIG_FILE} - WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} - COMMENT "Building documentation with doxygen" - ) - if(PARSE_OUTPUT_DIRECTORY) - clean_doc_output(${PARSE_OUTPUT_DIRECTORY}) - endif() - mark_as_doc(doxygen) - if(PARSE_DEPENDS) - add_dependencies(doxygen ${PARSE_DEPENDS}) - endif() -endfunction() diff --git a/cmake/Embed.cmake b/cmake/Embed.cmake new file mode 100755 index 0000000000000000000000000000000000000000..f3b133dcbd81a5a8061ad56ab86817a86c2d339f --- /dev/null +++ b/cmake/Embed.cmake @@ -0,0 +1,101 @@ +find_program(EMBED_LD ld) +find_program(EMBED_OBJCOPY objcopy) + +function(generate_embed_source EMBED_NAME) + set(options) + set(oneValueArgs SRC HEADER) + set(multiValueArgs OBJECTS SYMBOLS) + + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(EXTERNS) + set(INIT_KERNELS) + + list(LENGTH PARSE_SYMBOLS SYMBOLS_LEN) + list(LENGTH PARSE_OBJECTS OBJECTS_LEN) + if(NOT ${SYMBOLS_LEN} EQUAL ${OBJECTS_LEN}) + message(FATAL_ERROR "Symbols and objects dont match: ${SYMBOLS_LEN} != ${OBJECTS_LEN}") + endif() + math(EXPR LEN "${SYMBOLS_LEN} - 1") + + foreach(idx RANGE ${LEN}) + list(GET PARSE_SYMBOLS ${idx} SYMBOL) + list(GET PARSE_OBJECTS ${idx} OBJECT) + set(START_SYMBOL "_binary_${SYMBOL}_start") + set(END_SYMBOL "_binary_${SYMBOL}_end") + string(APPEND EXTERNS " + extern const char ${START_SYMBOL}[]; + extern const char ${END_SYMBOL}[]; + ") + + # TODO: Should use NAME_WLE + get_filename_component(BASE_NAME "${OBJECT}" NAME) + string(REGEX REPLACE ".[A-Za-z0-9_]$" "" BASE_NAME ${BASE_NAME}) + + string(APPEND INIT_KERNELS " + { \"${BASE_NAME}\", { ${START_SYMBOL}, ${END_SYMBOL}} }, + ") + endforeach() + + file(WRITE "${PARSE_HEADER}" " +#include +#include +#include +const std::unordered_map>& ${EMBED_NAME}(); +") + + file(WRITE "${PARSE_SRC}" " +#include <${EMBED_NAME}.hpp> +${EXTERNS} +const std::unordered_map>& ${EMBED_NAME}() +{ + static const std::unordered_map> result = {${INIT_KERNELS}}; + return result; +} +") +endfunction() + +function(embed_file OUTPUT_FILE OUTPUT_SYMBOL FILE) + set(WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + # Glob is used to compute the relative path + file(GLOB FILES RELATIVE ${WORKING_DIRECTORY} ${FILE}) + foreach(REL_FILE ${FILES}) + string(MAKE_C_IDENTIFIER "${REL_FILE}" SYMBOL) + get_filename_component(OUTPUT_FILE_DIR "${REL_FILE}" DIRECTORY) + file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE_DIR}") + set(OUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.o") + set(${OUTPUT_SYMBOL} ${SYMBOL} PARENT_SCOPE) + set(${OUTPUT_FILE} "${OUT_FILE}" PARENT_SCOPE) + add_custom_command( + OUTPUT "${OUT_FILE}" + COMMAND ${EMBED_LD} -r -o "${OUT_FILE}" -z noexecstack --format=binary "${REL_FILE}" + COMMAND ${EMBED_OBJCOPY} --rename-section .data=.rodata,alloc,load,readonly,data,contents "${OUT_FILE}" + WORKING_DIRECTORY ${WORKING_DIRECTORY} + DEPENDS ${FILE} + VERBATIM + ) + endforeach() +endfunction() + +function(add_embed_library EMBED_NAME) + file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/embed) + file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/embed/${EMBED_NAME}) + set(EMBED_DIR ${CMAKE_CURRENT_BINARY_DIR}/embed/${EMBED_NAME}) + set(SRC_FILE "${EMBED_DIR}/${EMBED_NAME}.cpp") + set(HEADER_FILE "${EMBED_DIR}/include/${EMBED_NAME}.hpp") + set(WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + set(OUTPUT_FILES) + set(SYMBOLS) + message(STATUS "Embedding files") + foreach(FILE ${ARGN}) + embed_file(OUTPUT_FILE OUTPUT_SYMBOL ${FILE}) + list(APPEND OUTPUT_FILES ${OUTPUT_FILE}) + list(APPEND SYMBOLS ${OUTPUT_SYMBOL}) + endforeach() + message(STATUS "Generating embedding library ${EMBED_NAME}") + generate_embed_source(${EMBED_NAME} SRC ${SRC_FILE} HEADER ${HEADER_FILE} OBJECTS ${OUTPUT_FILES} SYMBOLS ${SYMBOLS}) + add_library(${EMBED_NAME} STATIC ${OUTPUT_FILES} "${SRC_FILE}") + target_include_directories(${EMBED_NAME} PUBLIC "${EMBED_DIR}/include") + target_compile_options(${EMBED_NAME} PRIVATE -Wno-reserved-identifier) + set_target_properties(${EMBED_NAME} PROPERTIES POSITION_INDEPENDENT_CODE On) +endfunction() diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake old mode 100644 new mode 100755 index 4d43cae0a33a41a8aea8fa6cc1c3def5e23c13a4..61c0145ac3cb8183d8a6b5ec08c41fc1c51af14f --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -96,18 +96,24 @@ else() -Wno-gnu-zero-variadic-macro-arguments -Wno-missing-prototypes -Wno-nested-anon-types + -Wno-option-ignored -Wno-padded -Wno-shorten-64-to-32 -Wno-sign-conversion -Wno-unused-command-line-argument -Wno-weak-vtables + -Wno-c99-extensions + # -Wno-c++2a-designator ) else() list(APPEND CMAKE_COMPILER_WARNINGS -Wno-missing-field-initializers + -Wno-maybe-uninitialized # -Wno-deprecated-declarations ) endif() - add_definitions(${CMAKE_COMPILER_WARNINGS}) + foreach(COMPILER_WARNING ${CMAKE_COMPILER_WARNINGS}) + add_compile_options($<$:${COMPILER_WARNING}>) + endforeach() endforeach() endif () diff --git a/cmake/MainDoc.cmake b/cmake/MainDoc.cmake deleted file mode 100644 index 8a6cbe98b2b41c05cf6dcd21470835b1d042704d..0000000000000000000000000000000000000000 --- a/cmake/MainDoc.cmake +++ /dev/null @@ -1,37 +0,0 @@ -################################################################################ -# -# MIT License -# -# Copyright (c) 2017 Advanced Micro Devices, Inc. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# -################################################################################ - -if(NOT TARGET doc) - add_custom_target(doc) -endif() - -function(mark_as_doc) - add_dependencies(doc ${ARGN}) -endfunction() - -function(clean_doc_output DIR) - set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${DIR}) -endfunction() diff --git a/cmake/PythonModules.cmake b/cmake/PythonModules.cmake new file mode 100755 index 0000000000000000000000000000000000000000..2a64801e3890430002858dc9d0de80e915d74470 --- /dev/null +++ b/cmake/PythonModules.cmake @@ -0,0 +1,74 @@ +if(COMMAND find_python) + return() +endif() + +macro(py_exec) + execute_process(${ARGN} RESULT_VARIABLE RESULT) + if(NOT RESULT EQUAL 0) + message(FATAL_ERROR "Process failed: ${ARGN}") + endif() +endmacro() + +set(PYBIND11_NOPYTHON On) +find_package(pybind11 REQUIRED) +macro(find_python version) + find_program(PYTHON_CONFIG_${version} python${version}-config) + if(EXISTS ${PYTHON_CONFIG_${version}}) + py_exec(COMMAND ${PYTHON_CONFIG_${version}} --includes OUTPUT_VARIABLE _python_include_args) + separate_arguments(_python_includes UNIX_COMMAND "${_python_include_args}") + string(REPLACE "-I" "" _python_includes "${_python_includes}") + add_library(python${version}::headers INTERFACE IMPORTED GLOBAL) + set_target_properties(python${version}::headers PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${_python_includes}" + ) + py_exec(COMMAND ${PYTHON_CONFIG_${version}} --prefix OUTPUT_VARIABLE _python_prefix) + string(STRIP "${_python_prefix}" _python_prefix) + set(PYTHON_${version}_EXECUTABLE "${_python_prefix}/bin/python${version}" CACHE PATH "") + endif() +endmacro() +function(py_extension name version) + set(_python_module_extension ".so") + if(version VERSION_GREATER_EQUAL 3.0) + py_exec(COMMAND ${PYTHON_CONFIG_${version}} --extension-suffix OUTPUT_VARIABLE _python_module_extension) + string(STRIP "${_python_module_extension}" _python_module_extension) + endif() + set_target_properties(${name} PROPERTIES PREFIX "" SUFFIX "${_python_module_extension}") +endfunction() +function(py_add_module NAME) + set(options) + set(oneValueArgs PYTHON_VERSION PYTHON_MODULE) + set(multiValueArgs) + + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + set(PYTHON_VERSION ${PARSE_PYTHON_VERSION}) + + add_library(${NAME} MODULE ${PARSE_UNPARSED_ARGUMENTS}) + pybind11_strip(${NAME}) + py_extension(${NAME} ${PYTHON_VERSION}) + target_link_libraries(${NAME} PRIVATE pybind11::module pybind11::lto python${PYTHON_VERSION}::headers) + set_target_properties(${NAME} PROPERTIES + OUTPUT_NAME ${PARSE_PYTHON_MODULE} + C_VISIBILITY_PRESET hidden + CXX_VISIBILITY_PRESET hidden + ) + +endfunction() +set(PYTHON_SEARCH_VERSIONS 2.7 3.5 3.6 3.7 3.8 3.9) +set(PYTHON_DISABLE_VERSIONS "" CACHE STRING "") +foreach(PYTHON_DISABLE_VERSION ${PYTHON_DISABLE_VERSIONS}) + list(REMOVE_ITEM PYTHON_SEARCH_VERSIONS ${PYTHON_DISABLE_VERSION}) +endforeach() + +set(_PYTHON_VERSIONS) +foreach(PYTHON_VERSION ${PYTHON_SEARCH_VERSIONS}) + find_python(${PYTHON_VERSION}) + if(TARGET python${PYTHON_VERSION}::headers) + message(STATUS "Python ${PYTHON_VERSION} found.") + list(APPEND _PYTHON_VERSIONS ${PYTHON_VERSION}) + else() + message(STATUS "Python ${PYTHON_VERSION} not found.") + endif() +endforeach() + +# Make the variable global +set(PYTHON_VERSIONS "${_PYTHON_VERSIONS}" CACHE INTERNAL "" FORCE) diff --git a/cmake/RegisterOp.cmake b/cmake/RegisterOp.cmake new file mode 100644 index 0000000000000000000000000000000000000000..1619b3c87885384adcc3dd831925ea47bec00cea --- /dev/null +++ b/cmake/RegisterOp.cmake @@ -0,0 +1,38 @@ + +function(register_op TARGET_NAME) + set(options) + set(oneValueArgs HEADER) + set(multiValueArgs OPERATORS INCLUDES) + + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + string(MAKE_C_IDENTIFIER "${PARSE_HEADER}" BASE_NAME) + file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ops) + set(FILE_NAME ${CMAKE_CURRENT_BINARY_DIR}/ops/${BASE_NAME}.cpp) + file(WRITE "${FILE_NAME}" "") + foreach(INCLUDE ${PARSE_INCLUDES}) + file(APPEND "${FILE_NAME}" " +#include <${INCLUDE}> +") + endforeach() + file(APPEND "${FILE_NAME}" " +#include +#include <${PARSE_HEADER}> +") + + + file(APPEND "${FILE_NAME}" " +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +") + foreach(OPERATOR ${PARSE_OPERATORS}) + file(APPEND "${FILE_NAME}" " +MIGRAPHX_REGISTER_OP(${OPERATOR}) +") + endforeach() + file(APPEND "${FILE_NAME}" " +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +") + target_sources(${TARGET_NAME} PRIVATE ${FILE_NAME}) +endfunction() diff --git a/cmake/SphinxDoc.cmake b/cmake/SphinxDoc.cmake deleted file mode 100644 index 11aa533b48545985cb79163d56d14dbd71992637..0000000000000000000000000000000000000000 --- a/cmake/SphinxDoc.cmake +++ /dev/null @@ -1,90 +0,0 @@ -################################################################################ -# -# MIT License -# -# Copyright (c) 2017 Advanced Micro Devices, Inc. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# -################################################################################ -include(CMakeParseArguments) -include(MainDoc) -include(DoxygenDoc) - -find_program(SPHINX_EXECUTABLE NAMES sphinx-build - HINTS - $ENV{SPHINX_DIR} - PATH_SUFFIXES bin - DOC "Sphinx documentation generator" -) - -mark_as_advanced(SPHINX_EXECUTABLE) - -set(BINARY_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/sphinx/_build") - -# Sphinx cache with pickled ReST documents -set(SPHINX_CACHE_DIR "${CMAKE_CURRENT_BINARY_DIR}/sphinx/_doctrees") - -# HTML output directory -set(SPHINX_DEFAULT_HTML_DIR "${CMAKE_CURRENT_BINARY_DIR}/sphinx/html") -function(add_sphinx_doc SRC_DIR) - set(options) - set(oneValueArgs BUILDER OUTPUT_DIR) - set(multiValueArgs DEPENDS VARS TEMPLATE_VARS) - - cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - - string(TOUPPER ${PARSE_BUILDER} BUILDER) - - set(ADDITIONAL_ARGS) - foreach(VAR ${PARSE_VARS}) - list(APPEND ADDITIONAL_ARGS -D ${VAR}) - endforeach() - foreach(VAR ${PARSE_TEMPLATE_VARS}) - list(APPEND ADDITIONAL_ARGS -A ${VAR}) - endforeach() - - if(PARSE_OUTPUT_DIR) - get_filename_component(OUTPUT_DIR ${PARSE_OUTPUT_DIR} ABSOLUTE) - set(SPHINX_${BUILDER}_DIR ${OUTPUT_DIR} CACHE PATH "Path to ${PARSE_BUILDER} output") - else() - set(SPHINX_${BUILDER}_DIR "${CMAKE_CURRENT_BINARY_DIR}/sphinx/${PARSE_BUILDER}" CACHE PATH "Path to ${PARSE_BUILDER} output") - endif() - - add_custom_target(sphinx-${BUILDER} - ${SPHINX_EXECUTABLE} - -b ${PARSE_BUILDER} - -d "${SPHINX_CACHE_DIR}" - ${ADDITIONAL_ARGS} - "${SRC_DIR}" - "${SPHINX_${BUILDER}_DIR}" - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - COMMENT "Building ${PARSE_BUILDER} documentation with Sphinx" - ) - clean_doc_output(${SPHINX_${BUILDER}_DIR}) - clean_doc_output(${SPHINX_CACHE_DIR}) - clean_doc_output(${BINARY_BUILD_DIR}) - mark_as_doc(sphinx-${BUILDER}) - if(PARSE_DEPENDS) - add_dependencies(sphinx-${BUILDER} ${PARSE_DEPENDS}) - endif() - -endfunction() - - diff --git a/cmake/TargetFlags.cmake b/cmake/TargetFlags.cmake new file mode 100644 index 0000000000000000000000000000000000000000..ede876fb55f8d31e786445f7cda3cefb46f2c582 --- /dev/null +++ b/cmake/TargetFlags.cmake @@ -0,0 +1,85 @@ + +function(eval_and_strip_genex OUTPUT_VAR INPUT) + string(REPLACE "$" "1" INPUT "${INPUT}") + string(REPLACE "$" "1" INPUT "${INPUT}") + string(REPLACE "SHELL:" "" INPUT "${INPUT}") + string(REPLACE "$" "0" INPUT "${INPUT}") + string(REGEX REPLACE "\\$" "0" INPUT "${INPUT}") + string(REGEX REPLACE "\\$]*-NOTFOUND>" "0" INPUT "${INPUT}") + string(REGEX REPLACE "\\$]*>" "1" INPUT "${INPUT}") + string(REPLACE "$" "1" INPUT "${INPUT}") + string(REPLACE "$" "0" INPUT "${INPUT}") + string(REGEX REPLACE "\\$<0:[^<>]*>" "" INPUT "${INPUT}") + string(REGEX REPLACE "\\$<1:([^<>]*)>" "\\1" INPUT "${INPUT}") + string(GENEX_STRIP "${INPUT}" INPUT) + set(${OUTPUT_VAR} "${INPUT}" PARENT_SCOPE) +endfunction() + +function(get_target_property2 VAR TARGET PROPERTY) + get_target_property(_pflags ${TARGET} ${PROPERTY}) + if(_pflags) + eval_and_strip_genex(_pflags "${_pflags}") + set(${VAR} ${_pflags} PARENT_SCOPE) + else() + set(${VAR} "" PARENT_SCOPE) + endif() +endfunction() + +function(flags_requires_arg OUTPUT_VAR FLAG) + set(_args -x -isystem) + if(FLAG IN_LIST _args) + set(${OUTPUT_VAR} 1 PARENT_SCOPE) + else() + set(${OUTPUT_VAR} 0 PARENT_SCOPE) + endif() +endfunction() + +macro(append_flags FLAGS TARGET PROPERTY PREFIX) + get_target_property2(_pflags ${TARGET} ${PROPERTY}) + set(_requires_arg 0) + foreach(FLAG ${_pflags}) + string(STRIP "${FLAG}" FLAG) + if(FLAG) + if(TARGET ${FLAG} AND NOT _requires_arg) + target_flags(_pflags2 ${FLAG}) + string(APPEND ${FLAGS} " ${_pflags2}") + else() + string(APPEND ${FLAGS} " ${PREFIX}${FLAG}") + endif() + flags_requires_arg(_requires_arg "${FLAG}") + endif() + endforeach() +endmacro() + +macro(append_link_flags FLAGS TARGET PROPERTY) + get_target_property2(_pflags ${TARGET} ${PROPERTY}) + set(_requires_arg 0) + foreach(FLAG ${_pflags}) + string(STRIP "${FLAG}" FLAG) + if(FLAG) + if(TARGET ${FLAG} AND NOT _requires_arg) + target_flags(_pflags2 ${FLAG}) + string(APPEND ${FLAGS} " ${_pflags2}") + elseif(FLAG MATCHES "^-.*") + string(APPEND ${FLAGS} " ${FLAG}") + elseif(EXISTS ${FLAG}) + string(APPEND ${FLAGS} " ${FLAG}") + else() + string(APPEND ${FLAGS} " -l${FLAG}") + endif() + flags_requires_arg(_requires_arg "${FLAG}") + endif() + endforeach() +endmacro() + +function(target_flags FLAGS TARGET) + set(_flags) + append_flags(_flags ${TARGET} "INTERFACE_COMPILE_OPTIONS" "") + append_flags(_flags ${TARGET} "INTERFACE_COMPILE_DEFINITIONS" "-D") + append_flags(_flags ${TARGET} "INTERFACE_INCLUDE_DIRECTORIES" "-isystem ") + append_flags(_flags ${TARGET} "INTERFACE_LINK_DIRECTORIES" "-L ") + append_flags(_flags ${TARGET} "INTERFACE_LINK_OPTIONS" "") + append_link_flags(_flags ${TARGET} "INTERFACE_LINK_LIBRARIES" "") + # message("_flags: ${_flags}") + set(${FLAGS} ${_flags} PARENT_SCOPE) +endfunction() diff --git a/codecov.yml b/codecov.yml index 6b19a9604e2022911c6d4d506d7411d8859794f0..03abe2daeb23ab4a17b76ece694a7f05aff21a5f 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,3 +1,4 @@ ignore: - "test/" - + - "src/driver" + - "build/" diff --git a/cppcheck.rules b/cppcheck.rules old mode 100644 new mode 100755 index 1d5b4c607b7862ba097e5a8de48d0d48e23be7e6..fcccc58024701997cada53ba4eea4a4a8d8594d7 --- a/cppcheck.rules +++ b/cppcheck.rules @@ -1,5 +1,6 @@ + normal [;{}] [*] \w+? (\+\+|\-\-) ; UnusedDeref @@ -8,6 +9,7 @@ + normal if \( ([!] )*?(strlen) \( \w+? \) ([>] [0] )*?\) { StrlenEmptyString @@ -16,6 +18,7 @@ + normal [;{}] [*] \w+? (\+\+|\-\-) ; UnusedDeref @@ -42,6 +45,7 @@ + normal mutable \w+ MutableVariable @@ -50,6 +54,7 @@ + normal (memcpy|strcpy|strncpy|strcat|strncat) \( useStlAlgorithms @@ -58,6 +63,7 @@ + normal memset \( useStlAlgorithms @@ -66,6 +72,7 @@ + normal memcmp \( useStlAlgorithms @@ -74,6 +81,7 @@ + normal memchr \( useStlAlgorithms @@ -82,7 +90,8 @@ - \\W(fclose|free|hipFree|hipHostFree|hipFreeArray|hipMemFree|hipStreamDestroy|hipEventDestroy|hipArrayDestroy|hipCtxDestroy|hipDestroyTextureObject|hipDestroySurfaceObject) \( + normal + \\W(fclose|free|hipFree|hipHostFree|hipFreeArray|hipMemFree|hipStreamDestroy|hipEventDestroy|hipArrayDestroy|hipCtxDestroy|hipDestroyTextureObject|hipDestroySurfaceObject|miirDestroyHandle) \( useManagePointer style @@ -90,6 +99,33 @@ + normal + + + useSmartPointer + style + Use make_shared or make_unique instead of new + + + + - raw + normal \))]]> UseDeviceLaunch @@ -116,6 +152,24 @@ Else statement is not necessary. + + normal + |::) )*(?:\w+|>)(?: &|\*)*) (\w+) ; \2 = static_cast < \1 > (\([^()]*(?-1)*[^()]*\)) ;]]> + + RedundantCast + style + Static cast is redundant. + + + + normal + |::) )*(?:\w+|>)(?: &|\*)* > (\([^()]*(?-1)*[^()]*\)) ;]]> + + RedundantCast + style + Static cast is redundant. + + normal @@ -315,7 +369,7 @@ - simple + normal |::) )*(?:\w+|>)(?: &|\*)* (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = \w+ \[ \1 \] ; }]]> useStlAlgorithm @@ -324,7 +378,7 @@ - simple + normal |::) )*(?:\w+|>)(?: &|\*)* (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = \w+ ; }]]> useStlAlgorithm @@ -333,7 +387,7 @@ - simple + normal |::) )*(?:\w+|>)(?: &|\*)* (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (?:\w+ :: )*\w+ \( \) ; }]]> useStlAlgorithm @@ -342,7 +396,7 @@ - simple + normal |::) )*(?:\w+|>)(?: &|\*)* (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (?:\w+ :: )*\w+ \( \w+ \[ \1 \] \) ; }]]> useStlAlgorithm @@ -351,7 +405,7 @@ - simple + normal |::) )*(?:\w+|>)(?: &|\*)* (\w+) = \w+ ; \1 < \w+ ; (\1 \+\+|\+\+ \1|\1 \-\-|\-\- \1) \) { \w+ \[ \1 \] = (?:\w+ :: )*\w+ \( \w+ \[ \1 \] , \w+ \[ \1 \] \) ; }]]> useStlAlgorithm @@ -360,7 +414,7 @@ - simple + normal |::) )*(?:\w+|>)(?: &|\*)* (\w+) : (?:[^()]*(\([^()]*(?-1)*[^()]*\)))*[^)]*\) { (?:(?\w+) \+\+|\+\+ (?\w+)) ; if (\([^()]*(?-1)*[^()]*\)) { \w+ = \g{idx1}|\g{idx2} ; (?:break ; )?(?:return [^;]*; )?} }]]> useStlAlgorithm @@ -369,7 +423,7 @@ - simple + normal |::) )*(?:\w+|>)(?: &|\*)* (\w+) : (?:[^()]*(\([^()]*(?-1)*[^()]*\)))*[^)]*\) { if (\([^()]*(?-1)*[^()]*\)) { \w+ = (?\w) ; (?:break ; )?(?:return [^;]*; )?} (?:(\g{idx}) \+\+|\+\+ (\g{idx})) ; }]]> useStlAlgorithm @@ -378,7 +432,7 @@ - simple + normal |::) )*(?:\w+|>)(?: &|\*)* (\w+) : (?:[^()]*(\([^()]*(?-1)*[^()]*\)))*[^)]*\) { (?:(?\w+) \+\+|\+\+ (?\w+)) ; if (\([^()]*(?-1)*[^()]*\)) { return \g{idx1}|\g{idx2} ; } }]]> useStlAlgorithm @@ -387,7 +441,7 @@ - simple + normal |::) )*(?:\w+|>)(?: &|\*)* (\w+) : (?:[^()]*(\([^()]*(?-1)*[^()]*\)))*[^)]*\) { if (\([^()]*(?-1)*[^()]*\)) { return (?\w+) ; } (?:(\g{idx}) \+\+|\+\+ (\g{idx})) ; }]]> useStlAlgorithm diff --git a/dev-requirements.txt b/dev-requirements.txt old mode 100644 new mode 100755 index d75ba3a41875acac659af18afe676bbe89a65022..ee43bebe4d61af0e97b8a30d39d54c8b5e1bf970 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1,7 @@ pfultz2/rocm-recipes -danmar/cppcheck@ef714225bb31e9a76ac2484796763572386955ae -DHAVE_RULES=1 -ROCm-Developer-Tools/HIP@2490e42baa7d90458f0632fd9fbead2d395f41b9 -python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36 +facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cmake +ccache@v4.1 +pcre,pfultz2/pcre@8.45 -H sha256:d6f7182602a775a7d500a0cedca6449af0400c6493951513046d17615ed0bf11 +danmar/cppcheck@2.8 -DHAVE_RULES=1 +RadeonOpenCompute/rocm-cmake@1ebf7e7bc61bb5e949c171562b421264065230a7 --build -f requirements.txt diff --git a/doc/CMakeLists.txt b/doc/CMakeLists.txt old mode 100644 new mode 100755 index 4e5af96a73d29ccadaa6b654f1d356b33993a1d0..4622723a72baa3499a410428e38f813c35c8f493 --- a/doc/CMakeLists.txt +++ b/doc/CMakeLists.txt @@ -1,15 +1,24 @@ +project(migraphx-doc) +find_package(ROCM REQUIRED) -include(DoxygenDoc) +include(ROCMDoxygenDoc) -set(DOXYGEN_OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/doxygen/) -add_doxygen_doc( +set(DOXYGEN_OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/doxygen) +rocm_add_doxygen_doc( OUTPUT_DIRECTORY ${DOXYGEN_OUTPUT} INPUT - ${CMAKE_CURRENT_SOURCE_DIR}/../src + ${CMAKE_SOURCE_DIR}/src INCLUDE_PATH - ${CMAKE_CURRENT_SOURCE_DIR}/../src/include - ${CMAKE_CURRENT_SOURCE_DIR}/../src/targets/cpu/include - ${CMAKE_CURRENT_SOURCE_DIR}/../src/targets/gpu/include + ${CMAKE_SOURCE_DIR}/src/include + ${CMAKE_SOURCE_DIR}/src/targets/cpu/include + ${CMAKE_SOURCE_DIR}/src/targets/gpu/include + STRIP_FROM_INC_PATH + ${CMAKE_SOURCE_DIR}/src/include + ${CMAKE_SOURCE_DIR}/src/targets/cpu/include + ${CMAKE_SOURCE_DIR}/src/targets/gpu/include + EXCLUDE_PATTERNS + ${CMAKE_SOURCE_DIR}/src/targets/gpu/kernels + ${CMAKE_SOURCE_DIR}/src/targets/gpu/device SEARCH_INCLUDES YES MACRO_EXPANSION YES RECURSIVE YES @@ -29,26 +38,23 @@ add_doxygen_doc( EXTRACT_ALL YES ENUM_VALUES_PER_LINE 1 FULL_PATH_NAMES YES + WARN_LOGFILE "${DOXYGEN_OUTPUT}/DoxygenWarningLog.txt" PREDEFINED DOXYGEN ) -add_custom_target(remove_inline_ns - sed -i "s/MIGRAPHX_INLINE_NS:://g" *.xml - WORKING_DIRECTORY ${DOXYGEN_OUTPUT}/xml) -add_dependencies(remove_inline_ns doxygen) -include(SphinxDoc) -add_sphinx_doc(src +include(ROCMSphinxDoc) +rocm_add_sphinx_doc(src BUILDER html - OUTPUT_DIR html + OUTPUT_DIR html VARS breathe_projects.proj=${DOXYGEN_OUTPUT}/xml breathe_default_project=proj - DEPENDS doxygen remove_inline_ns + DEPENDS doxygen ) find_package(LATEX) if(LATEX_FOUND) - add_sphinx_doc(src + rocm_add_sphinx_doc(src BUILDER latex OUTPUT_DIR pdf VARS @@ -57,6 +63,6 @@ if(LATEX_FOUND) DEPENDS doxygen ) else() - message("Latex builder not found. Latex builder is required only for building the PDF documentation for MIGraph and is not necessary for building the library, or any other components. To build PDF documentation run make in ${CMAKE_CURRENT_SOURCE_DIR}/pdf, once a latex builder is installed.") + message("Latex builder not found. Latex builder is required only for building the PDF documentation for MIGraphX and is not necessary for building the library, or any other components. To build PDF documentation run make in ${CMAKE_CURRENT_SOURCE_DIR}/pdf, once a latex builder is installed.") endif() diff --git a/doc/requirements.txt b/doc/requirements.txt index ad8772f3f448e6b7bdfb843f598a9f15b34a613c..725376a06182dfe29bb27366f8938d6954793130 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,3 +1,5 @@ -sphinx==1.6.2 -breathe==4.9.1 +docutils==0.17.1 +sphinx==4.2.0 +breathe==4.31.0 +sphinx_rtd_theme==1.0.0 # git+https://github.com/arximboldi/breathe@fix-node-parent diff --git a/doc/src/conf.py b/doc/src/conf.py old mode 100644 new mode 100755 index 2d9b6a4315e216679be4f8a6f07db62738a5ef67..ec6a48d3eb0a3cf7d03cd2b2bf940c2a64c9c890 --- a/doc/src/conf.py +++ b/doc/src/conf.py @@ -18,6 +18,8 @@ # # import os # import sys +from datetime import date +import re # sys.path.insert(0, os.path.abspath('.')) # -- General configuration ------------------------------------------------ @@ -29,7 +31,11 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['breathe', 'sphinx.ext.mathjax', 'sphinx.ext.viewcode'] +extensions = [ + 'breathe', 'sphinx.ext.mathjax', 'sphinx.ext.viewcode', 'sphinx_rtd_theme', + 'sphinx.ext.autosectionlabel' +] +autosectionlabel_prefix_document = True # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] @@ -45,7 +51,7 @@ master_doc = 'index' # General information about the project. project = u'MIGraphX' -copyright = u'2018, AMD' +copyright = u'2018-{}, AMD'.format(date.today().year) author = u'AMD' # The version info for the project you're documenting, acts as replacement for @@ -53,9 +59,12 @@ author = u'AMD' # built documents. # # The short X.Y version. -version = u'0.1' +with open('../../CMakeLists.txt') as file: + version = next((re.findall('[0-9.]+', line)[0] + for line in file.readlines() + if 'rocm_setup_version' in line)) # The full version, including alpha/beta/rc tags. -release = u'0.1' +release = version # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -82,7 +91,7 @@ todo_include_todos = False # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'alabaster' +html_theme = 'sphinx_rtd_theme' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the diff --git a/doc/src/contributor_guide.rst b/doc/src/contributor_guide.rst new file mode 100755 index 0000000000000000000000000000000000000000..c863b8800335ffc0f6860ab694980a4802a85c30 --- /dev/null +++ b/doc/src/contributor_guide.rst @@ -0,0 +1,16 @@ +Contributor Guide +=============== + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + dev_intro + dev/data + dev/operators + dev/program + dev/targets + dev/quantization + dev/pass + dev/matchers + dev/tools diff --git a/doc/src/developer_guide.rst b/doc/src/cpp_user_guide.rst old mode 100644 new mode 100755 similarity index 52% rename from doc/src/developer_guide.rst rename to doc/src/cpp_user_guide.rst index 91b1dc36f39b636f57ce812e9f3ba3a31d5e2fc8..9e917f7dde9ee6ebae127a0e0a6f36ea3ffafae7 --- a/doc/src/developer_guide.rst +++ b/doc/src/cpp_user_guide.rst @@ -1,8 +1,8 @@ -Developer Guide -=============== +C++ User Guide +============== .. toctree:: :maxdepth: 2 :caption: Contents: - dev/matchers + reference/cpp diff --git a/doc/src/dev/data.rst b/doc/src/dev/data.rst new file mode 100755 index 0000000000000000000000000000000000000000..b0181d80a02739ae8b94f53bfb2a3e2c76116e2f --- /dev/null +++ b/doc/src/dev/data.rst @@ -0,0 +1,30 @@ +Data types +========== + +shape +----- + +.. doxygenstruct:: migraphx::internal::shape + +literal +------- + +.. doxygenstruct:: migraphx::internal::literal + +argument +-------- + +.. doxygenstruct:: migraphx::internal::argument + +raw_data +-------- + +.. doxygenstruct:: migraphx::internal::raw_data + +.. doxygenfunction:: migraphx::internal::visit_all + + +tensor_view +----------- + +.. doxygenstruct:: migraphx::internal::tensor_view diff --git a/doc/src/dev/operators.rst b/doc/src/dev/operators.rst new file mode 100755 index 0000000000000000000000000000000000000000..915eebef99068bffbced2e5e6350e9e10a9e9f09 --- /dev/null +++ b/doc/src/dev/operators.rst @@ -0,0 +1,16 @@ +Operators +========= + +operation +--------- + +.. doxygenstruct:: migraphx::internal::operation + +.. doxygenfunction:: migraphx::internal::is_context_free + +.. doxygenfunction:: migraphx::internal::has_finalize + +operators +--------- + +.. doxygennamespace:: migraphx::internal::op diff --git a/doc/src/dev/pass.rst b/doc/src/dev/pass.rst new file mode 100755 index 0000000000000000000000000000000000000000..7a6074c87d7b8c81d5764e8a374bdd152f17df10 --- /dev/null +++ b/doc/src/dev/pass.rst @@ -0,0 +1,67 @@ +Passes +====== + +pass +---- + +.. doxygenstruct:: migraphx::internal::pass + +dead_code_elimination +--------------------- + +.. doxygenstruct:: migraphx::internal::dead_code_elimination + +eliminate_common_subexpression +------------------------------ + +.. doxygenstruct:: migraphx::internal::eliminate_common_subexpression + +eliminate_concat +---------------- + +.. doxygenstruct:: migraphx::internal::eliminate_concat + +eliminate_contiguous +-------------------- + +.. doxygenstruct:: migraphx::internal::eliminate_contiguous + +eliminate_identity +------------------ + +.. doxygenstruct:: migraphx::internal::eliminate_identity + +eliminate_pad +------------- + +.. doxygenstruct:: migraphx::internal::eliminate_pad + +propagate_constant +------------------ + +.. doxygenstruct:: migraphx::internal::propagate_constant + +rewrite_batchnorm +----------------- + +.. doxygenstruct:: migraphx::internal::rewrite_batchnorm + +rewrite_rnn +----------- + +.. doxygenstruct:: migraphx::internal::rewrite_rnn + +schedule +-------- + +.. doxygenstruct:: migraphx::internal::schedule + +simplify_algebra +---------------- + +.. doxygenstruct:: migraphx::internal::simplify_algebra + +simplify_reshapes +----------------- + +.. doxygenstruct:: migraphx::internal::simplify_reshapes diff --git a/doc/src/dev/program.rst b/doc/src/dev/program.rst new file mode 100755 index 0000000000000000000000000000000000000000..ceefbe63c66c2590049ed33078b991649ae8fd6d --- /dev/null +++ b/doc/src/dev/program.rst @@ -0,0 +1,39 @@ +Program +======= + +instruction +----------- + +.. doxygenstruct:: migraphx::internal::instruction + +instruction_ref +--------------- + +.. cpp:type:: migraphx::internal::instruction_ref + + References an instruction in the program. + +program +------- + +.. doxygenstruct:: migraphx::internal::program + +parse_onnx +---------- + +.. doxygenfunction:: migraphx::internal::parse_onnx + +parse_tf +-------- + +.. doxygenfunction:: migraphx::internal::parse_tf + +onnx_options +------------ + +.. doxygenstruct:: migraphx::internal::onnx_options + +tf_options +---------- + +.. doxygenstruct:: migraphx::internal::tf_options diff --git a/doc/src/dev/quantization.rst b/doc/src/dev/quantization.rst new file mode 100755 index 0000000000000000000000000000000000000000..aecbd63188f30bbc4b5ec39d8676dea26a45f3e7 --- /dev/null +++ b/doc/src/dev/quantization.rst @@ -0,0 +1,13 @@ +Quantization +============ + +quantize_fp16 +------------- + +.. doxygenfunction:: migraphx::internal::quantize_fp16 + +quantize_int8 +------------- + +.. doxygenfunction:: migraphx::internal::quantize_int8 + diff --git a/doc/src/dev/roctx1.jpg b/doc/src/dev/roctx1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d3fb210f6d8d47aa4cf13ed600dd94cc3d950502 Binary files /dev/null and b/doc/src/dev/roctx1.jpg differ diff --git a/doc/src/dev/roctx2.jpg b/doc/src/dev/roctx2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3ace0b13ab0b1db87de0701328de5ea273b435e9 Binary files /dev/null and b/doc/src/dev/roctx2.jpg differ diff --git a/doc/src/dev/targets.rst b/doc/src/dev/targets.rst new file mode 100755 index 0000000000000000000000000000000000000000..323507c3186229c81730cb86c729209486e82b71 --- /dev/null +++ b/doc/src/dev/targets.rst @@ -0,0 +1,18 @@ +Targets +======= + +target +------ + +.. doxygenstruct:: migraphx::internal::target + +gpu::target +----------- + +.. doxygenstruct:: migraphx::internal::gpu::target + +cpu::target +----------- + +.. doxygenstruct:: migraphx::internal::cpu::target + diff --git a/doc/src/dev/tools.rst b/doc/src/dev/tools.rst new file mode 100644 index 0000000000000000000000000000000000000000..8b328e14865d6db8aa5e18c36277a0c76e3213c8 --- /dev/null +++ b/doc/src/dev/tools.rst @@ -0,0 +1,65 @@ +Tools +===== + +roctx.py +-------- +MIGraphX driver provides `roctx` command which can be used with `rocprof` binary to get marker timing information for each MIGraphX operator. +In order to help user to process timing information, rocTX helper script is provided at `tools/roctx.py`. +The `roctx.py` helper script provides two main functionality: `run` and `parse`. Available knobs and usage are given below: + +:: + + Usage: roctx.py [-h] [--json-path json_path] [--out out] + [--study-name study-name] [--repeat repeat] [--parse] + [--run run] [--debug] + +.. option:: --run + +Runs `migraphx-driver roctx` command and given `migraphx-driver` knobs, and then parses the results, providing GPU kernel timing information. +MIGraphX knobs can be given via a string to `--run` knob. Please see the examples below. + +.. option:: --parse + +Given `--json-path`, parses JSON file and provides GPU kernel timing information. + +.. option:: --out + +Output folder + +.. option:: --study-name + +Optional. Allows user to name a study for easier interpretation. Defaults to timestamp. + +.. option:: --repeat + +Number of iterations. Set to **2** by default. + +.. option:: --debug + +Provides additional debug information related to data. Only use for debugging purposes. + +**Examples:** + +**Running inference with rocTX for a given ONNX file:** +:: + python roctx.py --run '--onnx --gpu fcn-resnet50-11.onnx' --out output_folder --repeat 5 + +After a run, similar to output given below is expected at terminal. The output will provide `SUM`, `MIN`, `MAX` and `COUNT` information for each kernel executed for a given model. +Average total time is also provided. There are three files provided for reference: + +1. `OUTPUT CSV FILE` provides a summary of the run, providing utilized MIGraphX knobs and related kernel timing information +2. `KERNEL TIMING DETAILS` provides the hotspot kernel timing information +3. This will provide all output data related to all iterations executed during a run. + +An example output: + +.. image:: ./roctx1.jpg + +Hotspot kerel timing information: + +.. image:: ./roctx2.jpg + +**Parsing an already existing JSON file:** +:: + + python roctx.py --parse --json-path ../trace.json \ No newline at end of file diff --git a/doc/src/dev_intro.rst b/doc/src/dev_intro.rst new file mode 100644 index 0000000000000000000000000000000000000000..2b78c303cee71e632284c444d019a87fd88aedca --- /dev/null +++ b/doc/src/dev_intro.rst @@ -0,0 +1,152 @@ +MIGraphX Fundamentals +====================== + +MIGraphX provides an optimized execution engine for deep learning neural networks. +We will cover some simple operations in the MIGraphX framework here. +For a quick start guide to using MIGraphX, look in the examples directory: ``https://github.com/ROCmSoftwarePlatform/AMDMIGraphX/tree/develop/examples/migraphx``. + + +Location of the Examples +------------------------- + +The ``ref_dev_examples.cpp`` can be found in the test directory (``/test``). +The executable file ``test_ref_dev_examples`` based on this file will be created in the ``bin/`` of the build directory after running ``make -j$(nproc) test_ref_dev_examples``. +The executable will also be created when running ``make -j$(nproc) check``, alongside with all the other tests. +Directions for building MIGraphX from source can be found in the main README file: ``https://github.com/ROCmSoftwarePlatform/AMDMIGraphX#readme``. + + +Adding Two Literals +-------------------- + +A program is a collection of modules, which are collections of instructions to be executed when calling `eval `. +Each instruction has an associated `operation ` which represents the computation to be performed by the instruction. + +We start with a snippet of the simple ``add_two_literals()`` function:: + + // create the program and get a pointer to the main module + migraphx::program p; + auto* mm = p.get_main_module(); + + // add two literals to the program + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + + // make the add operation between the two literals and add it to the program + mm->add_instruction(migraphx::make_op("add"), one, two); + + // compile the program on the reference device + p.compile(migraphx::ref::target{}); + + // evaulate the program and retreive the result + auto result = p.eval({}).back(); + std::cout << "add_two_literals: 1 + 2 = " << result << "\n"; + +We start by creating a simple ``migraphx::program`` object and then getting a pointer to the main module of it. +The program is a collection of ``modules`` that start executing from the main module, so instructions are added to the modules rather than directly onto the program object. +We then use the `add_literal ` function to add an instruction that stores the literal number ``1`` while returning an `instruction_ref `. +The returned `instruction_ref ` can be used in another instruction as an input. +We use the same `add_literal ` function to add a ``2`` to the program. +After creating the literals, we then create the instruction to add the numbers together. +This is done by using the `add_instruction ` function with the ``"add"`` `operation ` created by `make_op ` along with the previous `add_literal` `instruction_ref ` for the input arguments of the instruction. +Finally, we can run this `program ` by compiling it for the reference target (CPU) and then running it with `eval ` +The result is then retreived and printed to the console. + +We can compile the program for the GPU as well, but the file will have to be moved to the ``test/gpu/`` directory and the correct target must be included:: + + #include + + +Using Parameters +----------------- + +The previous program will always produce the same value of adding ``1`` and ``2``. +In the next program we want to pass an input to a program and compute a value based on the input. +We can modify the program to take an input parameter ``x``, as seen in the ``add_parameter()`` function:: + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int32_type, {1}}; + + // add a "x" parameter with the shape s + auto x = mm->add_parameter("x", s); + auto two = mm->add_literal(2); + + // add the "add" instruction between the "x" parameter and "two" to the module + mm->add_instruction(migraphx::make_op("add"), x, two); + p.compile(migraphx::ref::target{}); + +This adds a parameter of type ``int32``, and compiles it for the CPU. +To run the program, we need to pass the parameter as a ``parameter_map`` when we call `eval `. +We create the ``parameter_map`` by setting the ``x`` key to an `argument ` object with an ``int`` data type:: + + // create a parameter_map object for passing a value to the "x" parameter + std::vector data = {4}; + migraphx::parameter_map params; + params["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(params).back(); + std::cout << "add_parameters: 4 + 2 = " << result << "\n"; + EXPECT(result.at() == 6); + + +Handling Tensor Data +--------------------- + +In the previous examples we have only been dealing with scalars, but the `shape ` class can describe multi-dimensional tensors. +For example, we can compute a simple convolution:: + + migraphx::program p; + auto* mm = p.get_main_module(); + + // create shape objects for the input tensor and weights + migraphx::shape input_shape{migraphx::shape::float_type, {2, 3, 4, 4}}; + migraphx::shape weights_shape{migraphx::shape::float_type, {3, 3, 3, 3}}; + + // create the parameters and add the "convolution" operation to the module + auto input = mm->add_parameter("X", input_shape); + auto weights = mm->add_parameter("W", weights_shape); + mm->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}), input, weights); + +Here we create two parameters for both the ``input`` and ``weights``. +In the previous examples, we created simple literals, however, most programs will take data from allocated buffers (usually on the GPU). +In this case, we can create `argument ` objects directly from the pointers to the buffers:: + + // Compile the program + p.compile(migraphx::ref::target{}); + + // Allocated buffers by the user + std::vector a = ...; + std::vector c = ...; + + // Solution vector + std::vector sol = ...; + + // Create the arguments in a parameter_map + migraphx::parameter_map params; + params["X"] = migraphx::argument(input_shape, a.data()); + params["W"] = migraphx::argument(weights_shape, c.data()); + + // Evaluate and confirm the result + auto result = p.eval(params).back(); + std::vector results_vector(64); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(results_vector, sol)); + +An `argument ` can handle memory buffers from either the GPU or the CPU. +By default when running the `program `, buffers are allocated on the corresponding target. +When compiling for the CPU, the buffers by default will be allocated on the CPU. +When compiling for the GPU, the buffers by default will be allocated on the GPU. +With the option ``offloaf_copy=true`` set while compiling for the GPU, the buffers will be located on the CPU. + + +Importing From ONNX +-------------------- + +A `program ` can be built directly from an onnx file using the MIGraphX ONNX parser. +This makes it easier to use neural networks directly from other frameworks. +In this case, there is an ``parse_onnx`` function:: + + program p = migraphx::parse_onnx("model.onnx"); + p.compile(migraphx::gpu::target{}); + diff --git a/doc/src/driver.rst b/doc/src/driver.rst new file mode 100755 index 0000000000000000000000000000000000000000..baa7047f124c3b120f94ee2208599c76992efa28 --- /dev/null +++ b/doc/src/driver.rst @@ -0,0 +1,81 @@ +MIGraphX Driver +=============== + +read +---- + +.. program:: migraphx-driver read + +Loads and prints input graph. + +.. include:: ./driver/read.rst + +compile +------- + +.. program:: migraphx-driver compile + +Compiles and prints input graph. + +.. include:: ./driver/compile.rst + +run +--- + +.. program:: migraphx-driver run + +Loads and prints input graph. + +.. include:: ./driver/compile.rst + +perf +---- + +.. program:: migraphx-driver perf + +Compiles and runs input graph then prints performance report. + +.. include:: ./driver/compile.rst + +.. option:: --iterations, -n [unsigned int] + +Number of iterations to run for perf report (Default: 100) + +verify +------ + +.. program:: migraphx-driver verify + +Runs reference and CPU or GPU implementations and checks outputs for consistency. + +.. include:: ./driver/compile.rst + +.. option:: --tolerance [double] + +Tolerance for errors (Default: 80) + +.. option:: -i, --per-instruction + +Verify each instruction + +.. option:: -r, --reduce + +Reduce program and verify + +roctx +---- + +.. program:: migraphx-driver roctx + +Provides marker information for each operation, allowing MIGraphX to be used with `rocprof `_ for performance analysis. +This allows user to get GPU-level kernel timing information. +An example command line combined with rocprof for tracing purposes is given below: + +.. code-block:: bash + + /opt/rocm/bin/rocprof --hip-trace --roctx-trace --flush-rate 1ms --timestamp on -d --obj-tracking on /opt/rocm/bin/migraphx-driver roctx + +After `rocprof` is run, the output directory will contain trace information for HIP, HCC and ROCTX in seperate `.txt` files. +To understand the interactions between API calls, it is recommended to utilize `roctx.py` helper script as desribed in :ref:`dev/tools:rocTX` section. + +.. include:: ./driver/compile.rst \ No newline at end of file diff --git a/doc/src/driver/compile.rst b/doc/src/driver/compile.rst new file mode 100755 index 0000000000000000000000000000000000000000..b31a91dbc7e7048d598dbcb95274ad7af0d81890 --- /dev/null +++ b/doc/src/driver/compile.rst @@ -0,0 +1,38 @@ +.. include:: ./driver/read.rst + +.. option:: --fill0 [std::vector] + +Fill parameter with 0s + +.. option:: --fill1 [std::vector] + +Fill parameter with 1s + +.. option:: --gpu + +Compile on the gpu + +.. option:: --cpu + +Compile on the cpu + +.. option:: --ref + +Compile on the reference implementation + +.. option:: --enable-offload-copy + +Enable implicit offload copying + +.. option:: --disable-fast-math + +Disable fast math optimization + +.. option:: --fp16 + +Quantize for fp16 + +.. option:: --int8 + +Quantize for int8 + diff --git a/doc/src/driver/read.rst b/doc/src/driver/read.rst new file mode 100755 index 0000000000000000000000000000000000000000..f0ab175d0efcbd07627ddb974f2d130f5ff49ee6 --- /dev/null +++ b/doc/src/driver/read.rst @@ -0,0 +1,80 @@ +.. option:: + +File to load + +.. option:: --model [resnet50|inceptionv3|alexnet] + +Load model + +.. option:: --onnx + +Load as onnx + +.. option:: --tf + +Load as tensorflow + +.. option:: --migraphx + +Load as MIGraphX + +.. option:: --migraphx-json + +Load as MIGraphX JSON + +.. option:: --batch [unsigned int] (Default: 1) + +Set batch size for model + +.. option:: --nhwc + +Treat tensorflow format as nhwc + +.. option:: --skip-unknown-operators + +Skip unknown operators when parsing and continue to parse. + +.. option:: --nchw + +Treat tensorflow format as nchw + +.. option:: --trim, -t [unsigned int] + +Trim instructions from the end (Default: 0) + +.. option:: --input-dim [std::vector] + +Dim of a parameter (format: "@name d1 d2 dn") + +.. option:: --optimize, -O + +Optimize when reading + +.. option:: --graphviz, -g + +Print out a graphviz representation. + +.. option:: --brief + +Make the output brief. + +.. option:: --cpp + +Print out the program as cpp program. + +.. option:: --json + +Print out program as json. + +.. option:: --text + +Print out program in text format. + +.. option:: --binary + +Print out program in binary format. + +.. option:: --output, -o [std::string] + +Output to file. + diff --git a/doc/src/index.rst b/doc/src/index.rst old mode 100644 new mode 100755 index dfa61c4303a92d1ba716e2b341ebbb6ce08a5f73..781e7bd157ac14b3ab3ba4fba3c2194be3efc100 --- a/doc/src/index.rst +++ b/doc/src/index.rst @@ -10,8 +10,10 @@ Welcome to AMD MIGraphX's documentation! :maxdepth: 3 :caption: Contents: - user_guide - developer_guide + py_user_guide + cpp_user_guide + driver + contributor_guide Indices and tables diff --git a/doc/src/overview.rst b/doc/src/overview.rst deleted file mode 100644 index f5a3f7c170823c21432bd3d87fb27a6a4e34b211..0000000000000000000000000000000000000000 --- a/doc/src/overview.rst +++ /dev/null @@ -1,89 +0,0 @@ -Overview -======== - -MIGraphX provides an optimized execution engine for deep learning neural networks. - -Building a program ------------------- - -A program consists of a set of instructions to be executed when calling `eval `. Each instruction has an associated `operation ` which represents the computation to be performed by the instruction. - -We can start by building a simple program to add two numbers together:: - - program p; - instruction_ref one = p.add_literal(1); - instruction_ref two = p.add_literal(2); - p.add_instruction(add{}, one, two); - -The `add_literal ` function will add an instruction to the program to store a literal number. The `instruction_ref ` is a reference to the instruction in the program, which can be used to compose the output of the instruction with another instruction. - -After creating the literals, we then create the instruction to add the numbers together. This is done by using the `add{} ` operation class along with the `instruction_ref ` for the input arguments of the instruction. - -Finally, we can run this `program ` by compiling it for the cpu and then running it with `eval `:: - - p.compile(cpu::target{}); - argument result = p.eval({}); - -The easiest way to see the result is to print it:: - - std::cout << result; - -Which will print ``3``. - -We can also compile the program for the gpu as well. - -Adding parameters ------------------ - -Of course, this program will always produce the same value which is quite uninteresting. Instead, we want to pass an input to a program and compute a value based on the input. This can be done with a parameter. For example, we can modify the program to take an input ``x``:: - - program p; - instruction_ref x = p.add_parameter("x", {shape::int64_type}); - instruction_ref two = p.add_literal(2); - p.add_instruction(add{}, x, two); - p.compile(cpu::target{}); - -This adds a parameter of type ``int64``, and compiles it for the ``cpu``. To run the program, we need to pass the parameter to it when we call `eval `:: - - argument result = p.eval({ - {"x", literal{1}.get_argument()} - }); - std::cout << result; - -This will print ``3``. - -A parameter is given as an `argument `. In this case, the simplest way of creating an `argument ` is from a `literal `. - -Tensor data ------------ - -In this example we are just creating numbers, but the `shape ` class can describe multi-dimensional tensors. For example, we can build a simple network with convolution and relu:: - - program p; - instruction_ref input = p.add_parameter("x", shape{shape::float_type, {1, 3, 32, 32}}); - instruction_ref weights = p.add_parameter("w", shape{shape::float_type, {1, 3, 5, 5}}); - instruction_ref conv = p.add_instruction(convolution{}, input, weights); - p.add_instruction(activation{"relu"}, conv); - -Here we create two parameters for both the ``input`` and ``weights``. In the previous examples, we just created simple literals, however, most programs will take data from already allocated buffers(usually on the GPU). In this case, we can create `argument ` objects directly from the pointers to the buffers:: - - // Compile the program - p.compile(gpu::target{}); - // Allocated buffers by the user - float* input = ...; - float* weights = ...; - // Create the arguments - argument input_arg{shape{shape::float_type, {1, 3, 32, 32}}, input}; - argument weights_arg{shape{shape::float_type, {1, 3, 32, 32}}, weights}; - p.eval({{"x", input_arg}, {"w", weights_arg}}) - -An `argument ` can handle memory buffers from either the GPU or the CPU, but when running the `program `, buffers should be allocated for the corresponding target. That is, when compiling for the CPU, the buffers should be allocated on the CPU, and when compiling for the GPU the buffers should be allocated on the GPU. - -Importing from onnx -------------------- - -A `program ` can be built directly from an onnx file, which makes it easier to use neural networks directly from other frameworks. In this case, there is an ``parse_onnx`` function:: - - program p = parse_onnx("model.onnx"); - p.compile(gpu::target{}); - diff --git a/doc/src/py_user_guide.rst b/doc/src/py_user_guide.rst new file mode 100644 index 0000000000000000000000000000000000000000..1f70005989a607d2d77dbd77f9327517d453fcb9 --- /dev/null +++ b/doc/src/py_user_guide.rst @@ -0,0 +1,8 @@ +Python User Guide +================= + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + reference/py diff --git a/doc/src/reference/cpp.rst b/doc/src/reference/cpp.rst new file mode 100755 index 0000000000000000000000000000000000000000..26e3d4158558d320f641287557bee4569cb9e55e --- /dev/null +++ b/doc/src/reference/cpp.rst @@ -0,0 +1,78 @@ + +C++ Reference +============= + +shape +----- + +.. doxygenenum:: migraphx_shape_datatype_t + +.. doxygenstruct:: migraphx::shape + +argument +-------- + +.. doxygenstruct:: migraphx::argument + +target +------ + +.. doxygenstruct:: migraphx::target + +program +------- + +.. doxygenstruct:: migraphx::program_parameter_shapes + +.. doxygenstruct:: migraphx::program_parameters + +.. doxygenstruct:: migraphx_compile_options + +.. doxygenstruct:: migraphx::program + +quantize +-------- + +.. doxygenstruct:: migraphx::quantize_op_names + +.. doxygenfunction:: migraphx::quantize_fp16(const program&) + +.. doxygenfunction:: migraphx::quantize_fp16(const program&, const quantize_op_names&) + +.. doxygenstruct:: migraphx::quantize_int8_options + +.. doxygenfunction:: migraphx::quantize_int8 + +parse_onnx +---------- + +.. doxygenstruct:: migraphx::onnx_options + +.. doxygenfunction:: migraphx::parse_onnx(const char *) + +.. doxygenfunction:: migraphx::parse_onnx(const char *, const migraphx::onnx_options&) + +.. doxygenfunction:: migraphx::parse_onnx_buffer(const std::string&) + +.. doxygenfunction:: migraphx::parse_onnx_buffer(const std::string&, const migraphx::onnx_options&) + +.. doxygenfunction:: migraphx::parse_onnx_buffer(const void *, size_t) + +.. doxygenfunction:: migraphx::parse_onnx_buffer(const void *, size_t, const migraphx::onnx_options&) + +load +---- + +.. doxygenstruct:: migraphx_file_options + +.. doxygenfunction:: migraphx::load(const char *) + +.. doxygenfunction:: migraphx::load(const char *, migraphx_file_options) + +save +---- + +.. doxygenfunction:: migraphx::save(const program&, const char *) + +.. doxygenfunction:: migraphx::save(const program&, const char *, migraphx_file_options) + diff --git a/doc/src/reference/data.rst b/doc/src/reference/data.rst deleted file mode 100644 index 341128b268105b85d6cd91b2b373045a55b61401..0000000000000000000000000000000000000000 --- a/doc/src/reference/data.rst +++ /dev/null @@ -1,30 +0,0 @@ -Data types -========== - -shape ------ - -.. doxygenstruct:: migraphx::shape - -literal -------- - -.. doxygenstruct:: migraphx::literal - -argument --------- - -.. doxygenstruct:: migraphx::argument - -raw_data --------- - -.. doxygenstruct:: migraphx::raw_data - -.. doxygenfunction:: migraphx::MIGRAPHX_INLINE_NS::visit_all - - -tensor_view ------------ - -.. doxygenstruct:: migraphx::tensor_view diff --git a/doc/src/reference/operators.rst b/doc/src/reference/operators.rst deleted file mode 100644 index 9e3794fcdd57e2c8a12f093ae850ae9425b44373..0000000000000000000000000000000000000000 --- a/doc/src/reference/operators.rst +++ /dev/null @@ -1,12 +0,0 @@ -Operators -========= - -operation ---------- - -.. doxygenstruct:: migraphx::operation - -operators ---------- - -.. doxygenfile:: operators.hpp diff --git a/doc/src/reference/pass.rst b/doc/src/reference/pass.rst deleted file mode 100644 index 93d8d386d2e8acdf8bc525b572d89b1c9223f2a8..0000000000000000000000000000000000000000 --- a/doc/src/reference/pass.rst +++ /dev/null @@ -1,47 +0,0 @@ -Passes -====== - -pass ----- - -.. doxygenstruct:: migraphx::pass - -dead_code_elimination ---------------------- - -.. doxygenstruct:: migraphx::dead_code_elimination - -common_subexpression_elimination --------------------------------- - -.. doxygenstruct:: migraphx::common_subexpression_elimination - -constant_propagate ------------------- - -.. doxygenstruct:: migraphx::constant_propagate - -eliminate_concat ----------------- - -.. doxygenstruct:: migraphx::eliminate_concat - -eliminate_contiguous --------------------- - -.. doxygenstruct:: migraphx::eliminate_contiguous - -fwd_conv_batchnorm_rewrite --------------------------- - -.. doxygenstruct:: migraphx::fwd_conv_batchnorm_rewrite - -simplify_algebra ----------------- - -.. doxygenstruct:: migraphx::simplify_algebra - -simplify_reshapes ------------------ - -.. doxygenstruct:: migraphx::simplify_reshapes diff --git a/doc/src/reference/program.rst b/doc/src/reference/program.rst deleted file mode 100644 index a5e4fc0a3c7d604751d687956ac9239af3095809..0000000000000000000000000000000000000000 --- a/doc/src/reference/program.rst +++ /dev/null @@ -1,24 +0,0 @@ -Program -======= - -instruction ------------ - -.. doxygenstruct:: migraphx::instruction - -instruction_ref ---------------- - -.. cpp:type:: migraphx::instruction_ref - - References an instruction in the program. - -program -------- - -.. doxygenstruct:: migraphx::program - -parse_onnx ----------- - -.. doxygenfunction:: migraphx::MIGRAPHX_INLINE_NS::parse_onnx diff --git a/doc/src/reference/py.rst b/doc/src/reference/py.rst new file mode 100755 index 0000000000000000000000000000000000000000..0956dff51d91f45e39e8180cfdc3614c08c0710c --- /dev/null +++ b/doc/src/reference/py.rst @@ -0,0 +1,322 @@ +.. py:module:: migraphx + +Python Reference +================ + +shape +----- + +.. py:class:: shape(type, lens, strides=None) + + Describes the shape of a tensor. This includes size, layout, and data type/ + +.. py:method:: type() + + An integer that represents the type. + + :rtype: int + +.. py:method:: lens() + + A list of the lengths of the shape. + + :rtype: list[int] + +.. py:method:: strides() + + A list of the strides of the shape. + + :rtype: list[int] + +.. py:method:: elements() + + The number of elements in the shape. + + :rtype: int + +.. py:method:: bytes() + + The number of bytes the shape uses. + + :rtype: int + +.. py:method:: type_size() + + The number of bytes one element uses + + :rtype: int + +.. py:method:: packed() + + Returns true if the shape is packed. + + :rtype: bool + +.. py:method:: transposed() + + Returns true if the shape is transposed. + + :rtype: bool + +.. py:method:: broadcasted() + + Returns true if the shape is broadcasted. + + :rtype: bool + +.. py:method:: standard() + + Returns true if the shape is a standard shape. That is, the shape is both packed and not transposed. + + :rtype: bool + +.. py:method:: scalar() + + Returns true if all strides are equal to 0 (scalar tensor). + + :rtype: bool + + +argument +-------- + +.. py:class:: argument(data) + + Construct an argument from a python buffer. This can include numpy arrays. + +.. py:method:: get_shape() + + Returns the shape of the argument. + + :rtype: shape + +.. py:method:: tolist() + + Convert the elements of the argument to a python list. + + :rtype: list + + +.. py:function:: generate_argument(s, seed=0) + + Generate an argument with random data. + + :param shape s: Shape of argument to generate. + :param int seed: The seed used for random number generation. + + :rtype: argument + +.. py:function:: fill_argument(s, value) + + Fill argument of shape s with value. + + :param shape s: Shape of argument to fill. + :param int value: Value to fill in the argument. + + :rtype argument + +target +------ + +.. py:class:: target() + + This represents the compilation target. + +.. py:function:: get_target(name) + + Constructs the target. + + :param str name: The name of the target to construct. This can either be 'gpu' or 'ref'. + + :rtype: target + + +module +------ +.. py:method:: print() + + Prints the contents of the module as list of instructions. + +.. py:method:: add_instruction(op, args, mod_args=[]) + + Adds instruction into the module. + + :param operation op: 'migraphx.op' to be added as instruction. + :param list[instruction] args: list of inputs to the op. + :param list[module] mod_args: optional list of module arguments to the operator. + :rtype instruction + +.. py:method:: add_literal(data) + + Adds constant or literal data of provided shape into the module from python buffer which includes numpy array. + + :param py::buffer data: Python buffer or numpy array + :rtype instruction + +.. py:method:: add_parameter(name, shape) + + Adds a parameter to the module with provided name and shape. + + :param str name: name of the parameter. + :param shape shape: shape of the parameter. + :rtype instruction + +.. py:method:: add_return(args) + + Adds a return instruction into the module. + + :param list[instruction] args: instruction arguments which need to be returned from the module. + :rtype instruction + + +program +------- + +.. py:class:: program() + + Represents the computation graph to be compiled and run. + +.. py:method:: clone() + + Make a copy of the program. + + :rtype: program + +.. py:method:: get_parameter_names() + + Get all the input arguments' or parameters' names to the program as a list. + + :rtype list[str] + +.. py:method:: get_parameter_shapes() + + Get the shapes of all the input parameters in the program. + + :rtype: dict[str, shape] + +.. py:method:: get_output_shapes() + + Get the shapes of the final outputs of the program. + + :rtype: list[shape] + +.. py:method:: compile(t, offload_copy=True, fast_math=True) + + Compiles the program for the target and optimizes it. + + :param target t: This is the target to compile the program for. + :param bool offload_copy: For targets with offloaded memory(such as the gpu), this will insert instructions during compilation to copy the input parameters to the offloaded memory and to copy the final result from the offloaded memory back to main memory. + :param bool fast_math: Optimize math functions to use faster approximate versions. There may be slight accuracy degredation when enabled. + +.. py:method:: get_main_module() + + Get main module of the program. + + :rtype module + +.. py:method:: create_module(name) + + Create and add a module of provided name into the program. + + :param str name : name of the new module. + :rtype module + +.. py:method:: run(params) + + Run the program. + + :param params: This is a map of the input parameters which will be used when running the program. + :type params: dict[str, argument] + + :return: The result of the last instruction. + :rtype: list[argument] + +.. py:method:: sort() + + Sort the modules of the program such that instructions appear in topologically sorted order. + +.. py:function:: quantize_fp16(prog, ins_names=["all"]) + + Quantize the program to use fp16. + + :param program prog: Program to quantize. + :param ins_names: List of instructions to quantize. + :type ins_names: list[str] + + +.. py:function:: quantize_int8(prog, t, calibration=[], ins_names=["dot", "convolution"]) + + Quantize the program to use int8. + + :param program prog: Program to quantize. + :param target t: Target that will be used to run the calibration data. + :param calibration: Calibration data used to decide the parameters to the int8 optimization. + :type calibration: list[dict[str, argument]] + :param ins_names: List of instructions to quantize. + :type ins_names: list[str] + + +op +-- +.. py::class:: op(name, kwargs) + + Construct an operation with name and arguments. + + :param str name : name of the operation, must be supported by MIGraphX. + :param dict[str, any] kwargs: arguments to the operation. + :rtype operation + + + +parse_onnx +---------- + +.. py:function:: parse_onnx(filename, default_dim_value=1, map_input_dims={}, skip_unknown_operators=false, print_program_on_error=false, max_loop_iterations=10) + + Load and parse an onnx file. + + :param str filename: Path to file. + :param str default_dim_value: default batch size to use (if not specified in onnx file). + :param str map_input_dims: Explicitly specify the dims of an input. + :param str skip_unknown_operators: Continue parsing onnx file if an unknown operator is found. + :param str print_program_on_error: Print program if an error occurs. + :param int max_loop_iterations: Maximum iteration number for the loop operator. + :rtype: program + +parse_tf +-------- + +.. py:function:: parse_tf(filename, is_nhwc=True, batch_size=1, map_input_dims=dict(), output_names=[]) + + Load and parse an tensorflow protobuf file file. + + :param str filename: Path to file. + :param bool is_nhwc: Use nhwc as default format. + :param str batch_size: default batch size to use (if not specified in protobuf). + :param dict[str, list[int]] map_input_dims: Optional arg to explictly specify dimensions of the inputs. + :param list[str] output_names: Optional argument specify names of the output nodes. + :rtype: program + +load +---- + +.. py:function:: load(filename, format='msgpack') + + Load a MIGraphX program. + + :param str filename: Path to file. + :param str format: Format of file. Valid options are msgpack or json. + + :rtype: program + +save +---- + +.. py:function:: save(p, filename, format='msgpack') + + Save a MIGraphX program. + + :param program p: Program to save. + :param str filename: Path to file. + :param str format: Format of file. Valid options are msgpack or json. + diff --git a/doc/src/reference/targets.rst b/doc/src/reference/targets.rst deleted file mode 100644 index 51aa549b48b4955168038d5bd3f75aa77a31ec2d..0000000000000000000000000000000000000000 --- a/doc/src/reference/targets.rst +++ /dev/null @@ -1,18 +0,0 @@ -Targets -======= - -target ------- - -.. doxygenstruct:: migraphx::target - -gpu::target ------------ - -.. doxygenstruct:: migraphx::gpu::target - -cpu::target ------------ - -.. doxygenstruct:: migraphx::cpu::target - diff --git a/doc/src/user_guide.rst b/doc/src/user_guide.rst deleted file mode 100644 index 16a343d9d6cfb045a93afbc6839830f4f04dfebc..0000000000000000000000000000000000000000 --- a/doc/src/user_guide.rst +++ /dev/null @@ -1,13 +0,0 @@ -User Guide -========== - -.. toctree:: - :maxdepth: 2 - :caption: Contents: - - overview - reference/data - reference/operators - reference/program - reference/targets - reference/pass diff --git a/examples/README.md b/examples/README.md new file mode 100755 index 0000000000000000000000000000000000000000..2bdd32c942a8f23b7af3eae636c0597eea2d454e --- /dev/null +++ b/examples/README.md @@ -0,0 +1,9 @@ +# AMD MIGraphX Examples + +## Description +This directory contains examples of common use cases for MIGraphX. + +## Examples: +- [MIGraphX usage and utilities](./migraphx) +- [Vision inference examples](./vision) +- [Natural language inference examples](./nlp) \ No newline at end of file diff --git a/examples/migraphx/README.md b/examples/migraphx/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bf046ca71dbdf3be04ff18d48a056d5a60a542ff --- /dev/null +++ b/examples/migraphx/README.md @@ -0,0 +1,7 @@ +# AMD MIGraphX usage and utilities + +- [C++ Parse, Load, and Save Graph Programs](./cpp_parse_load_save) +- [Exporting Frozen Graphs in TF1](./export_frozen_graph_tf1) +- [Exporting Frozen Graphs in TF2](./export_frozen_graph_tf2) +- [MIGraphX Docker Container](./migraphx_docker) +- [MIGraphX Driver](./migraphx_driver) \ No newline at end of file diff --git a/examples/migraphx/cpp_parse_load_save/CMakeLists.txt b/examples/migraphx/cpp_parse_load_save/CMakeLists.txt new file mode 100755 index 0000000000000000000000000000000000000000..0e7a20a94f6c72eaafe7b11dcffacc5861428214 --- /dev/null +++ b/examples/migraphx/cpp_parse_load_save/CMakeLists.txt @@ -0,0 +1,13 @@ +cmake_minimum_required(VERSION 3.5) +project (PLS) + +set (CMAKE_CXX_STANDARD 14) +set (EXAMPLE parse_load_save) + +list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm) +find_package (migraphx) + +message("source file: " ${EXAMPLE}.cpp " ---> bin: " ${EXAMPLE}) +add_executable(${EXAMPLE} ${EXAMPLE}.cpp) + +target_link_libraries(${EXAMPLE} migraphx::c) diff --git a/examples/migraphx/cpp_parse_load_save/README.md b/examples/migraphx/cpp_parse_load_save/README.md new file mode 100755 index 0000000000000000000000000000000000000000..184e7e2a23e26ee348d2e4b47dc2135106dd1e50 --- /dev/null +++ b/examples/migraphx/cpp_parse_load_save/README.md @@ -0,0 +1,76 @@ +# Parsing, Loading, and Saving MIGraphX Programs + +## Description +This examples demonstrates how to parse, load, and save a graph program using the MIGraphX C++ API. + +## Parsing +Computation graphs that have been saved in a compatible serialized format, such as [ONNX](https://onnx.ai/get-started.html), can be read in by MIGraphX to create a runable program. + +``` +migraphx::program p; +unsigned batch = 1; //Or read in as argument +migraphx::onnx_options options; +options.set_default_dim_value(batch); +p = parse_onnx(input_file, options); +``` + +## Saving +An instantiated migraphx::program object can then be serialized to MessagePack (.mxr) format and saved so that it can be loaded for future uses. + +A program can be saved with either of the following: +``` +migraphx::program p = ... ; +migraphx::save(p, output_file); +``` + +``` +migraphx::program p = ... ; +migraphx::file_options options; +options.set_file_format("msgpack"); +migraphx::save(p, output_file, options); +``` + +## Loading +Similarly, graphs that have been previously parsed, and possibly compiled, and then saved in either MessagePack or JSON format can be loaded at later time. + +MessagePack is the default format, and can be loaded with either: +``` +migraphx::program p; +p = migraphx::load(input_file); +``` + +``` +migraphx::program p; +migraphx::file_options options; +options.set_file_format("msgpack"); +p = migraphx::load(input_file, options); +``` +To load a program that has been saved in JSON format: +``` +migraphx::program p; +migraphx::file_options options; +options.set_file_format("json"); +p = migraphx::load(input_file, options); +``` + + +## Running the Example +The provided example [`parse_load_save.cpp`](./parse_load_save.cpp) has these features implemented to allow for comparing outputs. + +To compile and run the example from this directory: +``` +$ mkdir build +$ cd build +$ cmake .. +$ make +``` +There will now be an executable named `parse_load_save` with the following usage: +``` +$ ./parse_load_save [options] +options: + --parse onnx + --load json/msgpack + --save +``` + +The program will then attempt to parse or load the graph file, print out its internal graph structure if successful, and optionally save the program to a given file name. diff --git a/examples/migraphx/cpp_parse_load_save/parse_load_save.cpp b/examples/migraphx/cpp_parse_load_save/parse_load_save.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa93682f5f0fee876ee2c5910291b873cfb02839 --- /dev/null +++ b/examples/migraphx/cpp_parse_load_save/parse_load_save.cpp @@ -0,0 +1,105 @@ +#include +#include +#include +#include +#include + +// MIGraphX C++ API +#include + +char* getCmdOption(char**, char**, const std::string&); + +bool cmdOptionExists(char**, char**, const std::string&); + +int main(int argc, char** argv) +{ + if(argc < 2) + { + std::cout << "Usage: " << argv[0] << " " + << "[options]" << std::endl; + std::cout << "options:" << std::endl; + std::cout << "\t--parse onnx" << std::endl; + std::cout << "\t--load json/msgpack" << std::endl; + std::cout << "\t--save " << std::endl; + return 0; + } + + char* load_arg = getCmdOption(argv + 2, argv + argc, "--load"); + char* save_arg = getCmdOption(argv + 2, argv + argc, "--save"); + const char* input_file = argv[1]; + + migraphx::program p; + + if(cmdOptionExists(argv + 2, argv + argc, "--parse") || + !cmdOptionExists(argv + 2, argv + argc, "--load")) + { + std::cout << "Parsing ONNX File" << std::endl; + migraphx::onnx_options options; + p = parse_onnx(input_file, options); + } + else if(load_arg != nullptr) + { + std::cout << "Loading Graph File" << std::endl; + std::string format = load_arg; + if(format == "json") + { + migraphx::file_options options; + options.set_file_format("json"); + p = migraphx::load(input_file, options); + } + else if(format == "msgpack") + { + migraphx::file_options options; + options.set_file_format("msgpack"); + p = migraphx::load(input_file, options); + } + else + p = migraphx::load(input_file); + } + else + { + std::cout << "Error: Incorrect Usage" << std::endl; + std::cout << "Usage: " << argv[0] << " " + << "[options]" << std::endl; + std::cout << "options:" << std::endl; + std::cout << "\t--parse onnx" << std::endl; + std::cout << "\t--load json/msgpack" << std::endl; + std::cout << "\t--save " << std::endl; + return 0; + } + + std::cout << "Input Graph: " << std::endl; + p.print(); + std::cout << std::endl; + + if(cmdOptionExists(argv + 2, argv + argc, "--save")) + { + std::cout << "Saving program..." << std::endl; + std::string output_file; + output_file = save_arg == nullptr ? "out" : save_arg; + output_file.append(".mxr"); + + migraphx::file_options options; + options.set_file_format("msgpack"); + migraphx::save(p, output_file.c_str(), options); + std::cout << "Program has been saved as ./" << output_file << std::endl; + } + + return 0; +} + +char* getCmdOption(char** begin, char** end, const std::string& option) +{ + char** itr = std::find(begin, end, option); + if(itr != end && ++itr != end) + { + return *itr; + } + + return nullptr; +} + +bool cmdOptionExists(char** begin, char** end, const std::string& option) +{ + return std::find(begin, end, option) != end; +} diff --git a/examples/migraphx/export_frozen_graph_tf1/README.md b/examples/migraphx/export_frozen_graph_tf1/README.md new file mode 100755 index 0000000000000000000000000000000000000000..5d70d2a4824fe1375012503cb5e069d21534942c --- /dev/null +++ b/examples/migraphx/export_frozen_graph_tf1/README.md @@ -0,0 +1,144 @@ +# Exporting Frozen Graphs in Tensorflow 1 + +## Description +This example demonstrates how to export a frozen graph protobuf in Tensorflow 1.X that can be used as input to MIGraphX. Specifically, this is an example of exporting a frozen protobuf of a tensorflow BERT model. + +## How to Use this Example + + +In order to support bert from tensorflow's official [repository](https://github.com/google-research/bert), a serving_input_fn for the estimator must be implemented in [run_classifier.py](https://github.com/google-research/bert/blob/master/run_classifier.py). In this script, insert the following function after importing all libraries and setting up flags: + +``` +#... +flags.DEFINE_integer( + "num_tpu_cores", 8, + "Only used if `use_tpu` is True. Total number of TPU cores to use.") + +# insert function here +def serving_input_fn(): + label_ids = tf.placeholder(tf.int32, [None], name='label_ids') + input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids') + input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask') + segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids') + input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ + 'label_ids': label_ids, + 'input_ids': input_ids, + 'input_mask': input_mask, + 'segment_ids': segment_ids, + }, default_batch_size=1)() + return input_fn +``` + +Since we are passing dynamic shape placeholders in the serving_input_fn, the default_batch_size value will essentially determine the resulting shape in the graph. + +For inference, we will focus on the "probabilities" layer's output, and we can name this layer by modifying the following [line](https://github.com/google-research/bert/blob/master/run_classifier.py#L608): + +``` +probabilities = tf.nn.softmax(logits, axis=-1, name="output") + +``` + +Next, we need to export the saved model after training: +``` +def main(_): +# ... + with tf.gfile.GFile(output_predict_file, "w") as writer: + num_written_lines = 0 + tf.logging.info("***** Predict results *****") + for (i, prediction) in enumerate(result): + probabilities = prediction["probabilities"] + if i >= num_actual_predict_examples: + break + output_line = "\t".join( + str(class_probability) + for class_probability in probabilities) + "\n" + writer.write(output_line) + num_written_lines += 1 + assert num_written_lines == num_actual_predict_examples + +# insert code here + if FLAGS.do_train: # optional to attach export to train flag + estimator._export_to_tpu = False + estimator.export_savedmodel('saved_models', serving_input_fn) +# ... +``` + +Run bert with the suggested arguments: +``` +export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 +export GLUE_DIR=/path/to/glue + +python run_classifier.py \ + --task_name=MRPC \ + --do_train=true \ + --do_eval=true \ + --data_dir=$GLUE_DIR/MRPC \ + --vocab_file=$BERT_BASE_DIR/vocab.txt \ + --bert_config_file=$BERT_BASE_DIR/bert_config.json \ + --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ + --max_seq_length=128 \ + --train_batch_size=32 \ # change to appropriate size that fits on GPU + --learning_rate=2e-5 \ + --num_train_epochs=3.0 \ + --output_dir=/tmp/mrpc_output/ +``` + +When running, search for the following lines in the output: +``` +INFO:tensorflow:Restoring parameters from /tmp/model.ckpt-1603 +INFO:tensorflow:SavedModel written to: saved_models/temp-1564086017/saved_model.pb +``` + +Note the ID followed by "temp-" (in this case, 1564086017). A directory should exist under saved_models/ that is named with the ID. + +We also need to record the name of the output layer in bert. This can be done by inspecting the saved model. +``` +saved_model_cli show --dir saved_models/1564086017 --tag_set serve --signature_def serving_default +``` +The output should look like this: +``` +The given SavedModel SignatureDef contains the following input(s): + inputs['input_ids'] tensor_info: + dtype: DT_INT32 + shape: (1, 128) + name: input_ids_1:0 + inputs['input_mask'] tensor_info: + dtype: DT_INT32 + shape: (1, 128) + name: input_mask_1:0 + inputs['label_ids'] tensor_info: + dtype: DT_INT32 + shape: (1) + name: label_ids_1:0 + inputs['segment_ids'] tensor_info: + dtype: DT_INT32 + shape: (1, 128) + name: segment_ids_1:0 +The given SavedModel SignatureDef contains the following output(s): + outputs['probabilities'] tensor_info: + dtype: DT_FLOAT + shape: (1, 2) + name: loss/output:0 +Method name is: tensorflow/serving/predict + +``` +Here the output name is given as "loss/output:0", but we will strip the ":0" from the end, as we are concerned with the node only. + +We will use tensorflow's freeze graph utility script and the information gathered above to create the frozen protobuf file. + +``` +CKPT_NUM=1603 +MODEL_ID=1564086017 +OUT_NAME=loss/output + +cd /path/to/tensorflow +python tensorflow/python/tools/freeze_graph.py \ + --input_graph=/tmp/mrpc_model/graph.pbtxt \ + --input_binary=false \ + --input_checkpoint=/tmp/mrpc_model/model.ckpt-${CKPT_NUM} \ + --input_saved_model_dir=/path/to/bert/saved_models/${MODEL_ID} \ + --output_graph=/tmp/frozen_bert.pb \ + --output_node_names=${OUT_NAME} +``` + +The final output should be a frozen protobuf that is compatible with MIGraphX. \ No newline at end of file diff --git a/examples/migraphx/export_frozen_graph_tf2/.ipynb_checkpoints/Untitled-checkpoint.ipynb b/examples/migraphx/export_frozen_graph_tf2/.ipynb_checkpoints/Untitled-checkpoint.ipynb new file mode 100755 index 0000000000000000000000000000000000000000..7fec51502cbc3200b3d0ffc6bbba1fe85e197f3d --- /dev/null +++ b/examples/migraphx/export_frozen_graph_tf2/.ipynb_checkpoints/Untitled-checkpoint.ipynb @@ -0,0 +1,6 @@ +{ + "cells": [], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/migraphx/export_frozen_graph_tf2/.ipynb_checkpoints/example-checkpoint.ipynb b/examples/migraphx/export_frozen_graph_tf2/.ipynb_checkpoints/example-checkpoint.ipynb new file mode 100755 index 0000000000000000000000000000000000000000..6e5d2bddfd5cf6aa6078f25fc6418f7fd7edfaa9 --- /dev/null +++ b/examples/migraphx/export_frozen_graph_tf2/.ipynb_checkpoints/example-checkpoint.ipynb @@ -0,0 +1,3054 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Exporting Frozen Graphs in Tensorflow 2 \n", + "In order to use a trained model as input to MIGraphX, the model must be first be saved in a frozen graph format. This was accomplished in Tensorflow 1 by launching a graph in a tf.Session and then saving the session. However, Tensorflow has decided to deprecate Sessions in favor of functions and SavedModel format. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After importing the necessary libraries, the next step is to instantiate a model. For simplicity, in this example we will use a resnet50 architecture with pre-trained imagenet weights. These weights may also be trained or fine-tuned before freezing. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"resnet50\"\n", + "__________________________________________________________________________________________________\n", + "Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + "input_1 (InputLayer) [(None, 224, 224, 3) 0 \n", + "__________________________________________________________________________________________________\n", + "conv1_pad (ZeroPadding2D) (None, 230, 230, 3) 0 input_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv1_conv (Conv2D) (None, 112, 112, 64) 9472 conv1_pad[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv1_bn (BatchNormalization) (None, 112, 112, 64) 256 conv1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv1_relu (Activation) (None, 112, 112, 64) 0 conv1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "pool1_pad (ZeroPadding2D) (None, 114, 114, 64) 0 conv1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "pool1_pool (MaxPooling2D) (None, 56, 56, 64) 0 pool1_pad[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block1_1_conv (Conv2D) (None, 56, 56, 64) 4160 pool1_pool[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block1_1_bn (BatchNormali (None, 56, 56, 64) 256 conv2_block1_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block1_1_relu (Activation (None, 56, 56, 64) 0 conv2_block1_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block1_2_conv (Conv2D) (None, 56, 56, 64) 36928 conv2_block1_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block1_2_bn (BatchNormali (None, 56, 56, 64) 256 conv2_block1_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block1_2_relu (Activation (None, 56, 56, 64) 0 conv2_block1_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block1_0_conv (Conv2D) (None, 56, 56, 256) 16640 pool1_pool[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block1_3_conv (Conv2D) (None, 56, 56, 256) 16640 conv2_block1_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block1_0_bn (BatchNormali (None, 56, 56, 256) 1024 conv2_block1_0_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block1_3_bn (BatchNormali (None, 56, 56, 256) 1024 conv2_block1_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block1_add (Add) (None, 56, 56, 256) 0 conv2_block1_0_bn[0][0] \n", + " conv2_block1_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block1_out (Activation) (None, 56, 56, 256) 0 conv2_block1_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block2_1_conv (Conv2D) (None, 56, 56, 64) 16448 conv2_block1_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block2_1_bn (BatchNormali (None, 56, 56, 64) 256 conv2_block2_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block2_1_relu (Activation (None, 56, 56, 64) 0 conv2_block2_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block2_2_conv (Conv2D) (None, 56, 56, 64) 36928 conv2_block2_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block2_2_bn (BatchNormali (None, 56, 56, 64) 256 conv2_block2_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block2_2_relu (Activation (None, 56, 56, 64) 0 conv2_block2_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block2_3_conv (Conv2D) (None, 56, 56, 256) 16640 conv2_block2_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block2_3_bn (BatchNormali (None, 56, 56, 256) 1024 conv2_block2_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block2_add (Add) (None, 56, 56, 256) 0 conv2_block1_out[0][0] \n", + " conv2_block2_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block2_out (Activation) (None, 56, 56, 256) 0 conv2_block2_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block3_1_conv (Conv2D) (None, 56, 56, 64) 16448 conv2_block2_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block3_1_bn (BatchNormali (None, 56, 56, 64) 256 conv2_block3_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block3_1_relu (Activation (None, 56, 56, 64) 0 conv2_block3_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block3_2_conv (Conv2D) (None, 56, 56, 64) 36928 conv2_block3_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block3_2_bn (BatchNormali (None, 56, 56, 64) 256 conv2_block3_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block3_2_relu (Activation (None, 56, 56, 64) 0 conv2_block3_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block3_3_conv (Conv2D) (None, 56, 56, 256) 16640 conv2_block3_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block3_3_bn (BatchNormali (None, 56, 56, 256) 1024 conv2_block3_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block3_add (Add) (None, 56, 56, 256) 0 conv2_block2_out[0][0] \n", + " conv2_block3_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv2_block3_out (Activation) (None, 56, 56, 256) 0 conv2_block3_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block1_1_conv (Conv2D) (None, 28, 28, 128) 32896 conv2_block3_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block1_1_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block1_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block1_1_relu (Activation (None, 28, 28, 128) 0 conv3_block1_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block1_2_conv (Conv2D) (None, 28, 28, 128) 147584 conv3_block1_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block1_2_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block1_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block1_2_relu (Activation (None, 28, 28, 128) 0 conv3_block1_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block1_0_conv (Conv2D) (None, 28, 28, 512) 131584 conv2_block3_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block1_3_conv (Conv2D) (None, 28, 28, 512) 66048 conv3_block1_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block1_0_bn (BatchNormali (None, 28, 28, 512) 2048 conv3_block1_0_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block1_3_bn (BatchNormali (None, 28, 28, 512) 2048 conv3_block1_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block1_add (Add) (None, 28, 28, 512) 0 conv3_block1_0_bn[0][0] \n", + " conv3_block1_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block1_out (Activation) (None, 28, 28, 512) 0 conv3_block1_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block2_1_conv (Conv2D) (None, 28, 28, 128) 65664 conv3_block1_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block2_1_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block2_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block2_1_relu (Activation (None, 28, 28, 128) 0 conv3_block2_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block2_2_conv (Conv2D) (None, 28, 28, 128) 147584 conv3_block2_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block2_2_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block2_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block2_2_relu (Activation (None, 28, 28, 128) 0 conv3_block2_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block2_3_conv (Conv2D) (None, 28, 28, 512) 66048 conv3_block2_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block2_3_bn (BatchNormali (None, 28, 28, 512) 2048 conv3_block2_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block2_add (Add) (None, 28, 28, 512) 0 conv3_block1_out[0][0] \n", + " conv3_block2_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block2_out (Activation) (None, 28, 28, 512) 0 conv3_block2_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block3_1_conv (Conv2D) (None, 28, 28, 128) 65664 conv3_block2_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block3_1_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block3_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block3_1_relu (Activation (None, 28, 28, 128) 0 conv3_block3_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block3_2_conv (Conv2D) (None, 28, 28, 128) 147584 conv3_block3_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block3_2_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block3_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block3_2_relu (Activation (None, 28, 28, 128) 0 conv3_block3_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block3_3_conv (Conv2D) (None, 28, 28, 512) 66048 conv3_block3_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block3_3_bn (BatchNormali (None, 28, 28, 512) 2048 conv3_block3_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block3_add (Add) (None, 28, 28, 512) 0 conv3_block2_out[0][0] \n", + " conv3_block3_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block3_out (Activation) (None, 28, 28, 512) 0 conv3_block3_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block4_1_conv (Conv2D) (None, 28, 28, 128) 65664 conv3_block3_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block4_1_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block4_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block4_1_relu (Activation (None, 28, 28, 128) 0 conv3_block4_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block4_2_conv (Conv2D) (None, 28, 28, 128) 147584 conv3_block4_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block4_2_bn (BatchNormali (None, 28, 28, 128) 512 conv3_block4_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block4_2_relu (Activation (None, 28, 28, 128) 0 conv3_block4_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block4_3_conv (Conv2D) (None, 28, 28, 512) 66048 conv3_block4_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block4_3_bn (BatchNormali (None, 28, 28, 512) 2048 conv3_block4_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block4_add (Add) (None, 28, 28, 512) 0 conv3_block3_out[0][0] \n", + " conv3_block4_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv3_block4_out (Activation) (None, 28, 28, 512) 0 conv3_block4_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block1_1_conv (Conv2D) (None, 14, 14, 256) 131328 conv3_block4_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block1_1_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block1_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block1_1_relu (Activation (None, 14, 14, 256) 0 conv4_block1_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block1_2_conv (Conv2D) (None, 14, 14, 256) 590080 conv4_block1_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block1_2_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block1_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block1_2_relu (Activation (None, 14, 14, 256) 0 conv4_block1_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block1_0_conv (Conv2D) (None, 14, 14, 1024) 525312 conv3_block4_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block1_3_conv (Conv2D) (None, 14, 14, 1024) 263168 conv4_block1_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block1_0_bn (BatchNormali (None, 14, 14, 1024) 4096 conv4_block1_0_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block1_3_bn (BatchNormali (None, 14, 14, 1024) 4096 conv4_block1_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block1_add (Add) (None, 14, 14, 1024) 0 conv4_block1_0_bn[0][0] \n", + " conv4_block1_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block1_out (Activation) (None, 14, 14, 1024) 0 conv4_block1_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block2_1_conv (Conv2D) (None, 14, 14, 256) 262400 conv4_block1_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block2_1_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block2_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block2_1_relu (Activation (None, 14, 14, 256) 0 conv4_block2_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block2_2_conv (Conv2D) (None, 14, 14, 256) 590080 conv4_block2_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block2_2_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block2_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block2_2_relu (Activation (None, 14, 14, 256) 0 conv4_block2_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block2_3_conv (Conv2D) (None, 14, 14, 1024) 263168 conv4_block2_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block2_3_bn (BatchNormali (None, 14, 14, 1024) 4096 conv4_block2_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block2_add (Add) (None, 14, 14, 1024) 0 conv4_block1_out[0][0] \n", + " conv4_block2_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block2_out (Activation) (None, 14, 14, 1024) 0 conv4_block2_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block3_1_conv (Conv2D) (None, 14, 14, 256) 262400 conv4_block2_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block3_1_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block3_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block3_1_relu (Activation (None, 14, 14, 256) 0 conv4_block3_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block3_2_conv (Conv2D) (None, 14, 14, 256) 590080 conv4_block3_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block3_2_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block3_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block3_2_relu (Activation (None, 14, 14, 256) 0 conv4_block3_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block3_3_conv (Conv2D) (None, 14, 14, 1024) 263168 conv4_block3_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block3_3_bn (BatchNormali (None, 14, 14, 1024) 4096 conv4_block3_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block3_add (Add) (None, 14, 14, 1024) 0 conv4_block2_out[0][0] \n", + " conv4_block3_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block3_out (Activation) (None, 14, 14, 1024) 0 conv4_block3_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block4_1_conv (Conv2D) (None, 14, 14, 256) 262400 conv4_block3_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block4_1_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block4_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block4_1_relu (Activation (None, 14, 14, 256) 0 conv4_block4_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block4_2_conv (Conv2D) (None, 14, 14, 256) 590080 conv4_block4_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block4_2_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block4_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block4_2_relu (Activation (None, 14, 14, 256) 0 conv4_block4_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block4_3_conv (Conv2D) (None, 14, 14, 1024) 263168 conv4_block4_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block4_3_bn (BatchNormali (None, 14, 14, 1024) 4096 conv4_block4_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block4_add (Add) (None, 14, 14, 1024) 0 conv4_block3_out[0][0] \n", + " conv4_block4_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block4_out (Activation) (None, 14, 14, 1024) 0 conv4_block4_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block5_1_conv (Conv2D) (None, 14, 14, 256) 262400 conv4_block4_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block5_1_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block5_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block5_1_relu (Activation (None, 14, 14, 256) 0 conv4_block5_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block5_2_conv (Conv2D) (None, 14, 14, 256) 590080 conv4_block5_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block5_2_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block5_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block5_2_relu (Activation (None, 14, 14, 256) 0 conv4_block5_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block5_3_conv (Conv2D) (None, 14, 14, 1024) 263168 conv4_block5_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block5_3_bn (BatchNormali (None, 14, 14, 1024) 4096 conv4_block5_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block5_add (Add) (None, 14, 14, 1024) 0 conv4_block4_out[0][0] \n", + " conv4_block5_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block5_out (Activation) (None, 14, 14, 1024) 0 conv4_block5_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block6_1_conv (Conv2D) (None, 14, 14, 256) 262400 conv4_block5_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block6_1_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block6_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block6_1_relu (Activation (None, 14, 14, 256) 0 conv4_block6_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block6_2_conv (Conv2D) (None, 14, 14, 256) 590080 conv4_block6_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block6_2_bn (BatchNormali (None, 14, 14, 256) 1024 conv4_block6_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block6_2_relu (Activation (None, 14, 14, 256) 0 conv4_block6_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block6_3_conv (Conv2D) (None, 14, 14, 1024) 263168 conv4_block6_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block6_3_bn (BatchNormali (None, 14, 14, 1024) 4096 conv4_block6_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block6_add (Add) (None, 14, 14, 1024) 0 conv4_block5_out[0][0] \n", + " conv4_block6_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv4_block6_out (Activation) (None, 14, 14, 1024) 0 conv4_block6_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block1_1_conv (Conv2D) (None, 7, 7, 512) 524800 conv4_block6_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block1_1_bn (BatchNormali (None, 7, 7, 512) 2048 conv5_block1_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block1_1_relu (Activation (None, 7, 7, 512) 0 conv5_block1_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block1_2_conv (Conv2D) (None, 7, 7, 512) 2359808 conv5_block1_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block1_2_bn (BatchNormali (None, 7, 7, 512) 2048 conv5_block1_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block1_2_relu (Activation (None, 7, 7, 512) 0 conv5_block1_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block1_0_conv (Conv2D) (None, 7, 7, 2048) 2099200 conv4_block6_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block1_3_conv (Conv2D) (None, 7, 7, 2048) 1050624 conv5_block1_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block1_0_bn (BatchNormali (None, 7, 7, 2048) 8192 conv5_block1_0_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block1_3_bn (BatchNormali (None, 7, 7, 2048) 8192 conv5_block1_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block1_add (Add) (None, 7, 7, 2048) 0 conv5_block1_0_bn[0][0] \n", + " conv5_block1_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block1_out (Activation) (None, 7, 7, 2048) 0 conv5_block1_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block2_1_conv (Conv2D) (None, 7, 7, 512) 1049088 conv5_block1_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block2_1_bn (BatchNormali (None, 7, 7, 512) 2048 conv5_block2_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block2_1_relu (Activation (None, 7, 7, 512) 0 conv5_block2_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block2_2_conv (Conv2D) (None, 7, 7, 512) 2359808 conv5_block2_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block2_2_bn (BatchNormali (None, 7, 7, 512) 2048 conv5_block2_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block2_2_relu (Activation (None, 7, 7, 512) 0 conv5_block2_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block2_3_conv (Conv2D) (None, 7, 7, 2048) 1050624 conv5_block2_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block2_3_bn (BatchNormali (None, 7, 7, 2048) 8192 conv5_block2_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block2_add (Add) (None, 7, 7, 2048) 0 conv5_block1_out[0][0] \n", + " conv5_block2_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block2_out (Activation) (None, 7, 7, 2048) 0 conv5_block2_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block3_1_conv (Conv2D) (None, 7, 7, 512) 1049088 conv5_block2_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block3_1_bn (BatchNormali (None, 7, 7, 512) 2048 conv5_block3_1_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block3_1_relu (Activation (None, 7, 7, 512) 0 conv5_block3_1_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block3_2_conv (Conv2D) (None, 7, 7, 512) 2359808 conv5_block3_1_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block3_2_bn (BatchNormali (None, 7, 7, 512) 2048 conv5_block3_2_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block3_2_relu (Activation (None, 7, 7, 512) 0 conv5_block3_2_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block3_3_conv (Conv2D) (None, 7, 7, 2048) 1050624 conv5_block3_2_relu[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block3_3_bn (BatchNormali (None, 7, 7, 2048) 8192 conv5_block3_3_conv[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block3_add (Add) (None, 7, 7, 2048) 0 conv5_block2_out[0][0] \n", + " conv5_block3_3_bn[0][0] \n", + "__________________________________________________________________________________________________\n", + "conv5_block3_out (Activation) (None, 7, 7, 2048) 0 conv5_block3_add[0][0] \n", + "__________________________________________________________________________________________________\n", + "avg_pool (GlobalAveragePooling2 (None, 2048) 0 conv5_block3_out[0][0] \n", + "__________________________________________________________________________________________________\n", + "probs (Dense) (None, 1000) 2049000 avg_pool[0][0] \n", + "==================================================================================================\n", + "Total params: 25,636,712\n", + "Trainable params: 25,583,592\n", + "Non-trainable params: 53,120\n", + "__________________________________________________________________________________________________\n" + ] + } + ], + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution() #May not be required depending on tensorflow version\n", + "from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2\n", + "from tensorflow import keras\n", + "from tensorflow.keras import layers\n", + "\n", + "MODEL_NAME = \"resnet50\"\n", + "model = tf.keras.applications.ResNet50(weights=\"imagenet\")\n", + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SavedModel format\n", + "The simplest way to save a model is through saved\\_model.save()\n", + "\n", + "This will create an equivalent tensorflow program which can later be loaded for fine-tuning or inference, although it is not directly compatible with MIGraphX." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:From /home/amt/anaconda3/envs/tensorflow/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "If using Keras pass *_constraint arguments to layers.\n", + "INFO:tensorflow:Assets written to: ./Saved_Models/resnet50/assets\n" + ] + } + ], + "source": [ + "tf.saved_model.save(model, \"./Saved_Models/{}\".format(MODEL_NAME))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Convert to ConcreteFunction\n", + "To begin, we need to get the function equivalent of the model and then concretize the function to avoid retracing." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "full_model = tf.function(lambda x: model(x))\n", + "full_model = full_model.get_concrete_function(\n", + " x=tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Freeze ConcreteFunction and Serialize\n", + "Since we are saving the graph for the purpose of inference, all variables can be made constant (i.e. \"frozen\").\n", + "\n", + "Next, we need to obtain a serialized GraphDef representation of the graph. \n", + "\n", + "\n", + "Optionally, the operators can be printed out layer by layer followed by the inputs and outputs." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------------------------------------------------\n", + "Frozen model layers: \n", + "x\n", + "resnet50/conv1_pad/Pad/paddings\n", + "resnet50/conv1_pad/Pad\n", + "resnet50/conv1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv1_conv/Conv2D\n", + "resnet50/conv1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv1_conv/BiasAdd\n", + "resnet50/conv1_bn/ReadVariableOp/resource\n", + "resnet50/conv1_bn/ReadVariableOp\n", + "resnet50/conv1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv1_bn/ReadVariableOp_1\n", + "resnet50/conv1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv1_bn/FusedBatchNormV3\n", + "resnet50/conv1_relu/Relu\n", + "resnet50/pool1_pad/Pad/paddings\n", + "resnet50/pool1_pad/Pad\n", + "resnet50/pool1_pool/MaxPool\n", + "resnet50/conv2_block1_0_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv2_block1_0_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv2_block1_0_conv/Conv2D\n", + "resnet50/conv2_block1_0_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv2_block1_0_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv2_block1_0_conv/BiasAdd\n", + "resnet50/conv2_block1_0_bn/ReadVariableOp/resource\n", + "resnet50/conv2_block1_0_bn/ReadVariableOp\n", + "resnet50/conv2_block1_0_bn/ReadVariableOp_1/resource\n", + "resnet50/conv2_block1_0_bn/ReadVariableOp_1\n", + "resnet50/conv2_block1_0_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv2_block1_0_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv2_block1_0_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv2_block1_0_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv2_block1_0_bn/FusedBatchNormV3\n", + "resnet50/conv2_block1_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv2_block1_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv2_block1_1_conv/Conv2D\n", + "resnet50/conv2_block1_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv2_block1_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv2_block1_1_conv/BiasAdd\n", + "resnet50/conv2_block1_1_bn/ReadVariableOp/resource\n", + "resnet50/conv2_block1_1_bn/ReadVariableOp\n", + "resnet50/conv2_block1_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv2_block1_1_bn/ReadVariableOp_1\n", + "resnet50/conv2_block1_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv2_block1_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv2_block1_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv2_block1_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv2_block1_1_bn/FusedBatchNormV3\n", + "resnet50/conv2_block1_1_relu/Relu\n", + "resnet50/conv2_block1_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv2_block1_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv2_block1_2_conv/Conv2D\n", + "resnet50/conv2_block1_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv2_block1_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv2_block1_2_conv/BiasAdd\n", + "resnet50/conv2_block1_2_bn/ReadVariableOp/resource\n", + "resnet50/conv2_block1_2_bn/ReadVariableOp\n", + "resnet50/conv2_block1_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv2_block1_2_bn/ReadVariableOp_1\n", + "resnet50/conv2_block1_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv2_block1_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv2_block1_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv2_block1_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv2_block1_2_bn/FusedBatchNormV3\n", + "resnet50/conv2_block1_2_relu/Relu\n", + "resnet50/conv2_block1_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv2_block1_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv2_block1_3_conv/Conv2D\n", + "resnet50/conv2_block1_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv2_block1_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv2_block1_3_conv/BiasAdd\n", + "resnet50/conv2_block1_3_bn/ReadVariableOp/resource\n", + "resnet50/conv2_block1_3_bn/ReadVariableOp\n", + "resnet50/conv2_block1_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv2_block1_3_bn/ReadVariableOp_1\n", + "resnet50/conv2_block1_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv2_block1_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv2_block1_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv2_block1_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv2_block1_3_bn/FusedBatchNormV3\n", + "resnet50/conv2_block1_add/add\n", + "resnet50/conv2_block1_out/Relu\n", + "resnet50/conv2_block2_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv2_block2_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv2_block2_1_conv/Conv2D\n", + "resnet50/conv2_block2_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv2_block2_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv2_block2_1_conv/BiasAdd\n", + "resnet50/conv2_block2_1_bn/ReadVariableOp/resource\n", + "resnet50/conv2_block2_1_bn/ReadVariableOp\n", + "resnet50/conv2_block2_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv2_block2_1_bn/ReadVariableOp_1\n", + "resnet50/conv2_block2_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv2_block2_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv2_block2_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv2_block2_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv2_block2_1_bn/FusedBatchNormV3\n", + "resnet50/conv2_block2_1_relu/Relu\n", + "resnet50/conv2_block2_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv2_block2_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv2_block2_2_conv/Conv2D\n", + "resnet50/conv2_block2_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv2_block2_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv2_block2_2_conv/BiasAdd\n", + "resnet50/conv2_block2_2_bn/ReadVariableOp/resource\n", + "resnet50/conv2_block2_2_bn/ReadVariableOp\n", + "resnet50/conv2_block2_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv2_block2_2_bn/ReadVariableOp_1\n", + "resnet50/conv2_block2_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv2_block2_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv2_block2_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv2_block2_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv2_block2_2_bn/FusedBatchNormV3\n", + "resnet50/conv2_block2_2_relu/Relu\n", + "resnet50/conv2_block2_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv2_block2_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv2_block2_3_conv/Conv2D\n", + "resnet50/conv2_block2_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv2_block2_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv2_block2_3_conv/BiasAdd\n", + "resnet50/conv2_block2_3_bn/ReadVariableOp/resource\n", + "resnet50/conv2_block2_3_bn/ReadVariableOp\n", + "resnet50/conv2_block2_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv2_block2_3_bn/ReadVariableOp_1\n", + "resnet50/conv2_block2_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv2_block2_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv2_block2_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv2_block2_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv2_block2_3_bn/FusedBatchNormV3\n", + "resnet50/conv2_block2_add/add\n", + "resnet50/conv2_block2_out/Relu\n", + "resnet50/conv2_block3_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv2_block3_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv2_block3_1_conv/Conv2D\n", + "resnet50/conv2_block3_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv2_block3_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv2_block3_1_conv/BiasAdd\n", + "resnet50/conv2_block3_1_bn/ReadVariableOp/resource\n", + "resnet50/conv2_block3_1_bn/ReadVariableOp\n", + "resnet50/conv2_block3_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv2_block3_1_bn/ReadVariableOp_1\n", + "resnet50/conv2_block3_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv2_block3_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv2_block3_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv2_block3_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv2_block3_1_bn/FusedBatchNormV3\n", + "resnet50/conv2_block3_1_relu/Relu\n", + "resnet50/conv2_block3_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv2_block3_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv2_block3_2_conv/Conv2D\n", + "resnet50/conv2_block3_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv2_block3_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv2_block3_2_conv/BiasAdd\n", + "resnet50/conv2_block3_2_bn/ReadVariableOp/resource\n", + "resnet50/conv2_block3_2_bn/ReadVariableOp\n", + "resnet50/conv2_block3_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv2_block3_2_bn/ReadVariableOp_1\n", + "resnet50/conv2_block3_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv2_block3_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv2_block3_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv2_block3_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv2_block3_2_bn/FusedBatchNormV3\n", + "resnet50/conv2_block3_2_relu/Relu\n", + "resnet50/conv2_block3_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv2_block3_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv2_block3_3_conv/Conv2D\n", + "resnet50/conv2_block3_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv2_block3_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv2_block3_3_conv/BiasAdd\n", + "resnet50/conv2_block3_3_bn/ReadVariableOp/resource\n", + "resnet50/conv2_block3_3_bn/ReadVariableOp\n", + "resnet50/conv2_block3_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv2_block3_3_bn/ReadVariableOp_1\n", + "resnet50/conv2_block3_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv2_block3_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv2_block3_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv2_block3_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv2_block3_3_bn/FusedBatchNormV3\n", + "resnet50/conv2_block3_add/add\n", + "resnet50/conv2_block3_out/Relu\n", + "resnet50/conv3_block1_0_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv3_block1_0_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv3_block1_0_conv/Conv2D\n", + "resnet50/conv3_block1_0_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv3_block1_0_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv3_block1_0_conv/BiasAdd\n", + "resnet50/conv3_block1_0_bn/ReadVariableOp/resource\n", + "resnet50/conv3_block1_0_bn/ReadVariableOp\n", + "resnet50/conv3_block1_0_bn/ReadVariableOp_1/resource\n", + "resnet50/conv3_block1_0_bn/ReadVariableOp_1\n", + "resnet50/conv3_block1_0_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv3_block1_0_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv3_block1_0_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv3_block1_0_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv3_block1_0_bn/FusedBatchNormV3\n", + "resnet50/conv3_block1_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv3_block1_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv3_block1_1_conv/Conv2D\n", + "resnet50/conv3_block1_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv3_block1_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv3_block1_1_conv/BiasAdd\n", + "resnet50/conv3_block1_1_bn/ReadVariableOp/resource\n", + "resnet50/conv3_block1_1_bn/ReadVariableOp\n", + "resnet50/conv3_block1_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv3_block1_1_bn/ReadVariableOp_1\n", + "resnet50/conv3_block1_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv3_block1_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv3_block1_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv3_block1_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv3_block1_1_bn/FusedBatchNormV3\n", + "resnet50/conv3_block1_1_relu/Relu\n", + "resnet50/conv3_block1_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv3_block1_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv3_block1_2_conv/Conv2D\n", + "resnet50/conv3_block1_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv3_block1_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv3_block1_2_conv/BiasAdd\n", + "resnet50/conv3_block1_2_bn/ReadVariableOp/resource\n", + "resnet50/conv3_block1_2_bn/ReadVariableOp\n", + "resnet50/conv3_block1_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv3_block1_2_bn/ReadVariableOp_1\n", + "resnet50/conv3_block1_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv3_block1_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv3_block1_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv3_block1_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv3_block1_2_bn/FusedBatchNormV3\n", + "resnet50/conv3_block1_2_relu/Relu\n", + "resnet50/conv3_block1_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv3_block1_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv3_block1_3_conv/Conv2D\n", + "resnet50/conv3_block1_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv3_block1_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv3_block1_3_conv/BiasAdd\n", + "resnet50/conv3_block1_3_bn/ReadVariableOp/resource\n", + "resnet50/conv3_block1_3_bn/ReadVariableOp\n", + "resnet50/conv3_block1_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv3_block1_3_bn/ReadVariableOp_1\n", + "resnet50/conv3_block1_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv3_block1_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv3_block1_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv3_block1_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv3_block1_3_bn/FusedBatchNormV3\n", + "resnet50/conv3_block1_add/add\n", + "resnet50/conv3_block1_out/Relu\n", + "resnet50/conv3_block2_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv3_block2_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv3_block2_1_conv/Conv2D\n", + "resnet50/conv3_block2_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv3_block2_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv3_block2_1_conv/BiasAdd\n", + "resnet50/conv3_block2_1_bn/ReadVariableOp/resource\n", + "resnet50/conv3_block2_1_bn/ReadVariableOp\n", + "resnet50/conv3_block2_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv3_block2_1_bn/ReadVariableOp_1\n", + "resnet50/conv3_block2_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv3_block2_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv3_block2_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv3_block2_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv3_block2_1_bn/FusedBatchNormV3\n", + "resnet50/conv3_block2_1_relu/Relu\n", + "resnet50/conv3_block2_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv3_block2_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv3_block2_2_conv/Conv2D\n", + "resnet50/conv3_block2_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv3_block2_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv3_block2_2_conv/BiasAdd\n", + "resnet50/conv3_block2_2_bn/ReadVariableOp/resource\n", + "resnet50/conv3_block2_2_bn/ReadVariableOp\n", + "resnet50/conv3_block2_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv3_block2_2_bn/ReadVariableOp_1\n", + "resnet50/conv3_block2_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv3_block2_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv3_block2_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv3_block2_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv3_block2_2_bn/FusedBatchNormV3\n", + "resnet50/conv3_block2_2_relu/Relu\n", + "resnet50/conv3_block2_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv3_block2_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv3_block2_3_conv/Conv2D\n", + "resnet50/conv3_block2_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv3_block2_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv3_block2_3_conv/BiasAdd\n", + "resnet50/conv3_block2_3_bn/ReadVariableOp/resource\n", + "resnet50/conv3_block2_3_bn/ReadVariableOp\n", + "resnet50/conv3_block2_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv3_block2_3_bn/ReadVariableOp_1\n", + "resnet50/conv3_block2_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv3_block2_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv3_block2_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv3_block2_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv3_block2_3_bn/FusedBatchNormV3\n", + "resnet50/conv3_block2_add/add\n", + "resnet50/conv3_block2_out/Relu\n", + "resnet50/conv3_block3_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv3_block3_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv3_block3_1_conv/Conv2D\n", + "resnet50/conv3_block3_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv3_block3_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv3_block3_1_conv/BiasAdd\n", + "resnet50/conv3_block3_1_bn/ReadVariableOp/resource\n", + "resnet50/conv3_block3_1_bn/ReadVariableOp\n", + "resnet50/conv3_block3_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv3_block3_1_bn/ReadVariableOp_1\n", + "resnet50/conv3_block3_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv3_block3_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv3_block3_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv3_block3_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv3_block3_1_bn/FusedBatchNormV3\n", + "resnet50/conv3_block3_1_relu/Relu\n", + "resnet50/conv3_block3_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv3_block3_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv3_block3_2_conv/Conv2D\n", + "resnet50/conv3_block3_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv3_block3_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv3_block3_2_conv/BiasAdd\n", + "resnet50/conv3_block3_2_bn/ReadVariableOp/resource\n", + "resnet50/conv3_block3_2_bn/ReadVariableOp\n", + "resnet50/conv3_block3_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv3_block3_2_bn/ReadVariableOp_1\n", + "resnet50/conv3_block3_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv3_block3_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv3_block3_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv3_block3_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv3_block3_2_bn/FusedBatchNormV3\n", + "resnet50/conv3_block3_2_relu/Relu\n", + "resnet50/conv3_block3_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv3_block3_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv3_block3_3_conv/Conv2D\n", + "resnet50/conv3_block3_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv3_block3_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv3_block3_3_conv/BiasAdd\n", + "resnet50/conv3_block3_3_bn/ReadVariableOp/resource\n", + "resnet50/conv3_block3_3_bn/ReadVariableOp\n", + "resnet50/conv3_block3_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv3_block3_3_bn/ReadVariableOp_1\n", + "resnet50/conv3_block3_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv3_block3_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv3_block3_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv3_block3_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv3_block3_3_bn/FusedBatchNormV3\n", + "resnet50/conv3_block3_add/add\n", + "resnet50/conv3_block3_out/Relu\n", + "resnet50/conv3_block4_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv3_block4_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv3_block4_1_conv/Conv2D\n", + "resnet50/conv3_block4_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv3_block4_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv3_block4_1_conv/BiasAdd\n", + "resnet50/conv3_block4_1_bn/ReadVariableOp/resource\n", + "resnet50/conv3_block4_1_bn/ReadVariableOp\n", + "resnet50/conv3_block4_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv3_block4_1_bn/ReadVariableOp_1\n", + "resnet50/conv3_block4_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv3_block4_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv3_block4_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv3_block4_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv3_block4_1_bn/FusedBatchNormV3\n", + "resnet50/conv3_block4_1_relu/Relu\n", + "resnet50/conv3_block4_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv3_block4_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv3_block4_2_conv/Conv2D\n", + "resnet50/conv3_block4_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv3_block4_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv3_block4_2_conv/BiasAdd\n", + "resnet50/conv3_block4_2_bn/ReadVariableOp/resource\n", + "resnet50/conv3_block4_2_bn/ReadVariableOp\n", + "resnet50/conv3_block4_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv3_block4_2_bn/ReadVariableOp_1\n", + "resnet50/conv3_block4_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv3_block4_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv3_block4_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv3_block4_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv3_block4_2_bn/FusedBatchNormV3\n", + "resnet50/conv3_block4_2_relu/Relu\n", + "resnet50/conv3_block4_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv3_block4_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv3_block4_3_conv/Conv2D\n", + "resnet50/conv3_block4_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv3_block4_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv3_block4_3_conv/BiasAdd\n", + "resnet50/conv3_block4_3_bn/ReadVariableOp/resource\n", + "resnet50/conv3_block4_3_bn/ReadVariableOp\n", + "resnet50/conv3_block4_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv3_block4_3_bn/ReadVariableOp_1\n", + "resnet50/conv3_block4_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv3_block4_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv3_block4_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv3_block4_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv3_block4_3_bn/FusedBatchNormV3\n", + "resnet50/conv3_block4_add/add\n", + "resnet50/conv3_block4_out/Relu\n", + "resnet50/conv4_block1_0_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block1_0_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block1_0_conv/Conv2D\n", + "resnet50/conv4_block1_0_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block1_0_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block1_0_conv/BiasAdd\n", + "resnet50/conv4_block1_0_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block1_0_bn/ReadVariableOp\n", + "resnet50/conv4_block1_0_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block1_0_bn/ReadVariableOp_1\n", + "resnet50/conv4_block1_0_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block1_0_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block1_0_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block1_0_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block1_0_bn/FusedBatchNormV3\n", + "resnet50/conv4_block1_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block1_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block1_1_conv/Conv2D\n", + "resnet50/conv4_block1_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block1_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block1_1_conv/BiasAdd\n", + "resnet50/conv4_block1_1_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block1_1_bn/ReadVariableOp\n", + "resnet50/conv4_block1_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block1_1_bn/ReadVariableOp_1\n", + "resnet50/conv4_block1_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block1_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block1_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block1_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block1_1_bn/FusedBatchNormV3\n", + "resnet50/conv4_block1_1_relu/Relu\n", + "resnet50/conv4_block1_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block1_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block1_2_conv/Conv2D\n", + "resnet50/conv4_block1_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block1_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block1_2_conv/BiasAdd\n", + "resnet50/conv4_block1_2_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block1_2_bn/ReadVariableOp\n", + "resnet50/conv4_block1_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block1_2_bn/ReadVariableOp_1\n", + "resnet50/conv4_block1_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block1_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block1_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block1_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block1_2_bn/FusedBatchNormV3\n", + "resnet50/conv4_block1_2_relu/Relu\n", + "resnet50/conv4_block1_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block1_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block1_3_conv/Conv2D\n", + "resnet50/conv4_block1_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block1_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block1_3_conv/BiasAdd\n", + "resnet50/conv4_block1_3_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block1_3_bn/ReadVariableOp\n", + "resnet50/conv4_block1_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block1_3_bn/ReadVariableOp_1\n", + "resnet50/conv4_block1_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block1_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block1_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block1_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block1_3_bn/FusedBatchNormV3\n", + "resnet50/conv4_block1_add/add\n", + "resnet50/conv4_block1_out/Relu\n", + "resnet50/conv4_block2_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block2_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block2_1_conv/Conv2D\n", + "resnet50/conv4_block2_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block2_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block2_1_conv/BiasAdd\n", + "resnet50/conv4_block2_1_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block2_1_bn/ReadVariableOp\n", + "resnet50/conv4_block2_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block2_1_bn/ReadVariableOp_1\n", + "resnet50/conv4_block2_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block2_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block2_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block2_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block2_1_bn/FusedBatchNormV3\n", + "resnet50/conv4_block2_1_relu/Relu\n", + "resnet50/conv4_block2_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block2_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block2_2_conv/Conv2D\n", + "resnet50/conv4_block2_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block2_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block2_2_conv/BiasAdd\n", + "resnet50/conv4_block2_2_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block2_2_bn/ReadVariableOp\n", + "resnet50/conv4_block2_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block2_2_bn/ReadVariableOp_1\n", + "resnet50/conv4_block2_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block2_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block2_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block2_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block2_2_bn/FusedBatchNormV3\n", + "resnet50/conv4_block2_2_relu/Relu\n", + "resnet50/conv4_block2_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block2_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block2_3_conv/Conv2D\n", + "resnet50/conv4_block2_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block2_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block2_3_conv/BiasAdd\n", + "resnet50/conv4_block2_3_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block2_3_bn/ReadVariableOp\n", + "resnet50/conv4_block2_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block2_3_bn/ReadVariableOp_1\n", + "resnet50/conv4_block2_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block2_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block2_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block2_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block2_3_bn/FusedBatchNormV3\n", + "resnet50/conv4_block2_add/add\n", + "resnet50/conv4_block2_out/Relu\n", + "resnet50/conv4_block3_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block3_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block3_1_conv/Conv2D\n", + "resnet50/conv4_block3_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block3_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block3_1_conv/BiasAdd\n", + "resnet50/conv4_block3_1_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block3_1_bn/ReadVariableOp\n", + "resnet50/conv4_block3_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block3_1_bn/ReadVariableOp_1\n", + "resnet50/conv4_block3_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block3_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block3_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block3_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block3_1_bn/FusedBatchNormV3\n", + "resnet50/conv4_block3_1_relu/Relu\n", + "resnet50/conv4_block3_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block3_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block3_2_conv/Conv2D\n", + "resnet50/conv4_block3_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block3_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block3_2_conv/BiasAdd\n", + "resnet50/conv4_block3_2_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block3_2_bn/ReadVariableOp\n", + "resnet50/conv4_block3_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block3_2_bn/ReadVariableOp_1\n", + "resnet50/conv4_block3_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block3_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block3_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block3_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block3_2_bn/FusedBatchNormV3\n", + "resnet50/conv4_block3_2_relu/Relu\n", + "resnet50/conv4_block3_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block3_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block3_3_conv/Conv2D\n", + "resnet50/conv4_block3_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block3_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block3_3_conv/BiasAdd\n", + "resnet50/conv4_block3_3_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block3_3_bn/ReadVariableOp\n", + "resnet50/conv4_block3_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block3_3_bn/ReadVariableOp_1\n", + "resnet50/conv4_block3_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block3_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block3_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block3_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block3_3_bn/FusedBatchNormV3\n", + "resnet50/conv4_block3_add/add\n", + "resnet50/conv4_block3_out/Relu\n", + "resnet50/conv4_block4_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block4_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block4_1_conv/Conv2D\n", + "resnet50/conv4_block4_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block4_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block4_1_conv/BiasAdd\n", + "resnet50/conv4_block4_1_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block4_1_bn/ReadVariableOp\n", + "resnet50/conv4_block4_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block4_1_bn/ReadVariableOp_1\n", + "resnet50/conv4_block4_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block4_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block4_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block4_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block4_1_bn/FusedBatchNormV3\n", + "resnet50/conv4_block4_1_relu/Relu\n", + "resnet50/conv4_block4_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block4_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block4_2_conv/Conv2D\n", + "resnet50/conv4_block4_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block4_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block4_2_conv/BiasAdd\n", + "resnet50/conv4_block4_2_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block4_2_bn/ReadVariableOp\n", + "resnet50/conv4_block4_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block4_2_bn/ReadVariableOp_1\n", + "resnet50/conv4_block4_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block4_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block4_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block4_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block4_2_bn/FusedBatchNormV3\n", + "resnet50/conv4_block4_2_relu/Relu\n", + "resnet50/conv4_block4_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block4_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block4_3_conv/Conv2D\n", + "resnet50/conv4_block4_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block4_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block4_3_conv/BiasAdd\n", + "resnet50/conv4_block4_3_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block4_3_bn/ReadVariableOp\n", + "resnet50/conv4_block4_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block4_3_bn/ReadVariableOp_1\n", + "resnet50/conv4_block4_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block4_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block4_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block4_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block4_3_bn/FusedBatchNormV3\n", + "resnet50/conv4_block4_add/add\n", + "resnet50/conv4_block4_out/Relu\n", + "resnet50/conv4_block5_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block5_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block5_1_conv/Conv2D\n", + "resnet50/conv4_block5_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block5_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block5_1_conv/BiasAdd\n", + "resnet50/conv4_block5_1_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block5_1_bn/ReadVariableOp\n", + "resnet50/conv4_block5_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block5_1_bn/ReadVariableOp_1\n", + "resnet50/conv4_block5_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block5_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block5_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block5_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block5_1_bn/FusedBatchNormV3\n", + "resnet50/conv4_block5_1_relu/Relu\n", + "resnet50/conv4_block5_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block5_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block5_2_conv/Conv2D\n", + "resnet50/conv4_block5_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block5_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block5_2_conv/BiasAdd\n", + "resnet50/conv4_block5_2_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block5_2_bn/ReadVariableOp\n", + "resnet50/conv4_block5_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block5_2_bn/ReadVariableOp_1\n", + "resnet50/conv4_block5_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block5_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block5_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block5_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block5_2_bn/FusedBatchNormV3\n", + "resnet50/conv4_block5_2_relu/Relu\n", + "resnet50/conv4_block5_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block5_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block5_3_conv/Conv2D\n", + "resnet50/conv4_block5_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block5_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block5_3_conv/BiasAdd\n", + "resnet50/conv4_block5_3_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block5_3_bn/ReadVariableOp\n", + "resnet50/conv4_block5_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block5_3_bn/ReadVariableOp_1\n", + "resnet50/conv4_block5_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block5_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block5_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block5_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block5_3_bn/FusedBatchNormV3\n", + "resnet50/conv4_block5_add/add\n", + "resnet50/conv4_block5_out/Relu\n", + "resnet50/conv4_block6_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block6_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block6_1_conv/Conv2D\n", + "resnet50/conv4_block6_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block6_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block6_1_conv/BiasAdd\n", + "resnet50/conv4_block6_1_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block6_1_bn/ReadVariableOp\n", + "resnet50/conv4_block6_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block6_1_bn/ReadVariableOp_1\n", + "resnet50/conv4_block6_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block6_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block6_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block6_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block6_1_bn/FusedBatchNormV3\n", + "resnet50/conv4_block6_1_relu/Relu\n", + "resnet50/conv4_block6_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block6_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block6_2_conv/Conv2D\n", + "resnet50/conv4_block6_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block6_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block6_2_conv/BiasAdd\n", + "resnet50/conv4_block6_2_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block6_2_bn/ReadVariableOp\n", + "resnet50/conv4_block6_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block6_2_bn/ReadVariableOp_1\n", + "resnet50/conv4_block6_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block6_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block6_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block6_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block6_2_bn/FusedBatchNormV3\n", + "resnet50/conv4_block6_2_relu/Relu\n", + "resnet50/conv4_block6_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv4_block6_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv4_block6_3_conv/Conv2D\n", + "resnet50/conv4_block6_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv4_block6_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv4_block6_3_conv/BiasAdd\n", + "resnet50/conv4_block6_3_bn/ReadVariableOp/resource\n", + "resnet50/conv4_block6_3_bn/ReadVariableOp\n", + "resnet50/conv4_block6_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv4_block6_3_bn/ReadVariableOp_1\n", + "resnet50/conv4_block6_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv4_block6_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv4_block6_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv4_block6_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv4_block6_3_bn/FusedBatchNormV3\n", + "resnet50/conv4_block6_add/add\n", + "resnet50/conv4_block6_out/Relu\n", + "resnet50/conv5_block1_0_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv5_block1_0_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv5_block1_0_conv/Conv2D\n", + "resnet50/conv5_block1_0_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv5_block1_0_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv5_block1_0_conv/BiasAdd\n", + "resnet50/conv5_block1_0_bn/ReadVariableOp/resource\n", + "resnet50/conv5_block1_0_bn/ReadVariableOp\n", + "resnet50/conv5_block1_0_bn/ReadVariableOp_1/resource\n", + "resnet50/conv5_block1_0_bn/ReadVariableOp_1\n", + "resnet50/conv5_block1_0_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv5_block1_0_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv5_block1_0_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv5_block1_0_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv5_block1_0_bn/FusedBatchNormV3\n", + "resnet50/conv5_block1_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv5_block1_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv5_block1_1_conv/Conv2D\n", + "resnet50/conv5_block1_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv5_block1_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv5_block1_1_conv/BiasAdd\n", + "resnet50/conv5_block1_1_bn/ReadVariableOp/resource\n", + "resnet50/conv5_block1_1_bn/ReadVariableOp\n", + "resnet50/conv5_block1_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv5_block1_1_bn/ReadVariableOp_1\n", + "resnet50/conv5_block1_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv5_block1_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv5_block1_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv5_block1_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv5_block1_1_bn/FusedBatchNormV3\n", + "resnet50/conv5_block1_1_relu/Relu\n", + "resnet50/conv5_block1_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv5_block1_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv5_block1_2_conv/Conv2D\n", + "resnet50/conv5_block1_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv5_block1_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv5_block1_2_conv/BiasAdd\n", + "resnet50/conv5_block1_2_bn/ReadVariableOp/resource\n", + "resnet50/conv5_block1_2_bn/ReadVariableOp\n", + "resnet50/conv5_block1_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv5_block1_2_bn/ReadVariableOp_1\n", + "resnet50/conv5_block1_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv5_block1_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv5_block1_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv5_block1_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv5_block1_2_bn/FusedBatchNormV3\n", + "resnet50/conv5_block1_2_relu/Relu\n", + "resnet50/conv5_block1_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv5_block1_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv5_block1_3_conv/Conv2D\n", + "resnet50/conv5_block1_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv5_block1_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv5_block1_3_conv/BiasAdd\n", + "resnet50/conv5_block1_3_bn/ReadVariableOp/resource\n", + "resnet50/conv5_block1_3_bn/ReadVariableOp\n", + "resnet50/conv5_block1_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv5_block1_3_bn/ReadVariableOp_1\n", + "resnet50/conv5_block1_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv5_block1_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv5_block1_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv5_block1_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv5_block1_3_bn/FusedBatchNormV3\n", + "resnet50/conv5_block1_add/add\n", + "resnet50/conv5_block1_out/Relu\n", + "resnet50/conv5_block2_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv5_block2_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv5_block2_1_conv/Conv2D\n", + "resnet50/conv5_block2_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv5_block2_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv5_block2_1_conv/BiasAdd\n", + "resnet50/conv5_block2_1_bn/ReadVariableOp/resource\n", + "resnet50/conv5_block2_1_bn/ReadVariableOp\n", + "resnet50/conv5_block2_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv5_block2_1_bn/ReadVariableOp_1\n", + "resnet50/conv5_block2_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv5_block2_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv5_block2_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv5_block2_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv5_block2_1_bn/FusedBatchNormV3\n", + "resnet50/conv5_block2_1_relu/Relu\n", + "resnet50/conv5_block2_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv5_block2_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv5_block2_2_conv/Conv2D\n", + "resnet50/conv5_block2_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv5_block2_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv5_block2_2_conv/BiasAdd\n", + "resnet50/conv5_block2_2_bn/ReadVariableOp/resource\n", + "resnet50/conv5_block2_2_bn/ReadVariableOp\n", + "resnet50/conv5_block2_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv5_block2_2_bn/ReadVariableOp_1\n", + "resnet50/conv5_block2_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv5_block2_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv5_block2_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv5_block2_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv5_block2_2_bn/FusedBatchNormV3\n", + "resnet50/conv5_block2_2_relu/Relu\n", + "resnet50/conv5_block2_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv5_block2_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv5_block2_3_conv/Conv2D\n", + "resnet50/conv5_block2_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv5_block2_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv5_block2_3_conv/BiasAdd\n", + "resnet50/conv5_block2_3_bn/ReadVariableOp/resource\n", + "resnet50/conv5_block2_3_bn/ReadVariableOp\n", + "resnet50/conv5_block2_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv5_block2_3_bn/ReadVariableOp_1\n", + "resnet50/conv5_block2_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv5_block2_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv5_block2_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv5_block2_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv5_block2_3_bn/FusedBatchNormV3\n", + "resnet50/conv5_block2_add/add\n", + "resnet50/conv5_block2_out/Relu\n", + "resnet50/conv5_block3_1_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv5_block3_1_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv5_block3_1_conv/Conv2D\n", + "resnet50/conv5_block3_1_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv5_block3_1_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv5_block3_1_conv/BiasAdd\n", + "resnet50/conv5_block3_1_bn/ReadVariableOp/resource\n", + "resnet50/conv5_block3_1_bn/ReadVariableOp\n", + "resnet50/conv5_block3_1_bn/ReadVariableOp_1/resource\n", + "resnet50/conv5_block3_1_bn/ReadVariableOp_1\n", + "resnet50/conv5_block3_1_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv5_block3_1_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv5_block3_1_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv5_block3_1_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv5_block3_1_bn/FusedBatchNormV3\n", + "resnet50/conv5_block3_1_relu/Relu\n", + "resnet50/conv5_block3_2_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv5_block3_2_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv5_block3_2_conv/Conv2D\n", + "resnet50/conv5_block3_2_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv5_block3_2_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv5_block3_2_conv/BiasAdd\n", + "resnet50/conv5_block3_2_bn/ReadVariableOp/resource\n", + "resnet50/conv5_block3_2_bn/ReadVariableOp\n", + "resnet50/conv5_block3_2_bn/ReadVariableOp_1/resource\n", + "resnet50/conv5_block3_2_bn/ReadVariableOp_1\n", + "resnet50/conv5_block3_2_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv5_block3_2_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv5_block3_2_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv5_block3_2_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv5_block3_2_bn/FusedBatchNormV3\n", + "resnet50/conv5_block3_2_relu/Relu\n", + "resnet50/conv5_block3_3_conv/Conv2D/ReadVariableOp/resource\n", + "resnet50/conv5_block3_3_conv/Conv2D/ReadVariableOp\n", + "resnet50/conv5_block3_3_conv/Conv2D\n", + "resnet50/conv5_block3_3_conv/BiasAdd/ReadVariableOp/resource\n", + "resnet50/conv5_block3_3_conv/BiasAdd/ReadVariableOp\n", + "resnet50/conv5_block3_3_conv/BiasAdd\n", + "resnet50/conv5_block3_3_bn/ReadVariableOp/resource\n", + "resnet50/conv5_block3_3_bn/ReadVariableOp\n", + "resnet50/conv5_block3_3_bn/ReadVariableOp_1/resource\n", + "resnet50/conv5_block3_3_bn/ReadVariableOp_1\n", + "resnet50/conv5_block3_3_bn/FusedBatchNormV3/ReadVariableOp/resource\n", + "resnet50/conv5_block3_3_bn/FusedBatchNormV3/ReadVariableOp\n", + "resnet50/conv5_block3_3_bn/FusedBatchNormV3/ReadVariableOp_1/resource\n", + "resnet50/conv5_block3_3_bn/FusedBatchNormV3/ReadVariableOp_1\n", + "resnet50/conv5_block3_3_bn/FusedBatchNormV3\n", + "resnet50/conv5_block3_add/add\n", + "resnet50/conv5_block3_out/Relu\n", + "resnet50/avg_pool/Mean/reduction_indices\n", + "resnet50/avg_pool/Mean\n", + "resnet50/probs/MatMul/ReadVariableOp/resource\n", + "resnet50/probs/MatMul/ReadVariableOp\n", + "resnet50/probs/MatMul\n", + "resnet50/probs/BiasAdd/ReadVariableOp/resource\n", + "resnet50/probs/BiasAdd/ReadVariableOp\n", + "resnet50/probs/BiasAdd\n", + "resnet50/probs/Softmax\n", + "Identity\n", + "--------------------------------------------------\n", + "Frozen model inputs: \n", + "[]\n", + "Frozen model outputs: \n", + "[]\n" + ] + } + ], + "source": [ + "frozen_func = convert_variables_to_constants_v2(full_model)\n", + "frozen_func.graph.as_graph_def()\n", + "\n", + "layers = [op.name for op in frozen_func.graph.get_operations()]\n", + "print(\"-\" * 50)\n", + "print(\"Frozen model layers: \")\n", + "for layer in layers:\n", + " print(layer)\n", + "\n", + "print(\"-\" * 50)\n", + "print(\"Frozen model inputs: \")\n", + "print(frozen_func.inputs)\n", + "print(\"Frozen model outputs: \")\n", + "print(frozen_func.outputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save Frozen Graph as Protobuf\n", + "Finally, we can save to hard drive, and now the frozen graph will be stored as `./frozen_models/_frozen_graph.pb`" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'./frozen_models/resnet50_frozen_graph.pb'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tf.io.write_graph(graph_or_graph_def=frozen_func.graph,\n", + " logdir=\"./frozen_models\",\n", + " name=\"{}_frozen_graph.pb\".format(MODEL_NAME),\n", + " as_text=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Assuming MIGraphX has already been built and installed on your system, the driver can be used to verify that the frozen graph has been correctly exported. " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reading: ./frozen_models/resnet50_frozen_graph.pb\n", + "@0 = @literal{ ... } -> float_type, {1000}, {1}\n", + "@1 = @literal{ ... } -> float_type, {2048, 1000}, {1000, 1}\n", + "@2 = @literal{1, 2} -> int32_type, {2}, {1}\n", + "@3 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@4 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@5 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@6 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@7 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@8 = @literal{ ... } -> float_type, {1, 1, 512, 2048}, {1048576, 1048576, 2048, 1}\n", + "@9 = @literal{ ... } -> float_type, {512}, {1}\n", + "@10 = @literal{ ... } -> float_type, {512}, {1}\n", + "@11 = @literal{ ... } -> float_type, {512}, {1}\n", + "@12 = @literal{ ... } -> float_type, {512}, {1}\n", + "@13 = @literal{ ... } -> float_type, {512}, {1}\n", + "@14 = @literal{ ... } -> float_type, {3, 3, 512, 512}, {786432, 262144, 512, 1}\n", + "@15 = @literal{ ... } -> float_type, {512}, {1}\n", + "@16 = @literal{ ... } -> float_type, {512}, {1}\n", + "@17 = @literal{ ... } -> float_type, {512}, {1}\n", + "@18 = @literal{ ... } -> float_type, {512}, {1}\n", + "@19 = @literal{ ... } -> float_type, {512}, {1}\n", + "@20 = @literal{ ... } -> float_type, {1, 1, 2048, 512}, {1048576, 1048576, 512, 1}\n", + "@21 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@22 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@23 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@24 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@25 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@26 = @literal{ ... } -> float_type, {1, 1, 512, 2048}, {1048576, 1048576, 2048, 1}\n", + "@27 = @literal{ ... } -> float_type, {512}, {1}\n", + "@28 = @literal{ ... } -> float_type, {512}, {1}\n", + "@29 = @literal{ ... } -> float_type, {512}, {1}\n", + "@30 = @literal{ ... } -> float_type, {512}, {1}\n", + "@31 = @literal{ ... } -> float_type, {512}, {1}\n", + "@32 = @literal{ ... } -> float_type, {3, 3, 512, 512}, {786432, 262144, 512, 1}\n", + "@33 = @literal{ ... } -> float_type, {512}, {1}\n", + "@34 = @literal{ ... } -> float_type, {512}, {1}\n", + "@35 = @literal{ ... } -> float_type, {512}, {1}\n", + "@36 = @literal{ ... } -> float_type, {512}, {1}\n", + "@37 = @literal{ ... } -> float_type, {512}, {1}\n", + "@38 = @literal{ ... } -> float_type, {1, 1, 2048, 512}, {1048576, 1048576, 512, 1}\n", + "@39 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@40 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@41 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@42 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@43 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@44 = @literal{ ... } -> float_type, {1, 1, 512, 2048}, {1048576, 1048576, 2048, 1}\n", + "@45 = @literal{ ... } -> float_type, {512}, {1}\n", + "@46 = @literal{ ... } -> float_type, {512}, {1}\n", + "@47 = @literal{ ... } -> float_type, {512}, {1}\n", + "@48 = @literal{ ... } -> float_type, {512}, {1}\n", + "@49 = @literal{ ... } -> float_type, {512}, {1}\n", + "@50 = @literal{ ... } -> float_type, {3, 3, 512, 512}, {786432, 262144, 512, 1}\n", + "@51 = @literal{ ... } -> float_type, {512}, {1}\n", + "@52 = @literal{ ... } -> float_type, {512}, {1}\n", + "@53 = @literal{ ... } -> float_type, {512}, {1}\n", + "@54 = @literal{ ... } -> float_type, {512}, {1}\n", + "@55 = @literal{ ... } -> float_type, {512}, {1}\n", + "@56 = @literal{ ... } -> float_type, {1, 1, 1024, 512}, {524288, 524288, 512, 1}\n", + "@57 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@58 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@59 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@60 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@61 = @literal{ ... } -> float_type, {2048}, {1}\n", + "@62 = @literal{ ... } -> float_type, {1, 1, 1024, 2048}, {2097152, 2097152, 2048, 1}\n", + "@63 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@64 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@65 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@66 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@67 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@68 = @literal{ ... } -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@69 = @literal{ ... } -> float_type, {256}, {1}\n", + "@70 = @literal{ ... } -> float_type, {256}, {1}\n", + "@71 = @literal{ ... } -> float_type, {256}, {1}\n", + "@72 = @literal{ ... } -> float_type, {256}, {1}\n", + "@73 = @literal{ ... } -> float_type, {256}, {1}\n", + "@74 = @literal{ ... } -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@75 = @literal{ ... } -> float_type, {256}, {1}\n", + "@76 = @literal{ ... } -> float_type, {256}, {1}\n", + "@77 = @literal{ ... } -> float_type, {256}, {1}\n", + "@78 = @literal{ ... } -> float_type, {256}, {1}\n", + "@79 = @literal{ ... } -> float_type, {256}, {1}\n", + "@80 = @literal{ ... } -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@81 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@82 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@83 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@84 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@85 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@86 = @literal{ ... } -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@87 = @literal{ ... } -> float_type, {256}, {1}\n", + "@88 = @literal{ ... } -> float_type, {256}, {1}\n", + "@89 = @literal{ ... } -> float_type, {256}, {1}\n", + "@90 = @literal{ ... } -> float_type, {256}, {1}\n", + "@91 = @literal{ ... } -> float_type, {256}, {1}\n", + "@92 = @literal{ ... } -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@93 = @literal{ ... } -> float_type, {256}, {1}\n", + "@94 = @literal{ ... } -> float_type, {256}, {1}\n", + "@95 = @literal{ ... } -> float_type, {256}, {1}\n", + "@96 = @literal{ ... } -> float_type, {256}, {1}\n", + "@97 = @literal{ ... } -> float_type, {256}, {1}\n", + "@98 = @literal{ ... } -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@99 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@100 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@101 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@102 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@103 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@104 = @literal{ ... } -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@105 = @literal{ ... } -> float_type, {256}, {1}\n", + "@106 = @literal{ ... } -> float_type, {256}, {1}\n", + "@107 = @literal{ ... } -> float_type, {256}, {1}\n", + "@108 = @literal{ ... } -> float_type, {256}, {1}\n", + "@109 = @literal{ ... } -> float_type, {256}, {1}\n", + "@110 = @literal{ ... } -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@111 = @literal{ ... } -> float_type, {256}, {1}\n", + "@112 = @literal{ ... } -> float_type, {256}, {1}\n", + "@113 = @literal{ ... } -> float_type, {256}, {1}\n", + "@114 = @literal{ ... } -> float_type, {256}, {1}\n", + "@115 = @literal{ ... } -> float_type, {256}, {1}\n", + "@116 = @literal{ ... } -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@117 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@118 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@119 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@120 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@121 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@122 = @literal{ ... } -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@123 = @literal{ ... } -> float_type, {256}, {1}\n", + "@124 = @literal{ ... } -> float_type, {256}, {1}\n", + "@125 = @literal{ ... } -> float_type, {256}, {1}\n", + "@126 = @literal{ ... } -> float_type, {256}, {1}\n", + "@127 = @literal{ ... } -> float_type, {256}, {1}\n", + "@128 = @literal{ ... } -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@129 = @literal{ ... } -> float_type, {256}, {1}\n", + "@130 = @literal{ ... } -> float_type, {256}, {1}\n", + "@131 = @literal{ ... } -> float_type, {256}, {1}\n", + "@132 = @literal{ ... } -> float_type, {256}, {1}\n", + "@133 = @literal{ ... } -> float_type, {256}, {1}\n", + "@134 = @literal{ ... } -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@135 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@136 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@137 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@138 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@139 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@140 = @literal{ ... } -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@141 = @literal{ ... } -> float_type, {256}, {1}\n", + "@142 = @literal{ ... } -> float_type, {256}, {1}\n", + "@143 = @literal{ ... } -> float_type, {256}, {1}\n", + "@144 = @literal{ ... } -> float_type, {256}, {1}\n", + "@145 = @literal{ ... } -> float_type, {256}, {1}\n", + "@146 = @literal{ ... } -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@147 = @literal{ ... } -> float_type, {256}, {1}\n", + "@148 = @literal{ ... } -> float_type, {256}, {1}\n", + "@149 = @literal{ ... } -> float_type, {256}, {1}\n", + "@150 = @literal{ ... } -> float_type, {256}, {1}\n", + "@151 = @literal{ ... } -> float_type, {256}, {1}\n", + "@152 = @literal{ ... } -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@153 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@154 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@155 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@156 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@157 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@158 = @literal{ ... } -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@159 = @literal{ ... } -> float_type, {256}, {1}\n", + "@160 = @literal{ ... } -> float_type, {256}, {1}\n", + "@161 = @literal{ ... } -> float_type, {256}, {1}\n", + "@162 = @literal{ ... } -> float_type, {256}, {1}\n", + "@163 = @literal{ ... } -> float_type, {256}, {1}\n", + "@164 = @literal{ ... } -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@165 = @literal{ ... } -> float_type, {256}, {1}\n", + "@166 = @literal{ ... } -> float_type, {256}, {1}\n", + "@167 = @literal{ ... } -> float_type, {256}, {1}\n", + "@168 = @literal{ ... } -> float_type, {256}, {1}\n", + "@169 = @literal{ ... } -> float_type, {256}, {1}\n", + "@170 = @literal{ ... } -> float_type, {1, 1, 512, 256}, {131072, 131072, 256, 1}\n", + "@171 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@172 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@173 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@174 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@175 = @literal{ ... } -> float_type, {1024}, {1}\n", + "@176 = @literal{ ... } -> float_type, {1, 1, 512, 1024}, {524288, 524288, 1024, 1}\n", + "@177 = @literal{ ... } -> float_type, {512}, {1}\n", + "@178 = @literal{ ... } -> float_type, {512}, {1}\n", + "@179 = @literal{ ... } -> float_type, {512}, {1}\n", + "@180 = @literal{ ... } -> float_type, {512}, {1}\n", + "@181 = @literal{ ... } -> float_type, {512}, {1}\n", + "@182 = @literal{ ... } -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@183 = @literal{ ... } -> float_type, {128}, {1}\n", + "@184 = @literal{ ... } -> float_type, {128}, {1}\n", + "@185 = @literal{ ... } -> float_type, {128}, {1}\n", + "@186 = @literal{ ... } -> float_type, {128}, {1}\n", + "@187 = @literal{ ... } -> float_type, {128}, {1}\n", + "@188 = @literal{ ... } -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@189 = @literal{ ... } -> float_type, {128}, {1}\n", + "@190 = @literal{ ... } -> float_type, {128}, {1}\n", + "@191 = @literal{ ... } -> float_type, {128}, {1}\n", + "@192 = @literal{ ... } -> float_type, {128}, {1}\n", + "@193 = @literal{ ... } -> float_type, {128}, {1}\n", + "@194 = @literal{ ... } -> float_type, {1, 1, 512, 128}, {65536, 65536, 128, 1}\n", + "@195 = @literal{ ... } -> float_type, {512}, {1}\n", + "@196 = @literal{ ... } -> float_type, {512}, {1}\n", + "@197 = @literal{ ... } -> float_type, {512}, {1}\n", + "@198 = @literal{ ... } -> float_type, {512}, {1}\n", + "@199 = @literal{ ... } -> float_type, {512}, {1}\n", + "@200 = @literal{ ... } -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@201 = @literal{ ... } -> float_type, {128}, {1}\n", + "@202 = @literal{ ... } -> float_type, {128}, {1}\n", + "@203 = @literal{ ... } -> float_type, {128}, {1}\n", + "@204 = @literal{ ... } -> float_type, {128}, {1}\n", + "@205 = @literal{ ... } -> float_type, {128}, {1}\n", + "@206 = @literal{ ... } -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@207 = @literal{ ... } -> float_type, {128}, {1}\n", + "@208 = @literal{ ... } -> float_type, {128}, {1}\n", + "@209 = @literal{ ... } -> float_type, {128}, {1}\n", + "@210 = @literal{ ... } -> float_type, {128}, {1}\n", + "@211 = @literal{ ... } -> float_type, {128}, {1}\n", + "@212 = @literal{ ... } -> float_type, {1, 1, 512, 128}, {65536, 65536, 128, 1}\n", + "@213 = @literal{ ... } -> float_type, {512}, {1}\n", + "@214 = @literal{ ... } -> float_type, {512}, {1}\n", + "@215 = @literal{ ... } -> float_type, {512}, {1}\n", + "@216 = @literal{ ... } -> float_type, {512}, {1}\n", + "@217 = @literal{ ... } -> float_type, {512}, {1}\n", + "@218 = @literal{ ... } -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@219 = @literal{ ... } -> float_type, {128}, {1}\n", + "@220 = @literal{ ... } -> float_type, {128}, {1}\n", + "@221 = @literal{ ... } -> float_type, {128}, {1}\n", + "@222 = @literal{ ... } -> float_type, {128}, {1}\n", + "@223 = @literal{ ... } -> float_type, {128}, {1}\n", + "@224 = @literal{ ... } -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@225 = @literal{ ... } -> float_type, {128}, {1}\n", + "@226 = @literal{ ... } -> float_type, {128}, {1}\n", + "@227 = @literal{ ... } -> float_type, {128}, {1}\n", + "@228 = @literal{ ... } -> float_type, {128}, {1}\n", + "@229 = @literal{ ... } -> float_type, {128}, {1}\n", + "@230 = @literal{ ... } -> float_type, {1, 1, 512, 128}, {65536, 65536, 128, 1}\n", + "@231 = @literal{ ... } -> float_type, {512}, {1}\n", + "@232 = @literal{ ... } -> float_type, {512}, {1}\n", + "@233 = @literal{ ... } -> float_type, {512}, {1}\n", + "@234 = @literal{ ... } -> float_type, {512}, {1}\n", + "@235 = @literal{ ... } -> float_type, {512}, {1}\n", + "@236 = @literal{ ... } -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@237 = @literal{ ... } -> float_type, {128}, {1}\n", + "@238 = @literal{ ... } -> float_type, {128}, {1}\n", + "@239 = @literal{ ... } -> float_type, {128}, {1}\n", + "@240 = @literal{ ... } -> float_type, {128}, {1}\n", + "@241 = @literal{ ... } -> float_type, {128}, {1}\n", + "@242 = @literal{ ... } -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@243 = @literal{ ... } -> float_type, {128}, {1}\n", + "@244 = @literal{ ... } -> float_type, {128}, {1}\n", + "@245 = @literal{ ... } -> float_type, {128}, {1}\n", + "@246 = @literal{ ... } -> float_type, {128}, {1}\n", + "@247 = @literal{ ... } -> float_type, {128}, {1}\n", + "@248 = @literal{ ... } -> float_type, {1, 1, 256, 128}, {32768, 32768, 128, 1}\n", + "@249 = @literal{ ... } -> float_type, {512}, {1}\n", + "@250 = @literal{ ... } -> float_type, {512}, {1}\n", + "@251 = @literal{ ... } -> float_type, {512}, {1}\n", + "@252 = @literal{ ... } -> float_type, {512}, {1}\n", + "@253 = @literal{ ... } -> float_type, {512}, {1}\n", + "@254 = @literal{ ... } -> float_type, {1, 1, 256, 512}, {131072, 131072, 512, 1}\n", + "@255 = @literal{ ... } -> float_type, {256}, {1}\n", + "@256 = @literal{ ... } -> float_type, {256}, {1}\n", + "@257 = @literal{ ... } -> float_type, {256}, {1}\n", + "@258 = @literal{ ... } -> float_type, {256}, {1}\n", + "@259 = @literal{ ... } -> float_type, {256}, {1}\n", + "@260 = @literal{ ... } -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@261 = @literal{ ... } -> float_type, {64}, {1}\n", + "@262 = @literal{ ... } -> float_type, {64}, {1}\n", + "@263 = @literal{ ... } -> float_type, {64}, {1}\n", + "@264 = @literal{ ... } -> float_type, {64}, {1}\n", + "@265 = @literal{ ... } -> float_type, {64}, {1}\n", + "@266 = @literal{ ... } -> float_type, {3, 3, 64, 64}, {12288, 4096, 64, 1}\n", + "@267 = @literal{ ... } -> float_type, {64}, {1}\n", + "@268 = @literal{ ... } -> float_type, {64}, {1}\n", + "@269 = @literal{ ... } -> float_type, {64}, {1}\n", + "@270 = @literal{ ... } -> float_type, {64}, {1}\n", + "@271 = @literal{ ... } -> float_type, {64}, {1}\n", + "@272 = @literal{ ... } -> float_type, {1, 1, 256, 64}, {16384, 16384, 64, 1}\n", + "@273 = @literal{ ... } -> float_type, {256}, {1}\n", + "@274 = @literal{ ... } -> float_type, {256}, {1}\n", + "@275 = @literal{ ... } -> float_type, {256}, {1}\n", + "@276 = @literal{ ... } -> float_type, {256}, {1}\n", + "@277 = @literal{ ... } -> float_type, {256}, {1}\n", + "@278 = @literal{ ... } -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@279 = @literal{ ... } -> float_type, {64}, {1}\n", + "@280 = @literal{ ... } -> float_type, {64}, {1}\n", + "@281 = @literal{ ... } -> float_type, {64}, {1}\n", + "@282 = @literal{ ... } -> float_type, {64}, {1}\n", + "@283 = @literal{ ... } -> float_type, {64}, {1}\n", + "@284 = @literal{ ... } -> float_type, {3, 3, 64, 64}, {12288, 4096, 64, 1}\n", + "@285 = @literal{ ... } -> float_type, {64}, {1}\n", + "@286 = @literal{ ... } -> float_type, {64}, {1}\n", + "@287 = @literal{ ... } -> float_type, {64}, {1}\n", + "@288 = @literal{ ... } -> float_type, {64}, {1}\n", + "@289 = @literal{ ... } -> float_type, {64}, {1}\n", + "@290 = @literal{ ... } -> float_type, {1, 1, 256, 64}, {16384, 16384, 64, 1}\n", + "@291 = @literal{ ... } -> float_type, {256}, {1}\n", + "@292 = @literal{ ... } -> float_type, {256}, {1}\n", + "@293 = @literal{ ... } -> float_type, {256}, {1}\n", + "@294 = @literal{ ... } -> float_type, {256}, {1}\n", + "@295 = @literal{ ... } -> float_type, {256}, {1}\n", + "@296 = @literal{ ... } -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@297 = @literal{ ... } -> float_type, {64}, {1}\n", + "@298 = @literal{ ... } -> float_type, {64}, {1}\n", + "@299 = @literal{ ... } -> float_type, {64}, {1}\n", + "@300 = @literal{ ... } -> float_type, {64}, {1}\n", + "@301 = @literal{ ... } -> float_type, {64}, {1}\n", + "@302 = @literal{ ... } -> float_type, {3, 3, 64, 64}, {12288, 4096, 64, 1}\n", + "@303 = @literal{ ... } -> float_type, {64}, {1}\n", + "@304 = @literal{ ... } -> float_type, {64}, {1}\n", + "@305 = @literal{ ... } -> float_type, {64}, {1}\n", + "@306 = @literal{ ... } -> float_type, {64}, {1}\n", + "@307 = @literal{ ... } -> float_type, {64}, {1}\n", + "@308 = @literal{ ... } -> float_type, {1, 1, 64, 64}, {4096, 4096, 64, 1}\n", + "@309 = @literal{ ... } -> float_type, {256}, {1}\n", + "@310 = @literal{ ... } -> float_type, {256}, {1}\n", + "@311 = @literal{ ... } -> float_type, {256}, {1}\n", + "@312 = @literal{ ... } -> float_type, {256}, {1}\n", + "@313 = @literal{ ... } -> float_type, {256}, {1}\n", + "@314 = @literal{ ... } -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@315 = @literal{0, 0, 1, 1, 1, 1, 0, 0} -> int32_type, {4, 2}, {2, 1}\n", + "@316 = @literal{ ... } -> float_type, {64}, {1}\n", + "@317 = @literal{ ... } -> float_type, {64}, {1}\n", + "@318 = @literal{ ... } -> float_type, {64}, {1}\n", + "@319 = @literal{ ... } -> float_type, {64}, {1}\n", + "@320 = @literal{ ... } -> float_type, {64}, {1}\n", + "@321 = @literal{ ... } -> float_type, {7, 7, 3, 64}, {1344, 192, 64, 1}\n", + "@322 = @literal{0, 0, 3, 3, 3, 3, 0, 0} -> int32_type, {4, 2}, {2, 1}\n", + "x = @param:x -> float_type, {1, 3, 224, 224}, {150528, 50176, 224, 1}\n", + "@323 = transpose[dims={0, 2, 3, 1}](x) -> float_type, {1, 224, 224, 3}, {150528, 224, 1, 50176}\n", + "@324 = transpose[dims={0, 3, 1, 2}](@323) -> float_type, {1, 3, 224, 224}, {150528, 50176, 224, 1}\n", + "@325 = pad[mode=0,pads={0, 0, 3, 3, 0, 0, 3, 3},value=0](@324) -> float_type, {1, 3, 230, 230}, {158700, 52900, 230, 1}\n", + "@326 = transpose[dims={0, 2, 3, 1}](@325) -> float_type, {1, 230, 230, 3}, {158700, 230, 1, 52900}\n", + "@327 = transpose[dims={0, 2, 3, 1}](@321) -> float_type, {7, 3, 64, 7}, {1344, 64, 1, 192}\n", + "@328 = transpose[dims={0, 3, 1, 2}](@327) -> float_type, {7, 7, 3, 64}, {1344, 192, 64, 1}\n", + "@329 = identity(@328) -> float_type, {7, 7, 3, 64}, {1344, 192, 64, 1}\n", + "@330 = transpose[dims={0, 2, 3, 1}](@329) -> float_type, {7, 3, 64, 7}, {1344, 64, 1, 192}\n", + "@331 = transpose[dims={0, 3, 1, 2}](@326) -> float_type, {1, 3, 230, 230}, {158700, 52900, 230, 1}\n", + "@332 = transpose[dims={0, 3, 1, 2}](@330) -> float_type, {7, 7, 3, 64}, {1344, 192, 64, 1}\n", + "@333 = transpose[dims={3, 2, 0, 1}](@332) -> float_type, {64, 3, 7, 7}, {1, 64, 1344, 192}\n", + "@334 = transpose[dims={3, 2, 0, 1}](@332) -> float_type, {64, 3, 7, 7}, {1, 64, 1344, 192}\n", + "@335 = convolution[padding={0, 0},stride={2, 2},dilation={1, 1},group=1,padding_mode=2](@331,@334) -> float_type, {1, 64, 112, 112}, {802816, 12544, 112, 1}\n", + "@336 = transpose[dims={0, 2, 3, 1}](@335) -> float_type, {1, 112, 112, 64}, {802816, 112, 1, 12544}\n", + "@337 = identity(@320) -> float_type, {64}, {1}\n", + "@338 = transpose[dims={0, 3, 1, 2}](@336) -> float_type, {1, 64, 112, 112}, {802816, 12544, 112, 1}\n", + "@339 = broadcast[axis=1,dims={1, 64, 112, 112}](@337) -> float_type, {1, 64, 112, 112}, {0, 1, 0, 0}\n", + "@340 = add(@338,@339) -> float_type, {1, 64, 112, 112}, {802816, 12544, 112, 1}\n", + "@341 = transpose[dims={0, 2, 3, 1}](@340) -> float_type, {1, 112, 112, 64}, {802816, 112, 1, 12544}\n", + "@342 = identity(@319) -> float_type, {64}, {1}\n", + "@343 = identity(@318) -> float_type, {64}, {1}\n", + "@344 = identity(@317) -> float_type, {64}, {1}\n", + "@345 = identity(@316) -> float_type, {64}, {1}\n", + "@346 = unknown:FusedBatchNormV3(@341,@342,@343,@344,@345) -> float_type, {1, 112, 112, 64}, {802816, 112, 1, 12544}\n", + "@347 = transpose[dims={0, 3, 1, 2}](@346) -> float_type, {1, 64, 112, 112}, {802816, 12544, 112, 1}\n", + "@348 = relu(@347) -> float_type, {1, 64, 112, 112}, {802816, 12544, 112, 1}\n", + "@349 = transpose[dims={0, 2, 3, 1}](@348) -> float_type, {1, 112, 112, 64}, {802816, 112, 1, 12544}\n", + "@350 = transpose[dims={0, 3, 1, 2}](@349) -> float_type, {1, 64, 112, 112}, {802816, 12544, 112, 1}\n", + "@351 = pad[mode=0,pads={0, 0, 1, 1, 0, 0, 1, 1},value=0](@350) -> float_type, {1, 64, 114, 114}, {831744, 12996, 114, 1}\n", + "@352 = transpose[dims={0, 2, 3, 1}](@351) -> float_type, {1, 114, 114, 64}, {831744, 114, 1, 12996}\n", + "@353 = transpose[dims={0, 3, 1, 2}](@352) -> float_type, {1, 64, 114, 114}, {831744, 12996, 114, 1}\n", + "@354 = pooling[mode=max,padding={0, 0},stride={2, 2},lengths={3, 3},ceil_mode=0](@353) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@355 = transpose[dims={0, 2, 3, 1}](@354) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@356 = transpose[dims={0, 2, 3, 1}](@314) -> float_type, {1, 64, 256, 1}, {16384, 256, 1, 16384}\n", + "@357 = transpose[dims={0, 3, 1, 2}](@356) -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@358 = identity(@357) -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@359 = transpose[dims={0, 2, 3, 1}](@358) -> float_type, {1, 64, 256, 1}, {16384, 256, 1, 16384}\n", + "@360 = transpose[dims={0, 3, 1, 2}](@355) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@361 = transpose[dims={0, 3, 1, 2}](@359) -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@362 = transpose[dims={3, 2, 0, 1}](@361) -> float_type, {256, 64, 1, 1}, {1, 256, 16384, 16384}\n", + "@363 = transpose[dims={3, 2, 0, 1}](@361) -> float_type, {256, 64, 1, 1}, {1, 256, 16384, 16384}\n", + "@364 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@360,@363) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@365 = transpose[dims={0, 2, 3, 1}](@364) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@366 = identity(@313) -> float_type, {256}, {1}\n", + "@367 = transpose[dims={0, 3, 1, 2}](@365) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@368 = broadcast[axis=1,dims={1, 256, 56, 56}](@366) -> float_type, {1, 256, 56, 56}, {0, 1, 0, 0}\n", + "@369 = add(@367,@368) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@370 = transpose[dims={0, 2, 3, 1}](@369) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@371 = identity(@312) -> float_type, {256}, {1}\n", + "@372 = identity(@311) -> float_type, {256}, {1}\n", + "@373 = identity(@310) -> float_type, {256}, {1}\n", + "@374 = identity(@309) -> float_type, {256}, {1}\n", + "@375 = unknown:FusedBatchNormV3(@370,@371,@372,@373,@374) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@376 = transpose[dims={0, 2, 3, 1}](@308) -> float_type, {1, 64, 64, 1}, {4096, 64, 1, 4096}\n", + "@377 = transpose[dims={0, 3, 1, 2}](@376) -> float_type, {1, 1, 64, 64}, {4096, 4096, 64, 1}\n", + "@378 = identity(@377) -> float_type, {1, 1, 64, 64}, {4096, 4096, 64, 1}\n", + "@379 = transpose[dims={0, 2, 3, 1}](@378) -> float_type, {1, 64, 64, 1}, {4096, 64, 1, 4096}\n", + "@380 = transpose[dims={0, 3, 1, 2}](@355) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@381 = transpose[dims={0, 3, 1, 2}](@379) -> float_type, {1, 1, 64, 64}, {4096, 4096, 64, 1}\n", + "@382 = transpose[dims={3, 2, 0, 1}](@381) -> float_type, {64, 64, 1, 1}, {1, 64, 4096, 4096}\n", + "@383 = transpose[dims={3, 2, 0, 1}](@381) -> float_type, {64, 64, 1, 1}, {1, 64, 4096, 4096}\n", + "@384 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@380,@383) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@385 = transpose[dims={0, 2, 3, 1}](@384) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@386 = identity(@307) -> float_type, {64}, {1}\n", + "@387 = transpose[dims={0, 3, 1, 2}](@385) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@388 = broadcast[axis=1,dims={1, 64, 56, 56}](@386) -> float_type, {1, 64, 56, 56}, {0, 1, 0, 0}\n", + "@389 = add(@387,@388) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@390 = transpose[dims={0, 2, 3, 1}](@389) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@391 = identity(@306) -> float_type, {64}, {1}\n", + "@392 = identity(@305) -> float_type, {64}, {1}\n", + "@393 = identity(@304) -> float_type, {64}, {1}\n", + "@394 = identity(@303) -> float_type, {64}, {1}\n", + "@395 = unknown:FusedBatchNormV3(@390,@391,@392,@393,@394) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@396 = transpose[dims={0, 3, 1, 2}](@395) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@397 = relu(@396) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@398 = transpose[dims={0, 2, 3, 1}](@397) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@399 = transpose[dims={0, 2, 3, 1}](@302) -> float_type, {3, 64, 64, 3}, {12288, 64, 1, 4096}\n", + "@400 = transpose[dims={0, 3, 1, 2}](@399) -> float_type, {3, 3, 64, 64}, {12288, 4096, 64, 1}\n", + "@401 = identity(@400) -> float_type, {3, 3, 64, 64}, {12288, 4096, 64, 1}\n", + "@402 = transpose[dims={0, 2, 3, 1}](@401) -> float_type, {3, 64, 64, 3}, {12288, 64, 1, 4096}\n", + "@403 = transpose[dims={0, 3, 1, 2}](@398) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@404 = transpose[dims={0, 3, 1, 2}](@402) -> float_type, {3, 3, 64, 64}, {12288, 4096, 64, 1}\n", + "@405 = transpose[dims={3, 2, 0, 1}](@404) -> float_type, {64, 64, 3, 3}, {1, 64, 12288, 4096}\n", + "@406 = transpose[dims={3, 2, 0, 1}](@404) -> float_type, {64, 64, 3, 3}, {1, 64, 12288, 4096}\n", + "@407 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@403,@406) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@408 = transpose[dims={0, 2, 3, 1}](@407) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@409 = identity(@301) -> float_type, {64}, {1}\n", + "@410 = transpose[dims={0, 3, 1, 2}](@408) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@411 = broadcast[axis=1,dims={1, 64, 56, 56}](@409) -> float_type, {1, 64, 56, 56}, {0, 1, 0, 0}\n", + "@412 = add(@410,@411) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@413 = transpose[dims={0, 2, 3, 1}](@412) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@414 = identity(@300) -> float_type, {64}, {1}\n", + "@415 = identity(@299) -> float_type, {64}, {1}\n", + "@416 = identity(@298) -> float_type, {64}, {1}\n", + "@417 = identity(@297) -> float_type, {64}, {1}\n", + "@418 = unknown:FusedBatchNormV3(@413,@414,@415,@416,@417) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@419 = transpose[dims={0, 3, 1, 2}](@418) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@420 = relu(@419) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@421 = transpose[dims={0, 2, 3, 1}](@420) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@422 = transpose[dims={0, 2, 3, 1}](@296) -> float_type, {1, 64, 256, 1}, {16384, 256, 1, 16384}\n", + "@423 = transpose[dims={0, 3, 1, 2}](@422) -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@424 = identity(@423) -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@425 = transpose[dims={0, 2, 3, 1}](@424) -> float_type, {1, 64, 256, 1}, {16384, 256, 1, 16384}\n", + "@426 = transpose[dims={0, 3, 1, 2}](@421) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@427 = transpose[dims={0, 3, 1, 2}](@425) -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@428 = transpose[dims={3, 2, 0, 1}](@427) -> float_type, {256, 64, 1, 1}, {1, 256, 16384, 16384}\n", + "@429 = transpose[dims={3, 2, 0, 1}](@427) -> float_type, {256, 64, 1, 1}, {1, 256, 16384, 16384}\n", + "@430 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@426,@429) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@431 = transpose[dims={0, 2, 3, 1}](@430) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@432 = identity(@295) -> float_type, {256}, {1}\n", + "@433 = transpose[dims={0, 3, 1, 2}](@431) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@434 = broadcast[axis=1,dims={1, 256, 56, 56}](@432) -> float_type, {1, 256, 56, 56}, {0, 1, 0, 0}\n", + "@435 = add(@433,@434) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@436 = transpose[dims={0, 2, 3, 1}](@435) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@437 = identity(@294) -> float_type, {256}, {1}\n", + "@438 = identity(@293) -> float_type, {256}, {1}\n", + "@439 = identity(@292) -> float_type, {256}, {1}\n", + "@440 = identity(@291) -> float_type, {256}, {1}\n", + "@441 = unknown:FusedBatchNormV3(@436,@437,@438,@439,@440) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@442 = unknown:AddV2(@375,@441) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@443 = transpose[dims={0, 3, 1, 2}](@442) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@444 = relu(@443) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@445 = transpose[dims={0, 2, 3, 1}](@444) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@446 = transpose[dims={0, 2, 3, 1}](@290) -> float_type, {1, 256, 64, 1}, {16384, 64, 1, 16384}\n", + "@447 = transpose[dims={0, 3, 1, 2}](@446) -> float_type, {1, 1, 256, 64}, {16384, 16384, 64, 1}\n", + "@448 = identity(@447) -> float_type, {1, 1, 256, 64}, {16384, 16384, 64, 1}\n", + "@449 = transpose[dims={0, 2, 3, 1}](@448) -> float_type, {1, 256, 64, 1}, {16384, 64, 1, 16384}\n", + "@450 = transpose[dims={0, 3, 1, 2}](@445) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@451 = transpose[dims={0, 3, 1, 2}](@449) -> float_type, {1, 1, 256, 64}, {16384, 16384, 64, 1}\n", + "@452 = transpose[dims={3, 2, 0, 1}](@451) -> float_type, {64, 256, 1, 1}, {1, 64, 16384, 16384}\n", + "@453 = transpose[dims={3, 2, 0, 1}](@451) -> float_type, {64, 256, 1, 1}, {1, 64, 16384, 16384}\n", + "@454 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@450,@453) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@455 = transpose[dims={0, 2, 3, 1}](@454) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@456 = identity(@289) -> float_type, {64}, {1}\n", + "@457 = transpose[dims={0, 3, 1, 2}](@455) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@458 = broadcast[axis=1,dims={1, 64, 56, 56}](@456) -> float_type, {1, 64, 56, 56}, {0, 1, 0, 0}\n", + "@459 = add(@457,@458) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@460 = transpose[dims={0, 2, 3, 1}](@459) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@461 = identity(@288) -> float_type, {64}, {1}\n", + "@462 = identity(@287) -> float_type, {64}, {1}\n", + "@463 = identity(@286) -> float_type, {64}, {1}\n", + "@464 = identity(@285) -> float_type, {64}, {1}\n", + "@465 = unknown:FusedBatchNormV3(@460,@461,@462,@463,@464) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@466 = transpose[dims={0, 3, 1, 2}](@465) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@467 = relu(@466) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@468 = transpose[dims={0, 2, 3, 1}](@467) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@469 = transpose[dims={0, 2, 3, 1}](@284) -> float_type, {3, 64, 64, 3}, {12288, 64, 1, 4096}\n", + "@470 = transpose[dims={0, 3, 1, 2}](@469) -> float_type, {3, 3, 64, 64}, {12288, 4096, 64, 1}\n", + "@471 = identity(@470) -> float_type, {3, 3, 64, 64}, {12288, 4096, 64, 1}\n", + "@472 = transpose[dims={0, 2, 3, 1}](@471) -> float_type, {3, 64, 64, 3}, {12288, 64, 1, 4096}\n", + "@473 = transpose[dims={0, 3, 1, 2}](@468) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@474 = transpose[dims={0, 3, 1, 2}](@472) -> float_type, {3, 3, 64, 64}, {12288, 4096, 64, 1}\n", + "@475 = transpose[dims={3, 2, 0, 1}](@474) -> float_type, {64, 64, 3, 3}, {1, 64, 12288, 4096}\n", + "@476 = transpose[dims={3, 2, 0, 1}](@474) -> float_type, {64, 64, 3, 3}, {1, 64, 12288, 4096}\n", + "@477 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@473,@476) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@478 = transpose[dims={0, 2, 3, 1}](@477) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@479 = identity(@283) -> float_type, {64}, {1}\n", + "@480 = transpose[dims={0, 3, 1, 2}](@478) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@481 = broadcast[axis=1,dims={1, 64, 56, 56}](@479) -> float_type, {1, 64, 56, 56}, {0, 1, 0, 0}\n", + "@482 = add(@480,@481) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@483 = transpose[dims={0, 2, 3, 1}](@482) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@484 = identity(@282) -> float_type, {64}, {1}\n", + "@485 = identity(@281) -> float_type, {64}, {1}\n", + "@486 = identity(@280) -> float_type, {64}, {1}\n", + "@487 = identity(@279) -> float_type, {64}, {1}\n", + "@488 = unknown:FusedBatchNormV3(@483,@484,@485,@486,@487) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@489 = transpose[dims={0, 3, 1, 2}](@488) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@490 = relu(@489) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@491 = transpose[dims={0, 2, 3, 1}](@490) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@492 = transpose[dims={0, 2, 3, 1}](@278) -> float_type, {1, 64, 256, 1}, {16384, 256, 1, 16384}\n", + "@493 = transpose[dims={0, 3, 1, 2}](@492) -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@494 = identity(@493) -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@495 = transpose[dims={0, 2, 3, 1}](@494) -> float_type, {1, 64, 256, 1}, {16384, 256, 1, 16384}\n", + "@496 = transpose[dims={0, 3, 1, 2}](@491) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@497 = transpose[dims={0, 3, 1, 2}](@495) -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@498 = transpose[dims={3, 2, 0, 1}](@497) -> float_type, {256, 64, 1, 1}, {1, 256, 16384, 16384}\n", + "@499 = transpose[dims={3, 2, 0, 1}](@497) -> float_type, {256, 64, 1, 1}, {1, 256, 16384, 16384}\n", + "@500 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@496,@499) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@501 = transpose[dims={0, 2, 3, 1}](@500) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@502 = identity(@277) -> float_type, {256}, {1}\n", + "@503 = transpose[dims={0, 3, 1, 2}](@501) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@504 = broadcast[axis=1,dims={1, 256, 56, 56}](@502) -> float_type, {1, 256, 56, 56}, {0, 1, 0, 0}\n", + "@505 = add(@503,@504) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@506 = transpose[dims={0, 2, 3, 1}](@505) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@507 = identity(@276) -> float_type, {256}, {1}\n", + "@508 = identity(@275) -> float_type, {256}, {1}\n", + "@509 = identity(@274) -> float_type, {256}, {1}\n", + "@510 = identity(@273) -> float_type, {256}, {1}\n", + "@511 = unknown:FusedBatchNormV3(@506,@507,@508,@509,@510) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@512 = unknown:AddV2(@445,@511) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@513 = transpose[dims={0, 3, 1, 2}](@512) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@514 = relu(@513) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@515 = transpose[dims={0, 2, 3, 1}](@514) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@516 = transpose[dims={0, 2, 3, 1}](@272) -> float_type, {1, 256, 64, 1}, {16384, 64, 1, 16384}\n", + "@517 = transpose[dims={0, 3, 1, 2}](@516) -> float_type, {1, 1, 256, 64}, {16384, 16384, 64, 1}\n", + "@518 = identity(@517) -> float_type, {1, 1, 256, 64}, {16384, 16384, 64, 1}\n", + "@519 = transpose[dims={0, 2, 3, 1}](@518) -> float_type, {1, 256, 64, 1}, {16384, 64, 1, 16384}\n", + "@520 = transpose[dims={0, 3, 1, 2}](@515) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@521 = transpose[dims={0, 3, 1, 2}](@519) -> float_type, {1, 1, 256, 64}, {16384, 16384, 64, 1}\n", + "@522 = transpose[dims={3, 2, 0, 1}](@521) -> float_type, {64, 256, 1, 1}, {1, 64, 16384, 16384}\n", + "@523 = transpose[dims={3, 2, 0, 1}](@521) -> float_type, {64, 256, 1, 1}, {1, 64, 16384, 16384}\n", + "@524 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@520,@523) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@525 = transpose[dims={0, 2, 3, 1}](@524) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@526 = identity(@271) -> float_type, {64}, {1}\n", + "@527 = transpose[dims={0, 3, 1, 2}](@525) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@528 = broadcast[axis=1,dims={1, 64, 56, 56}](@526) -> float_type, {1, 64, 56, 56}, {0, 1, 0, 0}\n", + "@529 = add(@527,@528) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@530 = transpose[dims={0, 2, 3, 1}](@529) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@531 = identity(@270) -> float_type, {64}, {1}\n", + "@532 = identity(@269) -> float_type, {64}, {1}\n", + "@533 = identity(@268) -> float_type, {64}, {1}\n", + "@534 = identity(@267) -> float_type, {64}, {1}\n", + "@535 = unknown:FusedBatchNormV3(@530,@531,@532,@533,@534) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@536 = transpose[dims={0, 3, 1, 2}](@535) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@537 = relu(@536) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@538 = transpose[dims={0, 2, 3, 1}](@537) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@539 = transpose[dims={0, 2, 3, 1}](@266) -> float_type, {3, 64, 64, 3}, {12288, 64, 1, 4096}\n", + "@540 = transpose[dims={0, 3, 1, 2}](@539) -> float_type, {3, 3, 64, 64}, {12288, 4096, 64, 1}\n", + "@541 = identity(@540) -> float_type, {3, 3, 64, 64}, {12288, 4096, 64, 1}\n", + "@542 = transpose[dims={0, 2, 3, 1}](@541) -> float_type, {3, 64, 64, 3}, {12288, 64, 1, 4096}\n", + "@543 = transpose[dims={0, 3, 1, 2}](@538) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@544 = transpose[dims={0, 3, 1, 2}](@542) -> float_type, {3, 3, 64, 64}, {12288, 4096, 64, 1}\n", + "@545 = transpose[dims={3, 2, 0, 1}](@544) -> float_type, {64, 64, 3, 3}, {1, 64, 12288, 4096}\n", + "@546 = transpose[dims={3, 2, 0, 1}](@544) -> float_type, {64, 64, 3, 3}, {1, 64, 12288, 4096}\n", + "@547 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@543,@546) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@548 = transpose[dims={0, 2, 3, 1}](@547) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@549 = identity(@265) -> float_type, {64}, {1}\n", + "@550 = transpose[dims={0, 3, 1, 2}](@548) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@551 = broadcast[axis=1,dims={1, 64, 56, 56}](@549) -> float_type, {1, 64, 56, 56}, {0, 1, 0, 0}\n", + "@552 = add(@550,@551) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@553 = transpose[dims={0, 2, 3, 1}](@552) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@554 = identity(@264) -> float_type, {64}, {1}\n", + "@555 = identity(@263) -> float_type, {64}, {1}\n", + "@556 = identity(@262) -> float_type, {64}, {1}\n", + "@557 = identity(@261) -> float_type, {64}, {1}\n", + "@558 = unknown:FusedBatchNormV3(@553,@554,@555,@556,@557) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@559 = transpose[dims={0, 3, 1, 2}](@558) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@560 = relu(@559) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@561 = transpose[dims={0, 2, 3, 1}](@560) -> float_type, {1, 56, 56, 64}, {200704, 56, 1, 3136}\n", + "@562 = transpose[dims={0, 2, 3, 1}](@260) -> float_type, {1, 64, 256, 1}, {16384, 256, 1, 16384}\n", + "@563 = transpose[dims={0, 3, 1, 2}](@562) -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@564 = identity(@563) -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@565 = transpose[dims={0, 2, 3, 1}](@564) -> float_type, {1, 64, 256, 1}, {16384, 256, 1, 16384}\n", + "@566 = transpose[dims={0, 3, 1, 2}](@561) -> float_type, {1, 64, 56, 56}, {200704, 3136, 56, 1}\n", + "@567 = transpose[dims={0, 3, 1, 2}](@565) -> float_type, {1, 1, 64, 256}, {16384, 16384, 256, 1}\n", + "@568 = transpose[dims={3, 2, 0, 1}](@567) -> float_type, {256, 64, 1, 1}, {1, 256, 16384, 16384}\n", + "@569 = transpose[dims={3, 2, 0, 1}](@567) -> float_type, {256, 64, 1, 1}, {1, 256, 16384, 16384}\n", + "@570 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@566,@569) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@571 = transpose[dims={0, 2, 3, 1}](@570) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@572 = identity(@259) -> float_type, {256}, {1}\n", + "@573 = transpose[dims={0, 3, 1, 2}](@571) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@574 = broadcast[axis=1,dims={1, 256, 56, 56}](@572) -> float_type, {1, 256, 56, 56}, {0, 1, 0, 0}\n", + "@575 = add(@573,@574) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@576 = transpose[dims={0, 2, 3, 1}](@575) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@577 = identity(@258) -> float_type, {256}, {1}\n", + "@578 = identity(@257) -> float_type, {256}, {1}\n", + "@579 = identity(@256) -> float_type, {256}, {1}\n", + "@580 = identity(@255) -> float_type, {256}, {1}\n", + "@581 = unknown:FusedBatchNormV3(@576,@577,@578,@579,@580) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@582 = unknown:AddV2(@515,@581) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@583 = transpose[dims={0, 3, 1, 2}](@582) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@584 = relu(@583) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@585 = transpose[dims={0, 2, 3, 1}](@584) -> float_type, {1, 56, 56, 256}, {802816, 56, 1, 3136}\n", + "@586 = transpose[dims={0, 2, 3, 1}](@254) -> float_type, {1, 256, 512, 1}, {131072, 512, 1, 131072}\n", + "@587 = transpose[dims={0, 3, 1, 2}](@586) -> float_type, {1, 1, 256, 512}, {131072, 131072, 512, 1}\n", + "@588 = identity(@587) -> float_type, {1, 1, 256, 512}, {131072, 131072, 512, 1}\n", + "@589 = transpose[dims={0, 2, 3, 1}](@588) -> float_type, {1, 256, 512, 1}, {131072, 512, 1, 131072}\n", + "@590 = transpose[dims={0, 3, 1, 2}](@585) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@591 = transpose[dims={0, 3, 1, 2}](@589) -> float_type, {1, 1, 256, 512}, {131072, 131072, 512, 1}\n", + "@592 = transpose[dims={3, 2, 0, 1}](@591) -> float_type, {512, 256, 1, 1}, {1, 512, 131072, 131072}\n", + "@593 = transpose[dims={3, 2, 0, 1}](@591) -> float_type, {512, 256, 1, 1}, {1, 512, 131072, 131072}\n", + "@594 = convolution[padding={0, 0},stride={2, 2},dilation={1, 1},group=1,padding_mode=2](@590,@593) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@595 = transpose[dims={0, 2, 3, 1}](@594) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@596 = identity(@253) -> float_type, {512}, {1}\n", + "@597 = transpose[dims={0, 3, 1, 2}](@595) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@598 = broadcast[axis=1,dims={1, 512, 28, 28}](@596) -> float_type, {1, 512, 28, 28}, {0, 1, 0, 0}\n", + "@599 = add(@597,@598) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@600 = transpose[dims={0, 2, 3, 1}](@599) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@601 = identity(@252) -> float_type, {512}, {1}\n", + "@602 = identity(@251) -> float_type, {512}, {1}\n", + "@603 = identity(@250) -> float_type, {512}, {1}\n", + "@604 = identity(@249) -> float_type, {512}, {1}\n", + "@605 = unknown:FusedBatchNormV3(@600,@601,@602,@603,@604) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@606 = transpose[dims={0, 2, 3, 1}](@248) -> float_type, {1, 256, 128, 1}, {32768, 128, 1, 32768}\n", + "@607 = transpose[dims={0, 3, 1, 2}](@606) -> float_type, {1, 1, 256, 128}, {32768, 32768, 128, 1}\n", + "@608 = identity(@607) -> float_type, {1, 1, 256, 128}, {32768, 32768, 128, 1}\n", + "@609 = transpose[dims={0, 2, 3, 1}](@608) -> float_type, {1, 256, 128, 1}, {32768, 128, 1, 32768}\n", + "@610 = transpose[dims={0, 3, 1, 2}](@585) -> float_type, {1, 256, 56, 56}, {802816, 3136, 56, 1}\n", + "@611 = transpose[dims={0, 3, 1, 2}](@609) -> float_type, {1, 1, 256, 128}, {32768, 32768, 128, 1}\n", + "@612 = transpose[dims={3, 2, 0, 1}](@611) -> float_type, {128, 256, 1, 1}, {1, 128, 32768, 32768}\n", + "@613 = transpose[dims={3, 2, 0, 1}](@611) -> float_type, {128, 256, 1, 1}, {1, 128, 32768, 32768}\n", + "@614 = convolution[padding={0, 0},stride={2, 2},dilation={1, 1},group=1,padding_mode=2](@610,@613) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@615 = transpose[dims={0, 2, 3, 1}](@614) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@616 = identity(@247) -> float_type, {128}, {1}\n", + "@617 = transpose[dims={0, 3, 1, 2}](@615) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@618 = broadcast[axis=1,dims={1, 128, 28, 28}](@616) -> float_type, {1, 128, 28, 28}, {0, 1, 0, 0}\n", + "@619 = add(@617,@618) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@620 = transpose[dims={0, 2, 3, 1}](@619) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@621 = identity(@246) -> float_type, {128}, {1}\n", + "@622 = identity(@245) -> float_type, {128}, {1}\n", + "@623 = identity(@244) -> float_type, {128}, {1}\n", + "@624 = identity(@243) -> float_type, {128}, {1}\n", + "@625 = unknown:FusedBatchNormV3(@620,@621,@622,@623,@624) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@626 = transpose[dims={0, 3, 1, 2}](@625) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@627 = relu(@626) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@628 = transpose[dims={0, 2, 3, 1}](@627) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@629 = transpose[dims={0, 2, 3, 1}](@242) -> float_type, {3, 128, 128, 3}, {49152, 128, 1, 16384}\n", + "@630 = transpose[dims={0, 3, 1, 2}](@629) -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@631 = identity(@630) -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@632 = transpose[dims={0, 2, 3, 1}](@631) -> float_type, {3, 128, 128, 3}, {49152, 128, 1, 16384}\n", + "@633 = transpose[dims={0, 3, 1, 2}](@628) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@634 = transpose[dims={0, 3, 1, 2}](@632) -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@635 = transpose[dims={3, 2, 0, 1}](@634) -> float_type, {128, 128, 3, 3}, {1, 128, 49152, 16384}\n", + "@636 = transpose[dims={3, 2, 0, 1}](@634) -> float_type, {128, 128, 3, 3}, {1, 128, 49152, 16384}\n", + "@637 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@633,@636) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@638 = transpose[dims={0, 2, 3, 1}](@637) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@639 = identity(@241) -> float_type, {128}, {1}\n", + "@640 = transpose[dims={0, 3, 1, 2}](@638) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@641 = broadcast[axis=1,dims={1, 128, 28, 28}](@639) -> float_type, {1, 128, 28, 28}, {0, 1, 0, 0}\n", + "@642 = add(@640,@641) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@643 = transpose[dims={0, 2, 3, 1}](@642) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@644 = identity(@240) -> float_type, {128}, {1}\n", + "@645 = identity(@239) -> float_type, {128}, {1}\n", + "@646 = identity(@238) -> float_type, {128}, {1}\n", + "@647 = identity(@237) -> float_type, {128}, {1}\n", + "@648 = unknown:FusedBatchNormV3(@643,@644,@645,@646,@647) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@649 = transpose[dims={0, 3, 1, 2}](@648) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@650 = relu(@649) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@651 = transpose[dims={0, 2, 3, 1}](@650) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@652 = transpose[dims={0, 2, 3, 1}](@236) -> float_type, {1, 128, 512, 1}, {65536, 512, 1, 65536}\n", + "@653 = transpose[dims={0, 3, 1, 2}](@652) -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@654 = identity(@653) -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@655 = transpose[dims={0, 2, 3, 1}](@654) -> float_type, {1, 128, 512, 1}, {65536, 512, 1, 65536}\n", + "@656 = transpose[dims={0, 3, 1, 2}](@651) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@657 = transpose[dims={0, 3, 1, 2}](@655) -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@658 = transpose[dims={3, 2, 0, 1}](@657) -> float_type, {512, 128, 1, 1}, {1, 512, 65536, 65536}\n", + "@659 = transpose[dims={3, 2, 0, 1}](@657) -> float_type, {512, 128, 1, 1}, {1, 512, 65536, 65536}\n", + "@660 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@656,@659) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@661 = transpose[dims={0, 2, 3, 1}](@660) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@662 = identity(@235) -> float_type, {512}, {1}\n", + "@663 = transpose[dims={0, 3, 1, 2}](@661) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@664 = broadcast[axis=1,dims={1, 512, 28, 28}](@662) -> float_type, {1, 512, 28, 28}, {0, 1, 0, 0}\n", + "@665 = add(@663,@664) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@666 = transpose[dims={0, 2, 3, 1}](@665) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@667 = identity(@234) -> float_type, {512}, {1}\n", + "@668 = identity(@233) -> float_type, {512}, {1}\n", + "@669 = identity(@232) -> float_type, {512}, {1}\n", + "@670 = identity(@231) -> float_type, {512}, {1}\n", + "@671 = unknown:FusedBatchNormV3(@666,@667,@668,@669,@670) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@672 = unknown:AddV2(@605,@671) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@673 = transpose[dims={0, 3, 1, 2}](@672) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@674 = relu(@673) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@675 = transpose[dims={0, 2, 3, 1}](@674) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@676 = transpose[dims={0, 2, 3, 1}](@230) -> float_type, {1, 512, 128, 1}, {65536, 128, 1, 65536}\n", + "@677 = transpose[dims={0, 3, 1, 2}](@676) -> float_type, {1, 1, 512, 128}, {65536, 65536, 128, 1}\n", + "@678 = identity(@677) -> float_type, {1, 1, 512, 128}, {65536, 65536, 128, 1}\n", + "@679 = transpose[dims={0, 2, 3, 1}](@678) -> float_type, {1, 512, 128, 1}, {65536, 128, 1, 65536}\n", + "@680 = transpose[dims={0, 3, 1, 2}](@675) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@681 = transpose[dims={0, 3, 1, 2}](@679) -> float_type, {1, 1, 512, 128}, {65536, 65536, 128, 1}\n", + "@682 = transpose[dims={3, 2, 0, 1}](@681) -> float_type, {128, 512, 1, 1}, {1, 128, 65536, 65536}\n", + "@683 = transpose[dims={3, 2, 0, 1}](@681) -> float_type, {128, 512, 1, 1}, {1, 128, 65536, 65536}\n", + "@684 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@680,@683) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@685 = transpose[dims={0, 2, 3, 1}](@684) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@686 = identity(@229) -> float_type, {128}, {1}\n", + "@687 = transpose[dims={0, 3, 1, 2}](@685) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@688 = broadcast[axis=1,dims={1, 128, 28, 28}](@686) -> float_type, {1, 128, 28, 28}, {0, 1, 0, 0}\n", + "@689 = add(@687,@688) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@690 = transpose[dims={0, 2, 3, 1}](@689) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@691 = identity(@228) -> float_type, {128}, {1}\n", + "@692 = identity(@227) -> float_type, {128}, {1}\n", + "@693 = identity(@226) -> float_type, {128}, {1}\n", + "@694 = identity(@225) -> float_type, {128}, {1}\n", + "@695 = unknown:FusedBatchNormV3(@690,@691,@692,@693,@694) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@696 = transpose[dims={0, 3, 1, 2}](@695) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@697 = relu(@696) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@698 = transpose[dims={0, 2, 3, 1}](@697) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@699 = transpose[dims={0, 2, 3, 1}](@224) -> float_type, {3, 128, 128, 3}, {49152, 128, 1, 16384}\n", + "@700 = transpose[dims={0, 3, 1, 2}](@699) -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@701 = identity(@700) -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@702 = transpose[dims={0, 2, 3, 1}](@701) -> float_type, {3, 128, 128, 3}, {49152, 128, 1, 16384}\n", + "@703 = transpose[dims={0, 3, 1, 2}](@698) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@704 = transpose[dims={0, 3, 1, 2}](@702) -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@705 = transpose[dims={3, 2, 0, 1}](@704) -> float_type, {128, 128, 3, 3}, {1, 128, 49152, 16384}\n", + "@706 = transpose[dims={3, 2, 0, 1}](@704) -> float_type, {128, 128, 3, 3}, {1, 128, 49152, 16384}\n", + "@707 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@703,@706) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@708 = transpose[dims={0, 2, 3, 1}](@707) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@709 = identity(@223) -> float_type, {128}, {1}\n", + "@710 = transpose[dims={0, 3, 1, 2}](@708) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@711 = broadcast[axis=1,dims={1, 128, 28, 28}](@709) -> float_type, {1, 128, 28, 28}, {0, 1, 0, 0}\n", + "@712 = add(@710,@711) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@713 = transpose[dims={0, 2, 3, 1}](@712) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@714 = identity(@222) -> float_type, {128}, {1}\n", + "@715 = identity(@221) -> float_type, {128}, {1}\n", + "@716 = identity(@220) -> float_type, {128}, {1}\n", + "@717 = identity(@219) -> float_type, {128}, {1}\n", + "@718 = unknown:FusedBatchNormV3(@713,@714,@715,@716,@717) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@719 = transpose[dims={0, 3, 1, 2}](@718) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@720 = relu(@719) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@721 = transpose[dims={0, 2, 3, 1}](@720) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@722 = transpose[dims={0, 2, 3, 1}](@218) -> float_type, {1, 128, 512, 1}, {65536, 512, 1, 65536}\n", + "@723 = transpose[dims={0, 3, 1, 2}](@722) -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@724 = identity(@723) -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@725 = transpose[dims={0, 2, 3, 1}](@724) -> float_type, {1, 128, 512, 1}, {65536, 512, 1, 65536}\n", + "@726 = transpose[dims={0, 3, 1, 2}](@721) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@727 = transpose[dims={0, 3, 1, 2}](@725) -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@728 = transpose[dims={3, 2, 0, 1}](@727) -> float_type, {512, 128, 1, 1}, {1, 512, 65536, 65536}\n", + "@729 = transpose[dims={3, 2, 0, 1}](@727) -> float_type, {512, 128, 1, 1}, {1, 512, 65536, 65536}\n", + "@730 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@726,@729) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@731 = transpose[dims={0, 2, 3, 1}](@730) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@732 = identity(@217) -> float_type, {512}, {1}\n", + "@733 = transpose[dims={0, 3, 1, 2}](@731) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@734 = broadcast[axis=1,dims={1, 512, 28, 28}](@732) -> float_type, {1, 512, 28, 28}, {0, 1, 0, 0}\n", + "@735 = add(@733,@734) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@736 = transpose[dims={0, 2, 3, 1}](@735) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@737 = identity(@216) -> float_type, {512}, {1}\n", + "@738 = identity(@215) -> float_type, {512}, {1}\n", + "@739 = identity(@214) -> float_type, {512}, {1}\n", + "@740 = identity(@213) -> float_type, {512}, {1}\n", + "@741 = unknown:FusedBatchNormV3(@736,@737,@738,@739,@740) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@742 = unknown:AddV2(@675,@741) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@743 = transpose[dims={0, 3, 1, 2}](@742) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@744 = relu(@743) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@745 = transpose[dims={0, 2, 3, 1}](@744) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@746 = transpose[dims={0, 2, 3, 1}](@212) -> float_type, {1, 512, 128, 1}, {65536, 128, 1, 65536}\n", + "@747 = transpose[dims={0, 3, 1, 2}](@746) -> float_type, {1, 1, 512, 128}, {65536, 65536, 128, 1}\n", + "@748 = identity(@747) -> float_type, {1, 1, 512, 128}, {65536, 65536, 128, 1}\n", + "@749 = transpose[dims={0, 2, 3, 1}](@748) -> float_type, {1, 512, 128, 1}, {65536, 128, 1, 65536}\n", + "@750 = transpose[dims={0, 3, 1, 2}](@745) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@751 = transpose[dims={0, 3, 1, 2}](@749) -> float_type, {1, 1, 512, 128}, {65536, 65536, 128, 1}\n", + "@752 = transpose[dims={3, 2, 0, 1}](@751) -> float_type, {128, 512, 1, 1}, {1, 128, 65536, 65536}\n", + "@753 = transpose[dims={3, 2, 0, 1}](@751) -> float_type, {128, 512, 1, 1}, {1, 128, 65536, 65536}\n", + "@754 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@750,@753) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@755 = transpose[dims={0, 2, 3, 1}](@754) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@756 = identity(@211) -> float_type, {128}, {1}\n", + "@757 = transpose[dims={0, 3, 1, 2}](@755) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@758 = broadcast[axis=1,dims={1, 128, 28, 28}](@756) -> float_type, {1, 128, 28, 28}, {0, 1, 0, 0}\n", + "@759 = add(@757,@758) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@760 = transpose[dims={0, 2, 3, 1}](@759) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@761 = identity(@210) -> float_type, {128}, {1}\n", + "@762 = identity(@209) -> float_type, {128}, {1}\n", + "@763 = identity(@208) -> float_type, {128}, {1}\n", + "@764 = identity(@207) -> float_type, {128}, {1}\n", + "@765 = unknown:FusedBatchNormV3(@760,@761,@762,@763,@764) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@766 = transpose[dims={0, 3, 1, 2}](@765) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@767 = relu(@766) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@768 = transpose[dims={0, 2, 3, 1}](@767) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@769 = transpose[dims={0, 2, 3, 1}](@206) -> float_type, {3, 128, 128, 3}, {49152, 128, 1, 16384}\n", + "@770 = transpose[dims={0, 3, 1, 2}](@769) -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@771 = identity(@770) -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@772 = transpose[dims={0, 2, 3, 1}](@771) -> float_type, {3, 128, 128, 3}, {49152, 128, 1, 16384}\n", + "@773 = transpose[dims={0, 3, 1, 2}](@768) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@774 = transpose[dims={0, 3, 1, 2}](@772) -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@775 = transpose[dims={3, 2, 0, 1}](@774) -> float_type, {128, 128, 3, 3}, {1, 128, 49152, 16384}\n", + "@776 = transpose[dims={3, 2, 0, 1}](@774) -> float_type, {128, 128, 3, 3}, {1, 128, 49152, 16384}\n", + "@777 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@773,@776) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@778 = transpose[dims={0, 2, 3, 1}](@777) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@779 = identity(@205) -> float_type, {128}, {1}\n", + "@780 = transpose[dims={0, 3, 1, 2}](@778) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@781 = broadcast[axis=1,dims={1, 128, 28, 28}](@779) -> float_type, {1, 128, 28, 28}, {0, 1, 0, 0}\n", + "@782 = add(@780,@781) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@783 = transpose[dims={0, 2, 3, 1}](@782) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@784 = identity(@204) -> float_type, {128}, {1}\n", + "@785 = identity(@203) -> float_type, {128}, {1}\n", + "@786 = identity(@202) -> float_type, {128}, {1}\n", + "@787 = identity(@201) -> float_type, {128}, {1}\n", + "@788 = unknown:FusedBatchNormV3(@783,@784,@785,@786,@787) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@789 = transpose[dims={0, 3, 1, 2}](@788) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@790 = relu(@789) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@791 = transpose[dims={0, 2, 3, 1}](@790) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@792 = transpose[dims={0, 2, 3, 1}](@200) -> float_type, {1, 128, 512, 1}, {65536, 512, 1, 65536}\n", + "@793 = transpose[dims={0, 3, 1, 2}](@792) -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@794 = identity(@793) -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@795 = transpose[dims={0, 2, 3, 1}](@794) -> float_type, {1, 128, 512, 1}, {65536, 512, 1, 65536}\n", + "@796 = transpose[dims={0, 3, 1, 2}](@791) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@797 = transpose[dims={0, 3, 1, 2}](@795) -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@798 = transpose[dims={3, 2, 0, 1}](@797) -> float_type, {512, 128, 1, 1}, {1, 512, 65536, 65536}\n", + "@799 = transpose[dims={3, 2, 0, 1}](@797) -> float_type, {512, 128, 1, 1}, {1, 512, 65536, 65536}\n", + "@800 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@796,@799) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@801 = transpose[dims={0, 2, 3, 1}](@800) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@802 = identity(@199) -> float_type, {512}, {1}\n", + "@803 = transpose[dims={0, 3, 1, 2}](@801) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@804 = broadcast[axis=1,dims={1, 512, 28, 28}](@802) -> float_type, {1, 512, 28, 28}, {0, 1, 0, 0}\n", + "@805 = add(@803,@804) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@806 = transpose[dims={0, 2, 3, 1}](@805) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@807 = identity(@198) -> float_type, {512}, {1}\n", + "@808 = identity(@197) -> float_type, {512}, {1}\n", + "@809 = identity(@196) -> float_type, {512}, {1}\n", + "@810 = identity(@195) -> float_type, {512}, {1}\n", + "@811 = unknown:FusedBatchNormV3(@806,@807,@808,@809,@810) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@812 = unknown:AddV2(@745,@811) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@813 = transpose[dims={0, 3, 1, 2}](@812) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@814 = relu(@813) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@815 = transpose[dims={0, 2, 3, 1}](@814) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@816 = transpose[dims={0, 2, 3, 1}](@194) -> float_type, {1, 512, 128, 1}, {65536, 128, 1, 65536}\n", + "@817 = transpose[dims={0, 3, 1, 2}](@816) -> float_type, {1, 1, 512, 128}, {65536, 65536, 128, 1}\n", + "@818 = identity(@817) -> float_type, {1, 1, 512, 128}, {65536, 65536, 128, 1}\n", + "@819 = transpose[dims={0, 2, 3, 1}](@818) -> float_type, {1, 512, 128, 1}, {65536, 128, 1, 65536}\n", + "@820 = transpose[dims={0, 3, 1, 2}](@815) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@821 = transpose[dims={0, 3, 1, 2}](@819) -> float_type, {1, 1, 512, 128}, {65536, 65536, 128, 1}\n", + "@822 = transpose[dims={3, 2, 0, 1}](@821) -> float_type, {128, 512, 1, 1}, {1, 128, 65536, 65536}\n", + "@823 = transpose[dims={3, 2, 0, 1}](@821) -> float_type, {128, 512, 1, 1}, {1, 128, 65536, 65536}\n", + "@824 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@820,@823) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@825 = transpose[dims={0, 2, 3, 1}](@824) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@826 = identity(@193) -> float_type, {128}, {1}\n", + "@827 = transpose[dims={0, 3, 1, 2}](@825) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@828 = broadcast[axis=1,dims={1, 128, 28, 28}](@826) -> float_type, {1, 128, 28, 28}, {0, 1, 0, 0}\n", + "@829 = add(@827,@828) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@830 = transpose[dims={0, 2, 3, 1}](@829) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@831 = identity(@192) -> float_type, {128}, {1}\n", + "@832 = identity(@191) -> float_type, {128}, {1}\n", + "@833 = identity(@190) -> float_type, {128}, {1}\n", + "@834 = identity(@189) -> float_type, {128}, {1}\n", + "@835 = unknown:FusedBatchNormV3(@830,@831,@832,@833,@834) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@836 = transpose[dims={0, 3, 1, 2}](@835) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@837 = relu(@836) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@838 = transpose[dims={0, 2, 3, 1}](@837) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@839 = transpose[dims={0, 2, 3, 1}](@188) -> float_type, {3, 128, 128, 3}, {49152, 128, 1, 16384}\n", + "@840 = transpose[dims={0, 3, 1, 2}](@839) -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@841 = identity(@840) -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@842 = transpose[dims={0, 2, 3, 1}](@841) -> float_type, {3, 128, 128, 3}, {49152, 128, 1, 16384}\n", + "@843 = transpose[dims={0, 3, 1, 2}](@838) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@844 = transpose[dims={0, 3, 1, 2}](@842) -> float_type, {3, 3, 128, 128}, {49152, 16384, 128, 1}\n", + "@845 = transpose[dims={3, 2, 0, 1}](@844) -> float_type, {128, 128, 3, 3}, {1, 128, 49152, 16384}\n", + "@846 = transpose[dims={3, 2, 0, 1}](@844) -> float_type, {128, 128, 3, 3}, {1, 128, 49152, 16384}\n", + "@847 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@843,@846) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@848 = transpose[dims={0, 2, 3, 1}](@847) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@849 = identity(@187) -> float_type, {128}, {1}\n", + "@850 = transpose[dims={0, 3, 1, 2}](@848) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@851 = broadcast[axis=1,dims={1, 128, 28, 28}](@849) -> float_type, {1, 128, 28, 28}, {0, 1, 0, 0}\n", + "@852 = add(@850,@851) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@853 = transpose[dims={0, 2, 3, 1}](@852) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@854 = identity(@186) -> float_type, {128}, {1}\n", + "@855 = identity(@185) -> float_type, {128}, {1}\n", + "@856 = identity(@184) -> float_type, {128}, {1}\n", + "@857 = identity(@183) -> float_type, {128}, {1}\n", + "@858 = unknown:FusedBatchNormV3(@853,@854,@855,@856,@857) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@859 = transpose[dims={0, 3, 1, 2}](@858) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@860 = relu(@859) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@861 = transpose[dims={0, 2, 3, 1}](@860) -> float_type, {1, 28, 28, 128}, {100352, 28, 1, 784}\n", + "@862 = transpose[dims={0, 2, 3, 1}](@182) -> float_type, {1, 128, 512, 1}, {65536, 512, 1, 65536}\n", + "@863 = transpose[dims={0, 3, 1, 2}](@862) -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@864 = identity(@863) -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@865 = transpose[dims={0, 2, 3, 1}](@864) -> float_type, {1, 128, 512, 1}, {65536, 512, 1, 65536}\n", + "@866 = transpose[dims={0, 3, 1, 2}](@861) -> float_type, {1, 128, 28, 28}, {100352, 784, 28, 1}\n", + "@867 = transpose[dims={0, 3, 1, 2}](@865) -> float_type, {1, 1, 128, 512}, {65536, 65536, 512, 1}\n", + "@868 = transpose[dims={3, 2, 0, 1}](@867) -> float_type, {512, 128, 1, 1}, {1, 512, 65536, 65536}\n", + "@869 = transpose[dims={3, 2, 0, 1}](@867) -> float_type, {512, 128, 1, 1}, {1, 512, 65536, 65536}\n", + "@870 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@866,@869) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@871 = transpose[dims={0, 2, 3, 1}](@870) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@872 = identity(@181) -> float_type, {512}, {1}\n", + "@873 = transpose[dims={0, 3, 1, 2}](@871) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@874 = broadcast[axis=1,dims={1, 512, 28, 28}](@872) -> float_type, {1, 512, 28, 28}, {0, 1, 0, 0}\n", + "@875 = add(@873,@874) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@876 = transpose[dims={0, 2, 3, 1}](@875) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@877 = identity(@180) -> float_type, {512}, {1}\n", + "@878 = identity(@179) -> float_type, {512}, {1}\n", + "@879 = identity(@178) -> float_type, {512}, {1}\n", + "@880 = identity(@177) -> float_type, {512}, {1}\n", + "@881 = unknown:FusedBatchNormV3(@876,@877,@878,@879,@880) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@882 = unknown:AddV2(@815,@881) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@883 = transpose[dims={0, 3, 1, 2}](@882) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@884 = relu(@883) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@885 = transpose[dims={0, 2, 3, 1}](@884) -> float_type, {1, 28, 28, 512}, {401408, 28, 1, 784}\n", + "@886 = transpose[dims={0, 2, 3, 1}](@176) -> float_type, {1, 512, 1024, 1}, {524288, 1024, 1, 524288}\n", + "@887 = transpose[dims={0, 3, 1, 2}](@886) -> float_type, {1, 1, 512, 1024}, {524288, 524288, 1024, 1}\n", + "@888 = identity(@887) -> float_type, {1, 1, 512, 1024}, {524288, 524288, 1024, 1}\n", + "@889 = transpose[dims={0, 2, 3, 1}](@888) -> float_type, {1, 512, 1024, 1}, {524288, 1024, 1, 524288}\n", + "@890 = transpose[dims={0, 3, 1, 2}](@885) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@891 = transpose[dims={0, 3, 1, 2}](@889) -> float_type, {1, 1, 512, 1024}, {524288, 524288, 1024, 1}\n", + "@892 = transpose[dims={3, 2, 0, 1}](@891) -> float_type, {1024, 512, 1, 1}, {1, 1024, 524288, 524288}\n", + "@893 = transpose[dims={3, 2, 0, 1}](@891) -> float_type, {1024, 512, 1, 1}, {1, 1024, 524288, 524288}\n", + "@894 = convolution[padding={0, 0},stride={2, 2},dilation={1, 1},group=1,padding_mode=2](@890,@893) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@895 = transpose[dims={0, 2, 3, 1}](@894) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@896 = identity(@175) -> float_type, {1024}, {1}\n", + "@897 = transpose[dims={0, 3, 1, 2}](@895) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@898 = broadcast[axis=1,dims={1, 1024, 14, 14}](@896) -> float_type, {1, 1024, 14, 14}, {0, 1, 0, 0}\n", + "@899 = add(@897,@898) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@900 = transpose[dims={0, 2, 3, 1}](@899) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@901 = identity(@174) -> float_type, {1024}, {1}\n", + "@902 = identity(@173) -> float_type, {1024}, {1}\n", + "@903 = identity(@172) -> float_type, {1024}, {1}\n", + "@904 = identity(@171) -> float_type, {1024}, {1}\n", + "@905 = unknown:FusedBatchNormV3(@900,@901,@902,@903,@904) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@906 = transpose[dims={0, 2, 3, 1}](@170) -> float_type, {1, 512, 256, 1}, {131072, 256, 1, 131072}\n", + "@907 = transpose[dims={0, 3, 1, 2}](@906) -> float_type, {1, 1, 512, 256}, {131072, 131072, 256, 1}\n", + "@908 = identity(@907) -> float_type, {1, 1, 512, 256}, {131072, 131072, 256, 1}\n", + "@909 = transpose[dims={0, 2, 3, 1}](@908) -> float_type, {1, 512, 256, 1}, {131072, 256, 1, 131072}\n", + "@910 = transpose[dims={0, 3, 1, 2}](@885) -> float_type, {1, 512, 28, 28}, {401408, 784, 28, 1}\n", + "@911 = transpose[dims={0, 3, 1, 2}](@909) -> float_type, {1, 1, 512, 256}, {131072, 131072, 256, 1}\n", + "@912 = transpose[dims={3, 2, 0, 1}](@911) -> float_type, {256, 512, 1, 1}, {1, 256, 131072, 131072}\n", + "@913 = transpose[dims={3, 2, 0, 1}](@911) -> float_type, {256, 512, 1, 1}, {1, 256, 131072, 131072}\n", + "@914 = convolution[padding={0, 0},stride={2, 2},dilation={1, 1},group=1,padding_mode=2](@910,@913) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@915 = transpose[dims={0, 2, 3, 1}](@914) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@916 = identity(@169) -> float_type, {256}, {1}\n", + "@917 = transpose[dims={0, 3, 1, 2}](@915) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@918 = broadcast[axis=1,dims={1, 256, 14, 14}](@916) -> float_type, {1, 256, 14, 14}, {0, 1, 0, 0}\n", + "@919 = add(@917,@918) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@920 = transpose[dims={0, 2, 3, 1}](@919) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@921 = identity(@168) -> float_type, {256}, {1}\n", + "@922 = identity(@167) -> float_type, {256}, {1}\n", + "@923 = identity(@166) -> float_type, {256}, {1}\n", + "@924 = identity(@165) -> float_type, {256}, {1}\n", + "@925 = unknown:FusedBatchNormV3(@920,@921,@922,@923,@924) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@926 = transpose[dims={0, 3, 1, 2}](@925) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@927 = relu(@926) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@928 = transpose[dims={0, 2, 3, 1}](@927) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@929 = transpose[dims={0, 2, 3, 1}](@164) -> float_type, {3, 256, 256, 3}, {196608, 256, 1, 65536}\n", + "@930 = transpose[dims={0, 3, 1, 2}](@929) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@931 = identity(@930) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@932 = transpose[dims={0, 2, 3, 1}](@931) -> float_type, {3, 256, 256, 3}, {196608, 256, 1, 65536}\n", + "@933 = transpose[dims={0, 3, 1, 2}](@928) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@934 = transpose[dims={0, 3, 1, 2}](@932) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@935 = transpose[dims={3, 2, 0, 1}](@934) -> float_type, {256, 256, 3, 3}, {1, 256, 196608, 65536}\n", + "@936 = transpose[dims={3, 2, 0, 1}](@934) -> float_type, {256, 256, 3, 3}, {1, 256, 196608, 65536}\n", + "@937 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@933,@936) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@938 = transpose[dims={0, 2, 3, 1}](@937) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@939 = identity(@163) -> float_type, {256}, {1}\n", + "@940 = transpose[dims={0, 3, 1, 2}](@938) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@941 = broadcast[axis=1,dims={1, 256, 14, 14}](@939) -> float_type, {1, 256, 14, 14}, {0, 1, 0, 0}\n", + "@942 = add(@940,@941) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@943 = transpose[dims={0, 2, 3, 1}](@942) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@944 = identity(@162) -> float_type, {256}, {1}\n", + "@945 = identity(@161) -> float_type, {256}, {1}\n", + "@946 = identity(@160) -> float_type, {256}, {1}\n", + "@947 = identity(@159) -> float_type, {256}, {1}\n", + "@948 = unknown:FusedBatchNormV3(@943,@944,@945,@946,@947) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@949 = transpose[dims={0, 3, 1, 2}](@948) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@950 = relu(@949) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@951 = transpose[dims={0, 2, 3, 1}](@950) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@952 = transpose[dims={0, 2, 3, 1}](@158) -> float_type, {1, 256, 1024, 1}, {262144, 1024, 1, 262144}\n", + "@953 = transpose[dims={0, 3, 1, 2}](@952) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@954 = identity(@953) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@955 = transpose[dims={0, 2, 3, 1}](@954) -> float_type, {1, 256, 1024, 1}, {262144, 1024, 1, 262144}\n", + "@956 = transpose[dims={0, 3, 1, 2}](@951) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@957 = transpose[dims={0, 3, 1, 2}](@955) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@958 = transpose[dims={3, 2, 0, 1}](@957) -> float_type, {1024, 256, 1, 1}, {1, 1024, 262144, 262144}\n", + "@959 = transpose[dims={3, 2, 0, 1}](@957) -> float_type, {1024, 256, 1, 1}, {1, 1024, 262144, 262144}\n", + "@960 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@956,@959) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@961 = transpose[dims={0, 2, 3, 1}](@960) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@962 = identity(@157) -> float_type, {1024}, {1}\n", + "@963 = transpose[dims={0, 3, 1, 2}](@961) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@964 = broadcast[axis=1,dims={1, 1024, 14, 14}](@962) -> float_type, {1, 1024, 14, 14}, {0, 1, 0, 0}\n", + "@965 = add(@963,@964) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@966 = transpose[dims={0, 2, 3, 1}](@965) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@967 = identity(@156) -> float_type, {1024}, {1}\n", + "@968 = identity(@155) -> float_type, {1024}, {1}\n", + "@969 = identity(@154) -> float_type, {1024}, {1}\n", + "@970 = identity(@153) -> float_type, {1024}, {1}\n", + "@971 = unknown:FusedBatchNormV3(@966,@967,@968,@969,@970) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@972 = unknown:AddV2(@905,@971) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@973 = transpose[dims={0, 3, 1, 2}](@972) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@974 = relu(@973) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@975 = transpose[dims={0, 2, 3, 1}](@974) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@976 = transpose[dims={0, 2, 3, 1}](@152) -> float_type, {1, 1024, 256, 1}, {262144, 256, 1, 262144}\n", + "@977 = transpose[dims={0, 3, 1, 2}](@976) -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@978 = identity(@977) -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@979 = transpose[dims={0, 2, 3, 1}](@978) -> float_type, {1, 1024, 256, 1}, {262144, 256, 1, 262144}\n", + "@980 = transpose[dims={0, 3, 1, 2}](@975) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@981 = transpose[dims={0, 3, 1, 2}](@979) -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@982 = transpose[dims={3, 2, 0, 1}](@981) -> float_type, {256, 1024, 1, 1}, {1, 256, 262144, 262144}\n", + "@983 = transpose[dims={3, 2, 0, 1}](@981) -> float_type, {256, 1024, 1, 1}, {1, 256, 262144, 262144}\n", + "@984 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@980,@983) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@985 = transpose[dims={0, 2, 3, 1}](@984) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@986 = identity(@151) -> float_type, {256}, {1}\n", + "@987 = transpose[dims={0, 3, 1, 2}](@985) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@988 = broadcast[axis=1,dims={1, 256, 14, 14}](@986) -> float_type, {1, 256, 14, 14}, {0, 1, 0, 0}\n", + "@989 = add(@987,@988) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@990 = transpose[dims={0, 2, 3, 1}](@989) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@991 = identity(@150) -> float_type, {256}, {1}\n", + "@992 = identity(@149) -> float_type, {256}, {1}\n", + "@993 = identity(@148) -> float_type, {256}, {1}\n", + "@994 = identity(@147) -> float_type, {256}, {1}\n", + "@995 = unknown:FusedBatchNormV3(@990,@991,@992,@993,@994) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@996 = transpose[dims={0, 3, 1, 2}](@995) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@997 = relu(@996) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@998 = transpose[dims={0, 2, 3, 1}](@997) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@999 = transpose[dims={0, 2, 3, 1}](@146) -> float_type, {3, 256, 256, 3}, {196608, 256, 1, 65536}\n", + "@1000 = transpose[dims={0, 3, 1, 2}](@999) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@1001 = identity(@1000) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@1002 = transpose[dims={0, 2, 3, 1}](@1001) -> float_type, {3, 256, 256, 3}, {196608, 256, 1, 65536}\n", + "@1003 = transpose[dims={0, 3, 1, 2}](@998) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1004 = transpose[dims={0, 3, 1, 2}](@1002) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@1005 = transpose[dims={3, 2, 0, 1}](@1004) -> float_type, {256, 256, 3, 3}, {1, 256, 196608, 65536}\n", + "@1006 = transpose[dims={3, 2, 0, 1}](@1004) -> float_type, {256, 256, 3, 3}, {1, 256, 196608, 65536}\n", + "@1007 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@1003,@1006) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1008 = transpose[dims={0, 2, 3, 1}](@1007) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1009 = identity(@145) -> float_type, {256}, {1}\n", + "@1010 = transpose[dims={0, 3, 1, 2}](@1008) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1011 = broadcast[axis=1,dims={1, 256, 14, 14}](@1009) -> float_type, {1, 256, 14, 14}, {0, 1, 0, 0}\n", + "@1012 = add(@1010,@1011) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1013 = transpose[dims={0, 2, 3, 1}](@1012) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1014 = identity(@144) -> float_type, {256}, {1}\n", + "@1015 = identity(@143) -> float_type, {256}, {1}\n", + "@1016 = identity(@142) -> float_type, {256}, {1}\n", + "@1017 = identity(@141) -> float_type, {256}, {1}\n", + "@1018 = unknown:FusedBatchNormV3(@1013,@1014,@1015,@1016,@1017) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1019 = transpose[dims={0, 3, 1, 2}](@1018) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1020 = relu(@1019) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1021 = transpose[dims={0, 2, 3, 1}](@1020) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1022 = transpose[dims={0, 2, 3, 1}](@140) -> float_type, {1, 256, 1024, 1}, {262144, 1024, 1, 262144}\n", + "@1023 = transpose[dims={0, 3, 1, 2}](@1022) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@1024 = identity(@1023) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@1025 = transpose[dims={0, 2, 3, 1}](@1024) -> float_type, {1, 256, 1024, 1}, {262144, 1024, 1, 262144}\n", + "@1026 = transpose[dims={0, 3, 1, 2}](@1021) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1027 = transpose[dims={0, 3, 1, 2}](@1025) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@1028 = transpose[dims={3, 2, 0, 1}](@1027) -> float_type, {1024, 256, 1, 1}, {1, 1024, 262144, 262144}\n", + "@1029 = transpose[dims={3, 2, 0, 1}](@1027) -> float_type, {1024, 256, 1, 1}, {1, 1024, 262144, 262144}\n", + "@1030 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@1026,@1029) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1031 = transpose[dims={0, 2, 3, 1}](@1030) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1032 = identity(@139) -> float_type, {1024}, {1}\n", + "@1033 = transpose[dims={0, 3, 1, 2}](@1031) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1034 = broadcast[axis=1,dims={1, 1024, 14, 14}](@1032) -> float_type, {1, 1024, 14, 14}, {0, 1, 0, 0}\n", + "@1035 = add(@1033,@1034) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1036 = transpose[dims={0, 2, 3, 1}](@1035) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1037 = identity(@138) -> float_type, {1024}, {1}\n", + "@1038 = identity(@137) -> float_type, {1024}, {1}\n", + "@1039 = identity(@136) -> float_type, {1024}, {1}\n", + "@1040 = identity(@135) -> float_type, {1024}, {1}\n", + "@1041 = unknown:FusedBatchNormV3(@1036,@1037,@1038,@1039,@1040) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1042 = unknown:AddV2(@975,@1041) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1043 = transpose[dims={0, 3, 1, 2}](@1042) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1044 = relu(@1043) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1045 = transpose[dims={0, 2, 3, 1}](@1044) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1046 = transpose[dims={0, 2, 3, 1}](@134) -> float_type, {1, 1024, 256, 1}, {262144, 256, 1, 262144}\n", + "@1047 = transpose[dims={0, 3, 1, 2}](@1046) -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@1048 = identity(@1047) -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@1049 = transpose[dims={0, 2, 3, 1}](@1048) -> float_type, {1, 1024, 256, 1}, {262144, 256, 1, 262144}\n", + "@1050 = transpose[dims={0, 3, 1, 2}](@1045) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1051 = transpose[dims={0, 3, 1, 2}](@1049) -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@1052 = transpose[dims={3, 2, 0, 1}](@1051) -> float_type, {256, 1024, 1, 1}, {1, 256, 262144, 262144}\n", + "@1053 = transpose[dims={3, 2, 0, 1}](@1051) -> float_type, {256, 1024, 1, 1}, {1, 256, 262144, 262144}\n", + "@1054 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@1050,@1053) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1055 = transpose[dims={0, 2, 3, 1}](@1054) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1056 = identity(@133) -> float_type, {256}, {1}\n", + "@1057 = transpose[dims={0, 3, 1, 2}](@1055) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1058 = broadcast[axis=1,dims={1, 256, 14, 14}](@1056) -> float_type, {1, 256, 14, 14}, {0, 1, 0, 0}\n", + "@1059 = add(@1057,@1058) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1060 = transpose[dims={0, 2, 3, 1}](@1059) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1061 = identity(@132) -> float_type, {256}, {1}\n", + "@1062 = identity(@131) -> float_type, {256}, {1}\n", + "@1063 = identity(@130) -> float_type, {256}, {1}\n", + "@1064 = identity(@129) -> float_type, {256}, {1}\n", + "@1065 = unknown:FusedBatchNormV3(@1060,@1061,@1062,@1063,@1064) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1066 = transpose[dims={0, 3, 1, 2}](@1065) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1067 = relu(@1066) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1068 = transpose[dims={0, 2, 3, 1}](@1067) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1069 = transpose[dims={0, 2, 3, 1}](@128) -> float_type, {3, 256, 256, 3}, {196608, 256, 1, 65536}\n", + "@1070 = transpose[dims={0, 3, 1, 2}](@1069) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@1071 = identity(@1070) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@1072 = transpose[dims={0, 2, 3, 1}](@1071) -> float_type, {3, 256, 256, 3}, {196608, 256, 1, 65536}\n", + "@1073 = transpose[dims={0, 3, 1, 2}](@1068) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1074 = transpose[dims={0, 3, 1, 2}](@1072) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@1075 = transpose[dims={3, 2, 0, 1}](@1074) -> float_type, {256, 256, 3, 3}, {1, 256, 196608, 65536}\n", + "@1076 = transpose[dims={3, 2, 0, 1}](@1074) -> float_type, {256, 256, 3, 3}, {1, 256, 196608, 65536}\n", + "@1077 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@1073,@1076) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1078 = transpose[dims={0, 2, 3, 1}](@1077) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1079 = identity(@127) -> float_type, {256}, {1}\n", + "@1080 = transpose[dims={0, 3, 1, 2}](@1078) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1081 = broadcast[axis=1,dims={1, 256, 14, 14}](@1079) -> float_type, {1, 256, 14, 14}, {0, 1, 0, 0}\n", + "@1082 = add(@1080,@1081) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1083 = transpose[dims={0, 2, 3, 1}](@1082) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1084 = identity(@126) -> float_type, {256}, {1}\n", + "@1085 = identity(@125) -> float_type, {256}, {1}\n", + "@1086 = identity(@124) -> float_type, {256}, {1}\n", + "@1087 = identity(@123) -> float_type, {256}, {1}\n", + "@1088 = unknown:FusedBatchNormV3(@1083,@1084,@1085,@1086,@1087) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1089 = transpose[dims={0, 3, 1, 2}](@1088) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1090 = relu(@1089) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1091 = transpose[dims={0, 2, 3, 1}](@1090) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1092 = transpose[dims={0, 2, 3, 1}](@122) -> float_type, {1, 256, 1024, 1}, {262144, 1024, 1, 262144}\n", + "@1093 = transpose[dims={0, 3, 1, 2}](@1092) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@1094 = identity(@1093) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@1095 = transpose[dims={0, 2, 3, 1}](@1094) -> float_type, {1, 256, 1024, 1}, {262144, 1024, 1, 262144}\n", + "@1096 = transpose[dims={0, 3, 1, 2}](@1091) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1097 = transpose[dims={0, 3, 1, 2}](@1095) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@1098 = transpose[dims={3, 2, 0, 1}](@1097) -> float_type, {1024, 256, 1, 1}, {1, 1024, 262144, 262144}\n", + "@1099 = transpose[dims={3, 2, 0, 1}](@1097) -> float_type, {1024, 256, 1, 1}, {1, 1024, 262144, 262144}\n", + "@1100 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@1096,@1099) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1101 = transpose[dims={0, 2, 3, 1}](@1100) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1102 = identity(@121) -> float_type, {1024}, {1}\n", + "@1103 = transpose[dims={0, 3, 1, 2}](@1101) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1104 = broadcast[axis=1,dims={1, 1024, 14, 14}](@1102) -> float_type, {1, 1024, 14, 14}, {0, 1, 0, 0}\n", + "@1105 = add(@1103,@1104) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1106 = transpose[dims={0, 2, 3, 1}](@1105) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1107 = identity(@120) -> float_type, {1024}, {1}\n", + "@1108 = identity(@119) -> float_type, {1024}, {1}\n", + "@1109 = identity(@118) -> float_type, {1024}, {1}\n", + "@1110 = identity(@117) -> float_type, {1024}, {1}\n", + "@1111 = unknown:FusedBatchNormV3(@1106,@1107,@1108,@1109,@1110) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1112 = unknown:AddV2(@1045,@1111) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1113 = transpose[dims={0, 3, 1, 2}](@1112) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1114 = relu(@1113) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1115 = transpose[dims={0, 2, 3, 1}](@1114) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1116 = transpose[dims={0, 2, 3, 1}](@116) -> float_type, {1, 1024, 256, 1}, {262144, 256, 1, 262144}\n", + "@1117 = transpose[dims={0, 3, 1, 2}](@1116) -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@1118 = identity(@1117) -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@1119 = transpose[dims={0, 2, 3, 1}](@1118) -> float_type, {1, 1024, 256, 1}, {262144, 256, 1, 262144}\n", + "@1120 = transpose[dims={0, 3, 1, 2}](@1115) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1121 = transpose[dims={0, 3, 1, 2}](@1119) -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@1122 = transpose[dims={3, 2, 0, 1}](@1121) -> float_type, {256, 1024, 1, 1}, {1, 256, 262144, 262144}\n", + "@1123 = transpose[dims={3, 2, 0, 1}](@1121) -> float_type, {256, 1024, 1, 1}, {1, 256, 262144, 262144}\n", + "@1124 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@1120,@1123) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1125 = transpose[dims={0, 2, 3, 1}](@1124) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1126 = identity(@115) -> float_type, {256}, {1}\n", + "@1127 = transpose[dims={0, 3, 1, 2}](@1125) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1128 = broadcast[axis=1,dims={1, 256, 14, 14}](@1126) -> float_type, {1, 256, 14, 14}, {0, 1, 0, 0}\n", + "@1129 = add(@1127,@1128) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1130 = transpose[dims={0, 2, 3, 1}](@1129) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1131 = identity(@114) -> float_type, {256}, {1}\n", + "@1132 = identity(@113) -> float_type, {256}, {1}\n", + "@1133 = identity(@112) -> float_type, {256}, {1}\n", + "@1134 = identity(@111) -> float_type, {256}, {1}\n", + "@1135 = unknown:FusedBatchNormV3(@1130,@1131,@1132,@1133,@1134) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1136 = transpose[dims={0, 3, 1, 2}](@1135) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1137 = relu(@1136) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1138 = transpose[dims={0, 2, 3, 1}](@1137) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1139 = transpose[dims={0, 2, 3, 1}](@110) -> float_type, {3, 256, 256, 3}, {196608, 256, 1, 65536}\n", + "@1140 = transpose[dims={0, 3, 1, 2}](@1139) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@1141 = identity(@1140) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@1142 = transpose[dims={0, 2, 3, 1}](@1141) -> float_type, {3, 256, 256, 3}, {196608, 256, 1, 65536}\n", + "@1143 = transpose[dims={0, 3, 1, 2}](@1138) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1144 = transpose[dims={0, 3, 1, 2}](@1142) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@1145 = transpose[dims={3, 2, 0, 1}](@1144) -> float_type, {256, 256, 3, 3}, {1, 256, 196608, 65536}\n", + "@1146 = transpose[dims={3, 2, 0, 1}](@1144) -> float_type, {256, 256, 3, 3}, {1, 256, 196608, 65536}\n", + "@1147 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@1143,@1146) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1148 = transpose[dims={0, 2, 3, 1}](@1147) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1149 = identity(@109) -> float_type, {256}, {1}\n", + "@1150 = transpose[dims={0, 3, 1, 2}](@1148) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1151 = broadcast[axis=1,dims={1, 256, 14, 14}](@1149) -> float_type, {1, 256, 14, 14}, {0, 1, 0, 0}\n", + "@1152 = add(@1150,@1151) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1153 = transpose[dims={0, 2, 3, 1}](@1152) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1154 = identity(@108) -> float_type, {256}, {1}\n", + "@1155 = identity(@107) -> float_type, {256}, {1}\n", + "@1156 = identity(@106) -> float_type, {256}, {1}\n", + "@1157 = identity(@105) -> float_type, {256}, {1}\n", + "@1158 = unknown:FusedBatchNormV3(@1153,@1154,@1155,@1156,@1157) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1159 = transpose[dims={0, 3, 1, 2}](@1158) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1160 = relu(@1159) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1161 = transpose[dims={0, 2, 3, 1}](@1160) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1162 = transpose[dims={0, 2, 3, 1}](@104) -> float_type, {1, 256, 1024, 1}, {262144, 1024, 1, 262144}\n", + "@1163 = transpose[dims={0, 3, 1, 2}](@1162) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@1164 = identity(@1163) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@1165 = transpose[dims={0, 2, 3, 1}](@1164) -> float_type, {1, 256, 1024, 1}, {262144, 1024, 1, 262144}\n", + "@1166 = transpose[dims={0, 3, 1, 2}](@1161) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1167 = transpose[dims={0, 3, 1, 2}](@1165) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@1168 = transpose[dims={3, 2, 0, 1}](@1167) -> float_type, {1024, 256, 1, 1}, {1, 1024, 262144, 262144}\n", + "@1169 = transpose[dims={3, 2, 0, 1}](@1167) -> float_type, {1024, 256, 1, 1}, {1, 1024, 262144, 262144}\n", + "@1170 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@1166,@1169) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1171 = transpose[dims={0, 2, 3, 1}](@1170) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1172 = identity(@103) -> float_type, {1024}, {1}\n", + "@1173 = transpose[dims={0, 3, 1, 2}](@1171) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1174 = broadcast[axis=1,dims={1, 1024, 14, 14}](@1172) -> float_type, {1, 1024, 14, 14}, {0, 1, 0, 0}\n", + "@1175 = add(@1173,@1174) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1176 = transpose[dims={0, 2, 3, 1}](@1175) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1177 = identity(@102) -> float_type, {1024}, {1}\n", + "@1178 = identity(@101) -> float_type, {1024}, {1}\n", + "@1179 = identity(@100) -> float_type, {1024}, {1}\n", + "@1180 = identity(@99) -> float_type, {1024}, {1}\n", + "@1181 = unknown:FusedBatchNormV3(@1176,@1177,@1178,@1179,@1180) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1182 = unknown:AddV2(@1115,@1181) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1183 = transpose[dims={0, 3, 1, 2}](@1182) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1184 = relu(@1183) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1185 = transpose[dims={0, 2, 3, 1}](@1184) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1186 = transpose[dims={0, 2, 3, 1}](@98) -> float_type, {1, 1024, 256, 1}, {262144, 256, 1, 262144}\n", + "@1187 = transpose[dims={0, 3, 1, 2}](@1186) -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@1188 = identity(@1187) -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@1189 = transpose[dims={0, 2, 3, 1}](@1188) -> float_type, {1, 1024, 256, 1}, {262144, 256, 1, 262144}\n", + "@1190 = transpose[dims={0, 3, 1, 2}](@1185) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1191 = transpose[dims={0, 3, 1, 2}](@1189) -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@1192 = transpose[dims={3, 2, 0, 1}](@1191) -> float_type, {256, 1024, 1, 1}, {1, 256, 262144, 262144}\n", + "@1193 = transpose[dims={3, 2, 0, 1}](@1191) -> float_type, {256, 1024, 1, 1}, {1, 256, 262144, 262144}\n", + "@1194 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@1190,@1193) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1195 = transpose[dims={0, 2, 3, 1}](@1194) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1196 = identity(@97) -> float_type, {256}, {1}\n", + "@1197 = transpose[dims={0, 3, 1, 2}](@1195) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1198 = broadcast[axis=1,dims={1, 256, 14, 14}](@1196) -> float_type, {1, 256, 14, 14}, {0, 1, 0, 0}\n", + "@1199 = add(@1197,@1198) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1200 = transpose[dims={0, 2, 3, 1}](@1199) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1201 = identity(@96) -> float_type, {256}, {1}\n", + "@1202 = identity(@95) -> float_type, {256}, {1}\n", + "@1203 = identity(@94) -> float_type, {256}, {1}\n", + "@1204 = identity(@93) -> float_type, {256}, {1}\n", + "@1205 = unknown:FusedBatchNormV3(@1200,@1201,@1202,@1203,@1204) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1206 = transpose[dims={0, 3, 1, 2}](@1205) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1207 = relu(@1206) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1208 = transpose[dims={0, 2, 3, 1}](@1207) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1209 = transpose[dims={0, 2, 3, 1}](@92) -> float_type, {3, 256, 256, 3}, {196608, 256, 1, 65536}\n", + "@1210 = transpose[dims={0, 3, 1, 2}](@1209) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@1211 = identity(@1210) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@1212 = transpose[dims={0, 2, 3, 1}](@1211) -> float_type, {3, 256, 256, 3}, {196608, 256, 1, 65536}\n", + "@1213 = transpose[dims={0, 3, 1, 2}](@1208) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1214 = transpose[dims={0, 3, 1, 2}](@1212) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@1215 = transpose[dims={3, 2, 0, 1}](@1214) -> float_type, {256, 256, 3, 3}, {1, 256, 196608, 65536}\n", + "@1216 = transpose[dims={3, 2, 0, 1}](@1214) -> float_type, {256, 256, 3, 3}, {1, 256, 196608, 65536}\n", + "@1217 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@1213,@1216) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1218 = transpose[dims={0, 2, 3, 1}](@1217) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1219 = identity(@91) -> float_type, {256}, {1}\n", + "@1220 = transpose[dims={0, 3, 1, 2}](@1218) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1221 = broadcast[axis=1,dims={1, 256, 14, 14}](@1219) -> float_type, {1, 256, 14, 14}, {0, 1, 0, 0}\n", + "@1222 = add(@1220,@1221) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1223 = transpose[dims={0, 2, 3, 1}](@1222) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1224 = identity(@90) -> float_type, {256}, {1}\n", + "@1225 = identity(@89) -> float_type, {256}, {1}\n", + "@1226 = identity(@88) -> float_type, {256}, {1}\n", + "@1227 = identity(@87) -> float_type, {256}, {1}\n", + "@1228 = unknown:FusedBatchNormV3(@1223,@1224,@1225,@1226,@1227) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1229 = transpose[dims={0, 3, 1, 2}](@1228) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1230 = relu(@1229) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1231 = transpose[dims={0, 2, 3, 1}](@1230) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1232 = transpose[dims={0, 2, 3, 1}](@86) -> float_type, {1, 256, 1024, 1}, {262144, 1024, 1, 262144}\n", + "@1233 = transpose[dims={0, 3, 1, 2}](@1232) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@1234 = identity(@1233) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@1235 = transpose[dims={0, 2, 3, 1}](@1234) -> float_type, {1, 256, 1024, 1}, {262144, 1024, 1, 262144}\n", + "@1236 = transpose[dims={0, 3, 1, 2}](@1231) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1237 = transpose[dims={0, 3, 1, 2}](@1235) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@1238 = transpose[dims={3, 2, 0, 1}](@1237) -> float_type, {1024, 256, 1, 1}, {1, 1024, 262144, 262144}\n", + "@1239 = transpose[dims={3, 2, 0, 1}](@1237) -> float_type, {1024, 256, 1, 1}, {1, 1024, 262144, 262144}\n", + "@1240 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@1236,@1239) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1241 = transpose[dims={0, 2, 3, 1}](@1240) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1242 = identity(@85) -> float_type, {1024}, {1}\n", + "@1243 = transpose[dims={0, 3, 1, 2}](@1241) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1244 = broadcast[axis=1,dims={1, 1024, 14, 14}](@1242) -> float_type, {1, 1024, 14, 14}, {0, 1, 0, 0}\n", + "@1245 = add(@1243,@1244) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1246 = transpose[dims={0, 2, 3, 1}](@1245) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1247 = identity(@84) -> float_type, {1024}, {1}\n", + "@1248 = identity(@83) -> float_type, {1024}, {1}\n", + "@1249 = identity(@82) -> float_type, {1024}, {1}\n", + "@1250 = identity(@81) -> float_type, {1024}, {1}\n", + "@1251 = unknown:FusedBatchNormV3(@1246,@1247,@1248,@1249,@1250) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1252 = unknown:AddV2(@1185,@1251) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1253 = transpose[dims={0, 3, 1, 2}](@1252) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1254 = relu(@1253) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1255 = transpose[dims={0, 2, 3, 1}](@1254) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1256 = transpose[dims={0, 2, 3, 1}](@80) -> float_type, {1, 1024, 256, 1}, {262144, 256, 1, 262144}\n", + "@1257 = transpose[dims={0, 3, 1, 2}](@1256) -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@1258 = identity(@1257) -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@1259 = transpose[dims={0, 2, 3, 1}](@1258) -> float_type, {1, 1024, 256, 1}, {262144, 256, 1, 262144}\n", + "@1260 = transpose[dims={0, 3, 1, 2}](@1255) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1261 = transpose[dims={0, 3, 1, 2}](@1259) -> float_type, {1, 1, 1024, 256}, {262144, 262144, 256, 1}\n", + "@1262 = transpose[dims={3, 2, 0, 1}](@1261) -> float_type, {256, 1024, 1, 1}, {1, 256, 262144, 262144}\n", + "@1263 = transpose[dims={3, 2, 0, 1}](@1261) -> float_type, {256, 1024, 1, 1}, {1, 256, 262144, 262144}\n", + "@1264 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@1260,@1263) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1265 = transpose[dims={0, 2, 3, 1}](@1264) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1266 = identity(@79) -> float_type, {256}, {1}\n", + "@1267 = transpose[dims={0, 3, 1, 2}](@1265) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1268 = broadcast[axis=1,dims={1, 256, 14, 14}](@1266) -> float_type, {1, 256, 14, 14}, {0, 1, 0, 0}\n", + "@1269 = add(@1267,@1268) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1270 = transpose[dims={0, 2, 3, 1}](@1269) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1271 = identity(@78) -> float_type, {256}, {1}\n", + "@1272 = identity(@77) -> float_type, {256}, {1}\n", + "@1273 = identity(@76) -> float_type, {256}, {1}\n", + "@1274 = identity(@75) -> float_type, {256}, {1}\n", + "@1275 = unknown:FusedBatchNormV3(@1270,@1271,@1272,@1273,@1274) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1276 = transpose[dims={0, 3, 1, 2}](@1275) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1277 = relu(@1276) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1278 = transpose[dims={0, 2, 3, 1}](@1277) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1279 = transpose[dims={0, 2, 3, 1}](@74) -> float_type, {3, 256, 256, 3}, {196608, 256, 1, 65536}\n", + "@1280 = transpose[dims={0, 3, 1, 2}](@1279) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@1281 = identity(@1280) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@1282 = transpose[dims={0, 2, 3, 1}](@1281) -> float_type, {3, 256, 256, 3}, {196608, 256, 1, 65536}\n", + "@1283 = transpose[dims={0, 3, 1, 2}](@1278) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1284 = transpose[dims={0, 3, 1, 2}](@1282) -> float_type, {3, 3, 256, 256}, {196608, 65536, 256, 1}\n", + "@1285 = transpose[dims={3, 2, 0, 1}](@1284) -> float_type, {256, 256, 3, 3}, {1, 256, 196608, 65536}\n", + "@1286 = transpose[dims={3, 2, 0, 1}](@1284) -> float_type, {256, 256, 3, 3}, {1, 256, 196608, 65536}\n", + "@1287 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@1283,@1286) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1288 = transpose[dims={0, 2, 3, 1}](@1287) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1289 = identity(@73) -> float_type, {256}, {1}\n", + "@1290 = transpose[dims={0, 3, 1, 2}](@1288) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1291 = broadcast[axis=1,dims={1, 256, 14, 14}](@1289) -> float_type, {1, 256, 14, 14}, {0, 1, 0, 0}\n", + "@1292 = add(@1290,@1291) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1293 = transpose[dims={0, 2, 3, 1}](@1292) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1294 = identity(@72) -> float_type, {256}, {1}\n", + "@1295 = identity(@71) -> float_type, {256}, {1}\n", + "@1296 = identity(@70) -> float_type, {256}, {1}\n", + "@1297 = identity(@69) -> float_type, {256}, {1}\n", + "@1298 = unknown:FusedBatchNormV3(@1293,@1294,@1295,@1296,@1297) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1299 = transpose[dims={0, 3, 1, 2}](@1298) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1300 = relu(@1299) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1301 = transpose[dims={0, 2, 3, 1}](@1300) -> float_type, {1, 14, 14, 256}, {50176, 14, 1, 196}\n", + "@1302 = transpose[dims={0, 2, 3, 1}](@68) -> float_type, {1, 256, 1024, 1}, {262144, 1024, 1, 262144}\n", + "@1303 = transpose[dims={0, 3, 1, 2}](@1302) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@1304 = identity(@1303) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@1305 = transpose[dims={0, 2, 3, 1}](@1304) -> float_type, {1, 256, 1024, 1}, {262144, 1024, 1, 262144}\n", + "@1306 = transpose[dims={0, 3, 1, 2}](@1301) -> float_type, {1, 256, 14, 14}, {50176, 196, 14, 1}\n", + "@1307 = transpose[dims={0, 3, 1, 2}](@1305) -> float_type, {1, 1, 256, 1024}, {262144, 262144, 1024, 1}\n", + "@1308 = transpose[dims={3, 2, 0, 1}](@1307) -> float_type, {1024, 256, 1, 1}, {1, 1024, 262144, 262144}\n", + "@1309 = transpose[dims={3, 2, 0, 1}](@1307) -> float_type, {1024, 256, 1, 1}, {1, 1024, 262144, 262144}\n", + "@1310 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@1306,@1309) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1311 = transpose[dims={0, 2, 3, 1}](@1310) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1312 = identity(@67) -> float_type, {1024}, {1}\n", + "@1313 = transpose[dims={0, 3, 1, 2}](@1311) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1314 = broadcast[axis=1,dims={1, 1024, 14, 14}](@1312) -> float_type, {1, 1024, 14, 14}, {0, 1, 0, 0}\n", + "@1315 = add(@1313,@1314) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1316 = transpose[dims={0, 2, 3, 1}](@1315) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1317 = identity(@66) -> float_type, {1024}, {1}\n", + "@1318 = identity(@65) -> float_type, {1024}, {1}\n", + "@1319 = identity(@64) -> float_type, {1024}, {1}\n", + "@1320 = identity(@63) -> float_type, {1024}, {1}\n", + "@1321 = unknown:FusedBatchNormV3(@1316,@1317,@1318,@1319,@1320) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1322 = unknown:AddV2(@1255,@1321) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1323 = transpose[dims={0, 3, 1, 2}](@1322) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1324 = relu(@1323) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1325 = transpose[dims={0, 2, 3, 1}](@1324) -> float_type, {1, 14, 14, 1024}, {200704, 14, 1, 196}\n", + "@1326 = transpose[dims={0, 2, 3, 1}](@62) -> float_type, {1, 1024, 2048, 1}, {2097152, 2048, 1, 2097152}\n", + "@1327 = transpose[dims={0, 3, 1, 2}](@1326) -> float_type, {1, 1, 1024, 2048}, {2097152, 2097152, 2048, 1}\n", + "@1328 = identity(@1327) -> float_type, {1, 1, 1024, 2048}, {2097152, 2097152, 2048, 1}\n", + "@1329 = transpose[dims={0, 2, 3, 1}](@1328) -> float_type, {1, 1024, 2048, 1}, {2097152, 2048, 1, 2097152}\n", + "@1330 = transpose[dims={0, 3, 1, 2}](@1325) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1331 = transpose[dims={0, 3, 1, 2}](@1329) -> float_type, {1, 1, 1024, 2048}, {2097152, 2097152, 2048, 1}\n", + "@1332 = transpose[dims={3, 2, 0, 1}](@1331) -> float_type, {2048, 1024, 1, 1}, {1, 2048, 2097152, 2097152}\n", + "@1333 = transpose[dims={3, 2, 0, 1}](@1331) -> float_type, {2048, 1024, 1, 1}, {1, 2048, 2097152, 2097152}\n", + "@1334 = convolution[padding={0, 0},stride={2, 2},dilation={1, 1},group=1,padding_mode=2](@1330,@1333) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1335 = transpose[dims={0, 2, 3, 1}](@1334) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1336 = identity(@61) -> float_type, {2048}, {1}\n", + "@1337 = transpose[dims={0, 3, 1, 2}](@1335) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1338 = broadcast[axis=1,dims={1, 2048, 7, 7}](@1336) -> float_type, {1, 2048, 7, 7}, {0, 1, 0, 0}\n", + "@1339 = add(@1337,@1338) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1340 = transpose[dims={0, 2, 3, 1}](@1339) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1341 = identity(@60) -> float_type, {2048}, {1}\n", + "@1342 = identity(@59) -> float_type, {2048}, {1}\n", + "@1343 = identity(@58) -> float_type, {2048}, {1}\n", + "@1344 = identity(@57) -> float_type, {2048}, {1}\n", + "@1345 = unknown:FusedBatchNormV3(@1340,@1341,@1342,@1343,@1344) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1346 = transpose[dims={0, 2, 3, 1}](@56) -> float_type, {1, 1024, 512, 1}, {524288, 512, 1, 524288}\n", + "@1347 = transpose[dims={0, 3, 1, 2}](@1346) -> float_type, {1, 1, 1024, 512}, {524288, 524288, 512, 1}\n", + "@1348 = identity(@1347) -> float_type, {1, 1, 1024, 512}, {524288, 524288, 512, 1}\n", + "@1349 = transpose[dims={0, 2, 3, 1}](@1348) -> float_type, {1, 1024, 512, 1}, {524288, 512, 1, 524288}\n", + "@1350 = transpose[dims={0, 3, 1, 2}](@1325) -> float_type, {1, 1024, 14, 14}, {200704, 196, 14, 1}\n", + "@1351 = transpose[dims={0, 3, 1, 2}](@1349) -> float_type, {1, 1, 1024, 512}, {524288, 524288, 512, 1}\n", + "@1352 = transpose[dims={3, 2, 0, 1}](@1351) -> float_type, {512, 1024, 1, 1}, {1, 512, 524288, 524288}\n", + "@1353 = transpose[dims={3, 2, 0, 1}](@1351) -> float_type, {512, 1024, 1, 1}, {1, 512, 524288, 524288}\n", + "@1354 = convolution[padding={0, 0},stride={2, 2},dilation={1, 1},group=1,padding_mode=2](@1350,@1353) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1355 = transpose[dims={0, 2, 3, 1}](@1354) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1356 = identity(@55) -> float_type, {512}, {1}\n", + "@1357 = transpose[dims={0, 3, 1, 2}](@1355) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1358 = broadcast[axis=1,dims={1, 512, 7, 7}](@1356) -> float_type, {1, 512, 7, 7}, {0, 1, 0, 0}\n", + "@1359 = add(@1357,@1358) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1360 = transpose[dims={0, 2, 3, 1}](@1359) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1361 = identity(@54) -> float_type, {512}, {1}\n", + "@1362 = identity(@53) -> float_type, {512}, {1}\n", + "@1363 = identity(@52) -> float_type, {512}, {1}\n", + "@1364 = identity(@51) -> float_type, {512}, {1}\n", + "@1365 = unknown:FusedBatchNormV3(@1360,@1361,@1362,@1363,@1364) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1366 = transpose[dims={0, 3, 1, 2}](@1365) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1367 = relu(@1366) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1368 = transpose[dims={0, 2, 3, 1}](@1367) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1369 = transpose[dims={0, 2, 3, 1}](@50) -> float_type, {3, 512, 512, 3}, {786432, 512, 1, 262144}\n", + "@1370 = transpose[dims={0, 3, 1, 2}](@1369) -> float_type, {3, 3, 512, 512}, {786432, 262144, 512, 1}\n", + "@1371 = identity(@1370) -> float_type, {3, 3, 512, 512}, {786432, 262144, 512, 1}\n", + "@1372 = transpose[dims={0, 2, 3, 1}](@1371) -> float_type, {3, 512, 512, 3}, {786432, 512, 1, 262144}\n", + "@1373 = transpose[dims={0, 3, 1, 2}](@1368) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1374 = transpose[dims={0, 3, 1, 2}](@1372) -> float_type, {3, 3, 512, 512}, {786432, 262144, 512, 1}\n", + "@1375 = transpose[dims={3, 2, 0, 1}](@1374) -> float_type, {512, 512, 3, 3}, {1, 512, 786432, 262144}\n", + "@1376 = transpose[dims={3, 2, 0, 1}](@1374) -> float_type, {512, 512, 3, 3}, {1, 512, 786432, 262144}\n", + "@1377 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@1373,@1376) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1378 = transpose[dims={0, 2, 3, 1}](@1377) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1379 = identity(@49) -> float_type, {512}, {1}\n", + "@1380 = transpose[dims={0, 3, 1, 2}](@1378) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1381 = broadcast[axis=1,dims={1, 512, 7, 7}](@1379) -> float_type, {1, 512, 7, 7}, {0, 1, 0, 0}\n", + "@1382 = add(@1380,@1381) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1383 = transpose[dims={0, 2, 3, 1}](@1382) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1384 = identity(@48) -> float_type, {512}, {1}\n", + "@1385 = identity(@47) -> float_type, {512}, {1}\n", + "@1386 = identity(@46) -> float_type, {512}, {1}\n", + "@1387 = identity(@45) -> float_type, {512}, {1}\n", + "@1388 = unknown:FusedBatchNormV3(@1383,@1384,@1385,@1386,@1387) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1389 = transpose[dims={0, 3, 1, 2}](@1388) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1390 = relu(@1389) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1391 = transpose[dims={0, 2, 3, 1}](@1390) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1392 = transpose[dims={0, 2, 3, 1}](@44) -> float_type, {1, 512, 2048, 1}, {1048576, 2048, 1, 1048576}\n", + "@1393 = transpose[dims={0, 3, 1, 2}](@1392) -> float_type, {1, 1, 512, 2048}, {1048576, 1048576, 2048, 1}\n", + "@1394 = identity(@1393) -> float_type, {1, 1, 512, 2048}, {1048576, 1048576, 2048, 1}\n", + "@1395 = transpose[dims={0, 2, 3, 1}](@1394) -> float_type, {1, 512, 2048, 1}, {1048576, 2048, 1, 1048576}\n", + "@1396 = transpose[dims={0, 3, 1, 2}](@1391) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1397 = transpose[dims={0, 3, 1, 2}](@1395) -> float_type, {1, 1, 512, 2048}, {1048576, 1048576, 2048, 1}\n", + "@1398 = transpose[dims={3, 2, 0, 1}](@1397) -> float_type, {2048, 512, 1, 1}, {1, 2048, 1048576, 1048576}\n", + "@1399 = transpose[dims={3, 2, 0, 1}](@1397) -> float_type, {2048, 512, 1, 1}, {1, 2048, 1048576, 1048576}\n", + "@1400 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@1396,@1399) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1401 = transpose[dims={0, 2, 3, 1}](@1400) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1402 = identity(@43) -> float_type, {2048}, {1}\n", + "@1403 = transpose[dims={0, 3, 1, 2}](@1401) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1404 = broadcast[axis=1,dims={1, 2048, 7, 7}](@1402) -> float_type, {1, 2048, 7, 7}, {0, 1, 0, 0}\n", + "@1405 = add(@1403,@1404) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1406 = transpose[dims={0, 2, 3, 1}](@1405) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1407 = identity(@42) -> float_type, {2048}, {1}\n", + "@1408 = identity(@41) -> float_type, {2048}, {1}\n", + "@1409 = identity(@40) -> float_type, {2048}, {1}\n", + "@1410 = identity(@39) -> float_type, {2048}, {1}\n", + "@1411 = unknown:FusedBatchNormV3(@1406,@1407,@1408,@1409,@1410) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1412 = unknown:AddV2(@1345,@1411) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1413 = transpose[dims={0, 3, 1, 2}](@1412) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1414 = relu(@1413) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1415 = transpose[dims={0, 2, 3, 1}](@1414) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1416 = transpose[dims={0, 2, 3, 1}](@38) -> float_type, {1, 2048, 512, 1}, {1048576, 512, 1, 1048576}\n", + "@1417 = transpose[dims={0, 3, 1, 2}](@1416) -> float_type, {1, 1, 2048, 512}, {1048576, 1048576, 512, 1}\n", + "@1418 = identity(@1417) -> float_type, {1, 1, 2048, 512}, {1048576, 1048576, 512, 1}\n", + "@1419 = transpose[dims={0, 2, 3, 1}](@1418) -> float_type, {1, 2048, 512, 1}, {1048576, 512, 1, 1048576}\n", + "@1420 = transpose[dims={0, 3, 1, 2}](@1415) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1421 = transpose[dims={0, 3, 1, 2}](@1419) -> float_type, {1, 1, 2048, 512}, {1048576, 1048576, 512, 1}\n", + "@1422 = transpose[dims={3, 2, 0, 1}](@1421) -> float_type, {512, 2048, 1, 1}, {1, 512, 1048576, 1048576}\n", + "@1423 = transpose[dims={3, 2, 0, 1}](@1421) -> float_type, {512, 2048, 1, 1}, {1, 512, 1048576, 1048576}\n", + "@1424 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@1420,@1423) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1425 = transpose[dims={0, 2, 3, 1}](@1424) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1426 = identity(@37) -> float_type, {512}, {1}\n", + "@1427 = transpose[dims={0, 3, 1, 2}](@1425) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1428 = broadcast[axis=1,dims={1, 512, 7, 7}](@1426) -> float_type, {1, 512, 7, 7}, {0, 1, 0, 0}\n", + "@1429 = add(@1427,@1428) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1430 = transpose[dims={0, 2, 3, 1}](@1429) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1431 = identity(@36) -> float_type, {512}, {1}\n", + "@1432 = identity(@35) -> float_type, {512}, {1}\n", + "@1433 = identity(@34) -> float_type, {512}, {1}\n", + "@1434 = identity(@33) -> float_type, {512}, {1}\n", + "@1435 = unknown:FusedBatchNormV3(@1430,@1431,@1432,@1433,@1434) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1436 = transpose[dims={0, 3, 1, 2}](@1435) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1437 = relu(@1436) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1438 = transpose[dims={0, 2, 3, 1}](@1437) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1439 = transpose[dims={0, 2, 3, 1}](@32) -> float_type, {3, 512, 512, 3}, {786432, 512, 1, 262144}\n", + "@1440 = transpose[dims={0, 3, 1, 2}](@1439) -> float_type, {3, 3, 512, 512}, {786432, 262144, 512, 1}\n", + "@1441 = identity(@1440) -> float_type, {3, 3, 512, 512}, {786432, 262144, 512, 1}\n", + "@1442 = transpose[dims={0, 2, 3, 1}](@1441) -> float_type, {3, 512, 512, 3}, {786432, 512, 1, 262144}\n", + "@1443 = transpose[dims={0, 3, 1, 2}](@1438) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1444 = transpose[dims={0, 3, 1, 2}](@1442) -> float_type, {3, 3, 512, 512}, {786432, 262144, 512, 1}\n", + "@1445 = transpose[dims={3, 2, 0, 1}](@1444) -> float_type, {512, 512, 3, 3}, {1, 512, 786432, 262144}\n", + "@1446 = transpose[dims={3, 2, 0, 1}](@1444) -> float_type, {512, 512, 3, 3}, {1, 512, 786432, 262144}\n", + "@1447 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@1443,@1446) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1448 = transpose[dims={0, 2, 3, 1}](@1447) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1449 = identity(@31) -> float_type, {512}, {1}\n", + "@1450 = transpose[dims={0, 3, 1, 2}](@1448) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1451 = broadcast[axis=1,dims={1, 512, 7, 7}](@1449) -> float_type, {1, 512, 7, 7}, {0, 1, 0, 0}\n", + "@1452 = add(@1450,@1451) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1453 = transpose[dims={0, 2, 3, 1}](@1452) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1454 = identity(@30) -> float_type, {512}, {1}\n", + "@1455 = identity(@29) -> float_type, {512}, {1}\n", + "@1456 = identity(@28) -> float_type, {512}, {1}\n", + "@1457 = identity(@27) -> float_type, {512}, {1}\n", + "@1458 = unknown:FusedBatchNormV3(@1453,@1454,@1455,@1456,@1457) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1459 = transpose[dims={0, 3, 1, 2}](@1458) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1460 = relu(@1459) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1461 = transpose[dims={0, 2, 3, 1}](@1460) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1462 = transpose[dims={0, 2, 3, 1}](@26) -> float_type, {1, 512, 2048, 1}, {1048576, 2048, 1, 1048576}\n", + "@1463 = transpose[dims={0, 3, 1, 2}](@1462) -> float_type, {1, 1, 512, 2048}, {1048576, 1048576, 2048, 1}\n", + "@1464 = identity(@1463) -> float_type, {1, 1, 512, 2048}, {1048576, 1048576, 2048, 1}\n", + "@1465 = transpose[dims={0, 2, 3, 1}](@1464) -> float_type, {1, 512, 2048, 1}, {1048576, 2048, 1, 1048576}\n", + "@1466 = transpose[dims={0, 3, 1, 2}](@1461) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1467 = transpose[dims={0, 3, 1, 2}](@1465) -> float_type, {1, 1, 512, 2048}, {1048576, 1048576, 2048, 1}\n", + "@1468 = transpose[dims={3, 2, 0, 1}](@1467) -> float_type, {2048, 512, 1, 1}, {1, 2048, 1048576, 1048576}\n", + "@1469 = transpose[dims={3, 2, 0, 1}](@1467) -> float_type, {2048, 512, 1, 1}, {1, 2048, 1048576, 1048576}\n", + "@1470 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@1466,@1469) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1471 = transpose[dims={0, 2, 3, 1}](@1470) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1472 = identity(@25) -> float_type, {2048}, {1}\n", + "@1473 = transpose[dims={0, 3, 1, 2}](@1471) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1474 = broadcast[axis=1,dims={1, 2048, 7, 7}](@1472) -> float_type, {1, 2048, 7, 7}, {0, 1, 0, 0}\n", + "@1475 = add(@1473,@1474) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1476 = transpose[dims={0, 2, 3, 1}](@1475) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1477 = identity(@24) -> float_type, {2048}, {1}\n", + "@1478 = identity(@23) -> float_type, {2048}, {1}\n", + "@1479 = identity(@22) -> float_type, {2048}, {1}\n", + "@1480 = identity(@21) -> float_type, {2048}, {1}\n", + "@1481 = unknown:FusedBatchNormV3(@1476,@1477,@1478,@1479,@1480) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1482 = unknown:AddV2(@1415,@1481) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1483 = transpose[dims={0, 3, 1, 2}](@1482) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1484 = relu(@1483) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1485 = transpose[dims={0, 2, 3, 1}](@1484) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1486 = transpose[dims={0, 2, 3, 1}](@20) -> float_type, {1, 2048, 512, 1}, {1048576, 512, 1, 1048576}\n", + "@1487 = transpose[dims={0, 3, 1, 2}](@1486) -> float_type, {1, 1, 2048, 512}, {1048576, 1048576, 512, 1}\n", + "@1488 = identity(@1487) -> float_type, {1, 1, 2048, 512}, {1048576, 1048576, 512, 1}\n", + "@1489 = transpose[dims={0, 2, 3, 1}](@1488) -> float_type, {1, 2048, 512, 1}, {1048576, 512, 1, 1048576}\n", + "@1490 = transpose[dims={0, 3, 1, 2}](@1485) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1491 = transpose[dims={0, 3, 1, 2}](@1489) -> float_type, {1, 1, 2048, 512}, {1048576, 1048576, 512, 1}\n", + "@1492 = transpose[dims={3, 2, 0, 1}](@1491) -> float_type, {512, 2048, 1, 1}, {1, 512, 1048576, 1048576}\n", + "@1493 = transpose[dims={3, 2, 0, 1}](@1491) -> float_type, {512, 2048, 1, 1}, {1, 512, 1048576, 1048576}\n", + "@1494 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@1490,@1493) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1495 = transpose[dims={0, 2, 3, 1}](@1494) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1496 = identity(@19) -> float_type, {512}, {1}\n", + "@1497 = transpose[dims={0, 3, 1, 2}](@1495) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1498 = broadcast[axis=1,dims={1, 512, 7, 7}](@1496) -> float_type, {1, 512, 7, 7}, {0, 1, 0, 0}\n", + "@1499 = add(@1497,@1498) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1500 = transpose[dims={0, 2, 3, 1}](@1499) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1501 = identity(@18) -> float_type, {512}, {1}\n", + "@1502 = identity(@17) -> float_type, {512}, {1}\n", + "@1503 = identity(@16) -> float_type, {512}, {1}\n", + "@1504 = identity(@15) -> float_type, {512}, {1}\n", + "@1505 = unknown:FusedBatchNormV3(@1500,@1501,@1502,@1503,@1504) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1506 = transpose[dims={0, 3, 1, 2}](@1505) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1507 = relu(@1506) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1508 = transpose[dims={0, 2, 3, 1}](@1507) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1509 = transpose[dims={0, 2, 3, 1}](@14) -> float_type, {3, 512, 512, 3}, {786432, 512, 1, 262144}\n", + "@1510 = transpose[dims={0, 3, 1, 2}](@1509) -> float_type, {3, 3, 512, 512}, {786432, 262144, 512, 1}\n", + "@1511 = identity(@1510) -> float_type, {3, 3, 512, 512}, {786432, 262144, 512, 1}\n", + "@1512 = transpose[dims={0, 2, 3, 1}](@1511) -> float_type, {3, 512, 512, 3}, {786432, 512, 1, 262144}\n", + "@1513 = transpose[dims={0, 3, 1, 2}](@1508) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1514 = transpose[dims={0, 3, 1, 2}](@1512) -> float_type, {3, 3, 512, 512}, {786432, 262144, 512, 1}\n", + "@1515 = transpose[dims={3, 2, 0, 1}](@1514) -> float_type, {512, 512, 3, 3}, {1, 512, 786432, 262144}\n", + "@1516 = transpose[dims={3, 2, 0, 1}](@1514) -> float_type, {512, 512, 3, 3}, {1, 512, 786432, 262144}\n", + "@1517 = convolution[padding={1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=1](@1513,@1516) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1518 = transpose[dims={0, 2, 3, 1}](@1517) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1519 = identity(@13) -> float_type, {512}, {1}\n", + "@1520 = transpose[dims={0, 3, 1, 2}](@1518) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1521 = broadcast[axis=1,dims={1, 512, 7, 7}](@1519) -> float_type, {1, 512, 7, 7}, {0, 1, 0, 0}\n", + "@1522 = add(@1520,@1521) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1523 = transpose[dims={0, 2, 3, 1}](@1522) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1524 = identity(@12) -> float_type, {512}, {1}\n", + "@1525 = identity(@11) -> float_type, {512}, {1}\n", + "@1526 = identity(@10) -> float_type, {512}, {1}\n", + "@1527 = identity(@9) -> float_type, {512}, {1}\n", + "@1528 = unknown:FusedBatchNormV3(@1523,@1524,@1525,@1526,@1527) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1529 = transpose[dims={0, 3, 1, 2}](@1528) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1530 = relu(@1529) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1531 = transpose[dims={0, 2, 3, 1}](@1530) -> float_type, {1, 7, 7, 512}, {25088, 7, 1, 49}\n", + "@1532 = transpose[dims={0, 2, 3, 1}](@8) -> float_type, {1, 512, 2048, 1}, {1048576, 2048, 1, 1048576}\n", + "@1533 = transpose[dims={0, 3, 1, 2}](@1532) -> float_type, {1, 1, 512, 2048}, {1048576, 1048576, 2048, 1}\n", + "@1534 = identity(@1533) -> float_type, {1, 1, 512, 2048}, {1048576, 1048576, 2048, 1}\n", + "@1535 = transpose[dims={0, 2, 3, 1}](@1534) -> float_type, {1, 512, 2048, 1}, {1048576, 2048, 1, 1048576}\n", + "@1536 = transpose[dims={0, 3, 1, 2}](@1531) -> float_type, {1, 512, 7, 7}, {25088, 49, 7, 1}\n", + "@1537 = transpose[dims={0, 3, 1, 2}](@1535) -> float_type, {1, 1, 512, 2048}, {1048576, 1048576, 2048, 1}\n", + "@1538 = transpose[dims={3, 2, 0, 1}](@1537) -> float_type, {2048, 512, 1, 1}, {1, 2048, 1048576, 1048576}\n", + "@1539 = transpose[dims={3, 2, 0, 1}](@1537) -> float_type, {2048, 512, 1, 1}, {1, 2048, 1048576, 1048576}\n", + "@1540 = convolution[padding={0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=2](@1536,@1539) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1541 = transpose[dims={0, 2, 3, 1}](@1540) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1542 = identity(@7) -> float_type, {2048}, {1}\n", + "@1543 = transpose[dims={0, 3, 1, 2}](@1541) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1544 = broadcast[axis=1,dims={1, 2048, 7, 7}](@1542) -> float_type, {1, 2048, 7, 7}, {0, 1, 0, 0}\n", + "@1545 = add(@1543,@1544) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1546 = transpose[dims={0, 2, 3, 1}](@1545) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1547 = identity(@6) -> float_type, {2048}, {1}\n", + "@1548 = identity(@5) -> float_type, {2048}, {1}\n", + "@1549 = identity(@4) -> float_type, {2048}, {1}\n", + "@1550 = identity(@3) -> float_type, {2048}, {1}\n", + "@1551 = unknown:FusedBatchNormV3(@1546,@1547,@1548,@1549,@1550) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1552 = unknown:AddV2(@1485,@1551) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1553 = transpose[dims={0, 3, 1, 2}](@1552) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1554 = relu(@1553) -> float_type, {1, 2048, 7, 7}, {100352, 49, 7, 1}\n", + "@1555 = transpose[dims={0, 2, 3, 1}](@1554) -> float_type, {1, 7, 7, 2048}, {100352, 7, 1, 49}\n", + "@1556 = reduce_mean[axes={1, 2}](@1555) -> float_type, {1, 1, 1, 2048}, {2048, 2048, 2048, 1}\n", + "@1557 = squeeze[axes={1, 2}](@1556) -> float_type, {1, 2048}, {2048, 1}\n", + "@1558 = identity(@1) -> float_type, {2048, 1000}, {1000, 1}\n", + "@1559 = dot[alpha=1,beta=1](@1557,@1558) -> float_type, {1, 1000}, {1000, 1}\n", + "@1560 = identity(@0) -> float_type, {1000}, {1}\n", + "@1561 = broadcast[axis=1,dims={1, 1000}](@1560) -> float_type, {1, 1000}, {0, 1}\n", + "@1562 = add(@1559,@1561) -> float_type, {1, 1000}, {1000, 1}\n", + "@1563 = softmax[axis=1](@1562) -> float_type, {1, 1000}, {1000, 1}\n", + "@1564 = identity(@1563) -> float_type, {1, 1000}, {1000, 1}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "import subprocess\n", + "driver = \"/opt/rocm/bin/migraphx-driver\"\n", + "command = \"read\"\n", + "model_path = \"./frozen_models/{}_frozen_graph.pb\".format(MODEL_NAME)\n", + "process = subprocess.run([driver, command, model_path], \n", + " stdout=subprocess.PIPE, \n", + " universal_newlines=True)\n", + "\n", + "print(process.stdout)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tensorflow", + "language": "python", + "name": "tensorflow" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/migraphx/export_frozen_graph_tf2/README.md b/examples/migraphx/export_frozen_graph_tf2/README.md new file mode 100755 index 0000000000000000000000000000000000000000..7af3976891b0f9bdb357a744adc704859022f8f7 --- /dev/null +++ b/examples/migraphx/export_frozen_graph_tf2/README.md @@ -0,0 +1,19 @@ +# Exporting Frozen Graphs in Tensorflow 2 + +## Description +This example demonstrates how to export a frozen graph protobuf in Tensorflow 2.X that can be used as input for MIGraphX. The method for accomplishing this has changed from Tensorflow 1.X. Please refer to [export_frozen_graphs_tf1]() if you are not yet using Tensorflow 2. + +## How to Use this Example +If you do not already have Jupyter Notebooks installed, please refer to this [page](https://jupyter.org/install) for instructions. + +Once Jupyter Notebooks is installed, you can navigate to this directory and issue the command: + +``` +$ jupyter notebook +``` + +From the browser window that is launched, click on `example.ipynb` +You should now be able to run the notebook from your browser. + +To use this on your own models you wish to save, simply edit the first cell to include any additional libraries and modify `MODEL_NAME` and `model` to the model of your choosing. Additionally, training and fine-tuning can be performed before moving on to cells 2 and beyond. + diff --git a/examples/migraphx/export_frozen_graph_tf2/example.ipynb b/examples/migraphx/export_frozen_graph_tf2/example.ipynb new file mode 100755 index 0000000000000000000000000000000000000000..f6a8533eb51861e896101e6524dc1e258f744aef --- /dev/null +++ b/examples/migraphx/export_frozen_graph_tf2/example.ipynb @@ -0,0 +1,173 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Exporting Frozen Graphs in Tensorflow 2 \n", + "In order to use a trained model as input to MIGraphX, the model must be first be saved in a frozen graph format. This was accomplished in Tensorflow 1 by launching a graph in a tf.Session and then saving the session. However, Tensorflow has decided to deprecate Sessions in favor of functions and SavedModel format. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After importing the necessary libraries, the next step is to instantiate a model. For simplicity, in this example we will use a resnet50 architecture with pre-trained imagenet weights. These weights may also be trained or fine-tuned before freezing. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution() #May not be required depending on tensorflow version\n", + "from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2\n", + "from tensorflow import keras\n", + "from tensorflow.keras import layers\n", + "\n", + "MODEL_NAME = \"resnet50\"\n", + "model = tf.keras.applications.ResNet50(weights=\"imagenet\")\n", + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SavedModel format\n", + "The simplest way to save a model is through saved\\_model.save()\n", + "\n", + "This will create an equivalent tensorflow program which can later be loaded for fine-tuning or inference, although it is not directly compatible with MIGraphX." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tf.saved_model.save(model, \"./Saved_Models/{}\".format(MODEL_NAME))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Convert to ConcreteFunction\n", + "To begin, we need to get the function equivalent of the model and then concretize the function to avoid retracing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "full_model = tf.function(lambda x: model(x))\n", + "full_model = full_model.get_concrete_function(\n", + " x=tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Freeze ConcreteFunction and Serialize\n", + "Since we are saving the graph for the purpose of inference, all variables can be made constant (i.e. \"frozen\").\n", + "\n", + "Next, we need to obtain a serialized GraphDef representation of the graph. \n", + "\n", + "\n", + "Optionally, the operators can be printed out layer by layer followed by the inputs and outputs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "frozen_func = convert_variables_to_constants_v2(full_model)\n", + "frozen_func.graph.as_graph_def()\n", + "\n", + "layers = [op.name for op in frozen_func.graph.get_operations()]\n", + "print(\"-\" * 50)\n", + "print(\"Frozen model layers: \")\n", + "for layer in layers:\n", + " print(layer)\n", + "\n", + "print(\"-\" * 50)\n", + "print(\"Frozen model inputs: \")\n", + "print(frozen_func.inputs)\n", + "print(\"Frozen model outputs: \")\n", + "print(frozen_func.outputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save Frozen Graph as Protobuf\n", + "Finally, we can save to hard drive, and now the frozen graph will be stored as `./frozen_models/_frozen_graph.pb`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tf.io.write_graph(graph_or_graph_def=frozen_func.graph,\n", + " logdir=\"./frozen_models\",\n", + " name=\"{}_frozen_graph.pb\".format(MODEL_NAME),\n", + " as_text=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Assuming MIGraphX has already been built and installed on your system, the driver can be used to verify that the frozen graph has been correctly exported. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess\n", + "driver = \"/opt/rocm/bin/migraphx-driver\"\n", + "command = \"read\"\n", + "model_path = \"./frozen_models/{}_frozen_graph.pb\".format(MODEL_NAME)\n", + "process = subprocess.run([driver, command, model_path], \n", + " stdout=subprocess.PIPE, \n", + " universal_newlines=True)\n", + "\n", + "print(process.stdout)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tensorflow", + "language": "python", + "name": "tensorflow" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/migraphx/migraphx_docker/README.md b/examples/migraphx/migraphx_docker/README.md new file mode 100755 index 0000000000000000000000000000000000000000..55571a42b8809df2a4f59a263eda18f7850208e7 --- /dev/null +++ b/examples/migraphx/migraphx_docker/README.md @@ -0,0 +1,3 @@ +# MIGraphX Dockerfile + +Instructions for building and running the MIGraphX docker container can be found [here](https://github.com/ROCmSoftwarePlatform/AMDMIGraphX/blob/develop/README.md#using-docker) in this project's top level README. diff --git a/examples/migraphx/migraphx_driver/README.md b/examples/migraphx/migraphx_driver/README.md new file mode 100755 index 0000000000000000000000000000000000000000..b81926bbdaba1c5bab408b1189aedb3eb3f6c87f --- /dev/null +++ b/examples/migraphx/migraphx_driver/README.md @@ -0,0 +1,605 @@ +# MIGraphX Driver + +## Description +The MIGraphX driver is a tool that allows you to utilize many of the core functions of MIGraphX without having to write your own program. + +## How to Use this Example + +The MIGraphX driver is installed with MIGraphX and can be found in `/opt/rocm/bin/migraphx-driver`, or in `AMDMIGraphX/build/bin/migraphx-driver` after building the source code. + +See below for a comprehensive list of commands and option arguments, as well as some usage examples. + +### Commands +| Command | Description | +| --- | ---| +| op | When followed by the option --list or -l, prints all operators of MIGraphX | +| params | Prints the input and output parameter shapes | +| run | Compiles, allocates parameters, evaluates, and prints input graph | +| read | Loads and prints input graph | +| compile | Compiles and prints input graph | +| verify | Runs reference and GPU implementations and checks outputs for consistency | +| perf | Compiles and runs input graph then prints performance report | + +### Options +| Option | Description | +| --- | --- | +| --help \| -h | Show help | +| --model | Loads one of the three default models | +| --onnx | Load file as onnx graph | +| --tf | Load file as a tensorflow graph | +| --migraphx | Load file as a migraphx graph | +| --migraphx-json | Load file as a migraphx JSON graph | +| --nhwc | Treat tensorflow format as nhwc | +| --nchw | Treat tensorflow format as nchw | +| --skip-unknown-operators | Skip unknown operators when parsing and continue to parse | +| --trim \| -t | Trim instructions from the end | +| --optimize \| -O | Optimize when reading | +| --graphviz \| -g | Print out a graphviz representation | +| --brief | Make the output brief | +| --cpp | Print out the program as cpp program | +| --json | Print out program as json | +| --text | Print out program in text format | +| --binary | Print out program in binary format | +| --output \| -o | Output to file | +| --fill0 | Fill parameter with 0s | +| --fill1 | Fill parameter with 1s | +| --gpu | Compile on the gpu | +| --cpu | Compile on the cpu | +| --ref | Compile on the reference implementation | +| --enable-offload-copy | Enable implicit offload copying | +| --disable-fast-math | Disable fast math optimization | +| --fp16 | Quantize for fp16 | +| --int8 | Quantize for int8 | +| --tolerance | Tolerance for errors | +| --per-instruction \| -i | Verify each instruction | +| --reduce \| -r | Reduce program and verify | +| --iterations \| -n | Number of iterations to run for perf report | +| --list \| -l | List all the operators of MIGraphX | + +## Usage Examples +The examples below supply a simple MNIST ConvNet as the input graph. Models of higher complexity will have considerably larger outputs in most cases. + +##### Example: op +``` +$ /opt/rocm/bin/migraphx-driver op --list +``` + +
+View output + +``` +@literal +@param +@return +abs +acos +acosh +add +argmax +argmin +as_shape +asin +asinh +atan +atanh +batch_norm_inference +broadcast +capture +ceil +check_context::migraphx::version_1::gpu::context +clip +concat +contiguous +convert +convolution +cos +cosh +deconvolution +div +dot +elu +equal +erf +exp +flatten +floor +gather +gpu::abs +gpu::acos +gpu::acosh +gpu::add +gpu::add_clip +gpu::add_gelu +gpu::add_gelu_new +gpu::add_relu +gpu::add_tanh +gpu::argmax +gpu::argmin +gpu::asin +gpu::asinh +gpu::atan +gpu::atanh +gpu::batch_norm_inference +gpu::ceil +gpu::clip +gpu::concat +gpu::contiguous +gpu::conv_bias +gpu::conv_bias_relu +gpu::convert +gpu::convolution +gpu::cos +gpu::cosh +gpu::deconv +gpu::div +gpu::elu +gpu::equal +gpu::erf +gpu::exp +gpu::floor +gpu::gather +gpu::gelu +gpu::gelu_new +gpu::gemm +gpu::greater +gpu::int8_conv_pack +gpu::int8_gemm_pack_a +gpu::int8_gemm_pack_b +gpu::layernorm +gpu::leaky_relu +gpu::less +gpu::log +gpu::logsoftmax +gpu::lrn +gpu::max +gpu::min +gpu::mul +gpu::mul_add +gpu::mul_add_relu +gpu::pad +gpu::pooling +gpu::pow +gpu::prelu +gpu::quant_convolution +gpu::quant_gemm +gpu::recip +gpu::record_event +gpu::reduce_max +gpu::reduce_mean +gpu::reduce_min +gpu::reduce_prod +gpu::reduce_sum +gpu::relu +gpu::rnn_var_sl_last_output +gpu::rnn_var_sl_shift_output +gpu::rnn_var_sl_shift_sequence +gpu::round +gpu::rsqrt +gpu::set_stream +gpu::sigmoid +gpu::sign +gpu::sin +gpu::sinh +gpu::softmax +gpu::sqdiff +gpu::sqrt +gpu::sub +gpu::tan +gpu::tanh +gpu::triadd +gpu::triadd_clip +gpu::triadd_relu +gpu::triadd_sigmoid +gpu::triadd_tanh +gpu::wait_event +greater +gru +hip::allocate +hip::copy +hip::copy_from_gpu +hip::copy_to_gpu +hip::hip_allocate_memory +hip::hip_copy_literal +hip::sync_device +identity +im2col +leaky_relu +less +load +log +logsoftmax +lrn +lstm +max +min +mul +multibroadcast +neg +outline +pad +pooling +pow +prelu +quant_convolution +quant_dot +recip +reduce_max +reduce_mean +reduce_min +reduce_prod +reduce_sum +ref::batch_norm_inference +ref::convolution +ref::deconvolution +ref::dot +ref::elu +ref::im2col +ref::leaky_relu +ref::logsoftmax +ref::lrn +ref::op +ref::pad +ref::pooling_average +ref::pooling_max +ref::quant_convolution +ref::rnn_var_sl_last_output +ref::softmax +relu +reshape +rnn +rnn_last_cell_output +rnn_last_hs_output +rnn_var_sl_last_output +rnn_var_sl_shift_output +rnn_var_sl_shift_sequence +round +rsqrt +scalar +sigmoid +sign +sin +sinh +slice +softmax +sqdiff +sqrt +squeeze +sub +tan +tanh +transpose +undefined +unknown: +unsqueeze +``` + +
+

+ +##### Example: params +``` +$ /opt/rocm/bin/migraphx-driver params simple_graph.pb +``` + +
+View output + +``` +Reading: simple_graph.pb +x: float_type, {1, 28, 28}, {784, 28, 1} +``` + +
+

+ +##### Example: run (onnx file input) +``` +$ /opt/rocm/bin/migraphx-driver run --onnx simple_graph.onnx +``` + +
+View output + +``` +Compiling ... +Reading: simple_graph.onnx +@0 = check_context::migraphx::version_1::gpu::context -> float_type, {}, {} +@1 = hip::hip_allocate_memory[shape=float_type, {256}, {1},id=scratch] -> float_type, {256}, {1} +@2 = hip::hip_copy_literal[id=@literal:1] -> float_type, {784, 128}, {128, 1} +x:0 = @param:x:0 -> float_type, {1, 28, 28}, {784, 28, 1} +@3 = reshape[dims={-1, 784}](x:0) -> float_type, {1, 784}, {784, 1} +@4 = load[offset=0,end=512](@1) -> float_type, {1, 128}, {128, 1} +@5 = gpu::gemm[alpha=1,beta=0](@3,@2,@4) -> float_type, {1, 128}, {128, 1} +@6 = hip::hip_copy_literal[id=@literal:0] -> float_type, {128}, {1} +@7 = hip::hip_copy_literal[id=@literal:2] -> float_type, {10}, {1} +@8 = hip::hip_copy_literal[id=@literal:3] -> float_type, {128, 10}, {10, 1} +@9 = multibroadcast[output_lens={1, 128}](@6) -> float_type, {1, 128}, {0, 1} +@10 = load[offset=512,end=1024](@1) -> float_type, {1, 128}, {128, 1} +@11 = gpu::add_relu(@5,@9,@10) -> float_type, {1, 128}, {128, 1} +@12 = load[offset=0,end=40](@1) -> float_type, {1, 10}, {10, 1} +@13 = gpu::gemm[alpha=1,beta=0](@11,@8,@12) -> float_type, {1, 10}, {10, 1} +@14 = multibroadcast[output_lens={1, 10}](@7) -> float_type, {1, 10}, {0, 1} +@15 = load[offset=40,end=80](@1) -> float_type, {1, 10}, {10, 1} +@16 = gpu::add(@13,@14,@15) -> float_type, {1, 10}, {10, 1} +#output_0 = @param:#output_0 -> float_type, {1, 10}, {10, 1} +@17 = gpu::softmax[axis=1](@16,#output_0) -> float_type, {1, 10}, {10, 1} +@18 = @return(@17) + +Allocating params ... +@0 = check_context::migraphx::version_1::gpu::context -> float_type, {}, {} +@1 = hip::hip_allocate_memory[shape=float_type, {256}, {1},id=scratch] -> float_type, {256}, {1} +@2 = hip::hip_copy_literal[id=@literal:1] -> float_type, {784, 128}, {128, 1} +x:0 = @param:x:0 -> float_type, {1, 28, 28}, {784, 28, 1} +@3 = reshape[dims={-1, 784}](x:0) -> float_type, {1, 784}, {784, 1} +@4 = load[offset=0,end=512](@1) -> float_type, {1, 128}, {128, 1} +@5 = gpu::gemm[alpha=1,beta=0](@3,@2,@4) -> float_type, {1, 128}, {128, 1} +@6 = hip::hip_copy_literal[id=@literal:0] -> float_type, {128}, {1} +@7 = hip::hip_copy_literal[id=@literal:2] -> float_type, {10}, {1} +@8 = hip::hip_copy_literal[id=@literal:3] -> float_type, {128, 10}, {10, 1} +@9 = multibroadcast[output_lens={1, 128}](@6) -> float_type, {1, 128}, {0, 1} +@10 = load[offset=512,end=1024](@1) -> float_type, {1, 128}, {128, 1} +@11 = gpu::add_relu(@5,@9,@10) -> float_type, {1, 128}, {128, 1} +@12 = load[offset=0,end=40](@1) -> float_type, {1, 10}, {10, 1} +@13 = gpu::gemm[alpha=1,beta=0](@11,@8,@12) -> float_type, {1, 10}, {10, 1} +@14 = multibroadcast[output_lens={1, 10}](@7) -> float_type, {1, 10}, {0, 1} +@15 = load[offset=40,end=80](@1) -> float_type, {1, 10}, {10, 1} +@16 = gpu::add(@13,@14,@15) -> float_type, {1, 10}, {10, 1} +#output_0 = @param:#output_0 -> float_type, {1, 10}, {10, 1} +@17 = gpu::softmax[axis=1](@16,#output_0) -> float_type, {1, 10}, {10, 1} +@18 = @return(@17) +``` + +
+

+ +##### Example: read +``` +$ /opt/rocm/bin/migraphx-driver read simple_graph.pb +``` + +
+View output + +``` +Reading: simple_graph.pb +@0 = @literal{0.0136018, -0.0839988, 0.0375392, 0.0613085, -0.125795, 0.176185, 0.0761055, 0.0093384, -0.110057, -0.170587} -> float_type, {10}, {1} +@1 = @literal{ ... } -> float_type, {128, 10}, {10, 1} +@2 = @literal{ ... } -> float_type, {128}, {1} +@3 = @literal{ ... } -> float_type, {784, 128}, {128, 1} +@4 = @literal{-1, 784} -> int32_type, {2}, {1} +x = @param:x -> float_type, {1, 28, 28}, {784, 28, 1} +@5 = reshape[dims={-1, 784}](x) -> float_type, {1, 784}, {784, 1} +@6 = identity(@3) -> float_type, {784, 128}, {128, 1} +@7 = dot[alpha=1,beta=1](@5,@6) -> float_type, {1, 128}, {128, 1} +@8 = identity(@2) -> float_type, {128}, {1} +@9 = broadcast[axis=1,dims={1, 128}](@8) -> float_type, {1, 128}, {0, 1} +@10 = add(@7,@9) -> float_type, {1, 128}, {128, 1} +@11 = relu(@10) -> float_type, {1, 128}, {128, 1} +@12 = identity(@1) -> float_type, {128, 10}, {10, 1} +@13 = dot[alpha=1,beta=1](@11,@12) -> float_type, {1, 10}, {10, 1} +@14 = identity(@0) -> float_type, {10}, {1} +@15 = broadcast[axis=1,dims={1, 10}](@14) -> float_type, {1, 10}, {0, 1} +@16 = add(@13,@15) -> float_type, {1, 10}, {10, 1} +@17 = softmax[axis=1](@16) -> float_type, {1, 10}, {10, 1} +@18 = identity(@17) -> float_type, {1, 10}, {10, 1} +``` + +
+

+ +##### Example: compile (on GPU, quantized for fp16) +``` +$ /opt/rocm/bin/migraphx-driver compile --gpu --fp16 simple_graph.pb +``` + +
+View output + +``` +Compiling ... +Reading: simple_graph.pb +@0 = check_context::migraphx::version_1::gpu::context -> float_type, {}, {} +@1 = hip::hip_allocate_memory[shape=float_type, {456}, {1},id=scratch] -> float_type, {456}, {1} +@2 = hip::hip_copy_literal[id=@literal:0] -> half_type, {784, 128}, {128, 1} +@3 = load[offset=256,end=1824](@1) -> half_type, {1, 28, 28}, {784, 28, 1} +x = @param:x -> float_type, {1, 28, 28}, {784, 28, 1} +@4 = gpu::convert[target_type=1](x,@3) -> half_type, {1, 28, 28}, {784, 28, 1} +@5 = reshape[dims={-1, 784}](@4) -> half_type, {1, 784}, {784, 1} +@6 = load[offset=0,end=256](@1) -> half_type, {1, 128}, {128, 1} +@7 = gpu::gemm[alpha=1,beta=0](@5,@2,@6) -> half_type, {1, 128}, {128, 1} +@8 = hip::hip_copy_literal[id=@literal:2] -> half_type, {128, 10}, {10, 1} +@9 = hip::hip_copy_literal[id=@literal:1] -> half_type, {128}, {1} +@10 = hip::hip_copy_literal[id=@literal:3] -> half_type, {10}, {1} +@11 = load[offset=256,end=512](@1) -> half_type, {1, 128}, {128, 1} +@12 = broadcast[axis=1,dims={1, 128}](@9) -> half_type, {1, 128}, {0, 1} +@13 = gpu::add_relu(@7,@12,@11) -> half_type, {1, 128}, {128, 1} +@14 = load[offset=0,end=20](@1) -> half_type, {1, 10}, {10, 1} +@15 = gpu::gemm[alpha=1,beta=0](@13,@8,@14) -> half_type, {1, 10}, {10, 1} +@16 = broadcast[axis=1,dims={1, 10}](@10) -> half_type, {1, 10}, {0, 1} +@17 = load[offset=20,end=40](@1) -> half_type, {1, 10}, {10, 1} +@18 = gpu::add(@15,@16,@17) -> half_type, {1, 10}, {10, 1} +@19 = load[offset=0,end=20](@1) -> half_type, {1, 10}, {10, 1} +@20 = gpu::softmax[axis=1](@18,@19) -> half_type, {1, 10}, {10, 1} +output = @param:output -> float_type, {1, 10}, {10, 1} +@21 = gpu::convert[target_type=2](@20,output) -> float_type, {1, 10}, {10, 1} +``` + +
+

+ +##### Example: verify +``` +$ /opt/rocm/bin/migraphx-driver verify simple_graph.pb +``` + +
+View output + +``` +Reading: simple_graph.pb +@0 = @literal{0.0136018, -0.0839988, 0.0375392, 0.0613085, -0.125795, 0.176185, 0.0761055, 0.0093384, -0.110057, -0.170587} -> float_type, {10}, {1} +@1 = @literal{ ... } -> float_type, {128, 10}, {10, 1} +@2 = @literal{ ... } -> float_type, {128}, {1} +@3 = @literal{ ... } -> float_type, {784, 128}, {128, 1} +@4 = @literal{-1, 784} -> int32_type, {2}, {1} +x = @param:x -> float_type, {1, 28, 28}, {784, 28, 1} +@5 = reshape[dims={-1, 784}](x) -> float_type, {1, 784}, {784, 1} +@6 = identity(@3) -> float_type, {784, 128}, {128, 1} +@7 = dot[alpha=1,beta=1](@5,@6) -> float_type, {1, 128}, {128, 1} +@8 = identity(@2) -> float_type, {128}, {1} +@9 = broadcast[axis=1,dims={1, 128}](@8) -> float_type, {1, 128}, {0, 1} +@10 = add(@7,@9) -> float_type, {1, 128}, {128, 1} +@11 = relu(@10) -> float_type, {1, 128}, {128, 1} +@12 = identity(@1) -> float_type, {128, 10}, {10, 1} +@13 = dot[alpha=1,beta=1](@11,@12) -> float_type, {1, 10}, {10, 1} +@14 = identity(@0) -> float_type, {10}, {1} +@15 = broadcast[axis=1,dims={1, 10}](@14) -> float_type, {1, 10}, {0, 1} +@16 = add(@13,@15) -> float_type, {1, 10}, {10, 1} +@17 = softmax[axis=1](@16) -> float_type, {1, 10}, {10, 1} +@18 = identity(@17) -> float_type, {1, 10}, {10, 1} + +@0 = @literal{0.0136018, -0.0839988, 0.0375392, 0.0613085, -0.125795, 0.176185, 0.0761055, 0.0093384, -0.110057, -0.170587} -> float_type, {10}, {1} +@1 = @literal{ ... } -> float_type, {128, 10}, {10, 1} +@2 = @literal{ ... } -> float_type, {128}, {1} +@3 = @literal{ ... } -> float_type, {784, 128}, {128, 1} +@4 = @literal{-1, 784} -> int32_type, {2}, {1} +x = @param:x -> float_type, {1, 28, 28}, {784, 28, 1} +@5 = reshape[dims={-1, 784}](x) -> float_type, {1, 784}, {784, 1} +@6 = identity(@3) -> float_type, {784, 128}, {128, 1} +@7 = dot[alpha=1,beta=1](@5,@6) -> float_type, {1, 128}, {128, 1} +@8 = identity(@2) -> float_type, {128}, {1} +@9 = broadcast[axis=1,dims={1, 128}](@8) -> float_type, {1, 128}, {0, 1} +@10 = add(@7,@9) -> float_type, {1, 128}, {128, 1} +@11 = relu(@10) -> float_type, {1, 128}, {128, 1} +@12 = identity(@1) -> float_type, {128, 10}, {10, 1} +@13 = dot[alpha=1,beta=1](@11,@12) -> float_type, {1, 10}, {10, 1} +@14 = identity(@0) -> float_type, {10}, {1} +@15 = broadcast[axis=1,dims={1, 10}](@14) -> float_type, {1, 10}, {0, 1} +@16 = add(@13,@15) -> float_type, {1, 10}, {10, 1} +@17 = softmax[axis=1](@16) -> float_type, {1, 10}, {10, 1} +@18 = identity(@17) -> float_type, {1, 10}, {10, 1} + +@0 = @literal{0.0136018, -0.0839988, 0.0375392, 0.0613085, -0.125795, 0.176185, 0.0761055, 0.0093384, -0.110057, -0.170587} -> float_type, {10}, {1} +@1 = @literal{ ... } -> float_type, {128, 10}, {10, 1} +@2 = @literal{ ... } -> float_type, {128}, {1} +@3 = @literal{ ... } -> float_type, {784, 128}, {128, 1} +x = @param:x -> float_type, {1, 28, 28}, {784, 28, 1} +@4 = ref::reshape[dims={-1, 784}](x) -> float_type, {1, 784}, {784, 1} +@5 = ref::identity(@3) -> float_type, {784, 128}, {128, 1} +@6 = ref::dot[alpha=1,beta=1](@4,@5) -> float_type, {1, 128}, {128, 1} +@7 = ref::identity(@2) -> float_type, {128}, {1} +@8 = ref::broadcast[axis=1,dims={1, 128}](@7) -> float_type, {1, 128}, {0, 1} +@9 = ref::contiguous(@8) -> float_type, {1, 128}, {128, 1} +@10 = ref::add(@6,@9) -> float_type, {1, 128}, {128, 1} +@11 = ref::relu(@10) -> float_type, {1, 128}, {128, 1} +@12 = ref::identity(@1) -> float_type, {128, 10}, {10, 1} +@13 = ref::dot[alpha=1,beta=1](@11,@12) -> float_type, {1, 10}, {10, 1} +@14 = ref::identity(@0) -> float_type, {10}, {1} +@15 = ref::broadcast[axis=1,dims={1, 10}](@14) -> float_type, {1, 10}, {0, 1} +@16 = ref::contiguous(@15) -> float_type, {1, 10}, {10, 1} +@17 = ref::add(@13,@16) -> float_type, {1, 10}, {10, 1} +@18 = ref::softmax[axis=1](@17) -> float_type, {1, 10}, {10, 1} +@19 = ref::identity(@18) -> float_type, {1, 10}, {10, 1} + +@0 = check_context::migraphx::version_1::gpu::context -> float_type, {}, {} +@1 = hip::hip_allocate_memory[shape=float_type, {256}, {1},id=scratch] -> float_type, {256}, {1} +@2 = hip::hip_copy_literal[id=@literal:3] -> float_type, {784, 128}, {128, 1} +x = @param:x -> float_type, {1, 28, 28}, {784, 28, 1} +@3 = load[offset=0,end=512](@1) -> float_type, {1, 128}, {128, 1} +@4 = reshape[dims={-1, 784}](x) -> float_type, {1, 784}, {784, 1} +@5 = gpu::gemm[alpha=1,beta=0](@4,@2,@3) -> float_type, {1, 128}, {128, 1} +@6 = hip::hip_copy_literal[id=@literal:1] -> float_type, {128, 10}, {10, 1} +@7 = hip::hip_copy_literal[id=@literal:2] -> float_type, {128}, {1} +@8 = hip::hip_copy_literal[id=@literal:0] -> float_type, {10}, {1} +@9 = load[offset=512,end=1024](@1) -> float_type, {1, 128}, {128, 1} +@10 = broadcast[axis=1,dims={1, 128}](@7) -> float_type, {1, 128}, {0, 1} +@11 = gpu::add_relu(@5,@10,@9) -> float_type, {1, 128}, {128, 1} +@12 = load[offset=40,end=80](@1) -> float_type, {1, 10}, {10, 1} +@13 = gpu::gemm[alpha=1,beta=0](@11,@6,@12) -> float_type, {1, 10}, {10, 1} +@14 = load[offset=0,end=40](@1) -> float_type, {1, 10}, {10, 1} +@15 = broadcast[axis=1,dims={1, 10}](@8) -> float_type, {1, 10}, {0, 1} +@16 = gpu::add(@13,@15,@14) -> float_type, {1, 10}, {10, 1} +output = @param:output -> float_type, {1, 10}, {10, 1} +@17 = gpu::softmax[axis=1](@16,output) -> float_type, {1, 10}, {10, 1} +``` + +
+

+ +##### Example: perf +``` +$ /opt/rocm/bin/migraphx-driver perf simple_graph.pb +``` + +
+View output + +``` +Compiling ... +Reading: simple_graph.pb +@0 = check_context::migraphx::version_1::gpu::context -> float_type, {}, {} +@1 = hip::hip_allocate_memory[shape=float_type, {256}, {1},id=scratch] -> float_type, {256}, {1} +@2 = hip::hip_copy_literal[id=@literal:3] -> float_type, {784, 128}, {128, 1} +@3 = load[offset=0,end=512](@1) -> float_type, {1, 128}, {128, 1} +x = @param:x -> float_type, {1, 28, 28}, {784, 28, 1} +@4 = reshape[dims={-1, 784}](x) -> float_type, {1, 784}, {784, 1} +@5 = gpu::gemm[alpha=1,beta=0](@4,@2,@3) -> float_type, {1, 128}, {128, 1} +@6 = hip::hip_copy_literal[id=@literal:1] -> float_type, {128, 10}, {10, 1} +@7 = hip::hip_copy_literal[id=@literal:0] -> float_type, {10}, {1} +@8 = hip::hip_copy_literal[id=@literal:2] -> float_type, {128}, {1} +@9 = broadcast[axis=1,dims={1, 128}](@8) -> float_type, {1, 128}, {0, 1} +@10 = load[offset=512,end=1024](@1) -> float_type, {1, 128}, {128, 1} +@11 = gpu::add_relu(@5,@9,@10) -> float_type, {1, 128}, {128, 1} +@12 = load[offset=0,end=40](@1) -> float_type, {1, 10}, {10, 1} +@13 = gpu::gemm[alpha=1,beta=0](@11,@6,@12) -> float_type, {1, 10}, {10, 1} +@14 = broadcast[axis=1,dims={1, 10}](@7) -> float_type, {1, 10}, {0, 1} +@15 = load[offset=40,end=80](@1) -> float_type, {1, 10}, {10, 1} +@16 = gpu::add(@13,@14,@15) -> float_type, {1, 10}, {10, 1} +output = @param:output -> float_type, {1, 10}, {10, 1} +@17 = gpu::softmax[axis=1](@16,output) -> float_type, {1, 10}, {10, 1} + +Allocating params ... +Running performance report ... +@0 = check_context::migraphx::version_1::gpu::context -> float_type, {}, {}: 0.00057782ms, 1% +@1 = hip::hip_allocate_memory[shape=float_type, {256}, {1},id=scratch] -> float_type, {256}, {1}: 0.000295ms, 1% +@2 = hip::hip_copy_literal[id=@literal:3] -> float_type, {784, 128}, {128, 1}: 0.00027942ms, 1% +@3 = load[offset=0,end=512](@1) -> float_type, {1, 128}, {128, 1}: 0.000232ms, 1% +x = @param:x -> float_type, {1, 28, 28}, {784, 28, 1}: 0.0003206ms, 1% +@4 = reshape[dims={-1, 784}](x) -> float_type, {1, 784}, {784, 1}: 0.00033842ms, 1% +@5 = gpu::gemm[alpha=1,beta=0](@4,@2,@3) -> float_type, {1, 128}, {128, 1}: 0.212592ms, 52% +@6 = hip::hip_copy_literal[id=@literal:1] -> float_type, {128, 10}, {10, 1}: 0.00085822ms, 1% +@7 = hip::hip_copy_literal[id=@literal:0] -> float_type, {10}, {1}: 0.000382ms, 1% +@8 = hip::hip_copy_literal[id=@literal:2] -> float_type, {128}, {1}: 0.0003486ms, 1% +@9 = broadcast[axis=1,dims={1, 128}](@8) -> float_type, {1, 128}, {0, 1}: 0.000299ms, 1% +@10 = load[offset=512,end=1024](@1) -> float_type, {1, 128}, {128, 1}: 0.000234ms, 1% +@11 = gpu::add_relu(@5,@9,@10) -> float_type, {1, 128}, {128, 1}: 0.0416597ms, 11% +@12 = load[offset=0,end=40](@1) -> float_type, {1, 10}, {10, 1}: 0.0007548ms, 1% +@13 = gpu::gemm[alpha=1,beta=0](@11,@6,@12) -> float_type, {1, 10}, {10, 1}: 0.0733071ms, 18% +@14 = broadcast[axis=1,dims={1, 10}](@7) -> float_type, {1, 10}, {0, 1}: 0.00088142ms, 1% +@15 = load[offset=40,end=80](@1) -> float_type, {1, 10}, {10, 1}: 0.000408ms, 1% +@16 = gpu::add(@13,@14,@15) -> float_type, {1, 10}, {10, 1}: 0.0410144ms, 10% +output = @param:output -> float_type, {1, 10}, {10, 1}: 0.0010222ms, 1% +@17 = gpu::softmax[axis=1](@16,output) -> float_type, {1, 10}, {10, 1}: 0.0385636ms, 10% + +Summary: +gpu::gemm: 0.285899ms, 69% +gpu::add_relu: 0.0416597ms, 11% +gpu::add: 0.0410144ms, 10% +gpu::softmax: 0.0385636ms, 10% +hip::hip_copy_literal: 0.00186824ms, 1% +load: 0.0016288ms, 1% +@param: 0.0013428ms, 1% +broadcast: 0.00118042ms, 1% +check_context::migraphx::version_1::gpu::context: 0.00057782ms, 1% +reshape: 0.00033842ms, 1% +hip::hip_allocate_memory: 0.000295ms, 1% + +Rate: 2866.1/sec +Total time: 0.348906ms +Total instructions time: 0.414369ms +Overhead time: 0.00348144ms, -0.0654627ms +Overhead: 1%, -19% +``` + +
+

diff --git a/examples/nlp/README.md b/examples/nlp/README.md new file mode 100755 index 0000000000000000000000000000000000000000..f4d52f35474c4938ff83a111dea92733e9b404c0 --- /dev/null +++ b/examples/nlp/README.md @@ -0,0 +1,3 @@ +# Natural Language Processing Inference Examples + +- [Python BERT-SQuAD](./python_bert_squad) \ No newline at end of file diff --git a/examples/nlp/python_bert_squad/BERT-Squad.ipynb b/examples/nlp/python_bert_squad/BERT-Squad.ipynb new file mode 100755 index 0000000000000000000000000000000000000000..9c713a28d6c9806bfb4c60a488dcc4c424ebd244 --- /dev/null +++ b/examples/nlp/python_bert_squad/BERT-Squad.ipynb @@ -0,0 +1,281 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# BERT-SQuAD Inference Example with AMD MIGraphX" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This tutorial shows how to run the BERT-Squad model on ONNX-Runtime with MIGraphX backend." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Requirements " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip3 install -r requirements_bertsquad.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import json\n", + "import time\n", + "import os.path\n", + "from os import path\n", + "import sys\n", + "\n", + "import tokenizers\n", + "from run_onnx_squad import *\n", + "\n", + "import migraphx" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download BERT ONNX file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!wget -nc https://github.com/onnx/models/raw/main/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download uncased file / vocabulary" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!apt-get install unzip\n", + "!wget -q -nc https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip\n", + "!unzip -n uncased_L-12_H-768_A-12.zip" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Input data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "input_file = 'inputs.json'\n", + "with open(input_file) as json_file:\n", + " test_data = json.load(json_file)\n", + " print(json.dumps(test_data, indent=2))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Configuration for inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "max_seq_length = 256\n", + "doc_stride = 128\n", + "max_query_length = 64\n", + "batch_size = 1\n", + "n_best_size = 20\n", + "max_answer_length = 30" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Read vocabulary file and tokenize" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vocab_file = os.path.join('uncased_L-12_H-768_A-12', 'vocab.txt')\n", + "tokenizer = tokenizers.BertWordPieceTokenizer(vocab_file)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Convert the example to features to input" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# preprocess input\n", + "predict_file = 'inputs.json'\n", + "\n", + "# Use read_squad_examples method from run_onnx_squad to read the input file\n", + "eval_examples = read_squad_examples(input_file=predict_file)\n", + "\n", + "# Use convert_examples_to_features method from run_onnx_squad to get parameters from the input\n", + "input_ids, input_mask, segment_ids, extra_data = convert_examples_to_features(\n", + " eval_examples, tokenizer, max_seq_length, doc_stride, max_query_length)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compile with MIGraphX for GPU" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = migraphx.parse_onnx(\"bertsquad-10.onnx\")\n", + "model.compile(migraphx.get_target(\"gpu\"))\n", + "#model.print()\n", + "\n", + "model.get_parameter_names()\n", + "model.get_parameter_shapes()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run the input through the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n = len(input_ids)\n", + "bs = batch_size\n", + "all_results = []\n", + "\n", + "for idx in range(0, n):\n", + " item = eval_examples[idx]\n", + " print(item)\n", + "\n", + " result = model.run({\n", + " \"unique_ids_raw_output___9:0\":\n", + " np.array([item.qas_id], dtype=np.int64),\n", + " \"input_ids:0\":\n", + " input_ids[idx:idx + bs],\n", + " \"input_mask:0\":\n", + " input_mask[idx:idx + bs],\n", + " \"segment_ids:0\":\n", + " segment_ids[idx:idx + bs]\n", + " })\n", + "\n", + " in_batch = result[1].get_shape().lens()[0]\n", + " print(in_batch)\n", + " start_logits = [float(x) for x in result[1].tolist()]\n", + " end_logits = [float(x) for x in result[0].tolist()]\n", + " # print(start_logits)\n", + " # print(end_logits)\n", + " for i in range(0, in_batch):\n", + " unique_id = len(all_results)\n", + " all_results.append(\n", + " RawResult(unique_id=unique_id,\n", + " start_logits=start_logits,\n", + " end_logits=end_logits))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get the predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_dir = 'predictions'\n", + "os.makedirs(output_dir, exist_ok=True)\n", + "output_prediction_file = os.path.join(output_dir, \"predictions.json\")\n", + "output_nbest_file = os.path.join(output_dir, \"nbest_predictions.json\")\n", + "write_predictions(eval_examples, extra_data, all_results, n_best_size,\n", + " max_answer_length, True, output_prediction_file,\n", + " output_nbest_file)\n", + "\n", + "with open(output_prediction_file) as json_file:\n", + " test_data = json.load(json_file)\n", + " print(json.dumps(test_data, indent=2))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/nlp/python_bert_squad/README.md b/examples/nlp/python_bert_squad/README.md new file mode 100755 index 0000000000000000000000000000000000000000..8a4d2575375261d214378170eece02755fa8f707 --- /dev/null +++ b/examples/nlp/python_bert_squad/README.md @@ -0,0 +1,34 @@ +# BERT-SQuAD Example with MIGraphX +Question answering with BERT using MIGraphX optimizations on ROCm platform. + +There are two ways to run the example: +1) Install MIGraphX and Jupyter notebook to your system and then utilize `BERT-Squad.ipynb` notebook file. +2) Install MIGraphx to your system and follow the steps executing the python script `bert-squad-migraphx.py`. + +# Steps +1) Install MIGraphX to your environment. Please follow the steps to build MIGraphX given at https://github.com/ROCmSoftwarePlatform/AMDMIGraphX +2) Upgrade your pip3 to latest version +``` +pip3 install --upgrade pip +``` +3) Install the requirements file +``` +pip3 install -r requirements_bertsquad.txt +``` +4) Install `unzip` and fetch the uncased file (vocabulary): +``` +apt-get install unzip +wget -q https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip +unzip uncased_L-12_H-768_A-12.zip +``` +5) Get BERT ONNX model (bertsquad-10.onnx): +``` +wget https://github.com/onnx/models/raw/main/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx +``` +6) Run the inference, it will compile and run the model on three questions and small data provided in `inputs.json`: +``` +python3 bert-squad-migraphx.py +``` +## References +This example utilizes the following notebook :notebook: and applies it to MIGraphX: +https://github.com/onnx/models/blob/master/text/machine_comprehension/bert-squad/BERT-Squad.ipynb diff --git a/examples/nlp/python_bert_squad/bert-squad-migraphx.py b/examples/nlp/python_bert_squad/bert-squad-migraphx.py new file mode 100755 index 0000000000000000000000000000000000000000..d6c7a8bd438647f12c233338fd6ab3476f3a67a0 --- /dev/null +++ b/examples/nlp/python_bert_squad/bert-squad-migraphx.py @@ -0,0 +1,87 @@ +import numpy as np +import json +import os.path +import tokenizers +import collections +from run_onnx_squad import read_squad_examples, write_predictions, convert_examples_to_features +import migraphx + +RawResult = collections.namedtuple("RawResult", + ["unique_id", "start_logits", "end_logits"]) + +####################################### +input_file = 'inputs_amd.json' +with open(input_file) as json_file: + test_data = json.load(json_file) + print(json.dumps(test_data, indent=2)) + +# preprocess input +predict_file = 'inputs_amd.json' + +# Use read_squad_examples method from run_onnx_squad to read the input file +eval_examples = read_squad_examples(input_file=predict_file) + +max_seq_length = 256 +doc_stride = 128 +max_query_length = 64 +batch_size = 1 +n_best_size = 20 +max_answer_length = 30 + +vocab_file = os.path.join('uncased_L-12_H-768_A-12', 'vocab.txt') +tokenizer = tokenizers.BertWordPieceTokenizer(vocab_file) + +# Use convert_examples_to_features method from run_onnx_squad to get parameters from the input +input_ids, input_mask, segment_ids, extra_data = convert_examples_to_features( + eval_examples, tokenizer, max_seq_length, doc_stride, max_query_length) + +####################################### +# Compile +print("INFO: Parsing and compiling the model...") +model = migraphx.parse_onnx("bertsquad-10.onnx") +model.compile(migraphx.get_target("gpu")) +#model.print() + +print(model.get_parameter_names()) +print(model.get_parameter_shapes()) + +n = len(input_ids) +bs = batch_size +all_results = [] + +for idx in range(0, n): + item = eval_examples[idx] + print(item) + + result = model.run({ + "unique_ids_raw_output___9:0": + np.array([item.qas_id], dtype=np.int64), + "input_ids:0": + input_ids[idx:idx + bs], + "input_mask:0": + input_mask[idx:idx + bs], + "segment_ids:0": + segment_ids[idx:idx + bs] + }) + + in_batch = result[1].get_shape().lens()[0] + start_logits = [float(x) for x in result[1].tolist()] + end_logits = [float(x) for x in result[0].tolist()] + for i in range(0, in_batch): + unique_id = len(all_results) + all_results.append( + RawResult(unique_id=unique_id, + start_logits=start_logits, + end_logits=end_logits)) + +output_dir = 'predictions' +os.makedirs(output_dir, exist_ok=True) +output_prediction_file = os.path.join(output_dir, "predictions.json") +output_nbest_file = os.path.join(output_dir, "nbest_predictions.json") +write_predictions(eval_examples, extra_data, all_results, n_best_size, + max_answer_length, True, output_prediction_file, + output_nbest_file) + +with open(output_prediction_file) as json_file: + test_data = json.load(json_file) + print(json.dumps(test_data, indent=2)) diff --git a/examples/nlp/python_bert_squad/inputs.json b/examples/nlp/python_bert_squad/inputs.json new file mode 100755 index 0000000000000000000000000000000000000000..bdc39e2fdcd2df12d79c1ddf78a901dcfdc53988 --- /dev/null +++ b/examples/nlp/python_bert_squad/inputs.json @@ -0,0 +1,27 @@ +{ + "version": "1.4", + "data": [ + { + "paragraphs": [ + { + "context": "In its early years, the new convention center failed to meet attendance and revenue expectations.[12] By 2002, many Silicon Valley businesses were choosing the much larger Moscone Center in San Francisco over the San Jose Convention Center due to the latter's limited space. A ballot measure to finance an expansion via a hotel tax failed to reach the required two-thirds majority to pass. In June 2005, Team San Jose built the South Hall, a $6.77 million, blue and white tent, adding 80,000 square feet (7,400 m2) of exhibit space", + "qas": [ + { + "question": "where is the businesses choosing to go?", + "id": "1" + }, + { + "question": "how may votes did the ballot measure need?", + "id": "2" + }, + { + "question": "By what year many Silicon Valley businesses were choosing the Moscone Center?", + "id": "3" + } + ] + } + ], + "title": "Conference Center" + } + ] +} \ No newline at end of file diff --git a/examples/nlp/python_bert_squad/inputs_amd.json b/examples/nlp/python_bert_squad/inputs_amd.json new file mode 100755 index 0000000000000000000000000000000000000000..4eb5182e302f48ce60405ad051d39320f59da6da --- /dev/null +++ b/examples/nlp/python_bert_squad/inputs_amd.json @@ -0,0 +1,26 @@ +{ + "data": [ + { + "paragraphs": [ + { + "context": "ROCm is the first open-source exascale-class platform for accelerated computing that’s also programming-language independent. It brings a philosophy of choice, minimalism and modular software development to GPU computing. You are free to choose or even develop tools and a language run time for your application. ROCm is built for scale, it supports multi-GPU computing and has a rich system run time with the critical features that large-scale application, compiler and language-run-time development requires. Since the ROCm ecosystem is comprised of open technologies: frameworks (Tensorflow / PyTorch), libraries (MIOpen / Blas / RCCL), programming model (HIP), inter-connect (OCD) and up streamed Linux® Kernel support – the platform is continually optimized for performance and extensibility.", + "qas": [ + { + "question": "What is ROCm?", + "id": "1" + }, + { + "question": "Which frameworks does ROCm support?", + "id": "2" + }, + { + "question": "What is ROCm built for?", + "id": "3" + } + ] + } + ], + "title": "AMD ROCm" + } + ] +} \ No newline at end of file diff --git a/examples/nlp/python_bert_squad/requirements_bertsquad.txt b/examples/nlp/python_bert_squad/requirements_bertsquad.txt new file mode 100644 index 0000000000000000000000000000000000000000..34bce678a5bc7070f55d04b72385be034c74a80a --- /dev/null +++ b/examples/nlp/python_bert_squad/requirements_bertsquad.txt @@ -0,0 +1,3 @@ +tensorflow==2.7.2 +onnxruntime +tokenizers \ No newline at end of file diff --git a/examples/nlp/python_bert_squad/run_onnx_squad.py b/examples/nlp/python_bert_squad/run_onnx_squad.py new file mode 100755 index 0000000000000000000000000000000000000000..1865721b4f3b84d206694d944eca522a7a33be09 --- /dev/null +++ b/examples/nlp/python_bert_squad/run_onnx_squad.py @@ -0,0 +1,619 @@ +# Modifications Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Inference for squad/bert using onnx. + +This is going to do the samem as 'python run_squad.py --do_predict=True ...' using a squad/bert model +that was converted to onnx. Lots of code was taken from run_squad.py. +You run it with: + + +python onnx_squad.py --model $SQUAD_MODEL/squad.onnx \ + --vocab_file $BERT_BASE_DIR/uncased_L-12_H-768_A-12/vocab.txt + --predict_file $SQUAD_DATA/dev-v1.1.json \ + --bert_config_file $BERT_BASE_DIR/uncased_L-12_H-768_A-12/bert_config.json \ + --output /tmp/ +""" + +import argparse +import collections +import json +import math +import os +import sys +from timeit import default_timer as timer + +import numpy as np +import onnxruntime as onnxrt +import six +from tokenizers import BertWordPieceTokenizer +from tokenizers import pre_tokenizers + +RawResult = collections.namedtuple("RawResult", + ["unique_id", "start_logits", "end_logits"]) + +Feature = collections.namedtuple("Feature", [ + "unique_id", "tokens", "example_index", "token_to_orig_map", + "token_is_max_context" +]) + + +class SquadExample(object): + """A single training/test example for simple sequence classification.""" + def __init__(self, + qas_id, + question_text, + doc_tokens, + orig_answer_text=None, + start_position=None, + end_position=None): + self.qas_id = qas_id + self.question_text = question_text + self.doc_tokens = doc_tokens + self.orig_answer_text = orig_answer_text + self.start_position = start_position + self.end_position = end_position + + def __str__(self): + return self.__repr__() + + def __repr__(self): + s = [] + s.append("qas_id: %s" % (self.qas_id)) + s.append("question_text: %s" % (self.question_text)) + s.append("doc_tokens: [%s]" % (" ".join(self.doc_tokens))) + if self.start_position: + s.append("start_position: %d" % (self.start_position)) + if self.start_position: + s.append("end_position: %d" % (self.end_position)) + return ", ".join(s) + + +def _check_is_max_context(doc_spans, cur_span_index, position): + """Check if this is the 'max context' doc span for the token.""" + + # Because of the sliding window approach taken to scoring documents, a single + # token can appear in multiple documents. E.g. + # Doc: the man went to the store and bought a gallon of milk + # Span A: the man went to the + # Span B: to the store and bought + # Span C: and bought a gallon of + # ... + # + # Now the word 'bought' will have two scores from spans B and C. We only + # want to consider the score with "maximum context", which we define as + # the *minimum* of its left and right context (the *sum* of left and + # right context will always be the same, of course). + # + # In the example the maximum context for 'bought' would be span C since + # it has 1 left context and 3 right context, while span B has 4 left context + # and 0 right context. + best_score = None + best_span_index = None + for (span_index, doc_span) in enumerate(doc_spans): + end = doc_span.start + doc_span.length - 1 + if position < doc_span.start: + continue + if position > end: + continue + num_left_context = position - doc_span.start + num_right_context = end - position + score = min(num_left_context, + num_right_context) + 0.01 * doc_span.length + if best_score is None or score > best_score: + best_score = score + best_span_index = span_index + + return cur_span_index == best_span_index + + +def convert_examples_to_features(examples, tokenizer, max_seq_length, + doc_stride, max_query_length): + """Loads a data file into a list of `InputBatch`s.""" + + res_input_ids = [] + res_input_mask = [] + res_segment_ids = [] + extra = [] + unique_id = 0 + + for (example_index, example) in enumerate(examples): + query_tokens = tokenizer.encode(example.question_text) + + if len(query_tokens) > max_query_length: + query_tokens = query_tokens[0:max_query_length] + + tok_to_orig_index = [] + orig_to_tok_index = [] + all_doc_tokens = [] + for (i, token) in enumerate(example.doc_tokens): + orig_to_tok_index.append(len(all_doc_tokens)) + sub_tokens = tokenizer.encode(token, add_special_tokens=False) + for sub_token in sub_tokens.tokens: + tok_to_orig_index.append(i) + all_doc_tokens.append(sub_token) + + # The -3 accounts for [CLS], [SEP] and [SEP] + max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 + + # We can have documents that are longer than the maximum sequence length. + # To deal with this we do a sliding window approach, where we take chunks + # of the up to our max length with a stride of `doc_stride`. + _DocSpan = collections.namedtuple("DocSpan", ["start", "length"]) + doc_spans = [] + start_offset = 0 + while start_offset < len(all_doc_tokens): + length = len(all_doc_tokens) - start_offset + if length > max_tokens_for_doc: + length = max_tokens_for_doc + doc_spans.append(_DocSpan(start=start_offset, length=length)) + if start_offset + length == len(all_doc_tokens): + break + start_offset += min(length, doc_stride) + + for (doc_span_index, doc_span) in enumerate(doc_spans): + tokens = [] + token_to_orig_map = {} + token_is_max_context = {} + segment_ids = [] + tokens.append("[CLS]") + segment_ids.append(0) + for token in query_tokens.tokens: + tokens.append(token) + segment_ids.append(0) + tokens.append("[SEP]") + segment_ids.append(0) + + for i in range(doc_span.length): + split_token_index = doc_span.start + i + token_to_orig_map[len( + tokens)] = tok_to_orig_index[split_token_index] + + is_max_context = _check_is_max_context(doc_spans, + doc_span_index, + split_token_index) + token_is_max_context[len(tokens)] = is_max_context + tokens.append(all_doc_tokens[split_token_index]) + segment_ids.append(1) + tokens.append("[SEP]") + segment_ids.append(1) + + input_ids = [] + for token in tokens: + input_ids.append(tokenizer.token_to_id(token)) + + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + input_mask = [1] * len(input_ids) + + # Zero-pad up to the sequence length. + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + res_input_ids.append(np.array(input_ids, dtype=np.int64)) + res_input_mask.append(np.array(input_mask, dtype=np.int64)) + res_segment_ids.append(np.array(segment_ids, dtype=np.int64)) + feature = Feature(unique_id=unique_id, + tokens=tokens, + example_index=example_index, + token_to_orig_map=token_to_orig_map, + token_is_max_context=token_is_max_context) + extra.append(feature) + unique_id += 1 + return np.array(res_input_ids), np.array(res_input_mask), np.array( + res_segment_ids), extra + + +def read_squad_examples(input_file): + """Read a SQuAD json file into a list of SquadExample.""" + with open(input_file, "r") as f: + input_data = json.load(f)["data"] + + def is_whitespace(c): + if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: + return True + return False + + examples = [] + for idx, entry in enumerate(input_data): + for paragraph in entry["paragraphs"]: + paragraph_text = paragraph["context"] + doc_tokens = [] + char_to_word_offset = [] + prev_is_whitespace = True + for c in paragraph_text: + if is_whitespace(c): + prev_is_whitespace = True + else: + if prev_is_whitespace: + doc_tokens.append(c) + else: + doc_tokens[-1] += c + prev_is_whitespace = False + char_to_word_offset.append(len(doc_tokens) - 1) + + for qa in paragraph["qas"]: + qas_id = qa["id"] + question_text = qa["question"] + start_position = None + end_position = None + orig_answer_text = None + example = SquadExample(qas_id=qas_id, + question_text=question_text, + doc_tokens=doc_tokens, + orig_answer_text=orig_answer_text, + start_position=start_position, + end_position=end_position) + examples.append(example) + return examples + + +def write_predictions(all_examples, all_features, all_results, n_best_size, + max_answer_length, do_lower_case, output_prediction_file, + output_nbest_file): + """Write final predictions to the json file.""" + example_index_to_features = collections.defaultdict(list) + for feature in all_features: + example_index_to_features[feature.example_index].append(feature) + + unique_id_to_result = {} + for result in all_results: + unique_id_to_result[result.unique_id] = result + + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", [ + "feature_index", "start_index", "end_index", "start_logit", + "end_logit" + ]) + + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + for (example_index, example) in enumerate(all_examples): + features = example_index_to_features[example_index] + prelim_predictions = [] + for (feature_index, feature) in enumerate(features): + if not feature.unique_id in unique_id_to_result: + print("feature not in unique_Id", feature.unique_id) + continue + result = unique_id_to_result[feature.unique_id] + + start_indexes = _get_best_indexes(result.start_logits, n_best_size) + end_indexes = _get_best_indexes(result.end_logits, n_best_size) + for start_index in start_indexes: + for end_index in end_indexes: + # We could hypothetically create invalid predictions, e.g., predict + # that the start of the span is in the question. We throw out all + # invalid predictions. + if start_index >= len(feature.tokens): + continue + if end_index >= len(feature.tokens): + continue + if start_index not in feature.token_to_orig_map: + continue + if end_index not in feature.token_to_orig_map: + continue + if not feature.token_is_max_context.get( + start_index, False): + continue + if end_index < start_index: + continue + length = end_index - start_index + 1 + if length > max_answer_length: + continue + prelim_predictions.append( + _PrelimPrediction( + feature_index=feature_index, + start_index=start_index, + end_index=end_index, + start_logit=result.start_logits[start_index], + end_logit=result.end_logits[end_index])) + + prelim_predictions = sorted(prelim_predictions, + key=lambda x: + (x.start_logit + x.end_logit), + reverse=True) + + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit"]) + + seen_predictions = {} + nbest = [] + for pred in prelim_predictions: + if len(nbest) >= n_best_size: + break + feature = features[pred.feature_index] + + tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] + orig_doc_start = feature.token_to_orig_map[pred.start_index] + orig_doc_end = feature.token_to_orig_map[pred.end_index] + orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] + tok_text = " ".join(tok_tokens) + + # De-tokenize WordPieces that have been split off. + tok_text = tok_text.replace(" ##", "") + tok_text = tok_text.replace("##", "") + + # Clean whitespace + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + final_text = get_final_text(tok_text, orig_text, do_lower_case) + if final_text in seen_predictions: + continue + + seen_predictions[final_text] = True + nbest.append( + _NbestPrediction(text=final_text, + start_logit=pred.start_logit, + end_logit=pred.end_logit)) + + # In very rare edge cases we could have no valid predictions. So we + # just create a nonce prediction in this case to avoid failure. + if not nbest: + nbest.append( + _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) + + assert len(nbest) >= 1 + + total_scores = [] + for entry in nbest: + total_scores.append(entry.start_logit + entry.end_logit) + + probs = _compute_softmax(total_scores) + + nbest_json = [] + for (i, entry) in enumerate(nbest): + output = collections.OrderedDict() + output["text"] = entry.text + output["probability"] = probs[i] + output["start_logit"] = float(entry.start_logit) + output["end_logit"] = float(entry.end_logit) + nbest_json.append(output) + + all_predictions[example.qas_id] = nbest_json[0]["text"] + all_nbest_json[example.qas_id] = nbest_json + + with open(output_prediction_file, "w") as writer: + writer.write(json.dumps(all_predictions, indent=4) + "\n") + + with open(output_nbest_file, "w") as writer: + writer.write(json.dumps(all_nbest_json, indent=4) + "\n") + + +def get_final_text(pred_text, orig_text, do_lower_case): + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original + # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So + # now `orig_text` contains the span of our original text corresponding to the + # span that we predicted. + # + # However, `orig_text` may contain extra characters that we don't want in + # our prediction. + # + # For example, let's say: + # pred_text = steve smith + # orig_text = Steve Smith's + # + # We don't want to return `orig_text` because it contains the extra "'s". + # + # We don't want to return `pred_text` because it's already been normalized + # (the SQuAD eval script also does punctuation stripping/lower casing but + # our tokenizer does additional normalization like stripping accent + # characters). + # + # What we really want to return is "Steve Smith". + # + # Therefore, we have to apply a semi-complicated alignment heruistic between + # `pred_text` and `orig_text` to get a character-to-charcter alignment. This + # can fail in certain cases in which case we just return `orig_text`. + + def _strip_spaces(text): + ns_chars = [] + ns_to_s_map = collections.OrderedDict() + for (i, c) in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_chars)] = i + ns_chars.append(c) + ns_text = "".join(ns_chars) + return (ns_text, ns_to_s_map) + + # We first tokenize `orig_text`, strip whitespace from the result + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. + tokenizer = pre_tokenizers.Sequence( + [pre_tokenizers.Whitespace(), + pre_tokenizers.Punctuation()]) + + tok_text = [] + for item in tokenizer.pre_tokenize_str(orig_text): + tok_text.append(item[0]) + + tok_text = " ".join(tok_text) + + start_position = tok_text.find(pred_text) + if start_position == -1: + return orig_text + end_position = start_position + len(pred_text) - 1 + + (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) + (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) + + if len(orig_ns_text) != len(tok_ns_text): + return orig_text + + # We then project the characters in `pred_text` back to `orig_text` using + # the character-to-character alignment. + tok_s_to_ns_map = {} + for (i, tok_index) in six.iteritems(tok_ns_to_s_map): + tok_s_to_ns_map[tok_index] = i + + orig_start_position = None + if start_position in tok_s_to_ns_map: + ns_start_position = tok_s_to_ns_map[start_position] + if ns_start_position in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + + if orig_start_position is None: + return orig_text + + orig_end_position = None + if end_position in tok_s_to_ns_map: + ns_end_position = tok_s_to_ns_map[end_position] + if ns_end_position in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + + if orig_end_position is None: + return orig_text + + output_text = orig_text[orig_start_position:(orig_end_position + 1)] + return output_text + + +def _get_best_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(enumerate(logits), + key=lambda x: x[1], + reverse=True) + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + + +def _compute_softmax(scores): + """Compute softmax probability over raw logits.""" + if not scores: + return [] + + max_score = None + for score in scores: + if max_score is None or score > max_score: + max_score = score + + exp_scores = [] + total_sum = 0.0 + for score in scores: + x = math.exp(score - max_score) + exp_scores.append(x) + total_sum += x + + probs = [] + for score in exp_scores: + probs.append(score / total_sum) + return probs + + +def main(): + parser = argparse.ArgumentParser(description='onnx squad') + parser.add_argument('--model', required=True, help='model') + parser.add_argument('--vocab_file', required=True, help='vocab_file') + parser.add_argument('--bert_config_file', help='vocab_file') + parser.add_argument('--predict_file', required=True, help='predict_file') + parser.add_argument('--output_dir', help='output dir') + parser.add_argument('--max_seq_length', + type=int, + default=256, + help='max_seq_length') + parser.add_argument('--max_query_length', + type=int, + default=64, + help='max_query_length') + parser.add_argument('--max_answer_length', + type=int, + default=30, + help='max_answer_length') + parser.add_argument('--n_best_size', + type=int, + default=20, + help='n_best_size') + parser.add_argument('--doc_stride', + type=int, + default=128, + help='doc_stride') + parser.add_argument('--batch_size', type=int, default=1, help='batch_size') + parser.add_argument('--profile', + action='store_true', + help='enable chrome timeline trace profiling.') + parser.add_argument('--log', type=int, help='log level.') + args = parser.parse_args() + + sess_options = None + if args.profile: + sess_options = onnxrt.SessionOptions() + sess_options.enable_profiling = True + sess_options.profile_file_prefix = os.path.basename(args.model) + if args.log: + sess_options = onnxrt.SessionOptions() + sess_options.session_log_verbosity_level = args.log + + tokenizer = BertWordPieceTokenizer(args.vocab_file) + + eval_examples = read_squad_examples(input_file=args.predict_file) + input_ids, input_mask, segment_ids, extra_data = \ + convert_examples_to_features(eval_examples, tokenizer, args.max_seq_length, + args.doc_stride, args.max_query_length) + + sess = onnxrt.InferenceSession(args.model, sess_options) + for input_meta in sess.get_inputs(): + print(input_meta) + n = len(input_ids) + bs = args.batch_size + all_results = [] + start = timer() + for idx in range(0, n, bs): + data = { + "input_ids:0": input_ids[idx:idx + bs], + "input_mask:0": input_mask[idx:idx + bs], + "segment_ids:0": segment_ids[idx:idx + bs] + } + result = sess.run(["unstack:0", "unstack:1"], data) + in_batch = result[0].shape[1] + for i in range(0, in_batch): + unique_id = len(all_results) + all_results.append( + RawResult(unique_id=unique_id, + start_logits=result[0][0][i], + end_logits=result[1][0][i])) + if unique_id > 0 and unique_id % 100 == 0: + print("at {} {}sec per item".format( + unique_id, (timer() - start) / unique_id)) + end = timer() + + print("total time: {}sec, {}sec per item".format( + end - start, (end - start) / len(all_results))) + + if args.output_dir: + output_prediction_file = os.path.join(args.output_dir, + "predictions.json") + output_nbest_file = os.path.join(args.output_dir, + "nbest_predictions.json") + write_predictions(eval_examples, extra_data, all_results, + args.n_best_size, args.max_answer_length, True, + output_prediction_file, output_nbest_file) + if args.profile: + trace_file = sess.end_profiling() + print("trace file written to: {}".format(trace_file)) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/vision/README.md b/examples/vision/README.md new file mode 100644 index 0000000000000000000000000000000000000000..47c7a53e434fe0651de2629f49daa0a92b4cfb3f --- /dev/null +++ b/examples/vision/README.md @@ -0,0 +1,8 @@ +# Vision Inference Examples + +- [CPP MNIST](./cpp_mnist) +- [Python Resnet50](./python_resnet50) +- [Python Super Resolution](./python_super_resolution) +- [Python NFNet](./python_nfnet) +- [Python U-Net](./python_unet) +- [Python 3D-UNet](./python_3dunet) \ No newline at end of file diff --git a/examples/vision/cpp_mnist/CMakeLists.txt b/examples/vision/cpp_mnist/CMakeLists.txt new file mode 100755 index 0000000000000000000000000000000000000000..8b7f2a6c058f80bce772653d3332ef1454769180 --- /dev/null +++ b/examples/vision/cpp_mnist/CMakeLists.txt @@ -0,0 +1,13 @@ +cmake_minimum_required(VERSION 3.5) +project (CAI) + +set (CMAKE_CXX_STANDARD 14) +set (EXAMPLE mnist_inference) + +list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm) +find_package (migraphx) + +message("source file: " ${EXAMPLE}.cpp " ---> bin: " ${EXAMPLE}) +add_executable(${EXAMPLE} ${EXAMPLE}.cpp) + +target_link_libraries(${EXAMPLE} migraphx::c) diff --git a/examples/vision/cpp_mnist/README.md b/examples/vision/cpp_mnist/README.md new file mode 100755 index 0000000000000000000000000000000000000000..1c4cd4c01d6e051ccaeb77dff2c335bdd6571c1a --- /dev/null +++ b/examples/vision/cpp_mnist/README.md @@ -0,0 +1,181 @@ +# Performing Inference Using C++ API + +## Description +This example demonstrates how to perform inference using the MIGraphX C++ API. The model used is a convolutional network pre-trained on the MNIST dataset, and inference is performed on a random digit selected from the test set. + +## Content +- [Basic Setup](#Basic-Setup) +- [Quantization](#Quantization) +- [Compilation](#Compilation) +- [Preparing Input Data](#Preparing-Input-Data) +- [Evaluating Inputs and Handling Outputs](#Evaluating-Inputs-and-Handling-Outputs) +- [**Running this Example**](#Running-this-Example) + +## Basic Setup +Before running inference, we must first instantiate a network graph and select a compilation target. See [this example](../cpp_parse_load_save) for more information about working with MIGraphX program objects. +``` +migraphx::program prog; +migraphx::onnx_options onnx_opts; +prog = parse_onnx("../mnist-8.onnx", onnx_opts); + +std::string target_str; +if(CPU) + target_str = "cpu"; +else if(GPU) + target_str = "gpu"; +else + target_str = "ref"; +migraphx::target targ = migraphx::target(target_str.c_str()); +``` + +## Quantization +Optionally, graph programs may be quantized to fp16 or int8 precision to improve performance and memory usage. + +##### Floating Point 16-bit Precision +To quantize using fp16, we simply add the following line: +``` +migraphx::quantize_fp16(prog); +``` + +##### Integer 8-bit Precision +Int8 quantization requires calibration to accurately map ranges of floating point values onto integer values. + +To calibrate prior to inference, one or more inputs can be supplied as follows: +``` +std::vector calib_dig; +// ... read in data + +migraphx::quantize_int8_options quant_opts; +migraphx::program_parameters quant_params; +auto param_shapes = prog.get_parameter_shapes(); +for(auto&& name : param_shapes.names()) +{ + quant_params.add(name, migraphx::argument(param_shapes[name], calib_dig.data())); +} + +quant_opts.add_calibration_data(quant_params); +migraphx::quantize_int8(prog, targ, quant_opts); +``` + +## Compilation +Network graphs saved in e.g. ONNX or protobuf format are not target-specific. In order to run inference, we must compile the graph into a target-specific program. + +Two options may be turned on when compiling: +- `set_offload_copy(bool value)`: For targets with offloaded memory (such as the gpu), this will insert instructions during compilation to copy the input parameters to the offloaded memory and to copy the final result from the offloaded memory back to main memory. Default value is `false` for offload_copy. +- `set_fast_math(bool value)`: Optimize math functions to use faster approximate versions. There may be slight accuracy degredation when enabled. Default value is `true` for fast_math. + +The following snippet assumes `targ` has been set as "gpu", and will compile the program without the fast_math optimization. +``` +migraphx::compile_options comp_opts; +comp_opts.set_offload_copy(); +prog.compile(targ, comp_opts); +``` + +To compile a program with the default options, we simply call: +``` +prog.compile(targ); +``` + +The targets "ref" and "cpu" both compile the program to run on the CPU. The target "ref" is primarily used for correctness checking. The target "cpu" is under ongoing development and has more optimizations enabled. Additionally, the "cpu" target requires MIGraphX to be built with the `-DMIGRAPHX_ENABLE_CPU=On` flag. Specifically, +``` +CXX=/opt/rocm/llvm/bin/clang++ cmake -DMIGRAPHX_ENABLE_CPU=On .. +``` + +## Preparing Input Data +Now that we have a compiled program, the last step to perform infernce is to prepare the input data as program parameters. +The first step is to read in the data and store it in a `std::vector` we will in this case call `digit`. +Next, we create a program parameter containing the data stored in `digit`: +``` +migraphx::program_parameters prog_params; +auto param_shapes = prog.get_parameter_shapes(); +for(auto&& name : param_shapes.names()) +{ + prog_params.add(name, migraphx::argument(param_shapes[name], digit.data())); +} +``` + +## Evaluating Inputs and Handling Outputs +Now that everything is in place, the final step to run inference is to call: +``` +auto outputs = prog.eval(prog_params); +``` + +The output layer(s) will be returned and stored in `outputs`. Our network for this example returns a single output layer with the shape (1, 10). The index of the largest value in this output layer corresponds to the digit that the model has predicted. +``` +auto shape = outputs[0].get_shape(); +auto lengths = shape.lengths(); +auto num_results = std::accumulate(lengths.begin(), lengths.end(), 1, std::multiplies(); +float* results = reinterpret_cast(outputs[0].data()); +float* max = std::max_element(results, results + num_results); +int answer = max - results; +``` + +Other networks may require alternative processing of outputs. + + +## Running this Example +This directory contains everything that is needed to perform inference on an MNIST digit. To create the executable: +``` +$ mkdir build +$ cd build +$ CXX=/opt/rocm/llvm/bin/clang++ cmake .. +$ make +``` +There will now be an executable named `mnist_inference` in the `build` directory. This can be run with or without options. Executing without any options will produce the following output: +``` +Usage: ./mnist_inference [options] +options: + -c, --cpu Compile for CPU + -g, --gpu Compile for GPU + -f, --fp16 FP16 Quantization + -i, --int8 Int8 Quantization + --cal Int8 Calibration ON + -p, --print Print Graph at Each Stage + + +Parsing ONNX model... + +Compiling program for ref... + +Model input: +@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +@@@@@@@@@@@@@@@@@%=@@@@@@@@@ +@@@@@@@@@@@@@0+. +@@@@@@@@ +@@@@@@@@@@@0+ .. 0@@@@@@@ +@@@@@@@@@@+ .00 #@@@@@@@ +@@@@@@@@@% .0@0 #@@@@@@@ +@@@@@@@@@- .*0@@% #@@@@@@@ +@@@@@@@@@0+#@@@@@% #@@@@@@@ +@@@@@@@@@@@@@@@@@* #@@@@@@@ +@@@@@@@@@@@@@====- -@@@@@@@@ +@@@@@@@@@@@#- .0@@@@@@@@ +@@@@@@@@@#. .* =@@@@@@@@ +@@@@@@@@% =#@@. %@@@@@@@ +@@@@@@@+ -@@@- +* -#00@@@ +@@@@@@+ =@@#- .#@@#* .@@@ +@@@@@= %@#* =0@@@@@%--0@@@ +@@@@@ .. =@@@@@@@@@@@@@@ +@@@@@. *=0@@@@@@@@@@@@@@@ +@@@@@@%+=@@@@@@@@@@@@@@@@@@@ +@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + +Model evaluating input... +Inference complete +Inference time: 0.022ms + +Randomly chosen digit: 2 +Result from inference: 2 + +CORRECT + +``` + +*Note: the actual digit selected and printed will not necessarily be the same as shown above. \ No newline at end of file diff --git a/examples/vision/cpp_mnist/digits.txt b/examples/vision/cpp_mnist/digits.txt new file mode 100755 index 0000000000000000000000000000000000000000..dab490e95c781aeeb2ae986a78f481a597e8fa4f Binary files /dev/null and b/examples/vision/cpp_mnist/digits.txt differ diff --git a/examples/vision/cpp_mnist/mnist-7.onnx b/examples/vision/cpp_mnist/mnist-7.onnx new file mode 100755 index 0000000000000000000000000000000000000000..bb189a52ea6ed40f3d5f94f0834f1fdab49022fe Binary files /dev/null and b/examples/vision/cpp_mnist/mnist-7.onnx differ diff --git a/examples/vision/cpp_mnist/mnist-8.onnx b/examples/vision/cpp_mnist/mnist-8.onnx new file mode 100755 index 0000000000000000000000000000000000000000..fc1a3f733c6e6243dd23dacb125b7a372de55a50 Binary files /dev/null and b/examples/vision/cpp_mnist/mnist-8.onnx differ diff --git a/examples/vision/cpp_mnist/mnist_inference.cpp b/examples/vision/cpp_mnist/mnist_inference.cpp new file mode 100755 index 0000000000000000000000000000000000000000..ecde0aa7c32b64ab90f980dd4b27a309cd593340 --- /dev/null +++ b/examples/vision/cpp_mnist/mnist_inference.cpp @@ -0,0 +1,184 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void read_nth_digit(const int, std::vector&); + +int main(int argc, char** argv) +{ + if(argc == 1) + { + std::cout << "Usage: " << argv[0] << " [options]" << std::endl + << "options:" << std::endl + << "\t -c, --cpu Compile for CPU" << std::endl + << "\t -g, --gpu Compile for GPU" << std::endl + << "\t -f, --fp16 FP16 Quantization" << std::endl + << "\t -i, --int8 Int8 Quantization" << std::endl + << "\t --cal Int8 Calibration ON" << std::endl + << "\t -p, --print Print Graph at Each Stage" << std::endl + << std::endl + << std::endl; + } + + char** begin = argv + 1; + char** end = argv + argc; + const bool CPU = (std::find(begin, end, std::string("-c")) != end) || + std::find(begin, end, std::string("--cpu")) != end; + const bool GPU = std::find(begin, end, std::string("-g")) != end || + std::find(begin, end, std::string("--gpu")) != end; + const bool FP16 = std::find(begin, end, std::string("-f")) != end || + std::find(begin, end, std::string("--fp16")) != end; + const bool INT8 = std::find(begin, end, std::string("-i")) != end || + std::find(begin, end, std::string("--int8")) != end; + const bool CALIB = std::find(begin, end, std::string("--cal")) != end; + const bool PRINT = std::find(begin, end, std::string("-p")) != end || + std::find(begin, end, std::string("--print")) != end; + + migraphx::program prog; + migraphx::onnx_options onnx_opts; + prog = parse_onnx("../mnist-8.onnx", onnx_opts); + + std::cout << "Parsing ONNX model..." << std::endl; + if(PRINT) + prog.print(); + std::cout << std::endl; + + std::string target_str; + if(CPU) + target_str = "cpu"; + else if(GPU) + target_str = "gpu"; + else + target_str = "ref"; + migraphx::target targ = migraphx::target(target_str.c_str()); + + if(FP16) + { + migraphx::quantize_fp16(prog); + + std::cout << "Quantizing program for FP16..." << std::endl; + if(PRINT) + prog.print(); + std::cout << std::endl; + } + else if(INT8) + { + if(CALIB) + { + std::cout << "Calibration data: " << std::endl; + std::vector calib_dig; + read_nth_digit(9, calib_dig); + + migraphx::quantize_int8_options quant_opts; + migraphx::program_parameters quant_params; + auto param_shapes = prog.get_parameter_shapes(); + for(auto&& name : param_shapes.names()) + { + quant_params.add(name, migraphx::argument(param_shapes[name], calib_dig.data())); + } + + quant_opts.add_calibration_data(quant_params); + migraphx::quantize_int8(prog, targ, quant_opts); + } + else + { + migraphx::quantize_int8(prog, targ, migraphx::quantize_int8_options()); + } + + std::cout << "Quantizing program for INT8..." << std::endl; + if(PRINT) + prog.print(); + std::cout << std::endl; + } + + if(GPU) + { + migraphx::compile_options comp_opts; + comp_opts.set_offload_copy(); + prog.compile(targ, comp_opts); + } + else + { + prog.compile(targ); + } + + std::cout << "Compiling program for " << target_str << "..." << std::endl; + if(PRINT) + prog.print(); + std::cout << std::endl; + + std::vector digit; + std::random_device rd; + std::uniform_int_distribution dist(0, 9); + const int rand_digit = dist(rd); + std::cout << "Model input: " << std::endl; + read_nth_digit(rand_digit, digit); + + migraphx::program_parameters prog_params; + auto param_shapes = prog.get_parameter_shapes(); + auto input = param_shapes.names().front(); + prog_params.add(input, migraphx::argument(param_shapes[input], digit.data())); + + std::cout << "Model evaluating input..." << std::endl; + auto start = std::chrono::high_resolution_clock::now(); + auto outputs = prog.eval(prog_params); + auto stop = std::chrono::high_resolution_clock::now(); + auto elapsed = std::chrono::duration_cast(stop - start); + std::cout << "Inference complete" << std::endl; + std::cout << "Inference time: " << elapsed.count() * 1e-3 << "ms" << std::endl; + + auto shape = outputs[0].get_shape(); + auto lengths = shape.lengths(); + auto num_results = + std::accumulate(lengths.begin(), lengths.end(), 1, std::multiplies()); + float* results = reinterpret_cast(outputs[0].data()); + float* max = std::max_element(results, results + num_results); + int answer = max - results; + + std::cout << std::endl + << "Randomly chosen digit: " << rand_digit << std::endl + << "Result from inference: " << answer << std::endl + << std::endl + << (answer == rand_digit ? "CORRECT" : "INCORRECT") << std::endl + << std::endl; + + return 0; +} + +void read_nth_digit(const int n, std::vector& digit) +{ + const std::string SYMBOLS = "@0#%=+*-. "; + std::ifstream file("../digits.txt"); + const int DIGITS = 10; + const int HEIGHT = 28; + const int WIDTH = 28; + + if(!file.is_open()) + { + return; + } + + for(int d = 0; d < DIGITS; ++d) + { + for(int i = 0; i < HEIGHT * WIDTH; ++i) + { + unsigned char temp = 0; + file.read((char*)&temp, sizeof(temp)); + if(d == n) + { + float data = temp / 255.0; + digit.push_back(data); + std::cout << SYMBOLS[(int)(data * 10) % 11]; + if((i + 1) % WIDTH == 0) + std::cout << std::endl; + } + } + } + std::cout << std::endl; +} diff --git a/examples/vision/python_3dunet/3dunet_inference.ipynb b/examples/vision/python_3dunet/3dunet_inference.ipynb new file mode 100755 index 0000000000000000000000000000000000000000..a01f679ad1b541f483a14c2477ecfeb3cf999ece --- /dev/null +++ b/examples/vision/python_3dunet/3dunet_inference.ipynb @@ -0,0 +1,643 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fee8cfa5", + "metadata": {}, + "source": [ + "# 3D-UNet Example with MIGraphX\n", + "References:
\n", + "https://github.com/naomifridman/Unet_Brain_tumor_segmentation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09ceec31", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install SimpleITK matplotlib scikit-image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb22bcc4", + "metadata": {}, + "outputs": [], + "source": [ + "import migraphx\n", + "from PIL import Image\n", + "import numpy as np\n", + "import os\n", + "import SimpleITK as sitk" + ] + }, + { + "cell_type": "markdown", + "id": "cb973c63", + "metadata": {}, + "source": [ + "## Fetch U-NET ONNX Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1928662c", + "metadata": {}, + "outputs": [], + "source": [ + "!wget -nc https://zenodo.org/record/3928973/files/224_224_160.onnx" + ] + }, + { + "cell_type": "markdown", + "id": "1a64a616", + "metadata": {}, + "source": [ + "## Load ONNX Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53928a98", + "metadata": {}, + "outputs": [], + "source": [ + "model = migraphx.parse_onnx(\"224_224_160.onnx\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27e8587f", + "metadata": {}, + "outputs": [], + "source": [ + "model.compile(migraphx.get_target(\"gpu\"))" + ] + }, + { + "cell_type": "markdown", + "id": "2f6014a4", + "metadata": {}, + "source": [ + "## Print model parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e73728c", + "metadata": {}, + "outputs": [], + "source": [ + "print(model.get_parameter_names())\n", + "print(model.get_parameter_shapes())\n", + "print(model.get_output_shapes())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4cac52e", + "metadata": {}, + "outputs": [], + "source": [ + "img_type=['FLAIR', 'T1','T1CE', 'T2']\n", + "label_type_shrt = ['background', 'necrotic',\n", + " 'edema', 'enhancing']\n", + "label_type = ['background', 'necrotic and non-enhancing tumor', 'edema', 'enhancing tumor']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b65f9297", + "metadata": {}, + "outputs": [], + "source": [ + "red_multiplier = [1, 0.2, 0.2]\n", + "green_multiplier = [0.35,0.75,0.25]\n", + "blue_multiplier = [0,0.5,1.]#[0,0.25,0.9]\n", + "yellow_multiplier = [1,1,0.25]\n", + "brown_miltiplier = [40./255, 26./255, 13./255]\n", + "my_colors=[blue_multiplier, yellow_multiplier, brown_miltiplier]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e175ac5", + "metadata": {}, + "outputs": [], + "source": [ + "from importlib import reload # Python 3.4+ only." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "530e4f97", + "metadata": {}, + "outputs": [], + "source": [ + "import visualization_utils as vu\n", + "from visualization_utils import show_label_on_image4\n", + "reload(vu)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "865c46a2", + "metadata": {}, + "outputs": [], + "source": [ + "def show_img_label(img, lbl, modality = 0):\n", + " \n", + " if (len(lbl.shape)> 2):\n", + " lbl[0,0,3]=1 # for uniqe colors in plot\n", + " lbl = lbl_from_cat(lbl)\n", + " vu.show_n_images([img[:,:,modality],lbl, show_label_on_image4(img[:,:,modality],lbl)],\n", + " titles = [img_type[modality], 'Label', 'Label on '+ img_type[modality]]);\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e926482", + "metadata": {}, + "outputs": [], + "source": [ + "def read_img_sitk(img):\n", + " inputImage = sitk.ReadImage( img )\n", + " inputImage = sitk.Cast( inputImage, sitk.sitkFloat32 )\n", + " image = sitk.GetArrayFromImage(inputImage)\n", + " return image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b620138", + "metadata": {}, + "outputs": [], + "source": [ + "# ima files are of the form\n", + "# BraTS19_TCIA04_192_1_flair.nii.gz \n", + "# BraTS19_TCIA04_192_1_t1.nii.gz \n", + "# BraTS19_TCIA04_192_1_t2.nii.gz\n", + "# BraTS19_TCIA04_192_1_seg.nii.gz \n", + "# BraTS19_TCIA04_192_1_t1ce.nii.gz\n", + "\n", + "def read_image_into_numpy(dirpath):\n", + " \n", + " img_id = os.path.basename(dirpath)\n", + " np_image=np.zeros((4, 160, 224, 224), dtype=np.float32)\n", + " \n", + " ## Flair\n", + " flair_img = os.path.join(dirpath, img_id+'_flair.nii.gz')\n", + " if (not os.path.isfile(flair_img)):\n", + " print(flair_img,' not found aborting')\n", + " return None\n", + " np_image[0] = read_img_sitk(flair_img)\n", + " \n", + " ## T1\n", + " t1_nb4_img = os.path.join(dirpath, img_id+'_t1_nb4.nii.gz')\n", + " if (not os.path.isfile(t1_nb4_img)):\n", + " #print(t1_nb4_img,' not found')\n", + " t1_img = os.path.join(dirpath, img_id+'_t1.nii.gz')\n", + " if (not os.path.isfile(t1_img)):\n", + " print(t1_img,' not found aborting')\n", + " return None\n", + " np_image[1] = read_img_sitk(t1_img)\n", + " else:\n", + " np_image[1] = read_img_sitk(t1_nb4_img) \n", + " \n", + " ## T1CE\n", + " t1ce_nb4_img = os.path.join(dirpath, img_id+'_t1ce_nb4.nii.gz')\n", + " if (not os.path.isfile(t1ce_nb4_img)):\n", + " #print(t1ce_nb4_img,' not found')\n", + " t1ce_img = os.path.join(dirpath, img_id+'_t1ce.nii.gz')\n", + " if (not os.path.isfile(t1ce_img)):\n", + " print(t1ce_img,' not found aborting')\n", + " return None\n", + " np_image[2] = read_img_sitk(t1ce_img)\n", + " else:\n", + " np_image[2] = read_img_sitk(t1ce_nb4_img) \n", + " \n", + " \n", + " ## T2\n", + " t2_img = os.path.join(dirpath, img_id+'_t2.nii.gz')\n", + " if (not os.path.isfile(t2_img)):\n", + " print(t2_img,' not found aborting')\n", + " return None\n", + " np_image[3] = read_img_sitk(t2_img)\n", + "\n", + " return np_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2fb66f17", + "metadata": {}, + "outputs": [], + "source": [ + "def read_label_into_numpy(dirpath):\n", + " \n", + " img_id = os.path.basename(dirpath)\n", + " np_image=np.zeros((160, 224, 224), dtype=np.int)\n", + " \n", + " ## label\n", + " label_img = os.path.join(dirpath, img_id+'_seg.nii.gz')\n", + " if (not os.path.isfile(label_img)):\n", + " print(label_img,' not found aborting')\n", + " return None\n", + " np_image = read_img_sitk(label_img).astype(int)\n", + "\n", + " return np_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "558d47b9", + "metadata": {}, + "outputs": [], + "source": [ + "def bbox2_3D(img):\n", + "\n", + " r = np.any(img, axis=(1, 2))\n", + " c = np.any(img, axis=(0, 2))\n", + " z = np.any(img, axis=(0, 1))\n", + "\n", + " rmin, rmax = np.where(r)[0][[0, -1]]\n", + " cmin, cmax = np.where(c)[0][[0, -1]]\n", + " zmin, zmax = np.where(z)[0][[0, -1]]\n", + "\n", + " return [rmin, rmax, cmin, cmax, zmin, zmax]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1405e186", + "metadata": {}, + "outputs": [], + "source": [ + "def lbl_from_cat(cat_lbl):\n", + " \n", + " lbl=0\n", + " if (len(cat_lbl.shape)==3):\n", + " for i in range(1,4):\n", + " lbl = lbl + cat_lbl[:,:,i]*i\n", + " elif (len(cat_lbl.shape)==4):\n", + " for i in range(1,4):\n", + " lbl = lbl + cat_lbl[:,:,:,i]*i\n", + " else:\n", + " print('Error in lbl_from_cat', cat_lbl.shape)\n", + " return None\n", + " return lbl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24eb472f", + "metadata": {}, + "outputs": [], + "source": [ + "def show_label(lbl):\n", + " vu.show_n_images([lbl[:,:,k] for k in range(4)]+[lbl_from_cat(lbl)],\n", + " titles = label_type_shrt + ['Label'])\n", + "\n", + "def show_pred_im_label(im, lb, pred):\n", + " \n", + " vu.show_n_images([im[:,:,1], lb[:,:], \n", + " show_label_on_image4(im[:,:,1], lb[:,:]),\n", + " show_label_on_image4(im[:,:,1], pred[:,:])],\n", + " titles=['Flair', 'Label', 'Label on T1', 'Prediction on Flair'])\n", + "\n", + "def show_pred_im(im, pred):\n", + " \n", + " vu.show_n_images([im[:,:,1], \n", + " im[:,:,0],pred,\n", + " show_label_on_image4(im[:,:,1], pred[:,:])],\n", + " titles=['Flair','T1', 'Pred', 'Prediction on Flair'])" + ] + }, + { + "cell_type": "markdown", + "id": "d15f788b", + "metadata": {}, + "source": [ + "Multiple image inputs:\n", + "- Native (T1)\n", + "- Post-contrast T1-weighted (T1Gd)\n", + "- T2-weighted (T2)\n", + "- T2 Fluid Attenuated Inversion Recovery (T2-FLAIR)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a7aad87", + "metadata": {}, + "outputs": [], + "source": [ + "# Resize input images\n", + "from scipy.ndimage import zoom\n", + "\n", + "def resize(img, shape, mode='constant', orig_shape=(155, 240, 240)):\n", + " \"\"\"\n", + " Wrapper for scipy.ndimage.zoom suited for MRI images.\n", + " \"\"\"\n", + " assert len(shape) == 3, \"Can not have more than 3 dimensions\"\n", + " factors = (\n", + " shape[0]/orig_shape[0],\n", + " shape[1]/orig_shape[1], \n", + " shape[2]/orig_shape[2]\n", + " )\n", + " \n", + " # Resize to the given shape\n", + " return zoom(img, factors, mode=mode)\n", + "\n", + "def preprocess_label(img, out_shape=None, mode='nearest'):\n", + " \"\"\"\n", + " Separates out the 3 labels from the segmentation provided, namely:\n", + " GD-enhancing tumor (ET — label 4), the peritumoral edema (ED — label 2))\n", + " and the necrotic and non-enhancing tumor core (NCR/NET — label 1)\n", + " \"\"\"\n", + " ncr = img == 1 # Necrotic and Non-Enhancing Tumor (NCR/NET)\n", + " \n", + " ed = img == 2 # Peritumoral Edema (ED)\n", + " et = img == 4 # GD-enhancing Tumor (ET)\n", + " \n", + " if out_shape is not None:\n", + " ncr = resize(ncr, out_shape, mode=mode)\n", + " ed = resize(ed, out_shape, mode=mode)\n", + " et = resize(et, out_shape, mode=mode)\n", + " return np.array([ncr, ed, et], dtype=np.uint8)\n", + "\n", + "hgg_path = \"/code/AMDMIGraphX/bratsdata/MICCAI_BraTS_2019_Data_Training/HGG\"\n", + "np_image=np.zeros((4, 160, 224, 224), dtype=np.float32)\n", + "tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_flair.nii.gz'%hgg_path)\n", + "tmp = resize(tmp, [160,224,224])\n", + "mean = tmp.mean()\n", + "std = tmp.std()\n", + "np_image[0] = (tmp - mean) / std\n", + "\n", + "tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_t1.nii.gz'%hgg_path)\n", + "tmp = resize(tmp, [160,224,224])\n", + "mean = tmp.mean()\n", + "std = tmp.std()\n", + "np_image[1] = (tmp - mean) / std\n", + "\n", + "tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_t1ce.nii.gz'%hgg_path)\n", + "tmp = resize(tmp, [160,224,224])\n", + "mean = tmp.mean()\n", + "std = tmp.std()\n", + "np_image[2] = (tmp - mean) / std\n", + "\n", + "tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_t2.nii.gz'%hgg_path)\n", + "tmp = resize(tmp, [160,224,224])\n", + "mean = tmp.mean()\n", + "std = tmp.std()\n", + "np_image[3] = (tmp - mean) / std\n", + "\n", + "print(np_image.shape)\n", + "np_image_tmp = np_image.copy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7e5b3c6", + "metadata": {}, + "outputs": [], + "source": [ + "vu.show_n_images(np_image[:,100,:,:], titles=img_type)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19117da5", + "metadata": {}, + "outputs": [], + "source": [ + "np_lbl=np.zeros((160, 224, 224), dtype=np.int)\n", + "tmp = read_img_sitk('/code/AMDMIGraphX/bratsdata/MICCAI_BraTS_2019_Data_Training/HGG/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_seg.nii.gz').astype(int)\n", + "tmp = resize(tmp, [160,224,224])\n", + "print(tmp.shape)\n", + "np_lbl = tmp.astype(int)\n", + "print(np_lbl.shape)\n", + "\n", + "print(np_image.shape)\n", + "\n", + "img1 = vu.show_label_on_image4(np_image[1,100,:,:], np_lbl[100])\n", + "img2 = vu.show_label_on_image(np_image[1,100,:,:], np_lbl[100])\n", + "vu.show_n_images([img1,img2,np_image[0,100]])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "facdea15", + "metadata": {}, + "outputs": [], + "source": [ + "def get_pred(img, threshold=0.5):\n", + " out_img=img.copy()\n", + " out_img=np.where(out_img>threshold, 1,0)\n", + " return out_img\n", + "\n", + "def prediction_from_probabily_3D(img):\n", + " \n", + " int_image = get_pred(img)\n", + " return lbl_from_cat(int_image)\n", + "\n", + "def get_prediction_for_batch(pred_batch, threshold=0.5):\n", + " \n", + " out_batch = np.zeros((pred_batch.shape[0], 224, 224),dtype=np.int)\n", + " \n", + " for j in range(pred_batch.shape[0]):\n", + " pred = get_prediction(pred_batch[j])\n", + " if (pred.sum()>0):\n", + " print(j, np.unique(pred , return_counts=True))\n", + " out_batch[j] = lbl_from_cat(get_prediction(pred_batch[j]))\n", + " return out_batch\n", + "\n", + "def get_label_from_pred_batch(labels_batch):\n", + " \n", + " batch = np.zeros((labels_batch.shape[0], 224, 224), np.uint8)\n", + " \n", + " for j in range(labels_batch.shape[0]):\n", + " batch[j]=get_pred(labels_batch[j,:,:,0])+\\\n", + " get_pred(labels_batch[j,:,:,1])*2+\\\n", + " get_pred(labels_batch[j,:,:,2])*4\n", + "\n", + " return batch\n", + "\n", + "def predict_3D_img_prob(np_file):\n", + " \n", + " np_img = np.load(np_file)\n", + " for_pred_img = np.zeros((160, 224, 224, 4), np.float32)\n", + "\n", + " # Normalize image\n", + " for_pred_img = normalize_3D_image(np_img)\n", + "\n", + " mdl_pred_img = model.predict(for_pred_img)\n", + "\n", + " #pred_label = prediction_from_probabily_3D(mdl_pred_img)\n", + "\n", + " return mdl_pred_img\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f7fe7ee", + "metadata": {}, + "outputs": [], + "source": [ + "#Remember the MIGraphX model inputs\n", + "print(model.get_parameter_names())\n", + "print(model.get_parameter_shapes())\n", + "\n", + "np_image = np_image.transpose((0,2,3,1))\n", + "\n", + "print(np_image.shape)\n", + "print(np_image.strides)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfc47b53", + "metadata": {}, + "outputs": [], + "source": [ + "def normalize_3D_image(img):\n", + " for z in range(img.shape[0]):\n", + " for k in range(4):\n", + " if (img[z,:,:,k].max()>0):\n", + " img[z,:,:,k] /= img[z,:,:,k].max()\n", + " return img" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f990cb50", + "metadata": {}, + "outputs": [], + "source": [ + "print(np_image_tmp.shape)\n", + "np_image_tmp = np_image_tmp.transpose((1,2,3,0))\n", + "print(np_image_tmp.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24c3736d", + "metadata": {}, + "outputs": [], + "source": [ + "np_image = np.expand_dims(np_image, 0)\n", + "print(np_image.shape)\n", + "print(np_image.strides)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1aac6285", + "metadata": {}, + "outputs": [], + "source": [ + "input_im = np.zeros((1,4,224,224,160),dtype='float32')\n", + "np.lib.stride_tricks.as_strided(input_im, shape=np_image.shape, strides=input_im.strides)[:] = np_image #getting correct stride\n", + "print(input_im.strides)\n", + "print(input_im.shape)\n", + "\n", + "#input_im = normalize_3D_image(input_im)\n", + "\n", + "print(input_im.strides)\n", + "print(input_im.shape)\n", + "\n", + "result = model.run({\n", + " \"input\": input_im\n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5848b63d", + "metadata": {}, + "outputs": [], + "source": [ + "output = np.array(result[0])\n", + "print(output.shape)\n", + "output = output[0]\n", + "print(output.shape)\n", + "output = output.transpose((3,1,2,0))\n", + "print(output.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab77f7e9", + "metadata": {}, + "outputs": [], + "source": [ + "out = prediction_from_probabily_3D(output)\n", + "print(np_image_tmp.shape)\n", + "print(np_lbl.shape)\n", + "print(out.shape)\n", + "print(np.unique(out))\n", + "ind=[100]\n", + "for i in ind:\n", + " show_label(output[i])\n", + " show_label(get_pred(output[i]))\n", + " show_pred_im_label(np_image_tmp[i], np_lbl[i], out[i])" + ] + }, + { + "cell_type": "markdown", + "id": "d2862d81", + "metadata": {}, + "source": [ + "The possible prediction discrepancy is due to the not-perfect resizing 3D input image, as BRATS dataset has 3D images of size 160x240x240, meanwhile the ONNX model utilized here requires 155x224x224. This example is representative for how to utilize MIGraphX for such an application. All data processing should follow and match the model requirements otherwise. " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/vision/python_3dunet/README.md b/examples/vision/python_3dunet/README.md new file mode 100755 index 0000000000000000000000000000000000000000..2155aef5663b31d1f43b82522d9e906a654d2a39 --- /dev/null +++ b/examples/vision/python_3dunet/README.md @@ -0,0 +1,7 @@ +# 3D-Unet Inference with AMD MIGraphX + +This example applies image segmentation to 3D images using AMD MIGraphX on a given AMD GPU. + +## How to: +1) User will need to have access to the BRATS dataset. Please follow https://www.med.upenn.edu/cbica/brats2019/data.html for how to get access to the dataset. +2) Follow the provided notebook `3dunet_inference.ipynb`. diff --git a/examples/vision/python_3dunet/visualization_utils.py b/examples/vision/python_3dunet/visualization_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..8372c6b44c1d3bb94adc8a8e1355ba87ec734450 --- /dev/null +++ b/examples/vision/python_3dunet/visualization_utils.py @@ -0,0 +1,118 @@ +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +import matplotlib.pylab as pylab +import numpy as np + +params = { + 'legend.fontsize': 'x-large', + 'figure.figsize': (6, 5), + 'axes.labelsize': 'x-large', + 'axes.titlesize': 'x-large', + 'xtick.labelsize': 'x-large', + 'ytick.labelsize': 'x-large' +} +pylab.rcParams.update(params) + + +#----------------------------------------------------------- +def show_n_images(imgs, titles=None, enlarge=20, cmap='jet'): + + plt.set_cmap(cmap) + n = len(imgs) + gs1 = gridspec.GridSpec(1, n) + + fig1 = plt.figure() + # create a figure with the default size + fig1.set_size_inches(enlarge, 2 * enlarge) + + for i in range(n): + + ax1 = fig1.add_subplot(gs1[i]) + + ax1.imshow(imgs[i], interpolation='none') + if (titles is not None): + ax1.set_title(titles[i]) + ax1.set_ylim(ax1.get_ylim()[::-1]) + + plt.show() + + +#-------------------------------------------------------------- +from skimage import color, img_as_float +from skimage.exposure import adjust_gamma + + +# Creates an image of original brain with segmentation overlay +def show_label_on_image(test_img, test_lbl): + + label_im = test_lbl + + ones = np.argwhere(label_im == 1) + twos = np.argwhere(label_im == 2) + threes = np.argwhere(label_im == 3) + fours = np.argwhere(label_im == 4) + + gray_img = img_as_float(test_img / test_img.max()) + + # adjust gamma of image + # print(color.gray2rgb(gray_img)) + image = adjust_gamma(np.abs(color.gray2rgb(gray_img)), 0.45) + #sliced_image = image.copy() + + green_multiplier = [0.35, 0.75, 0.25] + blue_multiplier = [0, 0.5, 1.] #[0,0.25,0.9] + yellow_multiplier = [1, 1, 0.25] + brown_miltiplier = [40. / 255, 26. / 255, 13. / 255] + + # change colors of segmented classes + for i in range(len(ones)): + image[ones[i][0]][ones[i][1]] = blue_multiplier + for i in range(len(twos)): + image[twos[i][0]][twos[i][1]] = yellow_multiplier + for i in range(len(threes)): + image[threes[i][0]][threes[i][1]] = brown_miltiplier #blue_multiplier + for i in range(len(fours)): + image[fours[i][0]][fours[i][1]] = green_multiplier #yellow_multiplier + + return image + + +#------------------------------------------------------------------------------------- +def show_label_on_image4(test_img, label_im): + + alpha = 0.8 + + img = img_as_float(test_img / test_img.max()) + rows, cols = img.shape + + # Construct a colour image to superimpose + color_mask = np.zeros((rows, cols, 3)) + green_multiplier = [0.35, 0.75, 0.25] + blue_multiplier = [0, 0.25, 0.9] + yellow_multiplier = [1, 1, 0.25] + brown_miltiplier = [40. / 255, 26. / 255, 13. / 255] + + color_mask[label_im == 1] = blue_multiplier #[1, 0, 0] # Red block + color_mask[label_im == 2] = yellow_multiplier #[0, 1, 0] # Green block + color_mask[label_im == 3] = brown_miltiplier #[0, 0, 1] # Blue block + color_mask[label_im == 4] = green_multiplier #[0, 1, 1] # Blue block + + # Construct RGB version of grey-level image + img_color = np.dstack((img, img, img)) + + # Convert the input image and color mask to Hue Saturation Value (HSV) + # colorspace + img_hsv = color.rgb2hsv(img_color) + color_mask_hsv = color.rgb2hsv(color_mask) + + # Replace the hue and saturation of the original image + # with that of the color mask + img_hsv[..., 0] = color_mask_hsv[..., 0] + img_hsv[..., 1] = color_mask_hsv[..., 1] * alpha + + img_masked = color.hsv2rgb(img_hsv) + + return img_masked + + +#------------------------------------------------------------------------------ diff --git a/examples/vision/python_nfnet/README.md b/examples/vision/python_nfnet/README.md new file mode 100755 index 0000000000000000000000000000000000000000..89dc85aee4e35e49eaee02bbfd94a34cd35e3bf9 --- /dev/null +++ b/examples/vision/python_nfnet/README.md @@ -0,0 +1,55 @@ +# NFNet Inference with MIGraphX + +## NFNet +NFNet: Normalizer-Free Nets. An image recognition model that can be trained without batch normalization layers. It instead uses gradient clipping algorithm to provide same affects of BatchNorm. + +**Summary:** +- SOTA on ImageNet (86.5% top-1 w/o extra data) +- Up to 8.7x faster to train than EfficientNets to a given accuracy +- Normalizer-free (no BatchNorm) + +**Paper**: https://arxiv.org/pdf/2102.06171.pdf + +**Colab notebook**: https://github.com/deepmind/deepmind-research/tree/master/nfnets + +### Why not batch norm? + +Batch normalization has three significant practical disadvantages: +1. It is an expensive computational primitive, which incurs memory overhead and significantly increases the time required to evaluate the gradient in some networks. +2. It introduces a discrepancy between the behavior of the model during training and at inference time, introducing hidden hyper-parameters that have to be tuned. +3. Last and most important point, batch normalization breaks the independence between training examples in the minibatch (batch size matters with batch norm, distributed training becomes extremely cumbersome). + +Instead: + +- Authors provide Adaptive Gradient Clipping (AGC), which clips gradients based on the unit-wise ratio of gradient norms to parameter norms, and they demonstrate that AGC allows them to train normalizer-free networks with larger batch sizes and stronger data augmentations. +- They design a family of Normalizer-Free ResNets, called NFNets, which set new state-of-the-art validation accuracies on ImageNet for a range of training latencies. Their NFNet-F1 model achieves similar accuracy to EfficientNet-B7 while being 8.7× faster to train, and their largest model sets a new overall state of the art without extra data of 86.5% top-1 accuracy. +- They show that NFNets achieve substantially higher validation accuracies than batch-normalized networks when fine-tuning on ImageNet after pre-training on a large private dataset of 300 million labelled images. Their best model achieves 89.2% top-1 accuracy after fine-tuning. + +## Inference with MIGraphX using NFNet ONNX Model + +There is no ONNX model released for NFNet, as of June 2021, however a PyTorch model is available at: +https://github.com/rwightman/pytorch-image-models. +We provide an in-house produced and optimized ONNX model, which can be parsed and compiled using MIGraphX for AMD GPUs. The ONNX model file can be fetched using the Jupyter notebook we provide. + +### Requirements: +1) AMD GPU system with ROCm installed. +2) Jupyter notebook library. + +### How to use NFNet for image recognition: +Please utilize the notebook example provided: +1) Install jupyter notebook to your environment if not already installed: +``` +https://jupyter.org/install +``` +2) Connect to your jupyter server and utilize `nfnet_inference.ipynb` notebook file. + +### How to compare MIGraphX to ONNX Runtime for NFNet ONNX model: +First install requirements: +``` +pip3 install -r requirements_nfnet.txt +``` + +On your terminal, invoke: +``` +python3 ort_comparison.py +```` diff --git a/examples/vision/python_nfnet/nfnet_inference.ipynb b/examples/vision/python_nfnet/nfnet_inference.ipynb new file mode 100755 index 0000000000000000000000000000000000000000..39ed7f055196561a2b2434c7a0c35d50382c5603 --- /dev/null +++ b/examples/vision/python_nfnet/nfnet_inference.ipynb @@ -0,0 +1,260 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# NFNet Inference with AMD MIGraphX\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Normalizer-Free ResNet is a new residual convolutional network providing new state-of-the-art Top-1 accuracy of 86.5% at ImageNet dataset. The most important feature of the model is removing batch normalization. Instead of batch normalization, it uses adaptive gradient clipping to provide same regularization effect of BatchNorm.
Details of this network: https://arxiv.org/abs/2102.06171\n", + "\n", + "In this notebook, we are showing:
\n", + "- How to optimize NFNet ONNX model with AMD MIGraphX.\n", + "- How to run inference on AMD GPU with the optimized ONNX model.\n", + "\n", + "The NFNet utilized in this example is the smallest NFNet version, F0: 71.5M parameters (83.6% top-1 accuracy on ImageNet)\n", + "\n", + "Please make sure MIGraphX Python API is installed following the instructions at Github page." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Requirements" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!apt-get update\n", + "!apt-get install ffmpeg libsm6 libxext6 -y " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip3 install --upgrade pip\n", + "!pip3 install -r requirements_nfnet.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import cv2\n", + "import json\n", + "from PIL import Image\n", + "import time\n", + "from os import path " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Importing AMD MIGraphX Python Module" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import migraphx" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create NFNet ONNX file\n", + "Following repository provides functionality to create NFNet ONNX file from PyTorch model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!wget -nc https://www.dropbox.com/s/u4ga8zyxtppfzxc/dm_nfnet_f0.onnx" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load ImageNet labels" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open('../python_resnet50/imagenet_simple_labels.json') as json_data:\n", + " labels = json.load(json_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Load ONNX model using MIGraphX" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = migraphx.parse_onnx(\"dm_nfnet_f0.onnx\")\n", + "model.compile(migraphx.get_target(\"gpu\"))\n", + "\n", + "print(model.get_parameter_names())\n", + "print(model.get_parameter_shapes())\n", + "print(model.get_output_shapes())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Functions for image processing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def make_nxn(image, n):\n", + " height, width = image.shape[:2] \n", + " if height > width:\n", + " dif = height - width\n", + " bar = dif // 2 \n", + " square = image[(bar + (dif % 2)):(height - bar),:]\n", + " return cv2.resize(square, (n, n))\n", + " elif width > height:\n", + " dif = width - height\n", + " bar = dif // 2\n", + " square = image[:,(bar + (dif % 2)):(width - bar)]\n", + " return cv2.resize(square, (n, n))\n", + " else:\n", + " return cv2.resize(image, (n, n))\n", + " \n", + "def preprocess(img_data):\n", + " mean_vec = np.array([0.485, 0.456, 0.406])\n", + " stddev_vec = np.array([0.229, 0.224, 0.225])\n", + " norm_img_data = np.zeros(img_data.shape).astype('float32')\n", + " for i in range(img_data.shape[0]): \n", + " norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]\n", + " return norm_img_data\n", + "\n", + "def input_process(frame, dim):\n", + " # Crop and resize original image\n", + " cropped = make_nxn(frame, dim)\n", + " # Convert from HWC to CHW\n", + " chw = cropped.transpose(2,0,1)\n", + " # Apply normalization\n", + " pp = preprocess(chw)\n", + " # Add singleton dimension (CHW to NCHW)\n", + " data = np.expand_dims(pp.astype('float32'),0)\n", + " return data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download example image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Fetch example image: traffic light\n", + "!wget -nc http://farm5.static.flickr.com/4072/4462811418_8bc2bd42ca_z_d.jpg -O traffic_light.jpg\n", + "# Read the image\n", + "im = cv2.imread('traffic_light.jpg')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Process the read image to conform input requirements\n", + "data_input = input_process(im, 192)\n", + "\n", + "# Run the model\n", + "start = time.time()\n", + "results = model.run({'inputs':data_input}) # Your first inference would take longer than the following ones.\n", + "print(f\"Time inference took: {1000*(time.time() - start):.2f}ms\")\n", + "# Extract the index of the top prediction\n", + "res_npa = np.array(results[0])\n", + "print(f\"\\nResult: {labels[np.argmax(res_npa)]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run the model again, first one would take long\n", + "start = time.time()\n", + "results = model.run({'inputs':data_input}) # Your first inference would take longer than the following ones.\n", + "print(f\"Time inference took: {1000*(time.time() - start):.2f}ms\")\n", + "# Extract the index of the top prediction\n", + "res_npa = np.array(results[0])\n", + "print(f\"\\nResult: {labels[np.argmax(res_npa)]}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/vision/python_nfnet/ort_comparison.py b/examples/vision/python_nfnet/ort_comparison.py new file mode 100755 index 0000000000000000000000000000000000000000..5bdeea2ee932ecee98b4196bb76b07ef1ff423d4 --- /dev/null +++ b/examples/vision/python_nfnet/ort_comparison.py @@ -0,0 +1,48 @@ +import numpy +import onnxruntime as rt + +sess = rt.InferenceSession("dm_nfnet_f0.onnx") + +input_name = sess.get_inputs()[0].name +print("input name", input_name) +input_shape = sess.get_inputs()[0].shape +print("input shape", input_shape) +input_type = sess.get_inputs()[0].type +print("input type", input_type) + +output_name = sess.get_outputs()[0].name +print("output name", output_name) +output_shape = sess.get_outputs()[0].shape +print("output shape", output_shape) +output_type = sess.get_outputs()[0].type +print("output type", output_type) + +x = numpy.random.random((1, 3, 192, 192)) +x = x.astype(numpy.float32) + +import migraphx +model = migraphx.parse_onnx("dm_nfnet_f0.onnx") +model.compile(migraphx.get_target("gpu")) +print(model.get_parameter_names()) +print(model.get_parameter_shapes()) +print(model.get_output_shapes()) + +result_migraphx = model.run({"inputs": x}) +result_ort = sess.run([output_name], {input_name: x}) + +result_migraphx = result_migraphx[0].tolist() + +for i in range(10): + x = numpy.random.random((1, 3, 192, 192)) + x = x.astype(numpy.float32) + + result_migraphx = model.run({"inputs": x}) + result_ort = sess.run([output_name], {input_name: x}) + + try: + numpy.testing.assert_allclose(result_migraphx[0].tolist(), + result_ort[0][0], + rtol=1e-02) + print(f"Test #{i} completed.") + except AssertionError as e: + print(e) diff --git a/examples/vision/python_nfnet/requirements_nfnet.txt b/examples/vision/python_nfnet/requirements_nfnet.txt new file mode 100755 index 0000000000000000000000000000000000000000..8497704fae42eec83970ed0aa8d5718d69c83185 --- /dev/null +++ b/examples/vision/python_nfnet/requirements_nfnet.txt @@ -0,0 +1,3 @@ +opencv-python +onnxruntime +image \ No newline at end of file diff --git a/examples/vision/python_resnet50/README.md b/examples/vision/python_resnet50/README.md new file mode 100755 index 0000000000000000000000000000000000000000..6acfc44644187db31bc9f81123b078499a62c14d --- /dev/null +++ b/examples/vision/python_resnet50/README.md @@ -0,0 +1,16 @@ +# Performing Inference using MIGraphX Python Library + +## Description +This example uses a pre-trained Resnet50 V2 model to demonstrate how inference can be run in Python using the MIGraphX library. + +## How to Use this Example +If you do not already have Jupyter Notebooks installed, please refer to this [page](https://jupyter.org/install) for instructions. + +Once Jupyter Notebooks is installed, you can navigate to this directory and issue the command: + +``` +$ jupyter notebook +``` + +From the browser window that is launched, click on `resnet50_inference.ipynb` +You should now be able to run the notebook from your browser. diff --git a/examples/vision/python_resnet50/imagenet_simple_labels.json b/examples/vision/python_resnet50/imagenet_simple_labels.json new file mode 100755 index 0000000000000000000000000000000000000000..3e987688db0ca8b0ae3a608fe199fff446d6c1b7 --- /dev/null +++ b/examples/vision/python_resnet50/imagenet_simple_labels.json @@ -0,0 +1,1001 @@ +["tench", + "goldfish", + "great white shark", + "tiger shark", + "hammerhead shark", + "electric ray", + "stingray", + "cock", + "hen", + "ostrich", + "brambling", + "goldfinch", + "house finch", + "junco", + "indigo bunting", + "American robin", + "bulbul", + "jay", + "magpie", + "chickadee", + "American dipper", + "kite", + "bald eagle", + "vulture", + "great grey owl", + "fire salamander", + "smooth newt", + "newt", + "spotted salamander", + "axolotl", + "American bullfrog", + "tree frog", + "tailed frog", + "loggerhead sea turtle", + "leatherback sea turtle", + "mud turtle", + "terrapin", + "box turtle", + "banded gecko", + "green iguana", + "Carolina anole", + "desert grassland whiptail lizard", + "agama", + "frilled-necked lizard", + "alligator lizard", + "Gila monster", + "European green lizard", + "chameleon", + "Komodo dragon", + "Nile crocodile", + "American alligator", + "triceratops", + "worm snake", + "ring-necked snake", + "eastern hog-nosed snake", + "smooth green snake", + "kingsnake", + "garter snake", + "water snake", + "vine snake", + "night snake", + "boa constrictor", + "African rock python", + "Indian cobra", + "green mamba", + "sea snake", + "Saharan horned viper", + "eastern diamondback rattlesnake", + "sidewinder", + "trilobite", + "harvestman", + "scorpion", + "yellow garden spider", + "barn spider", + "European garden spider", + "southern black widow", + "tarantula", + "wolf spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse", + "prairie grouse", + "peacock", + "quail", + "partridge", + "grey parrot", + "macaw", + "sulphur-crested cockatoo", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "duck", + "red-breasted merganser", + "goose", + "black swan", + "tusker", + "echidna", + "platypus", + "wallaby", + "koala", + "wombat", + "jellyfish", + "sea anemone", + "brain coral", + "flatworm", + "nematode", + "conch", + "snail", + "slug", + "sea slug", + "chiton", + "chambered nautilus", + "Dungeness crab", + "rock crab", + "fiddler crab", + "red king crab", + "American lobster", + "spiny lobster", + "crayfish", + "hermit crab", + "isopod", + "white stork", + "black stork", + "spoonbill", + "flamingo", + "little blue heron", + "great egret", + "bittern", + "crane (bird)", + "limpkin", + "common gallinule", + "American coot", + "bustard", + "ruddy turnstone", + "dunlin", + "common redshank", + "dowitcher", + "oystercatcher", + "pelican", + "king penguin", + "albatross", + "grey whale", + "killer whale", + "dugong", + "sea lion", + "Chihuahua", + "Japanese Chin", + "Maltese", + "Pekingese", + "Shih Tzu", + "King Charles Spaniel", + "Papillon", + "toy terrier", + "Rhodesian Ridgeback", + "Afghan Hound", + "Basset Hound", + "Beagle", + "Bloodhound", + "Bluetick Coonhound", + "Black and Tan Coonhound", + "Treeing Walker Coonhound", + "English foxhound", + "Redbone Coonhound", + "borzoi", + "Irish Wolfhound", + "Italian Greyhound", + "Whippet", + "Ibizan Hound", + "Norwegian Elkhound", + "Otterhound", + "Saluki", + "Scottish Deerhound", + "Weimaraner", + "Staffordshire Bull Terrier", + "American Staffordshire Terrier", + "Bedlington Terrier", + "Border Terrier", + "Kerry Blue Terrier", + "Irish Terrier", + "Norfolk Terrier", + "Norwich Terrier", + "Yorkshire Terrier", + "Wire Fox Terrier", + "Lakeland Terrier", + "Sealyham Terrier", + "Airedale Terrier", + "Cairn Terrier", + "Australian Terrier", + "Dandie Dinmont Terrier", + "Boston Terrier", + "Miniature Schnauzer", + "Giant Schnauzer", + "Standard Schnauzer", + "Scottish Terrier", + "Tibetan Terrier", + "Australian Silky Terrier", + "Soft-coated Wheaten Terrier", + "West Highland White Terrier", + "Lhasa Apso", + "Flat-Coated Retriever", + "Curly-coated Retriever", + "Golden Retriever", + "Labrador Retriever", + "Chesapeake Bay Retriever", + "German Shorthaired Pointer", + "Vizsla", + "English Setter", + "Irish Setter", + "Gordon Setter", + "Brittany", + "Clumber Spaniel", + "English Springer Spaniel", + "Welsh Springer Spaniel", + "Cocker Spaniels", + "Sussex Spaniel", + "Irish Water Spaniel", + "Kuvasz", + "Schipperke", + "Groenendael", + "Malinois", + "Briard", + "Australian Kelpie", + "Komondor", + "Old English Sheepdog", + "Shetland Sheepdog", + "collie", + "Border Collie", + "Bouvier des Flandres", + "Rottweiler", + "German Shepherd Dog", + "Dobermann", + "Miniature Pinscher", + "Greater Swiss Mountain Dog", + "Bernese Mountain Dog", + "Appenzeller Sennenhund", + "Entlebucher Sennenhund", + "Boxer", + "Bullmastiff", + "Tibetan Mastiff", + "French Bulldog", + "Great Dane", + "St. Bernard", + "husky", + "Alaskan Malamute", + "Siberian Husky", + "Dalmatian", + "Affenpinscher", + "Basenji", + "pug", + "Leonberger", + "Newfoundland", + "Pyrenean Mountain Dog", + "Samoyed", + "Pomeranian", + "Chow Chow", + "Keeshond", + "Griffon Bruxellois", + "Pembroke Welsh Corgi", + "Cardigan Welsh Corgi", + "Toy Poodle", + "Miniature Poodle", + "Standard Poodle", + "Mexican hairless dog", + "grey wolf", + "Alaskan tundra wolf", + "red wolf", + "coyote", + "dingo", + "dhole", + "African wild dog", + "hyena", + "red fox", + "kit fox", + "Arctic fox", + "grey fox", + "tabby cat", + "tiger cat", + "Persian cat", + "Siamese cat", + "Egyptian Mau", + "cougar", + "lynx", + "leopard", + "snow leopard", + "jaguar", + "lion", + "tiger", + "cheetah", + "brown bear", + "American black bear", + "polar bear", + "sloth bear", + "mongoose", + "meerkat", + "tiger beetle", + "ladybug", + "ground beetle", + "longhorn beetle", + "leaf beetle", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant", + "grasshopper", + "cricket", + "stick insect", + "cockroach", + "mantis", + "cicada", + "leafhopper", + "lacewing", + "dragonfly", + "damselfly", + "red admiral", + "ringlet", + "monarch butterfly", + "small white", + "sulphur butterfly", + "gossamer-winged butterfly", + "starfish", + "sea urchin", + "sea cucumber", + "cottontail rabbit", + "hare", + "Angora rabbit", + "hamster", + "porcupine", + "fox squirrel", + "marmot", + "beaver", + "guinea pig", + "common sorrel", + "zebra", + "pig", + "wild boar", + "warthog", + "hippopotamus", + "ox", + "water buffalo", + "bison", + "ram", + "bighorn sheep", + "Alpine ibex", + "hartebeest", + "impala", + "gazelle", + "dromedary", + "llama", + "weasel", + "mink", + "European polecat", + "black-footed ferret", + "otter", + "skunk", + "badger", + "armadillo", + "three-toed sloth", + "orangutan", + "gorilla", + "chimpanzee", + "gibbon", + "siamang", + "guenon", + "patas monkey", + "baboon", + "macaque", + "langur", + "black-and-white colobus", + "proboscis monkey", + "marmoset", + "white-headed capuchin", + "howler monkey", + "titi", + "Geoffroy's spider monkey", + "common squirrel monkey", + "ring-tailed lemur", + "indri", + "Asian elephant", + "African bush elephant", + "red panda", + "giant panda", + "snoek", + "eel", + "coho salmon", + "rock beauty", + "clownfish", + "sturgeon", + "garfish", + "lionfish", + "pufferfish", + "abacus", + "abaya", + "academic gown", + "accordion", + "acoustic guitar", + "aircraft carrier", + "airliner", + "airship", + "altar", + "ambulance", + "amphibious vehicle", + "analog clock", + "apiary", + "apron", + "waste container", + "assault rifle", + "backpack", + "bakery", + "balance beam", + "balloon", + "ballpoint pen", + "Band-Aid", + "banjo", + "baluster", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel", + "wheelbarrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "swimming cap", + "bath towel", + "bathtub", + "station wagon", + "lighthouse", + "beaker", + "military cap", + "beer bottle", + "beer glass", + "bell-cot", + "bib", + "tandem bicycle", + "bikini", + "ring binder", + "binoculars", + "birdhouse", + "boathouse", + "bobsleigh", + "bolo tie", + "poke bonnet", + "bookcase", + "bookstore", + "bottle cap", + "bow", + "bow tie", + "brass", + "bra", + "breakwater", + "breastplate", + "broom", + "bucket", + "buckle", + "bulletproof vest", + "high-speed train", + "butcher shop", + "taxicab", + "cauldron", + "candle", + "cannon", + "canoe", + "can opener", + "cardigan", + "car mirror", + "carousel", + "tool kit", + "carton", + "car wheel", + "automated teller machine", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello", + "mobile phone", + "chain", + "chain-link fence", + "chain mail", + "chainsaw", + "chest", + "chiffonier", + "chime", + "china cabinet", + "Christmas stocking", + "church", + "movie theater", + "cleaver", + "cliff dwelling", + "cloak", + "clogs", + "cocktail shaker", + "coffee mug", + "coffeemaker", + "coil", + "combination lock", + "computer keyboard", + "confectionery store", + "container ship", + "convertible", + "corkscrew", + "cornet", + "cowboy boot", + "cowboy hat", + "cradle", + "crane (machine)", + "crash helmet", + "crate", + "infant bed", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam", + "desk", + "desktop computer", + "rotary dial telephone", + "diaper", + "digital clock", + "digital watch", + "dining table", + "dishcloth", + "dishwasher", + "disc brake", + "dock", + "dog sled", + "dome", + "doormat", + "drilling rig", + "drum", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso machine", + "face powder", + "feather boa", + "filing cabinet", + "fireboat", + "fire engine", + "fire screen sheet", + "flagpole", + "flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster bed", + "freight car", + "French horn", + "frying pan", + "fur coat", + "garbage truck", + "gas mask", + "gas pump", + "goblet", + "go-kart", + "golf ball", + "golf cart", + "gondola", + "gong", + "gown", + "grand piano", + "greenhouse", + "grille", + "grocery store", + "guillotine", + "barrette", + "hair spray", + "half-track", + "hammer", + "hamper", + "hair dryer", + "hand-held computer", + "handkerchief", + "hard disk drive", + "harmonica", + "harp", + "harvester", + "hatchet", + "holster", + "home theater", + "honeycomb", + "hook", + "hoop skirt", + "horizontal bar", + "horse-drawn vehicle", + "hourglass", + "iPod", + "clothes iron", + "jack-o'-lantern", + "jeans", + "jeep", + "T-shirt", + "jigsaw puzzle", + "pulled rickshaw", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat", + "ladle", + "lampshade", + "laptop computer", + "lawn mower", + "lens cap", + "paper knife", + "library", + "lifeboat", + "lighter", + "limousine", + "ocean liner", + "lipstick", + "slip-on shoe", + "lotion", + "speaker", + "loupe", + "sawmill", + "magnetic compass", + "mail bag", + "mailbox", + "tights", + "tank suit", + "manhole cover", + "maraca", + "marimba", + "mask", + "match", + "maypole", + "maze", + "measuring cup", + "medicine chest", + "megalith", + "microphone", + "microwave oven", + "military uniform", + "milk can", + "minibus", + "miniskirt", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home", + "Model T", + "modem", + "monastery", + "monitor", + "moped", + "mortar", + "square academic cap", + "mosque", + "mosquito net", + "scooter", + "mountain bike", + "tent", + "computer mouse", + "mousetrap", + "moving van", + "muzzle", + "nail", + "neck brace", + "necklace", + "nipple", + "notebook computer", + "obelisk", + "oboe", + "ocarina", + "odometer", + "oil filter", + "organ", + "oscilloscope", + "overskirt", + "bullock cart", + "oxygen mask", + "packet", + "paddle", + "paddle wheel", + "padlock", + "paintbrush", + "pajamas", + "palace", + "pan flute", + "paper towel", + "parachute", + "parallel bars", + "park bench", + "parking meter", + "passenger car", + "patio", + "payphone", + "pedestal", + "pencil case", + "pencil sharpener", + "perfume", + "Petri dish", + "photocopier", + "plectrum", + "Pickelhaube", + "picket fence", + "pickup truck", + "pier", + "piggy bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate ship", + "pitcher", + "hand plane", + "planetarium", + "plastic bag", + "plate rack", + "plow", + "plunger", + "Polaroid camera", + "pole", + "police van", + "poncho", + "billiard table", + "soda bottle", + "pot", + "potter's wheel", + "power drill", + "prayer rug", + "printer", + "prison", + "projectile", + "projector", + "hockey puck", + "punching bag", + "purse", + "quill", + "quilt", + "race car", + "racket", + "radiator", + "radio", + "radio telescope", + "rain barrel", + "recreational vehicle", + "reel", + "reflex camera", + "refrigerator", + "remote control", + "restaurant", + "revolver", + "rifle", + "rocking chair", + "rotisserie", + "eraser", + "rugby ball", + "ruler", + "running shoe", + "safe", + "safety pin", + "salt shaker", + "sandal", + "sarong", + "saxophone", + "scabbard", + "weighing scale", + "school bus", + "schooner", + "scoreboard", + "CRT screen", + "screw", + "screwdriver", + "seat belt", + "sewing machine", + "shield", + "shoe store", + "shoji", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "ski mask", + "sleeping bag", + "slide rule", + "sliding door", + "slot machine", + "snorkel", + "snowmobile", + "snowplow", + "soap dispenser", + "soccer ball", + "sock", + "solar thermal collector", + "sombrero", + "soup bowl", + "space bar", + "space heater", + "space shuttle", + "spatula", + "motorboat", + "spider web", + "spindle", + "sports car", + "spotlight", + "stage", + "steam locomotive", + "through arch bridge", + "steel drum", + "stethoscope", + "scarf", + "stone wall", + "stopwatch", + "stove", + "strainer", + "tram", + "stretcher", + "couch", + "stupa", + "submarine", + "suit", + "sundial", + "sunglass", + "sunglasses", + "sunscreen", + "suspension bridge", + "mop", + "sweatshirt", + "swimsuit", + "swing", + "switch", + "syringe", + "table lamp", + "tank", + "tape player", + "teapot", + "teddy bear", + "television", + "tennis ball", + "thatched roof", + "front curtain", + "thimble", + "threshing machine", + "throne", + "tile roof", + "toaster", + "tobacco shop", + "toilet seat", + "torch", + "totem pole", + "tow truck", + "toy store", + "tractor", + "semi-trailer truck", + "tray", + "trench coat", + "tricycle", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus", + "trombone", + "tub", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle", + "upright piano", + "vacuum cleaner", + "vase", + "vault", + "velvet", + "vending machine", + "vestment", + "viaduct", + "violin", + "volleyball", + "waffle iron", + "wall clock", + "wallet", + "wardrobe", + "military aircraft", + "sink", + "washing machine", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "wing", + "wok", + "wooden spoon", + "wool", + "split-rail fence", + "shipwreck", + "yawl", + "yurt", + "website", + "comic book", + "crossword", + "traffic sign", + "traffic light", + "dust jacket", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot", + "trifle", + "ice cream", + "ice pop", + "baguette", + "bagel", + "pretzel", + "cheeseburger", + "hot dog", + "mashed potato", + "cabbage", + "broccoli", + "cauliflower", + "zucchini", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber", + "artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple", + "banana", + "jackfruit", + "custard apple", + "pomegranate", + "hay", + "carbonara", + "chocolate syrup", + "dough", + "meatloaf", + "pizza", + "pot pie", + "burrito", + "red wine", + "espresso", + "cup", + "eggnog", + "alp", + "bubble", + "cliff", + "coral reef", + "geyser", + "lakeshore", + "promontory", + "shoal", + "seashore", + "valley", + "volcano", + "baseball player", + "bridegroom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper", + "corn", + "acorn", + "rose hip", + "horse chestnut seed", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn mushroom", + "earth star", + "hen-of-the-woods", + "bolete", + "ear", + "toilet paper"] + \ No newline at end of file diff --git a/examples/vision/python_resnet50/resnet50_inference.ipynb b/examples/vision/python_resnet50/resnet50_inference.ipynb new file mode 100755 index 0000000000000000000000000000000000000000..03d36c7f4a683c186d062d4e39d8c127010ec1a3 --- /dev/null +++ b/examples/vision/python_resnet50/resnet50_inference.ipynb @@ -0,0 +1,347 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Resnet50 Inference\n", + "\n", + "## Description\n", + "This example performs inference on a short wildlife video using a Resnet50 V2 model that has been pre-trained on imagenet data. The labels used for each class are simplified for readability, but still reflect the correct index-label pairs in official use. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install --upgrade pip\n", + "!pip install opencv-python==4.1.2.30\n", + "!pip install matplotlib\n", + "import numpy as np\n", + "from matplotlib import pyplot as plt \n", + "import cv2\n", + "import json\n", + "import time\n", + "import os.path\n", + "from os import path \n", + "import sys" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Importing MIGraphX Library\n", + "Sometimes the PYTHONPATH variable is not set during installation of MIGraphX. \n", + "If your receive a \"Module Not Found\" error when trying to `import migraphx` in your own application, try running:\n", + "```\n", + "$ export PYTHONPATH=/opt/rocm/lib:$PYTHONPATH\n", + "```\n", + "For this example, the library will be added to the kernel's sys.path.\n", + "\n", + "If you receive \"cannot open shared object file: No such file or directory\" , please make sure `/opt/rocm/lib` is included in $LD_LIBRARY_PATH\n", + "```\n", + " cannot open shared object file: No such file or directory\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "migx_lib_path = \"/opt/rocm/lib\"\n", + "if migx_lib_path not in sys.path:\n", + " sys.path.append(migx_lib_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import migraphx" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If this is your first time running this example, you will need to dowload the model and sample video.\n", + "\n", + "The following cell will ask you for your sudo password and then install/update the package `youtube-dl` if necessary. It will then use that tool to download the sample video." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not path.exists(\"./sample_vid.mp4\"):\n", + " import getpass\n", + " import os\n", + " password = getpass.getpass()\n", + " command = \"sudo -H -S pip install --upgrade youtube-dl\"\n", + " os.system('echo %s | %s' % (password, command))\n", + " !youtube-dl https://youtu.be/TkqYmvH_XVs \n", + " !mv sample_vid-TkqYmvH_XVs.mp4 sample_vid.mp4" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following will download the resnet50 v2 model from ONNX's model zoo." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not path.exists(\"./resnet50.onnx\"):\n", + " !wget https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-v2-7.onnx", + " !mv resnet50-v2-7.onnx resnet50.onnx" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load the simplified imagenet labels." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open('imagenet_simple_labels.json') as json_data:\n", + " labels = json.load(json_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model and Video Capture Setup\n", + "\n", + "The ONNX graph that is loaded by `parse_onnx()` is a generalized representation that must be compiled for a specific target before it can be executed. For this example, using the target \"gpu\" is recommended. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = migraphx.parse_onnx(\"resnet50.onnx\")\n", + "model.compile(migraphx.get_target(\"gpu\"))\n", + "model.print() # Printed in terminal \n", + "cap = cv2.VideoCapture(\"sample_vid.mp4\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Pre-Processing Video Frames\n", + "Resnet50 requires some preprocessing of video frames before it can run inference. \n", + "\n", + "The model will expect an NCHW tensor with the shape {1, 3, 224, 224} and the values loaded into a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The first step is to square up the dimensions of the original image by cropping the longer of the two to the size of the shorter dimension. This will help to avoid any stretching or compressing of the input image.\n", + "Then the image can be scaled up or down to the desired resolution of 224x224." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def make_nxn(image, n):\n", + " width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n", + " height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n", + " if height > width:\n", + " dif = height - width\n", + " bar = dif // 2 \n", + " square = image[(bar + (dif % 2)):(height - bar),:]\n", + " return cv2.resize(square, (n, n))\n", + " elif width > height:\n", + " dif = width - height\n", + " bar = dif // 2\n", + " square = image[:,(bar + (dif % 2)):(width - bar)]\n", + " return cv2.resize(square, (n, n))\n", + " else:\n", + " return cv2.resize(image, (n, n))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that the image data has the correct dimensions, the values can be normalized as described above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess(img_data):\n", + " mean_vec = np.array([0.485, 0.456, 0.406])\n", + " stddev_vec = np.array([0.229, 0.224, 0.225])\n", + " norm_img_data = np.zeros(img_data.shape).astype('float32')\n", + " for i in range(img_data.shape[0]): \n", + " norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]\n", + " return norm_img_data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run Inference on Single Frame\n", + "\n", + "The above pre-processing functions can now be applied to individual video frames and the data can be passed to the model for evaluation. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def predict_class(frame) -> int:\n", + " # Crop and resize original image\n", + " cropped = make_nxn(frame, 224)\n", + " # Convert from HWC to CHW\n", + " chw = cropped.transpose(2,0,1)\n", + " # Apply normalization\n", + " pp = preprocess(chw)\n", + " # Add singleton dimension (CHW to NCHW)\n", + " data = np.expand_dims(pp.astype('float32'),0)\n", + " # Run the model\n", + " results = model.run({'data':data})\n", + " # Extract the index of the top prediction\n", + " res_npa = np.array(results[0])\n", + " return np.argmax(res_npa)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Inference Loop over Full Video\n", + "\n", + "Now everything is in place so that we can run inference on each frame of the input video. The video will be played and the predicted label will be displayed on top of each frame. If you are working on headless server, please execute the following cell." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "while (cap.isOpened()):\n", + " start = time.perf_counter()\n", + " ret, frame = cap.read()\n", + " if not ret: break\n", + " \n", + " top_prediction = predict_class(frame)\n", + " \n", + " end = time.perf_counter()\n", + " fps = 1 / (end - start)\n", + " fps_str = f\"Frames per second: {fps:0.1f}\"\n", + " label_str = \"Top prediction: {}\".format(labels[top_prediction])\n", + "\n", + " labeled = cv2.putText(frame, \n", + " label_str, \n", + " (50, 50), \n", + " cv2.FONT_HERSHEY_SIMPLEX, \n", + " 2, \n", + " (255, 255, 255), \n", + " 3, \n", + " cv2.LINE_AA)\n", + " labeled = cv2.putText(labeled, \n", + " fps_str, \n", + " (50, 1060), \n", + " cv2.FONT_HERSHEY_SIMPLEX, \n", + " 2, \n", + " (255, 255, 255), \n", + " 3, \n", + " cv2.LINE_AA)\n", + " cv2.imshow(\"Resnet50 Inference\", labeled)\n", + "\n", + " if cv2.waitKey(1) & 0xFF == ord('q'): # 'q' to quit\n", + " break\n", + "\n", + "cap.release()\n", + "cv2.destroyAllWindows()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If script is run on a headless server where .imshow() experiences problems, the following cell for histogram can be run to verify functionalty:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_labels = []\n", + "while (cap.isOpened()):\n", + " start = time.perf_counter()\n", + " ret, frame = cap.read()\n", + " if not ret: break\n", + " \n", + " top_prediction = predict_class(frame)\n", + " output_labels.append(labels[top_prediction])\n", + "\n", + "cap.release()\n", + "output_labels = np.array(output_labels)\n", + "plt.hist(output_labels) \n", + "plt.xticks(rotation = 90)\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/vision/python_super_resolution/README.md b/examples/vision/python_super_resolution/README.md new file mode 100755 index 0000000000000000000000000000000000000000..9a65682f9897cad4a7a769772320497ece0f57b9 --- /dev/null +++ b/examples/vision/python_super_resolution/README.md @@ -0,0 +1,113 @@ +# Super Resolution with AMD MIGraphX + +This example is based on [ONNX run_super_resolution_model notebook](https://github.com/onnx/models/blob/master/vision/super_resolution/sub_pixel_cnn_2016/dependencies/Run_Super_Resolution_Model.ipynb) and modified for MIGraphX. + +## Description +Given an input image, this application resizes the image to 224x224 and then scales it to 672x672, thus it is useful for upscaling low-resolution images. + +### Model Utilized +> "Super Resolution uses efficient [Sub-pixel convolutional layer](https://arxiv.org/abs/1609.05158) described for increasing spatial resolution within network tasks. By increasing pixel count, images are then clarified, sharpened, and upscaled without losing the input image’s content and characteristics." [[Reference]](https://github.com/onnx/models/blob/master/vision/super_resolution/sub_pixel_cnn_2016/README.md) + +Model in PyTorch definitions: +``` +self.relu = nn.ReLU(inplace=inplace) +self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) +self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) +self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) +self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) +self.pixel_shuffle = nn.PixelShuffle(upscale_factor) +``` +## How-to +If you have jupyter installed, you can simply use the notebook given. Otherwise please follow the step-by-step guide. +### Jupyter Notebook +Run Jupyter notebook server on a ROCm and MIGraphX installed system, and run `Run_Super_Resolution_Model.ipynb` + +### Step by Step +1) Upgrade pip3. You may skip this stage if you already have latest pip3. This step is needed for OpenCV installation. +``` +pip3 install --upgrade pip +``` +2) Install requirements. +``` +pip3 install -r requirements.txt +``` +3) Import required libraries. +``` +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image, ImageDraw, ImageFont +from resizeimage import resizeimage +``` +4) Download ONNX model. +``` +wget -nc https://github.com/onnx/models/raw/master/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx +``` +5) Preprocess the sample image `cat.jpg`. +``` +orig_img = Image.open("./cat.jpg") +print(orig_img.size) +img = resizeimage.resize_cover(orig_img, [224,224], validate=False) +img_ycbcr = img.convert('YCbCr') +img_y_0, img_cb, img_cr = img_ycbcr.split() +img_ndarray = np.asarray(img_y_0) + +img_4 = np.expand_dims(np.expand_dims(img_ndarray, axis=0), axis=0) +img_5 = img_4.astype(np.float32) / 255.0 +``` +6) Import MIGraphX, parse & compile the ONNX model with MIGraphX. Print the model. +``` +model = migraphx.parse_onnx("super-resolution-10.onnx") +model.compile(migraphx.get_target("gpu")) +model.print() +``` +7) You can check the model inputs and outputs with the following functions. +``` +print(model.get_parameter_names()) +print(model.get_parameter_shapes()) +print(model.get_output_shapes()) +``` +8) Run the image throgh model and get the output data. +``` +result = model.run({ + "input": img_5 + }) + +data = np.array(result[0])[0] +``` +9) Post processing image. If matplotlib is installed correctly, it should show up the image. The output image will be stored with filename `output.jpg`. +``` +img_out_y = Image.fromarray(np.uint8((data* 255.0).clip(0, 255)[0]), mode='L') +# get the output image follow post-processing step from PyTorch implementation +final_img = Image.merge( + "YCbCr", [ + img_out_y, + img_cb.resize(img_out_y.size, Image.BICUBIC), + img_cr.resize(img_out_y.size, Image.BICUBIC), + ]).convert("RGB") +final_img.save("output.jpg") +print(final_img.size) +``` +10) Measure the improvement in terms of PSNR and show the both input and super-resolution image: +``` +import cv2 + +imgIN = cv2.imread('cat.jpg') +imgOUT = cv2.imread('output.jpg') +imgIN = cv2.cvtColor(imgIN, cv2.COLOR_BGR2RGB) #BGR to RGB +imgOUT = cv2.cvtColor(imgOUT, cv2.COLOR_BGR2RGB) + +imgIN_resized = cv2.resize(imgIN, (672,672)) #Resizing input to 672 + +psnr = cv2.PSNR(imgIN_resized, imgOUT) #dimensions need to be same +print("PSNR Value = %.3f db"%psnr) + +fig = plt.figure(figsize=(16, 16)) +sp1 = fig.add_subplot(1, 2, 1) +sp1.title.set_text('Output Super Resolution Image (%sx%s)'%(imgOUT.shape[0], imgOUT.shape[1])) +plt.imshow(imgOUT) + +sp2 = fig.add_subplot(1, 2, 2) +sp2.title.set_text('Input Image (%sx%s)'%(imgIN.shape[0], imgIN.shape[1])) +plt.imshow(imgIN) +plt.show() +``` \ No newline at end of file diff --git a/examples/vision/python_super_resolution/Run_Super_Resolution_Model.ipynb b/examples/vision/python_super_resolution/Run_Super_Resolution_Model.ipynb new file mode 100755 index 0000000000000000000000000000000000000000..11a5e4a835d568649eba7b9c35ab9d322ac18f52 --- /dev/null +++ b/examples/vision/python_super_resolution/Run_Super_Resolution_Model.ipynb @@ -0,0 +1,218 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Super Resolution Inference with AMD MIGraphX\n", + "This notebook is inspired from: https://github.com/onnx/models/blob/master/vision/super_resolution/sub_pixel_cnn_2016/dependencies/Run_Super_Resolution_Model.ipynb" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip3 install --upgrade pip #needed for opencv-python installation\n", + "!pip3 install -r requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from PIL import Image, ImageDraw, ImageFont\n", + "from resizeimage import resizeimage\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download ONNX Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!wget -nc https://github.com/onnx/models/raw/master/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import MIGraphX Python Module" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import migraphx" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Preprocessing Image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "orig_img = Image.open(\"./cat.jpg\")\n", + "print(orig_img.size)\n", + "img = resizeimage.resize_cover(orig_img, [224,224], validate=False)\n", + "img_ycbcr = img.convert('YCbCr')\n", + "img_y_0, img_cb, img_cr = img_ycbcr.split()\n", + "img_ndarray = np.asarray(img_y_0)\n", + "\n", + "img_4 = np.expand_dims(np.expand_dims(img_ndarray, axis=0), axis=0)\n", + "img_5 = img_4.astype(np.float32) / 255.0\n", + "img_5" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = migraphx.parse_onnx(\"super-resolution-10.onnx\")\n", + "model.compile(migraphx.get_target(\"gpu\"))\n", + "#model.print()\n", + "\n", + "print(model.get_parameter_names())\n", + "print(model.get_parameter_shapes())\n", + "print(model.get_output_shapes())\n", + "\n", + "\n", + "result = model.run({\n", + " \"input\": img_5\n", + " })\n", + "\n", + "data = np.array(result[0])[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Postprocessing Image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "img_out_y = Image.fromarray(np.uint8((data* 255.0).clip(0, 255)[0]), mode='L')\n", + "# get the output image follow post-processing step from PyTorch implementation\n", + "final_img = Image.merge(\n", + " \"YCbCr\", [\n", + " img_out_y,\n", + " img_cb.resize(img_out_y.size, Image.BICUBIC),\n", + " img_cr.resize(img_out_y.size, Image.BICUBIC),\n", + " ]).convert(\"RGB\")\n", + "final_img.save(\"output.jpg\")\n", + "print(final_img.size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## PSNR Comparison Output vs Input" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import cv2\n", + "\n", + "imgIN = cv2.imread('cat.jpg')\n", + "imgOUT = cv2.imread('output.jpg')\n", + "imgIN = cv2.cvtColor(imgIN, cv2.COLOR_BGR2RGB) #BGR to RGB\n", + "imgOUT = cv2.cvtColor(imgOUT, cv2.COLOR_BGR2RGB)\n", + "\n", + "imgIN_resized = cv2.resize(imgIN, (672,672)) #Resizing input to 672\n", + "\n", + "psnr = cv2.PSNR(imgIN_resized, imgOUT) #dimensions need to be same\n", + "print(\"PSNR Value = %.3f db\"%psnr)\n", + "\n", + "fig = plt.figure(figsize=(16, 16))\n", + "sp1 = fig.add_subplot(1, 2, 1)\n", + "sp1.title.set_text('Output Super Resolution Image (%sx%s)'%(imgOUT.shape[0], imgOUT.shape[1]))\n", + "plt.imshow(imgOUT)\n", + "\n", + "sp2 = fig.add_subplot(1, 2, 2)\n", + "sp2.title.set_text('Input Image (%sx%s)'%(imgIN.shape[0], imgIN.shape[1]))\n", + "plt.imshow(imgIN)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/vision/python_super_resolution/cat.jpg b/examples/vision/python_super_resolution/cat.jpg new file mode 100755 index 0000000000000000000000000000000000000000..18562b191b2efeb22e964bc4fb42ef397b4efffc Binary files /dev/null and b/examples/vision/python_super_resolution/cat.jpg differ diff --git a/examples/vision/python_super_resolution/requirements.txt b/examples/vision/python_super_resolution/requirements.txt new file mode 100755 index 0000000000000000000000000000000000000000..0a31e96fbf6b157440ad45f7a6b88664ac47d9b0 --- /dev/null +++ b/examples/vision/python_super_resolution/requirements.txt @@ -0,0 +1,3 @@ +matplotlib +python-resize-image +opencv-python \ No newline at end of file diff --git a/examples/vision/python_unet/README.md b/examples/vision/python_unet/README.md new file mode 100755 index 0000000000000000000000000000000000000000..17c500084e9233f665d7a83cf66d4fc03b3c23b9 --- /dev/null +++ b/examples/vision/python_unet/README.md @@ -0,0 +1,9 @@ +# U-Net Image Segmentation Inference with AMD MIGraphX + +This examples provides a simple example for utilizing U-Net ONNX model for image segmentation, using AMD MIGraphX graph optimization engine for fast inference. + +## How-to +Please utilize the notebook given `unet_inference.ipynb`. + +## Model Details +ONNX model utilized in this example can be found [here](https://www.dropbox.com/s/3ntkhyk30x05uuv/unet_13_256.onnx). \ No newline at end of file diff --git a/examples/vision/python_unet/car1.jpeg b/examples/vision/python_unet/car1.jpeg new file mode 100755 index 0000000000000000000000000000000000000000..5d81af863f77e2edb745e25a3691aa720aa1889f Binary files /dev/null and b/examples/vision/python_unet/car1.jpeg differ diff --git a/examples/vision/python_unet/requirements.txt b/examples/vision/python_unet/requirements.txt new file mode 100755 index 0000000000000000000000000000000000000000..806f22116faa6610345f0899927f952ae5dfedcd --- /dev/null +++ b/examples/vision/python_unet/requirements.txt @@ -0,0 +1,2 @@ +numpy +matplotlib \ No newline at end of file diff --git a/examples/vision/python_unet/unet_inference.ipynb b/examples/vision/python_unet/unet_inference.ipynb new file mode 100755 index 0000000000000000000000000000000000000000..88bc850f7f47a6bc622e464ed91e402fb9e4fb70 --- /dev/null +++ b/examples/vision/python_unet/unet_inference.ipynb @@ -0,0 +1,213 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cd7a3990", + "metadata": {}, + "source": [ + "## Import MIGraphX Python Library" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3930d7b8", + "metadata": {}, + "outputs": [], + "source": [ + "import migraphx\n", + "from PIL import Image\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "id": "b350c333", + "metadata": {}, + "source": [ + "## Fetch U-NET ONNX Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02a7b7de", + "metadata": {}, + "outputs": [], + "source": [ + "!wget -nc https://www.dropbox.com/s/3ntkhyk30x05uuv/unet_13_256.onnx" + ] + }, + { + "cell_type": "markdown", + "id": "a6cfe6e9", + "metadata": {}, + "source": [ + "## Load ONNX Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e05a13dc", + "metadata": {}, + "outputs": [], + "source": [ + "model = migraphx.parse_onnx(\"unet_13_256.onnx\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52c67023", + "metadata": {}, + "outputs": [], + "source": [ + "model.compile(migraphx.get_target(\"gpu\"))" + ] + }, + { + "cell_type": "markdown", + "id": "80edb6f1", + "metadata": {}, + "source": [ + "## Print model parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd5c3269", + "metadata": {}, + "outputs": [], + "source": [ + "print(model.get_parameter_names())\n", + "print(model.get_parameter_shapes())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47f956c7", + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess(pil_img, newW, newH):\n", + " w, h = pil_img.size\n", + " assert newW > 0 and newH > 0, 'Scale is too small'\n", + " pil_img = pil_img.resize((newW, newH))\n", + "\n", + " img_nd = np.array(pil_img)\n", + "\n", + " if len(img_nd.shape) == 2:\n", + " img_nd = np.expand_dims(img_nd, axis=2)\n", + "\n", + " # HWC to CHW\n", + " img_print = pil_img\n", + " img_trans = img_nd.transpose((2, 0, 1))\n", + " if img_trans.max() > 1:\n", + " img_trans = img_trans / 255\n", + " \n", + " img_trans = np.expand_dims(img_trans, 0)\n", + "\n", + " return img_trans, img_print\n", + "\n", + "def plot_img_and_mask(img, mask):\n", + " classes = mask.shape[0] if len(mask.shape) > 3 else 1\n", + " print(classes)\n", + " fig, ax = plt.subplots(1, classes + 1)\n", + " ax[0].set_title('Input image')\n", + " ax[0].imshow(img)\n", + " if classes > 1:\n", + " for i in range(classes):\n", + " ax[i+1].set_title(f'Output mask (class {i+1})')\n", + " ax[i+1].imshow(mask[:, :, i])\n", + " else:\n", + " ax[1].set_title(f'Output mask')\n", + " ax[1].imshow(mask[0,0])\n", + " plt.xticks([]), plt.yticks([])\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "389ddc4d", + "metadata": {}, + "outputs": [], + "source": [ + "img = Image.open(\"./car1.jpeg\")\n", + "img, imPrint = preprocess(img, 256, 256)\n", + "input_im = np.zeros((1,3,256,256),dtype='float32') \n", + "np.lib.stride_tricks.as_strided(input_im, shape=img.shape, strides=input_im.strides)[:] = img #getting correct stride\n", + "print(input_im.strides)\n", + "print(input_im.shape)\n", + "imPrint.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9de6f2a7", + "metadata": {}, + "outputs": [], + "source": [ + "mask = model.run({'inputs':input_im}) # Your first inference would take longer than the following ones.\n", + "output_mask = np.array(mask[0])\n", + "print(output_mask.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "acbd68e3", + "metadata": {}, + "outputs": [], + "source": [ + "def sigmoid(x):\n", + " return 1 / (1 + np.exp(-x))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58e3062c", + "metadata": {}, + "outputs": [], + "source": [ + "probs = sigmoid(output_mask)\n", + "full_mask = probs > 0.996\n", + "plot_img_and_mask(imPrint, full_mask)" + ] + }, + { + "cell_type": "markdown", + "id": "6126df0b", + "metadata": {}, + "source": [ + "NOTE: The model weights utilized here are trained by using car images with plain backgrounds. The imperfect result on a \"real-world\" image as shown above is expected. To get a better result fine-tuning the model on a dataset of real-world examples is recommended. " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/vision/python_yolov4/README.md b/examples/vision/python_yolov4/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6500d6c5a906025c5466616a18b8cd782b6b0e28 --- /dev/null +++ b/examples/vision/python_yolov4/README.md @@ -0,0 +1,8 @@ +# YoloV4 Object Detection +The notebook [yolov4_inference.ipynb](./yolov4_inference.ipynb) is intended to be an example of how to use MIGraphX to perform object detection. The model used within is a pre-trained yolov4 from the ONNX model zoo. + +## Run the Notebook +To run the example notebook, simply issue the following command from this directory: +``` +$ jupyter notebook yolov4_inference.ipynb +``` diff --git a/examples/vision/python_yolov4/image_processing.py b/examples/vision/python_yolov4/image_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..0d9d0fc071ab858362508d459fffcda0a807ec04 --- /dev/null +++ b/examples/vision/python_yolov4/image_processing.py @@ -0,0 +1,234 @@ +# All pre- and post-processing methods used below are borrowed from the ONNX MOdel Zoo +# https://github.com/onnx/models/tree/master/vision/object_detection_segmentation/yolov4 + +import numpy as np +import cv2 +from scipy import special +import colorsys +import random + + +# this function is from tensorflow-yolov4-tflite/core/utils.py +def image_preprocess(image, target_size, gt_boxes=None): + + ih, iw = target_size + h, w, _ = image.shape + + scale = min(iw / w, ih / h) + nw, nh = int(scale * w), int(scale * h) + image_resized = cv2.resize(image, (nw, nh)) + + image_padded = np.full(shape=[ih, iw, 3], fill_value=128.0) + dw, dh = (iw - nw) // 2, (ih - nh) // 2 + image_padded[dh:nh + dh, dw:nw + dw, :] = image_resized + image_padded = image_padded / 255. + + if gt_boxes is None: + return image_padded + + else: + gt_boxes[:, [0, 2]] = gt_boxes[:, [0, 2]] * scale + dw + gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] * scale + dh + return image_padded, gt_boxes + + +def get_anchors(anchors_path, tiny=False): + '''loads the anchors from a file''' + with open(anchors_path) as f: + anchors = f.readline() + anchors = np.array(anchors.split(','), dtype=np.float32) + return anchors.reshape(3, 3, 2) + + +def postprocess_bbbox(pred_bbox, ANCHORS, STRIDES, XYSCALE=[1, 1, 1]): + '''define anchor boxes''' + for i, pred in enumerate(pred_bbox): + conv_shape = pred.shape + output_size = conv_shape[1] + conv_raw_dxdy = pred[:, :, :, :, 0:2] + conv_raw_dwdh = pred[:, :, :, :, 2:4] + xy_grid = np.meshgrid(np.arange(output_size), np.arange(output_size)) + xy_grid = np.expand_dims(np.stack(xy_grid, axis=-1), axis=2) + + xy_grid = np.tile(np.expand_dims(xy_grid, axis=0), [1, 1, 1, 3, 1]) + xy_grid = xy_grid.astype(np.float) + + pred_xy = ((special.expit(conv_raw_dxdy) * XYSCALE[i]) - 0.5 * + (XYSCALE[i] - 1) + xy_grid) * STRIDES[i] + pred_wh = (np.exp(conv_raw_dwdh) * ANCHORS[i]) + pred[:, :, :, :, 0:4] = np.concatenate([pred_xy, pred_wh], axis=-1) + + pred_bbox = [np.reshape(x, (-1, np.shape(x)[-1])) for x in pred_bbox] + pred_bbox = np.concatenate(pred_bbox, axis=0) + return pred_bbox + + +def postprocess_boxes(pred_bbox, org_img_shape, input_size, score_threshold): + '''remove boundary boxs with a low detection probability''' + valid_scale = [0, np.inf] + pred_bbox = np.array(pred_bbox) + + pred_xywh = pred_bbox[:, 0:4] + pred_conf = pred_bbox[:, 4] + pred_prob = pred_bbox[:, 5:] + + # (1) (x, y, w, h) --> (xmin, ymin, xmax, ymax) + pred_coor = np.concatenate([ + pred_xywh[:, :2] - pred_xywh[:, 2:] * 0.5, + pred_xywh[:, :2] + pred_xywh[:, 2:] * 0.5 + ], + axis=-1) + # (2) (xmin, ymin, xmax, ymax) -> (xmin_org, ymin_org, xmax_org, ymax_org) + org_h, org_w = org_img_shape + resize_ratio = min(input_size / org_w, input_size / org_h) + + dw = (input_size - resize_ratio * org_w) / 2 + dh = (input_size - resize_ratio * org_h) / 2 + + pred_coor[:, 0::2] = 1.0 * (pred_coor[:, 0::2] - dw) / resize_ratio + pred_coor[:, 1::2] = 1.0 * (pred_coor[:, 1::2] - dh) / resize_ratio + + # (3) clip some boxes that are out of range + pred_coor = np.concatenate([ + np.maximum(pred_coor[:, :2], [0, 0]), + np.minimum(pred_coor[:, 2:], [org_w - 1, org_h - 1]) + ], + axis=-1) + invalid_mask = np.logical_or((pred_coor[:, 0] > pred_coor[:, 2]), + (pred_coor[:, 1] > pred_coor[:, 3])) + pred_coor[invalid_mask] = 0 + + # (4) discard some invalid boxes + bboxes_scale = np.sqrt( + np.multiply.reduce(pred_coor[:, 2:4] - pred_coor[:, 0:2], axis=-1)) + scale_mask = np.logical_and((valid_scale[0] < bboxes_scale), + (bboxes_scale < valid_scale[1])) + + # (5) discard some boxes with low scores + classes = np.argmax(pred_prob, axis=-1) + scores = pred_conf * pred_prob[np.arange(len(pred_coor)), classes] + score_mask = scores > score_threshold + mask = np.logical_and(scale_mask, score_mask) + coors, scores, classes = pred_coor[mask], scores[mask], classes[mask] + + return np.concatenate( + [coors, scores[:, np.newaxis], classes[:, np.newaxis]], axis=-1) + + +def bboxes_iou(boxes1, boxes2): + '''calculate the Intersection Over Union value''' + boxes1 = np.array(boxes1) + boxes2 = np.array(boxes2) + + boxes1_area = (boxes1[..., 2] - boxes1[..., 0]) * (boxes1[..., 3] - + boxes1[..., 1]) + boxes2_area = (boxes2[..., 2] - boxes2[..., 0]) * (boxes2[..., 3] - + boxes2[..., 1]) + + left_up = np.maximum(boxes1[..., :2], boxes2[..., :2]) + right_down = np.minimum(boxes1[..., 2:], boxes2[..., 2:]) + + inter_section = np.maximum(right_down - left_up, 0.0) + inter_area = inter_section[..., 0] * inter_section[..., 1] + union_area = boxes1_area + boxes2_area - inter_area + ious = np.maximum(1.0 * inter_area / union_area, np.finfo(np.float32).eps) + + return ious + + +def nms(bboxes, iou_threshold, sigma=0.3, method='nms'): + """ + :param bboxes: (xmin, ymin, xmax, ymax, score, class) + + Note: soft-nms, https://arxiv.org/pdf/1704.04503.pdf + https://github.com/bharatsingh430/soft-nms + """ + classes_in_img = list(set(bboxes[:, 5])) + best_bboxes = [] + + for cls in classes_in_img: + cls_mask = (bboxes[:, 5] == cls) + cls_bboxes = bboxes[cls_mask] + + while len(cls_bboxes) > 0: + max_ind = np.argmax(cls_bboxes[:, 4]) + best_bbox = cls_bboxes[max_ind] + best_bboxes.append(best_bbox) + cls_bboxes = np.concatenate( + [cls_bboxes[:max_ind], cls_bboxes[max_ind + 1:]]) + iou = bboxes_iou(best_bbox[np.newaxis, :4], cls_bboxes[:, :4]) + weight = np.ones((len(iou), ), dtype=np.float32) + + assert method in ['nms', 'soft-nms'] + + if method == 'nms': + iou_mask = iou > iou_threshold + weight[iou_mask] = 0.0 + + if method == 'soft-nms': + weight = np.exp(-(1.0 * iou**2 / sigma)) + + cls_bboxes[:, 4] = cls_bboxes[:, 4] * weight + score_mask = cls_bboxes[:, 4] > 0. + cls_bboxes = cls_bboxes[score_mask] + + return best_bboxes + + +def read_class_names(class_file_name): + '''loads class name from a file''' + names = {} + with open(class_file_name, 'r') as data: + for ID, name in enumerate(data): + names[ID] = name.strip('\n') + return names + + +def draw_bbox(image, + bboxes, + classes=read_class_names("./utilities/coco.names"), + show_label=True): + """ + bboxes: [x_min, y_min, x_max, y_max, probability, cls_id] format coordinates. + """ + + num_classes = len(classes) + image_h, image_w, _ = image.shape + hsv_tuples = [(1.0 * x / num_classes, 1., 1.) for x in range(num_classes)] + colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) + colors = list( + map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), + colors)) + + random.seed(0) + random.shuffle(colors) + random.seed(None) + + for i, bbox in enumerate(bboxes): + coor = np.array(bbox[:4], dtype=np.int32) + fontScale = 0.5 + score = bbox[4] + class_ind = int(bbox[5]) + bbox_color = colors[class_ind] + bbox_thick = int(0.6 * (image_h + image_w) / 600) + c1, c2 = (coor[0], coor[1]), (coor[2], coor[3]) + cv2.rectangle(image, c1, c2, bbox_color, bbox_thick) + + if show_label: + bbox_mess = '%s: %.2f' % (classes[class_ind], score) + t_size = cv2.getTextSize(bbox_mess, + 0, + fontScale, + thickness=bbox_thick // 2)[0] + cv2.rectangle(image, c1, + (c1[0] + t_size[0], c1[1] - t_size[1] - 3), + bbox_color, -1) + + cv2.putText(image, + bbox_mess, (c1[0], c1[1] - 2), + cv2.FONT_HERSHEY_SIMPLEX, + fontScale, (0, 0, 0), + bbox_thick // 2, + lineType=cv2.LINE_AA) + + return image diff --git a/examples/vision/python_yolov4/yolov4_inference.ipynb b/examples/vision/python_yolov4/yolov4_inference.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..475c07d8a4aa5229e4e1a1a2e1df3c8508ac76f4 --- /dev/null +++ b/examples/vision/python_yolov4/yolov4_inference.ipynb @@ -0,0 +1,195 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Object Detection with YoloV4\n", + "This notebook is intended to be an example of how to use MIGraphX to perform object detection. The model used below is a pre-trained yolov4 from the ONNX model zoo. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os.path\n", + "\n", + "if not os.path.exists(\"./utilities/coco.names\"):\n", + " !wget https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/yolov4/dependencies/coco.names -P ./utilities/\n", + "if not os.path.exists(\"./utilities/yolov4_anchors.txt\"):\n", + " !wget https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/yolov4/dependencies/yolov4_anchors.txt -P ./utilities/\n", + "if not os.path.exists(\"./utilities/input.jpg\"):\n", + " # The image used is from the COCO dataset (https://cocodataset.org/#explore)\n", + " # Other images can be tested by replacing the link below\n", + " image_link = \"https://farm3.staticflickr.com/2009/2306189268_88cc86b30f_z.jpg\"\n", + " !wget -O ./utilities/input.jpg $image_link\n", + "if not os.path.exists(\"./utilities/yolov4.onnx\"):\n", + " !wget https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/yolov4/model/yolov4.onnx -P ./utilities/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Serialize model using MIGraphX Driver\n", + "Please refer to the [MIGraphX Driver example](../../migraphx/migraphx_driver) if you would like more information about this tool." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not os.path.exists(\"yolov4_fp16.mxr\"):\n", + " !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.mxr\n", + "if not os.path.exists(\"yolov4.mxr\"):\n", + " !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.mxr" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import libraries \n", + "Please refer to [this section](https://github.com/ROCmSoftwarePlatform/AMDMIGraphX#using-migraphx-python-module) of the main README if the migraphx module is not found. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import migraphx\n", + "import cv2\n", + "import time\n", + "import numpy as np\n", + "import image_processing as ip\n", + "from PIL import Image" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Read and pre-process image data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "input_size = 416\n", + "\n", + "original_image = cv2.imread(\"./utilities/input.jpg\")\n", + "original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)\n", + "original_image_size = original_image.shape[:2]\n", + "\n", + "image_data = ip.image_preprocess(np.copy(original_image), [input_size, input_size])\n", + "image_data = image_data[np.newaxis, ...].astype(np.float32)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load and run model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load serialized model (either single- or half-precision)\n", + "model = migraphx.load(\"yolov4.mxr\", format=\"msgpack\")\n", + "#model = migraphx.load(\"yolov4_fp16.mxr\", format=\"msgpack\")\n", + "\n", + "# Get the name of the input parameter and convert image data to an MIGraphX argument\n", + "input_name = next(iter(model.get_parameter_shapes()))\n", + "input_argument = migraphx.argument(image_data)\n", + "\n", + "# Evaluate the model and convert the outputs for post-processing\n", + "outputs = model.run({input_name: input_argument})\n", + "detections = [np.ndarray(shape=out.get_shape().lens(), buffer=np.array(out.tolist()), dtype=float) for out in outputs]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Post-process the model outputs and display image with detection bounding boxes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ANCHORS = \"./utilities/yolov4_anchors.txt\"\n", + "STRIDES = [8, 16, 32]\n", + "XYSCALE = [1.2, 1.1, 1.05]\n", + "\n", + "ANCHORS = ip.get_anchors(ANCHORS)\n", + "STRIDES = np.array(STRIDES)\n", + "\n", + "pred_bbox = ip.postprocess_bbbox(detections, ANCHORS, STRIDES, XYSCALE)\n", + "bboxes = ip.postprocess_boxes(pred_bbox, original_image_size, input_size, 0.25)\n", + "bboxes = ip.nms(bboxes, 0.213, method='nms')\n", + "image = ip.draw_bbox(original_image, bboxes)\n", + "\n", + "image = Image.fromarray(image)\n", + "image.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "name": "python3", + "display_name": "Python 3.8.3 64-bit ('base': conda)" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + }, + "metadata": { + "interpreter": { + "hash": "d7283edef085bb46d38a3069bce96b3de1793019cb5bd7b1e86bf9785b67f304" + } + }, + "interpreter": { + "hash": "d7283edef085bb46d38a3069bce96b3de1793019cb5bd7b1e86bf9785b67f304" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/hip-clang.docker b/hip-clang.docker new file mode 100755 index 0000000000000000000000000000000000000000..2984d4f11bc8301abc0daae612a4cf33851e94de --- /dev/null +++ b/hip-clang.docker @@ -0,0 +1,60 @@ +FROM ubuntu:20.04 + +ARG PREFIX=/usr/local + +# Support multiarch +RUN dpkg --add-architecture i386 + +# Add rocm repository +RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.0.2/ ubuntu main > /etc/apt/sources.list.d/rocm.list' + +# Install dependencies +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ + apt-utils \ + build-essential \ + clang-format-10 \ + cmake \ + curl \ + doxygen \ + gdb \ + git \ + lcov \ + pkg-config \ + python3 \ + python3-dev \ + python3-pip \ + software-properties-common \ + wget \ + rocm-device-libs \ + hip-base \ + libnuma-dev \ + miopen-hip \ + rocblas \ + zlib1g-dev && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Workaround broken rocm packages +RUN ln -s /opt/rocm-* /opt/rocm +RUN echo "/opt/rocm/lib" > /etc/ld.so.conf.d/rocm.conf +RUN echo "/opt/rocm/llvm/lib" > /etc/ld.so.conf.d/rocm-llvm.conf +RUN ldconfig + +ENV LC_ALL=C.UTF-8 +ENV LANG=C.UTF-8 + +# Install yapf +RUN pip3 install yapf==0.28.0 + +# Install doc requirements +ADD doc/requirements.txt /doc-requirements.txt +RUN pip3 install -r /doc-requirements.txt + +# Install dependencies +ADD dev-requirements.txt /dev-requirements.txt +ADD requirements.txt /requirements.txt +ADD rbuild.ini /rbuild.ini + +COPY ./tools/install_prereqs.sh / +RUN /install_prereqs.sh /usr/local / && rm /install_prereqs.sh + diff --git a/rbuild.ini b/rbuild.ini new file mode 100755 index 0000000000000000000000000000000000000000..0749b843af8e1f516ac91efc72113b99766aea49 --- /dev/null +++ b/rbuild.ini @@ -0,0 +1,27 @@ +[main] +cxx = ${rocm_path}/llvm/bin/clang++ +cc = ${rocm_path}/llvm/bin/clang +deps = + pfultz2/rocm-recipes + -f requirements.txt + +[gh] +ignore = danmar/cppcheck +deps = + -f dev-requirements.txt + oneapi-src/oneDNN@v1.7 +define = + CMAKE_C_COMPILER_LAUNCHER=${deps_dir}/bin/ccache + CMAKE_CXX_COMPILER_LAUNCHER=${deps_dir}/bin/ccache + MIGRAPHX_ENABLE_CPU=On + +[develop] +cxx = ${rocm_path}/llvm/bin/clang++ +cc = ${rocm_path}/llvm/bin/clang +deps = + -f dev-requirements.txt + oneapi-src/oneDNN@v1.7 +define = + CMAKE_C_COMPILER_LAUNCHER=${deps_dir}/bin/ccache + CMAKE_CXX_COMPILER_LAUNCHER=${deps_dir}/bin/ccache + MIGRAPHX_ENABLE_CPU=On \ No newline at end of file diff --git a/requirements.txt b/requirements.txt old mode 100644 new mode 100755 index 586f8790b1f94227f86dc154e964bd2f1e6d5b3b..95bec385292fb6583b1ab8d0c571219eabafa644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ -google/protobuf@v3.8.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off -RadeonOpenCompute/rocm-cmake@b29ff83 --build -ROCmSoftwarePlatform/rocBLAS@7197df74e5a1ba64ff967065872e5f86a3516637 -ROCmSoftwarePlatform/MIOpen@2.0.0 +google/protobuf@v3.11.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off +nlohmann/json@v3.8.0 blaze,https://bitbucket.org/blaze-lib/blaze/get/f0755dea0e03.tar.gz -X header -DHEADER_DIR=blaze half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969 -pybind/pybind11@v2.2.4 -DPYBIND11_TEST=Off --build +pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build +msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt old mode 100644 new mode 100755 index 18bf7894c1642b6f499544cc00825c383e1532ba..87c0b88070aa70c1aeb042d68187883073bfc70b --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,62 +1,258 @@ include(ROCMInstallTargets) include(ROCMPackageConfigHelpers) +include(RegisterOp) +include(CheckCXXLinkerFlag) add_library(migraphx + adjust_allocation.cpp + analyze_streams.cpp + apply_alpha_beta.cpp + argument.cpp auto_contiguous.cpp - eliminate_common_subexpression.cpp - propagate_constant.cpp + common.cpp + compile_src.cpp + convert_to_json.cpp + cpp_generator.cpp dead_code_elimination.cpp + dom_info.cpp + dynamic_loader.cpp eliminate_allocation.cpp - eliminate_contiguous.cpp + eliminate_common_subexpression.cpp eliminate_concat.cpp + eliminate_contiguous.cpp + eliminate_data_type.cpp eliminate_identity.cpp eliminate_pad.cpp - rewrite_batchnorm.cpp - rewrite_rnn.cpp - rewrite_pooling.cpp env.cpp + file_buffer.cpp + fuse_pointwise.cpp generate.cpp + inline_module.cpp + insert_pad.cpp instruction.cpp + json.cpp + load_save.cpp + make_op.cpp + module.cpp + msgpack.cpp + normalize_attributes.cpp + normalize_ops.cpp + op_enums.cpp + operation.cpp + opt/memory_coloring.cpp + opt/memory_coloring_impl.cpp + pass_manager.cpp + permutation.cpp + preallocate_param.cpp + process.cpp program.cpp + propagate_constant.cpp quantization.cpp - shape.cpp + quantize_fp16.cpp + quantize_int8.cpp + reduce_dims.cpp + register_op.cpp + register_target.cpp + replace_allocate.cpp + simplify_qdq.cpp + rewrite_batchnorm.cpp + rewrite_pooling.cpp + rewrite_quantization.cpp + rewrite_rnn.cpp schedule.cpp - pass_manager.cpp + serialize.cpp + shape.cpp simplify_algebra.cpp simplify_reshapes.cpp - opt/memory_coloring.cpp - opt/memory_coloring_impl.cpp + tmp_dir.cpp + value.cpp + verify_args.cpp ) +configure_file(version.h.in include/migraphx/version.h) rocm_set_soversion(migraphx ${MIGRAPHX_SO_VERSION}) +function(register_migraphx_ops) + foreach(OP ${ARGN}) + register_op(migraphx HEADER migraphx/op/${OP}.hpp OPERATORS op::${OP}) + endforeach() +endfunction() +register_migraphx_ops( + abs + acosh + acos + add + allocate + argmax + argmin + asinh + asin + as_shape + atanh + atan + batch_norm_inference + broadcast + capture + ceil + clip + concat + contiguous + convert + convolution + cosh + cos + deconvolution + dequantizelinear + div + dot + elu + equal + erf + exp + flatten + floor + gather + gathernd + get_tuple_elem + greater + gru + identity + if_op + im2col + isnan + leaky_relu + less + load + log + logical_and + logical_or + logical_xor + logsoftmax + loop + lrn + lstm + max + min + mul + multibroadcast + multinomial + neg + nonmaxsuppression + nonzero + outline + pad + pointwise + pooling + pow + prefix_scan_sum + prelu + quant_convolution + quant_dot + quantizelinear + recip + reduce_max + reduce_mean + reduce_min + reduce_prod + reduce_sum + relu + reshape + reverse + rnn + rnn_last_cell_output + rnn_last_hs_output + rnn_var_sl_last_output + roialign + round + rsqrt + scalar + scatter_add + scatter_mul + scatter_none + scatternd_add + scatternd_mul + scatternd_none + sigmoid + sign + sinh + sin + slice + softmax + sqdiff + sqrt + squeeze + step + sub + tanh + tan + topk + transpose + unary_not + undefined + unknown + unsqueeze + where +) +register_op(migraphx HEADER migraphx/op/rnn_variable_seq_lens.hpp OPERATORS op::rnn_var_sl_shift_output op::rnn_var_sl_shift_sequence) +register_op(migraphx HEADER migraphx/builtin.hpp OPERATORS builtin::literal builtin::param builtin::returns) rocm_clang_tidy_check(migraphx) rocm_install_targets( TARGETS migraphx INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_BINARY_DIR}/include ) -find_path(HALF_INCLUDE_DIR half.hpp) -# TODO: Fix the incorrect path + +check_cxx_linker_flag(-lstdc++fs HAS_LIB_STD_FILESYSTEM) +if(HAS_LIB_STD_FILESYSTEM) +target_link_libraries(migraphx PRIVATE -lstdc++fs) +endif() + +target_link_libraries(migraphx PRIVATE -ldl) + target_include_directories(migraphx SYSTEM PUBLIC $) +find_package(Threads) +target_link_libraries(migraphx PUBLIC Threads::Threads) + +find_package(msgpack REQUIRED) +target_link_libraries(migraphx PRIVATE msgpackc-cxx) +# Make this available to the tests +target_link_libraries(migraphx INTERFACE $) + +add_library(migraphx_all_targets INTERFACE) + set(PACKAGE_DEPENDS) +add_subdirectory(api) add_subdirectory(driver) add_subdirectory(onnx) add_subdirectory(tf) add_subdirectory(py) +add_subdirectory(targets/ref) +target_link_libraries(migraphx_all_targets INTERFACE migraphx_ref) +if(MIGRAPHX_ENABLE_CPU) add_subdirectory(targets/cpu) +target_link_libraries(migraphx_all_targets INTERFACE migraphx_cpu) +target_compile_definitions(migraphx_all_targets INTERFACE -DHAVE_CPU) +endif() if(MIGRAPHX_ENABLE_GPU) -list(APPEND PACKAGE_DEPENDS MIOpen rocblas) +list(APPEND PACKAGE_DEPENDS PACKAGE MIOpen PACKAGE rocblas) add_subdirectory(targets/gpu) +target_link_libraries(migraphx_all_targets INTERFACE migraphx_gpu) +target_compile_definitions(migraphx_all_targets INTERFACE -DHAVE_GPU) +endif() + +if(HAVE_HALF_EXPR) + target_compile_definitions(migraphx PUBLIC -DHAS_HALF_V1) endif() rocm_export_targets( - TARGETS migraphx::migraphx + TARGETS migraphx::migraphx_c NAMESPACE migraphx:: DEPENDS + Threads ${PACKAGE_DEPENDS} ) diff --git a/src/adjust_allocation.cpp b/src/adjust_allocation.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1bf55eedba65e337dea8b3436d15f46ab62a2d0f --- /dev/null +++ b/src/adjust_allocation.cpp @@ -0,0 +1,47 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +void adjust_allocation::apply(module& m) const +{ + for(auto ins : iterator_for(m)) + { + // skip instruction with no input + if(ins->inputs().empty()) + continue; + + // Skip target-independent operators + if(ins->get_operator().is_context_free()) + continue; + + auto alias_ins = instruction::get_output_alias(ins, true); + if(alias_ins->name() != model.name() and alias_ins->name() != "@param") + continue; + // shape allocated is different from actual shape + // of the instruction, reallocate and replace the previous one + if(alias_ins->get_shape() == ins->get_shape()) + continue; + auto alloc_ins = m.insert_instruction(ins, model.allocate(ins->get_shape())); + m.replace_instruction(alias_ins, alloc_ins); + // If the memory is an output parameter then copy the memory to the parameter + if(alias_ins->name() == "@param") + { + auto copy = m.insert_instruction(std::next(ins), make_op(model.copy()), ins, alias_ins); + auto tail = range(std::next(copy), m.end()); + for(auto i : iterator_for(tail)) + { + if(contains(i->inputs(), ins)) + instruction::replace_argument(i, ins, copy); + } + } + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/analyze_streams.cpp b/src/analyze_streams.cpp new file mode 100644 index 0000000000000000000000000000000000000000..939a72ea094254e236b8f6b152d8c42203f09a96 --- /dev/null +++ b/src/analyze_streams.cpp @@ -0,0 +1,91 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +bool happens_before(const std::vector& e1, const std::vector& e2) +{ + return std::equal(e1.begin(), e1.end(), e2.begin(), e2.end(), std::less_equal<>{}) and + not std::equal(e1.begin(), e1.end(), e2.begin(), e2.end(), std::greater_equal<>{}); +} + +std::vector analyze_streams(const module& m, const stream_model& strmm) +{ + using vector_clock = std::vector; + std::vector races; + auto nstream = strmm.get_nstream(); + std::vector vclock(nstream, vector_clock(nstream)); + std::unordered_map timestamp; + std::unordered_map events; + for(auto ins : iterator_for(m)) + { + if(not strmm.has_stream(ins)) + continue; + std::size_t s = strmm.get_stream(ins); + assert(s < nstream); + assert(vclock.size() == nstream); + assert(vclock[s].size() == nstream); + if(strmm.is_record(ins)) + { + vclock[s][s]++; + auto event = strmm.get_event_id(ins); + events[event] = vclock[s]; + } + else if(strmm.is_wait(ins)) + { + auto event = strmm.get_event_id(ins); + if(not contains(events, event)) + MIGRAPHX_THROW("Event is waited on before being recorded: " + + std::to_string(event)); + auto payload = events.at(event); + assert(vclock[s].size() == payload.size()); + std::transform(vclock[s].begin(), + vclock[s].end(), + payload.begin(), + vclock[s].begin(), + [&](auto x, auto y) { return std::max(x, y); }); + vclock[s][s]++; + } + else + { + vclock[s][s]++; + } + timestamp[ins] = vclock[s]; + } + for(auto ins : iterator_for(m)) + { + if(not strmm.has_stream(ins)) + continue; + if(ins->inputs().empty()) + continue; + std::size_t s = strmm.get_stream(ins); + // Find inputs from different streams + std::vector inputs; + fix([&](auto self, auto start) { + for(auto input : start->inputs()) + { + if(not strmm.has_stream(input)) + self(input); + else if(strmm.get_stream(input) != s) + inputs.push_back(input); + } + })(ins); + auto it = std::find_if(inputs.begin(), inputs.end(), [&](auto input) { + return not happens_before(timestamp.at(input), timestamp.at(ins)); + }); + if(it != inputs.end()) + { + races.push_back({ins, *it}); + } + } + + return races; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/api/CMakeLists.txt b/src/api/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a6eace3d2154d4790d7d6438ca574273867b9afa --- /dev/null +++ b/src/api/CMakeLists.txt @@ -0,0 +1,15 @@ + +add_library(migraphx_c + api.cpp +) +set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c) +rocm_set_soversion(migraphx_c 3.0) + +rocm_clang_tidy_check(migraphx_c) +target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets) + +rocm_install_targets( + TARGETS migraphx_c + INCLUDE + ${CMAKE_CURRENT_SOURCE_DIR}/include +) diff --git a/src/api/api.cpp b/src/api/api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..95849e3f905b307cba11c05c616de13f66801bea --- /dev/null +++ b/src/api/api.cpp @@ -0,0 +1,1769 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { + +template +migraphx_status try_(F f, bool output = true) // NOLINT +{ + try + { + f(); + } + catch(const migraphx::exception& ex) + { + if(output) + std::cerr << "MIGraphX Error: " << ex.what() << std::endl; + if(ex.error > 0) + return migraphx_status(ex.error); + else + return migraphx_status_unknown_error; + } + catch(const std::exception& ex) + { + if(output) + std::cerr << "MIGraphX Error: " << ex.what() << std::endl; + return migraphx_status_unknown_error; + } + catch(...) + { + return migraphx_status_unknown_error; + } + return migraphx_status_success; +} + +shape::type_t to_shape_type(migraphx_shape_datatype_t t) +{ + switch(t) + { + case migraphx_shape_tuple_type: return shape::tuple_type; +#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \ + case migraphx_shape_##x: return shape::x; + MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT) +#undef MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT + } + MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type"); +} + +migraphx_shape_datatype_t to_shape_type(shape::type_t t) +{ + switch(t) + { + case shape::tuple_type: return migraphx_shape_tuple_type; +#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \ + case shape::x: return migraphx_shape_##x; + MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT) +#undef MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT + } + MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type"); +} + +template +auto to_obj_vector(const T* x, std::size_t n) +{ + std::vectorobject)> result; + std::transform(x, x + n, std::back_inserter(result), [&](auto&& y) { return y->object; }); + return result; +} + +template +auto to_objptr_vector(const U* x, std::size_t n) +{ + std::vector result; + std::transform( + x, x + n, std::back_inserter(result), [&](auto&& y) { return std::addressof(y->object); }); + return result; +} + +target get_target(const std::string& name) { return make_target(name); } + +void set_offload_copy(compile_options& options, bool value) { options.offload_copy = value; } + +void set_fast_math(compile_options& options, bool value) { options.fast_math = value; } + +void set_file_format(file_options& options, const char* format) { options.format = format; } + +void set_default_dim_value(onnx_options& options, size_t value) +{ + options.default_dim_value = value; +} + +void set_default_loop_iterations(onnx_options& options, int64_t value) +{ + options.max_loop_iterations = value; +} + +void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; } + +void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; } + +void set_input_parameter_shape(onnx_options& options, + const char* name, + std::vector dims) +{ + options.map_input_dims[std::string(name)] = std::move(dims); +} + +void set_input_parameter_shape(tf_options& options, const char* name, std::vector dims) +{ + options.map_input_dims[std::string(name)] = std::move(dims); +} + +void set_output_names(tf_options& options, std::vector names) +{ + options.output_node_names = std::vector(names.begin(), names.end()); +} + +template +std::vector get_names(const std::unordered_map& m) +{ + std::vector result; + std::transform( + m.begin(), m.end(), std::back_inserter(result), [](auto&& p) { return p.first.c_str(); }); + return result; +} + +void quantize_fp16_with_op_names(program& prog, std::vector& names) +{ + if(names.empty()) + { + names = {"all"}; + } + + migraphx::quantize_fp16(prog, names); +} + +struct quantize_int8_options +{ + std::vector calibration = {}; + std::vector op_names = {}; +}; + +void add_op_name(quantize_int8_options& options, const char* name) +{ + options.op_names.push_back(name); +} + +void add_calibration_data(quantize_int8_options& options, parameter_map& data) +{ + options.calibration.push_back(data); +} + +void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& options) +{ + if(options.op_names.empty()) + { + options.op_names = {"dot", "convolution"}; + } + + migraphx::quantize_int8(prog, t, options.calibration, options.op_names); +} + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wformat-nonliteral" +#endif + +operation create_op(const char* name, const char* attributes, va_list vlist) +{ + std::string sattributes = attributes == nullptr ? "" : attributes; + std::vector buffer(sattributes.size() * 2); + std::vsnprintf(buffer.data(), buffer.size(), sattributes.c_str(), vlist); + value v = value::object{}; + if(attributes != nullptr) + { + v = from_json_string(convert_to_json(std::string(buffer.data()))); + } + auto op = make_op(name, v); + + return op; +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +template +bool equal(const T& x, const T& y) +{ + return x == y; +} + +std::vector run(program& p, const parameter_map& params) { return p.eval(params); } + +std::vector get_output_shapes(program& p) { return p.get_output_shapes(); } + +void print_program(const program& p) { std::cout << p << std::endl; } + +void print_module(const module& m) { std::cout << m << std::endl; } + +struct experimental_custom_op +{ + std::string name; + experimental_custom_op() = default; + + experimental_custom_op(std::string pname) : name(std::move(pname)) {} +}; + +template +struct custom_operation +{ + template + static auto reflect(Self&, F) + { + return pack(); + } + CustomOp op; + std::string name() const { return op.xobject.name; } + + shape compute_shape(std::vector inputs) const + { + return op.compute_shape(std::move(inputs)); + } + + argument compute(const std::vector&) const { MIGRAPHX_THROW("Not computable"); } +}; + +template +void register_custom_op(const CustomOp& op) +{ + register_op(custom_operation{op}); +} + +migraphx::context get_context(const program& p) { return p.get_context(); } + +} // namespace migraphx + +template > +Target* object_cast(U* x) +{ + return reinterpret_cast(x); +} +template > +const Target* object_cast(const U* x) +{ + return reinterpret_cast(x); +} + +template > +Target* allocate(Ts&&... xs) +{ + return new Target(std::forward(xs)...); // NOLINT +} + +template +void destroy(T* x) +{ + delete x; // NOLINT +} +// TODO: Move to interface preamble +template +struct manage_generic_ptr +{ + manage_generic_ptr() = default; + + manage_generic_ptr(std::nullptr_t) {} + + manage_generic_ptr(void* pdata, C pcopier, D pdeleter) + : data(nullptr), copier(pcopier), deleter(pdeleter) + { + copier(&data, pdata); + } + + manage_generic_ptr(const manage_generic_ptr& rhs) + : data(nullptr), copier(rhs.copier), deleter(rhs.deleter) + { + if(copier) + copier(&data, rhs.data); + } + + manage_generic_ptr(manage_generic_ptr&& other) noexcept + : data(other.data), copier(other.copier), deleter(other.deleter) + { + other.data = nullptr; + other.copier = nullptr; + other.deleter = nullptr; + } + + manage_generic_ptr& operator=(manage_generic_ptr rhs) + { + std::swap(data, rhs.data); + std::swap(copier, rhs.copier); + std::swap(deleter, rhs.deleter); + return *this; + } + + ~manage_generic_ptr() + { + if(data != nullptr) + deleter(data); + } + + void* data = nullptr; + C copier = nullptr; + D deleter = nullptr; +}; + +extern "C" struct migraphx_shape; +struct migraphx_shape +{ + template + migraphx_shape(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + migraphx::shape object; +}; + +extern "C" struct migraphx_argument; +struct migraphx_argument +{ + template + migraphx_argument(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + migraphx::argument object; +}; + +extern "C" struct migraphx_target; +struct migraphx_target +{ + template + migraphx_target(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + migraphx::target object; +}; + +extern "C" struct migraphx_program_parameter_shapes; +struct migraphx_program_parameter_shapes +{ + template + migraphx_program_parameter_shapes(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + std::unordered_map object; +}; + +extern "C" struct migraphx_program_parameters; +struct migraphx_program_parameters +{ + template + migraphx_program_parameters(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + std::unordered_map object; +}; + +extern "C" struct migraphx_arguments; +struct migraphx_arguments +{ + template + migraphx_arguments(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + std::vector object; +}; + +extern "C" struct migraphx_shapes; +struct migraphx_shapes +{ + template + migraphx_shapes(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + std::vector object; +}; + +extern "C" struct migraphx_instruction; +struct migraphx_instruction +{ + template + migraphx_instruction(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + migraphx::instruction_ref object; +}; + +extern "C" struct migraphx_instructions; +struct migraphx_instructions +{ + template + migraphx_instructions(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + std::vector object; +}; + +extern "C" struct migraphx_modules; +struct migraphx_modules +{ + template + migraphx_modules(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + std::vector object; +}; + +extern "C" struct migraphx_module; +struct migraphx_module +{ + template + migraphx_module(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + migraphx::module object; +}; + +extern "C" struct migraphx_program; +struct migraphx_program +{ + template + migraphx_program(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + migraphx::program object; +}; + +extern "C" struct migraphx_operation; +struct migraphx_operation +{ + template + migraphx_operation(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + migraphx::operation object; +}; + +extern "C" struct migraphx_onnx_options; +struct migraphx_onnx_options +{ + template + migraphx_onnx_options(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + migraphx::onnx_options object; +}; + +extern "C" struct migraphx_file_options; +struct migraphx_file_options +{ + template + migraphx_file_options(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + migraphx::file_options object; +}; + +extern "C" struct migraphx_compile_options; +struct migraphx_compile_options +{ + template + migraphx_compile_options(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + migraphx::compile_options object; +}; + +extern "C" struct migraphx_tf_options; +struct migraphx_tf_options +{ + template + migraphx_tf_options(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + migraphx::tf_options object; +}; + +extern "C" struct migraphx_quantize_op_names; +struct migraphx_quantize_op_names +{ + template + migraphx_quantize_op_names(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + std::vector object; +}; + +extern "C" struct migraphx_quantize_int8_options; +struct migraphx_quantize_int8_options +{ + template + migraphx_quantize_int8_options(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + migraphx::quantize_int8_options object; +}; + +extern "C" struct migraphx_context; +struct migraphx_context +{ + template + migraphx_context(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + { + } + migraphx::context object; +}; + +extern "C" struct migraphx_experimental_custom_op; +struct migraphx_experimental_custom_op +{ + template + migraphx_experimental_custom_op(void* p, + migraphx_experimental_custom_op_copy c, + migraphx_experimental_custom_op_delete d, + Ts&&... xs) + : object_ptr(p, c, d), xobject(std::forward(xs)...) + { + } + manage_generic_ptr + object_ptr = nullptr; + migraphx::experimental_custom_op xobject; + migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr; + migraphx::shape compute_shape(std::vector inputs) const + { + std::remove_pointer_t out; + if(compute_shape_f == nullptr) + throw std::runtime_error("compute_shape function is missing."); + auto api_error_result = + compute_shape_f(&out, object_ptr.data, object_cast(&(inputs))); + if(api_error_result != migraphx_status_success) + throw std::runtime_error("Error in compute_shape."); + return (&out)->object; + } +}; + +extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) +{ + auto api_error_result = migraphx::try_([&] { destroy((shape)); }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, + const_migraphx_shape_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_shape_create(migraphx_shape_t* shape, + migraphx_shape_datatype_t type, + size_t* lengths, + size_t lengths_size) +{ + auto api_error_result = migraphx::try_([&] { + if(lengths == nullptr and lengths_size != 0) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer"); + *shape = object_cast( + allocate((migraphx::to_shape_type(type)), + (std::vector(lengths, lengths + lengths_size)))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape, + migraphx_shape_datatype_t type, + size_t* lengths, + size_t lengths_size, + size_t* strides, + size_t strides_size) +{ + auto api_error_result = migraphx::try_([&] { + if(lengths == nullptr and lengths_size != 0) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter lengths: Null pointer"); + if(strides == nullptr and strides_size != 0) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter strides: Null pointer"); + *shape = object_cast( + allocate((migraphx::to_shape_type(type)), + (std::vector(lengths, lengths + lengths_size)), + (std::vector(strides, strides + strides_size)))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape, + migraphx_shape_datatype_t type) +{ + auto api_error_result = migraphx::try_([&] { + *shape = object_cast( + allocate((migraphx::to_shape_type(type)))); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape) +{ + auto api_error_result = migraphx::try_([&] { + if(out == nullptr or out_size == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer"); + if(shape == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); + auto&& api_result = (shape->object).lens(); + *out = api_result.data(); + *out_size = api_result.size(); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape) +{ + auto api_error_result = migraphx::try_([&] { + if(out == nullptr or out_size == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer"); + if(shape == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); + auto&& api_result = (shape->object).strides(); + *out = api_result.data(); + *out_size = api_result.size(); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, + const_migraphx_shape_t shape) +{ + auto api_error_result = migraphx::try_([&] { + if(out == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer"); + if(shape == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); + *out = migraphx::to_shape_type((shape->object).type()); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape) +{ + auto api_error_result = migraphx::try_([&] { + if(shape == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); + *out = (shape->object).bytes(); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x) +{ + auto api_error_result = migraphx::try_([&] { + if(shape == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); + if(x == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer"); + *out = migraphx::equal((shape->object), (x->object)); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument) +{ + auto api_error_result = migraphx::try_([&] { destroy((argument)); }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_argument_assign_to(migraphx_argument_t output, + const_migraphx_argument_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer) +{ + auto api_error_result = migraphx::try_([&] { + if(shape == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); + *argument = object_cast( + allocate((shape->object), (buffer))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out, + const_migraphx_argument_t argument) +{ + auto api_error_result = migraphx::try_([&] { + if(argument == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer"); + *out = object_cast(&((argument->object).get_shape())); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_argument_buffer(char** out, const_migraphx_argument_t argument) +{ + auto api_error_result = migraphx::try_([&] { + if(argument == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer"); + *out = (argument->object).data(); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_argument_equal(bool* out, const_migraphx_argument_t argument, const_migraphx_argument_t x) +{ + auto api_error_result = migraphx::try_([&] { + if(argument == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer"); + if(x == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer"); + *out = migraphx::equal((argument->object), (x->object)); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, size_t seed) +{ + auto api_error_result = migraphx::try_([&] { + if(s == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter s: Null pointer"); + *out = allocate(migraphx::generate_argument((s->object), (seed))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_target_destroy(migraphx_target_t target) +{ + auto api_error_result = migraphx::try_([&] { destroy((target)); }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_target_assign_to(migraphx_target_t output, + const_migraphx_target_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name) +{ + auto api_error_result = migraphx::try_([&] { + *target = object_cast( + allocate(migraphx::get_target((name)))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_program_parameter_shapes_destroy( + migraphx_program_parameter_shapes_t program_parameter_shapes) +{ + auto api_error_result = migraphx::try_([&] { destroy((program_parameter_shapes)); }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output, + const_migraphx_program_parameter_shapes_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_program_parameter_shapes_size(size_t* out, + migraphx_program_parameter_shapes_t program_parameter_shapes) +{ + auto api_error_result = migraphx::try_([&] { + if(program_parameter_shapes == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, + "Bad parameter program_parameter_shapes: Null pointer"); + *out = (program_parameter_shapes->object).size(); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out, + migraphx_program_parameter_shapes_t program_parameter_shapes, + const char* name) +{ + auto api_error_result = migraphx::try_([&] { + if(program_parameter_shapes == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, + "Bad parameter program_parameter_shapes: Null pointer"); + *out = + object_cast(&((program_parameter_shapes->object).at((name)))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_program_parameter_shapes_names( + const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes) +{ + auto api_error_result = migraphx::try_([&] { + if(out == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer"); + if(program_parameter_shapes == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, + "Bad parameter program_parameter_shapes: Null pointer"); + auto&& api_result = migraphx::get_names((program_parameter_shapes->object)); + std::copy(api_result.begin(), api_result.end(), out); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters) +{ + auto api_error_result = migraphx::try_([&] { destroy((program_parameters)); }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_program_parameters_assign_to(migraphx_program_parameters_t output, + const_migraphx_program_parameters_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters) +{ + auto api_error_result = migraphx::try_([&] { + *program_parameters = object_cast( + allocate>()); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters, + const char* name, + const_migraphx_argument_t argument) +{ + auto api_error_result = migraphx::try_([&] { + if(program_parameters == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, + "Bad parameter program_parameters: Null pointer"); + if(argument == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter argument: Null pointer"); + (program_parameters->object)[(name)] = (argument->object); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments) +{ + auto api_error_result = migraphx::try_([&] { destroy((arguments)); }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output, + const_migraphx_arguments_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments) +{ + auto api_error_result = migraphx::try_([&] { + if(arguments == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer"); + *out = (arguments->object).size(); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t arguments, size_t idx) +{ + auto api_error_result = migraphx::try_([&] { + if(arguments == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter arguments: Null pointer"); + *out = object_cast(&((arguments->object).at((idx)))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes) +{ + auto api_error_result = migraphx::try_([&] { destroy((shapes)); }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output, + const_migraphx_shapes_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes) +{ + auto api_error_result = migraphx::try_([&] { + if(shapes == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer"); + *out = (shapes->object).size(); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx) +{ + auto api_error_result = migraphx::try_([&] { + if(shapes == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shapes: Null pointer"); + *out = object_cast(&((shapes->object).at((idx)))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction) +{ + auto api_error_result = migraphx::try_([&] { destroy((instruction)); }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_instruction_assign_to(migraphx_instruction_t output, + const_migraphx_instruction_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_instructions_destroy(migraphx_instructions_t instructions) +{ + auto api_error_result = migraphx::try_([&] { destroy((instructions)); }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions_t output, + const_migraphx_instructions_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions, + const_migraphx_instruction_t* ptr, + size_t size) +{ + auto api_error_result = migraphx::try_([&] { + *instructions = + object_cast(allocate>( + migraphx::to_obj_vector((ptr), (size)))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_modules_destroy(migraphx_modules_t modules) +{ + auto api_error_result = migraphx::try_([&] { destroy((modules)); }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_modules_assign_to(migraphx_modules_t output, + const_migraphx_modules_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size) +{ + auto api_error_result = migraphx::try_([&] { + *modules = object_cast(allocate>( + migraphx::to_objptr_vector((ptr), (size)))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_module_create(migraphx_module_t* module, char* name) +{ + auto api_error_result = migraphx::try_([&] { + if(name == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer"); + *module = object_cast(allocate((std::string(name)))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module) +{ + auto api_error_result = migraphx::try_([&] { + if(module == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer"); + migraphx::print_module((module->object)); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out, + migraphx_module_t module, + migraphx_operation_t op, + migraphx_instructions_t args) +{ + auto api_error_result = migraphx::try_([&] { + if(module == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer"); + if(op == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter op: Null pointer"); + if(args == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer"); + *out = allocate( + (module->object).add_instruction((op->object), (args->object))); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out, + migraphx_module_t module, + migraphx_operation_t op, + migraphx_instructions_t args, + migraphx_modules_t module_refs) +{ + auto api_error_result = migraphx::try_([&] { + if(module == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer"); + if(op == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter op: Null pointer"); + if(args == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer"); + if(module_refs == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module_refs: Null pointer"); + *out = allocate( + (module->object).add_instruction((op->object), (args->object), (module_refs->object))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_module_add_literal(migraphx_instruction_t* out, + migraphx_module_t module, + const_migraphx_shape_t shape, + const char* buffer) +{ + auto api_error_result = migraphx::try_([&] { + if(module == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer"); + if(shape == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); + *out = allocate( + (module->object).add_literal((shape->object), (buffer))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out, + migraphx_module_t module, + const char* name, + const_migraphx_shape_t shape) +{ + auto api_error_result = migraphx::try_([&] { + if(module == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer"); + if(shape == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer"); + *out = allocate( + (module->object).add_parameter((name), (shape->object))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* out, + migraphx_module_t module, + migraphx_instructions_t args) +{ + auto api_error_result = migraphx::try_([&] { + if(module == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer"); + if(args == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer"); + *out = allocate((module->object).add_return((args->object))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program) +{ + auto api_error_result = migraphx::try_([&] { destroy((program)); }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_program_assign_to(migraphx_program_t output, + const_migraphx_program_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_program_create(migraphx_program_t* program) +{ + auto api_error_result = migraphx::try_( + [&] { *program = object_cast(allocate()); }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, + migraphx_program_t program) +{ + auto api_error_result = migraphx::try_([&] { + if(program == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); + *out = object_cast((program->object).get_main_module()); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_program_create_module(migraphx_module_t* out, migraphx_program_t program, const char* name) +{ + auto api_error_result = migraphx::try_([&] { + if(program == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); + *out = object_cast((program->object).create_module((name))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program, + migraphx_target_t target, + migraphx_compile_options_t options) +{ + auto api_error_result = migraphx::try_([&] { + if(program == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); + if(target == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter target: Null pointer"); + if(options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); + (program->object).compile((target->object), (options->object)); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out, + migraphx_program_t program) +{ + auto api_error_result = migraphx::try_([&] { + if(program == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); + *out = + allocate((program->object).get_parameter_shapes()); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out, + migraphx_program_t program) +{ + auto api_error_result = migraphx::try_([&] { + if(program == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); + *out = allocate(migraphx::get_output_shapes((program->object))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t program) +{ + auto api_error_result = migraphx::try_([&] { + if(program == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); + migraphx::print_program((program->object)); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_program_sort(migraphx_program_t program) +{ + auto api_error_result = migraphx::try_([&] { + if(program == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); + (program->object).sort(); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out, + migraphx_program_t program, + migraphx_program_parameters_t params) +{ + auto api_error_result = migraphx::try_([&] { + if(program == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); + if(params == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter params: Null pointer"); + *out = allocate(migraphx::run((program->object), (params->object))); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x) +{ + auto api_error_result = migraphx::try_([&] { + if(program == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); + if(x == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer"); + *out = migraphx::equal((program->object), (x->object)); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_program_experimental_get_context(migraphx_context_t* out, const_migraphx_program_t program) +{ + auto api_error_result = migraphx::try_([&] { + if(program == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); + *out = allocate(migraphx::get_context((program->object))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t operation) +{ + auto api_error_result = migraphx::try_([&] { destroy((operation)); }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_operation_assign_to(migraphx_operation_t output, + const_migraphx_operation_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_operation_create(migraphx_operation_t* operation, + const char* name, + const char* attributes, + ...) +{ + va_list vlist; + va_start(vlist, attributes); + auto api_error_result = migraphx::try_([&] { + *operation = object_cast( + allocate(migraphx::create_op((name), (attributes), (vlist)))); + }); + va_end(vlist); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation) +{ + auto api_error_result = migraphx::try_([&] { + if(out == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer"); + if(operation == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter operation: Null pointer"); + auto&& api_result = (operation->object).name(); + auto* it = std::copy_n(api_result.begin(), std::min(api_result.size(), out_size - 1), out); + *it = '\0'; + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options) +{ + auto api_error_result = migraphx::try_([&] { + if(options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); + *out = allocate(migraphx::load((name), (options->object))); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options) +{ + auto api_error_result = migraphx::try_([&] { + if(p == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter p: Null pointer"); + if(options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); + migraphx::save((p->object), (name), (options->object)); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options) +{ + auto api_error_result = migraphx::try_([&] { destroy((onnx_options)); }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output, + const_migraphx_onnx_options_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options) +{ + auto api_error_result = migraphx::try_([&] { + *onnx_options = object_cast(allocate()); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape( + migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size) +{ + auto api_error_result = migraphx::try_([&] { + if(onnx_options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer"); + if(dims == nullptr and dims_size != 0) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dims: Null pointer"); + migraphx::set_input_parameter_shape( + (onnx_options->object), (name), (std::vector(dims, dims + dims_size))); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value) +{ + auto api_error_result = migraphx::try_([&] { + if(onnx_options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer"); + migraphx::set_default_dim_value((onnx_options->object), (value)); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options, + int64_t value) +{ + auto api_error_result = migraphx::try_([&] { + if(onnx_options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer"); + migraphx::set_default_loop_iterations((onnx_options->object), (value)); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options) +{ + auto api_error_result = migraphx::try_([&] { destroy((file_options)); }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output, + const_migraphx_file_options_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options) +{ + auto api_error_result = migraphx::try_([&] { + *file_options = object_cast(allocate()); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_file_options_set_file_format(migraphx_file_options_t file_options, const char* format) +{ + auto api_error_result = migraphx::try_([&] { + if(file_options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter file_options: Null pointer"); + migraphx::set_file_format((file_options->object), (format)); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_compile_options_destroy(migraphx_compile_options_t compile_options) +{ + auto api_error_result = migraphx::try_([&] { destroy((compile_options)); }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_compile_options_assign_to(migraphx_compile_options_t output, + const_migraphx_compile_options_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_compile_options_create(migraphx_compile_options_t* compile_options) +{ + auto api_error_result = migraphx::try_([&] { + *compile_options = + object_cast(allocate()); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value) +{ + auto api_error_result = migraphx::try_([&] { + if(compile_options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, + "Bad parameter compile_options: Null pointer"); + migraphx::set_offload_copy((compile_options->object), (value)); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, bool value) +{ + auto api_error_result = migraphx::try_([&] { + if(compile_options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, + "Bad parameter compile_options: Null pointer"); + migraphx::set_fast_math((compile_options->object), (value)); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options) +{ + auto api_error_result = migraphx::try_([&] { + if(options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); + *out = allocate(migraphx::parse_onnx((name), (options->object))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out, + const void* data, + size_t size, + migraphx_onnx_options_t options) +{ + auto api_error_result = migraphx::try_([&] { + if(options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); + *out = allocate( + migraphx::parse_onnx_buffer((data), (size), (options->object))); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options) +{ + auto api_error_result = migraphx::try_([&] { destroy((tf_options)); }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output, + const_migraphx_tf_options_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options) +{ + auto api_error_result = migraphx::try_([&] { + *tf_options = object_cast(allocate()); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, + bool is_nhwc) +{ + auto api_error_result = migraphx::try_([&] { + if(tf_options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer"); + migraphx::set_nhwc((tf_options->object), (is_nhwc)); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_tf_options_set_input_parameter_shape( + migraphx_tf_options_t tf_options, const char* name, size_t* dims, size_t dims_size) +{ + auto api_error_result = migraphx::try_([&] { + if(tf_options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer"); + if(dims == nullptr and dims_size != 0) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dims: Null pointer"); + migraphx::set_input_parameter_shape( + (tf_options->object), (name), (std::vector(dims, dims + dims_size))); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, size_t value) +{ + auto api_error_result = migraphx::try_([&] { + if(tf_options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer"); + migraphx::set_default_dim_value((tf_options->object), (value)); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options, + const char** names, + size_t names_size) +{ + auto api_error_result = migraphx::try_([&] { + if(tf_options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter tf_options: Null pointer"); + if(names == nullptr and names_size != 0) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter names: Null pointer"); + migraphx::set_output_names((tf_options->object), + (std::vector(names, names + names_size))); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options) +{ + auto api_error_result = migraphx::try_([&] { + if(options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); + *out = allocate(migraphx::parse_tf((name), (options->object))); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names) +{ + auto api_error_result = migraphx::try_([&] { destroy((quantize_op_names)); }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output, + const_migraphx_quantize_op_names_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names) +{ + auto api_error_result = migraphx::try_([&] { + *quantize_op_names = + object_cast(allocate>()); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, const char* name) +{ + auto api_error_result = migraphx::try_([&] { + if(quantize_op_names == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, + "Bad parameter quantize_op_names: Null pointer"); + (quantize_op_names->object).push_back((name)); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog, + migraphx_quantize_op_names_t name) +{ + auto api_error_result = migraphx::try_([&] { + if(prog == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer"); + if(name == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer"); + migraphx::quantize_fp16_with_op_names((prog->object), (name->object)); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_quantize_fp16(migraphx_program_t prog) +{ + auto api_error_result = migraphx::try_([&] { + if(prog == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer"); + migraphx::quantize_fp16((prog->object)); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options) +{ + auto api_error_result = migraphx::try_([&] { destroy((quantize_int8_options)); }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output, + const_migraphx_quantize_int8_options_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options) +{ + auto api_error_result = migraphx::try_([&] { + *quantize_int8_options = object_cast( + allocate()); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options, + const char* name) +{ + auto api_error_result = migraphx::try_([&] { + if(quantize_int8_options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, + "Bad parameter quantize_int8_options: Null pointer"); + migraphx::add_op_name((quantize_int8_options->object), (name)); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data( + migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data) +{ + auto api_error_result = migraphx::try_([&] { + if(quantize_int8_options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, + "Bad parameter quantize_int8_options: Null pointer"); + if(data == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter data: Null pointer"); + migraphx::add_calibration_data((quantize_int8_options->object), (data->object)); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog, + migraphx_target_t target, + migraphx_quantize_int8_options_t options) +{ + auto api_error_result = migraphx::try_([&] { + if(prog == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer"); + if(target == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter target: Null pointer"); + if(options == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer"); + migraphx::quantize_int8_wrap((prog->object), (target->object), (options->object)); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t context) +{ + auto api_error_result = migraphx::try_([&] { + if(context == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter context: Null pointer"); + (context->object).finish(); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_context_get_queue(void** out, migraphx_context_t context) +{ + auto api_error_result = migraphx::try_([&] { + if(context == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter context: Null pointer"); + *out = (context->object).get_queue().unsafe_get(); + }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op) +{ + auto api_error_result = migraphx::try_([&] { destroy((experimental_custom_op)); }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_experimental_custom_op_assign_to(migraphx_experimental_custom_op_t output, + const_migraphx_experimental_custom_op_t input) +{ + auto api_error_result = migraphx::try_([&] { *output = *input; }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op, + void* obj, + migraphx_experimental_custom_op_copy c, + migraphx_experimental_custom_op_delete d, + const char* name) +{ + auto api_error_result = migraphx::try_([&] { + *experimental_custom_op = + allocate((obj), (c), (d), (name)); + }); + return api_error_result; +} + +extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape( + migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input) +{ + auto api_error_result = migraphx::try_([&] { (obj)->compute_shape_f = (input); }); + return api_error_result; +} + +extern "C" migraphx_status +migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op) +{ + auto api_error_result = migraphx::try_([&] { + if(experimental_custom_op == nullptr) + MIGRAPHX_THROW(migraphx_status_bad_param, + "Bad parameter experimental_custom_op: Null pointer"); + migraphx::register_custom_op((*experimental_custom_op)); + }); + return api_error_result; +} diff --git a/src/api/include/migraphx/migraphx.h b/src/api/include/migraphx/migraphx.h new file mode 100644 index 0000000000000000000000000000000000000000..de397fec2bf4193073ebf8b272dbf966c49951f8 --- /dev/null +++ b/src/api/include/migraphx/migraphx.h @@ -0,0 +1,467 @@ +#ifndef MIGRAPHX_GUARD_C_API_MIGRAPHX_H +#define MIGRAPHX_GUARD_C_API_MIGRAPHX_H + +#include + +// Add new types here +// clang-format off +#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \ + m(bool_type, bool) \ + m(half_type, half) \ + m(float_type, float) \ + m(double_type, double) \ + m(uint8_type, uint8_t) \ + m(int8_type, int8_t) \ + m(uint16_type, uint16_t) \ + m(int16_type, int16_t) \ + m(int32_type, int32_t) \ + m(int64_type, int64_t) \ + m(uint32_type, uint32_t) \ + m(uint64_type, uint64_t) +// clang-format on + +#ifdef __cplusplus +extern "C" { +#endif + +// return code, more to be added later +typedef enum +{ + migraphx_status_success = 0, + migraphx_status_bad_param = 1, + migraphx_status_unknown_target = 3, + migraphx_status_unknown_error = 4, + +} migraphx_status; + +#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x, +/// An enum to represent the different data type inputs +typedef enum +{ + migraphx_shape_tuple_type, + MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) +} migraphx_shape_datatype_t; +#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES + +typedef struct migraphx_shape* migraphx_shape_t; +typedef const struct migraphx_shape* const_migraphx_shape_t; + +typedef struct migraphx_argument* migraphx_argument_t; +typedef const struct migraphx_argument* const_migraphx_argument_t; + +typedef struct migraphx_target* migraphx_target_t; +typedef const struct migraphx_target* const_migraphx_target_t; + +typedef struct migraphx_program_parameter_shapes* migraphx_program_parameter_shapes_t; +typedef const struct migraphx_program_parameter_shapes* const_migraphx_program_parameter_shapes_t; + +typedef struct migraphx_program_parameters* migraphx_program_parameters_t; +typedef const struct migraphx_program_parameters* const_migraphx_program_parameters_t; + +typedef struct migraphx_arguments* migraphx_arguments_t; +typedef const struct migraphx_arguments* const_migraphx_arguments_t; + +typedef struct migraphx_shapes* migraphx_shapes_t; +typedef const struct migraphx_shapes* const_migraphx_shapes_t; + +typedef struct migraphx_instruction* migraphx_instruction_t; +typedef const struct migraphx_instruction* const_migraphx_instruction_t; + +typedef struct migraphx_instructions* migraphx_instructions_t; +typedef const struct migraphx_instructions* const_migraphx_instructions_t; + +typedef struct migraphx_modules* migraphx_modules_t; +typedef const struct migraphx_modules* const_migraphx_modules_t; + +typedef struct migraphx_module* migraphx_module_t; +typedef const struct migraphx_module* const_migraphx_module_t; + +typedef struct migraphx_program* migraphx_program_t; +typedef const struct migraphx_program* const_migraphx_program_t; + +typedef struct migraphx_operation* migraphx_operation_t; +typedef const struct migraphx_operation* const_migraphx_operation_t; + +typedef struct migraphx_onnx_options* migraphx_onnx_options_t; +typedef const struct migraphx_onnx_options* const_migraphx_onnx_options_t; + +typedef struct migraphx_file_options* migraphx_file_options_t; +typedef const struct migraphx_file_options* const_migraphx_file_options_t; + +typedef struct migraphx_compile_options* migraphx_compile_options_t; +typedef const struct migraphx_compile_options* const_migraphx_compile_options_t; + +typedef struct migraphx_tf_options* migraphx_tf_options_t; +typedef const struct migraphx_tf_options* const_migraphx_tf_options_t; + +typedef struct migraphx_quantize_op_names* migraphx_quantize_op_names_t; +typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_names_t; + +typedef struct migraphx_quantize_int8_options* migraphx_quantize_int8_options_t; +typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int8_options_t; + +typedef struct migraphx_context* migraphx_context_t; +typedef const struct migraphx_context* const_migraphx_context_t; + +typedef struct migraphx_experimental_custom_op* migraphx_experimental_custom_op_t; +typedef const struct migraphx_experimental_custom_op* const_migraphx_experimental_custom_op_t; + +typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out, + void* obj, + migraphx_shapes_t inputs); + +typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input); + +typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input); + +migraphx_status migraphx_shape_destroy(migraphx_shape_t shape); + +migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input); + +migraphx_status migraphx_shape_create(migraphx_shape_t* shape, + migraphx_shape_datatype_t type, + size_t* lengths, + size_t lengths_size); + +migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape, + migraphx_shape_datatype_t type, + size_t* lengths, + size_t lengths_size, + size_t* strides, + size_t strides_size); + +migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape, + migraphx_shape_datatype_t type); + +migraphx_status +migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape); + +migraphx_status +migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape); + +migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, const_migraphx_shape_t shape); + +migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape); + +migraphx_status +migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x); + +migraphx_status migraphx_argument_destroy(migraphx_argument_t argument); + +migraphx_status migraphx_argument_assign_to(migraphx_argument_t output, + const_migraphx_argument_t input); + +migraphx_status +migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer); + +migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out, + const_migraphx_argument_t argument); + +migraphx_status migraphx_argument_buffer(char** out, const_migraphx_argument_t argument); + +migraphx_status +migraphx_argument_equal(bool* out, const_migraphx_argument_t argument, const_migraphx_argument_t x); + +migraphx_status +migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, size_t seed); + +migraphx_status migraphx_target_destroy(migraphx_target_t target); + +migraphx_status migraphx_target_assign_to(migraphx_target_t output, const_migraphx_target_t input); + +migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name); + +migraphx_status migraphx_program_parameter_shapes_destroy( + migraphx_program_parameter_shapes_t program_parameter_shapes); + +migraphx_status +migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output, + const_migraphx_program_parameter_shapes_t input); + +migraphx_status migraphx_program_parameter_shapes_size( + size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes); + +migraphx_status +migraphx_program_parameter_shapes_get(const_migraphx_shape_t* out, + migraphx_program_parameter_shapes_t program_parameter_shapes, + const char* name); + +migraphx_status migraphx_program_parameter_shapes_names( + const char** out, migraphx_program_parameter_shapes_t program_parameter_shapes); + +migraphx_status +migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters); + +migraphx_status migraphx_program_parameters_assign_to(migraphx_program_parameters_t output, + const_migraphx_program_parameters_t input); + +migraphx_status +migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters); + +migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t program_parameters, + const char* name, + const_migraphx_argument_t argument); + +migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments); + +migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output, + const_migraphx_arguments_t input); + +migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments); + +migraphx_status +migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t arguments, size_t idx); + +migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes); + +migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output, const_migraphx_shapes_t input); + +migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes); + +migraphx_status +migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx); + +migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction); + +migraphx_status migraphx_instruction_assign_to(migraphx_instruction_t output, + const_migraphx_instruction_t input); + +migraphx_status migraphx_instructions_destroy(migraphx_instructions_t instructions); + +migraphx_status migraphx_instructions_assign_to(migraphx_instructions_t output, + const_migraphx_instructions_t input); + +migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions, + const_migraphx_instruction_t* ptr, + size_t size); + +migraphx_status migraphx_modules_destroy(migraphx_modules_t modules); + +migraphx_status migraphx_modules_assign_to(migraphx_modules_t output, + const_migraphx_modules_t input); + +migraphx_status +migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size); + +migraphx_status migraphx_module_create(migraphx_module_t* module, char* name); + +migraphx_status migraphx_module_print(const_migraphx_module_t module); + +migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out, + migraphx_module_t module, + migraphx_operation_t op, + migraphx_instructions_t args); + +migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out, + migraphx_module_t module, + migraphx_operation_t op, + migraphx_instructions_t args, + migraphx_modules_t module_refs); + +migraphx_status migraphx_module_add_literal(migraphx_instruction_t* out, + migraphx_module_t module, + const_migraphx_shape_t shape, + const char* buffer); + +migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out, + migraphx_module_t module, + const char* name, + const_migraphx_shape_t shape); + +migraphx_status migraphx_module_add_return(migraphx_instruction_t* out, + migraphx_module_t module, + migraphx_instructions_t args); + +migraphx_status migraphx_program_destroy(migraphx_program_t program); + +migraphx_status migraphx_program_assign_to(migraphx_program_t output, + const_migraphx_program_t input); + +migraphx_status migraphx_program_create(migraphx_program_t* program); + +migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, + migraphx_program_t program); + +migraphx_status migraphx_program_create_module(migraphx_module_t* out, + migraphx_program_t program, + const char* name); + +migraphx_status migraphx_program_compile(migraphx_program_t program, + migraphx_target_t target, + migraphx_compile_options_t options); + +migraphx_status migraphx_program_get_parameter_shapes(migraphx_program_parameter_shapes_t* out, + migraphx_program_t program); + +migraphx_status migraphx_program_get_output_shapes(migraphx_shapes_t* out, + migraphx_program_t program); + +migraphx_status migraphx_program_print(const_migraphx_program_t program); + +migraphx_status migraphx_program_sort(migraphx_program_t program); + +migraphx_status migraphx_program_run(migraphx_arguments_t* out, + migraphx_program_t program, + migraphx_program_parameters_t params); + +migraphx_status +migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x); + +migraphx_status migraphx_program_experimental_get_context(migraphx_context_t* out, + const_migraphx_program_t program); + +migraphx_status migraphx_operation_destroy(migraphx_operation_t operation); + +migraphx_status migraphx_operation_assign_to(migraphx_operation_t output, + const_migraphx_operation_t input); + +migraphx_status migraphx_operation_create(migraphx_operation_t* operation, + const char* name, + const char* attributes, + ...); + +migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation); + +migraphx_status +migraphx_load(migraphx_program_t* out, const char* name, migraphx_file_options_t options); + +migraphx_status +migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t options); + +migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options); + +migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output, + const_migraphx_onnx_options_t input); + +migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options); + +migraphx_status migraphx_onnx_options_set_input_parameter_shape( + migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size); + +migraphx_status migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, + size_t value); + +migraphx_status +migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options, + int64_t value); + +migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options); + +migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output, + const_migraphx_file_options_t input); + +migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options); + +migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t file_options, + const char* format); + +migraphx_status migraphx_compile_options_destroy(migraphx_compile_options_t compile_options); + +migraphx_status migraphx_compile_options_assign_to(migraphx_compile_options_t output, + const_migraphx_compile_options_t input); + +migraphx_status migraphx_compile_options_create(migraphx_compile_options_t* compile_options); + +migraphx_status +migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_options, bool value); + +migraphx_status migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, + bool value); + +migraphx_status +migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options); + +migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out, + const void* data, + size_t size, + migraphx_onnx_options_t options); + +migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options); + +migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output, + const_migraphx_tf_options_t input); + +migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options); + +migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, bool is_nhwc); + +migraphx_status migraphx_tf_options_set_input_parameter_shape(migraphx_tf_options_t tf_options, + const char* name, + size_t* dims, + size_t dims_size); + +migraphx_status migraphx_tf_options_set_default_dim_value(migraphx_tf_options_t tf_options, + size_t value); + +migraphx_status migraphx_tf_options_set_output_names(migraphx_tf_options_t tf_options, + const char** names, + size_t names_size); + +migraphx_status +migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options_t options); + +migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names); + +migraphx_status migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output, + const_migraphx_quantize_op_names_t input); + +migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names); + +migraphx_status migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, + const char* name); + +migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog, + migraphx_quantize_op_names_t name); + +migraphx_status migraphx_quantize_fp16(migraphx_program_t prog); + +migraphx_status +migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options); + +migraphx_status +migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output, + const_migraphx_quantize_int8_options_t input); + +migraphx_status +migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options); + +migraphx_status +migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options, + const char* name); + +migraphx_status migraphx_quantize_int8_options_add_calibration_data( + migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data); + +migraphx_status migraphx_quantize_int8(migraphx_program_t prog, + migraphx_target_t target, + migraphx_quantize_int8_options_t options); + +migraphx_status migraphx_context_finish(const_migraphx_context_t context); + +migraphx_status migraphx_context_get_queue(void** out, migraphx_context_t context); + +migraphx_status +migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op); + +migraphx_status +migraphx_experimental_custom_op_assign_to(migraphx_experimental_custom_op_t output, + const_migraphx_experimental_custom_op_t input); + +migraphx_status +migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op, + void* obj, + migraphx_experimental_custom_op_copy c, + migraphx_experimental_custom_op_delete d, + const char* name); + +migraphx_status migraphx_experimental_custom_op_set_compute_shape( + migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input); + +migraphx_status +migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/src/api/include/migraphx/migraphx.hpp b/src/api/include/migraphx/migraphx.hpp new file mode 100644 index 0000000000000000000000000000000000000000..31b768c0e8c24b497073c98b143665dc0aa8ce54 --- /dev/null +++ b/src/api/include/migraphx/migraphx.hpp @@ -0,0 +1,1186 @@ +#ifndef MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP +#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP + +#include "migraphx.h" +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +#ifndef DOXYGEN +inline namespace api { // NOLINT +#endif + +#ifdef __has_cpp_attribute +#if __has_cpp_attribute(deprecated) +#define MIGRAPHX_DEPRECATED(...) [[deprecated(__VA_ARGS__)]] +#endif +#endif + +#ifndef MIGRAPHX_DEPRECATED +#define MIGRAPHX_DEPRECATED(...) +#endif + +template +struct rank : rank +{ +}; + +template <> +struct rank<0> +{ +}; + +template +T* make(F f, Ts&&... xs) +{ + T* result = nullptr; + auto e = f(&result, std::forward(xs)...); + if(e != migraphx_status_success) + throw std::runtime_error("Failed to call function"); + return result; +} + +template +void call(F f, Ts&&... xs) +{ + auto e = f(std::forward(xs)...); + if(e != migraphx_status_success) + throw std::runtime_error("Failed to call function"); +} + +template +struct iota_iterator +{ + Iterator index; + F f; + + using difference_type = std::ptrdiff_t; + using reference = decltype(f(std::declval())); + using value_type = typename std::remove_reference::type; + using pointer = typename std::add_pointer::type; + using iterator_category = std::input_iterator_tag; + + iota_iterator& operator+=(int n) + { + index += n; + return *this; + } + + iota_iterator& operator-=(int n) + { + index += n; + return *this; + } + + iota_iterator& operator++() + { + index++; + return *this; + } + + iota_iterator& operator--() + { + index--; + return *this; + } + + iota_iterator operator++(int) // NOLINT + { + iota_iterator it = *this; + index++; + return it; + } + + iota_iterator operator--(int) // NOLINT + { + iota_iterator it = *this; + index--; + return it; + } + // TODO: operator-> + reference operator*() const { return f(index); } + + friend iota_iterator operator+(iota_iterator x, iota_iterator y) + { + return iota_iterator(x.index + y.index, x.f); + } + + friend iota_iterator operator-(iota_iterator x, iota_iterator y) + { + return iota_iterator(x.index - y.index, x.f); + } + + friend bool operator==(iota_iterator x, iota_iterator y) { return x.index == y.index; } + + friend bool operator!=(iota_iterator x, iota_iterator y) { return x.index != y.index; } +}; + +template +struct array_base +{ + const Derived& derived() const { return static_cast(*this); } + + template + using value_type_t = decltype(std::declval()[0]); + + struct iterator_read + { + const Derived* self; + template + value_type_t operator()(size_t pidx) const + { + return (*self)[pidx]; + } + }; + + template + using iterator_t = iota_iterator; + + bool empty() const { return derived().size() == 0; } + + template + value_type_t front() const + { + return derived()[0]; + } + + template + value_type_t back() const + { + return derived()[derived().size() - 1]; + } + + template + iterator_t begin() const + { + return {0, {&derived()}}; + } + + template + iterator_t end() const + { + return {derived().size(), {&derived()}}; + } +}; + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnon-template-friend" +#endif + +template +struct holder +{ + // Friend injection + friend auto migraphx_adl_handle_lookup(holder); + // Function left unimplemented since its only used in non-evaluated + // context + T get() const; +}; + +template +struct handle_lookup +{ + friend auto migraphx_adl_handle_lookup(holder) { return holder{}; } +}; + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif + +template +using as_handle = decltype( + migraphx_adl_handle_lookup(holder>>{}).get()); + +struct own +{ +}; +struct borrow +{ +}; + +template +struct share +{ + share(std::shared_ptr p) : ptr(std::move(p)) {} + + template + std::shared_ptr alias(U* p) const + { + return std::shared_ptr{ptr, p}; + } + + private: + std::shared_ptr ptr; +}; + +template +struct handle_base : handle_lookup> +{ + using handle_type = T; + handle_base() : m_handle(nullptr) {} + template + void make_handle(F f, Ts&&... xs) + { + using type = typename std::remove_cv::type; + set_handle(make(f, std::forward(xs)...), own{}); + } + + const std::shared_ptr& get_handle() const { return m_handle; } + + T* get_handle_ptr() const + { + assert(m_handle != nullptr); + return get_handle().get(); + } + + template + void set_handle(U* ptr, own) + { + m_handle = std::shared_ptr{ptr, Deleter}; + } + + template + void set_handle(U* ptr, borrow) + { + m_handle = std::shared_ptr{ptr, [](U*) {}}; + } + + template + void set_handle(U* ptr, share b) + { + m_handle = std::shared_ptr{ptr, [b](U*) {}}; + } + + share share_handle() const { return {m_handle}; } + + template + void assign_to_handle(U* x) + { + Assigner(x, this->get_handle_ptr()); + } + + protected: + std::shared_ptr m_handle; +}; + +// NOLINTNEXTLINE +#define MIGRAPHX_HANDLE_CONSTRUCTOR(name) \ + template {}>::type> \ + name(HandleType* p, Lifetime lifetime) \ + { \ + this->set_handle(p, std::move(lifetime)); \ + } + +template +struct interface_base : Base +{ + interface_base() : Base() {} + + protected: + template + static migraphx_status try_(F f) // NOLINT + { + try + { + f(); + return migraphx_status_success; + } + catch(...) + { + return migraphx_status_unknown_error; + } + } + + template + void make_interface(F f, T& obj, Ts&&... xs) + { + auto copy = [](void** out, void* input) { + return try_([&] { + T** y = reinterpret_cast(out); + T* x = reinterpret_cast(input); + assert(x != nullptr and y != nullptr and *y == nullptr); + // cppcheck-suppress useSmartPointer + *y = new T(*x); // NOLINT + }); + }; + auto del = [](void* input) { + return try_([&] { + T* x = reinterpret_cast(input); + delete x; // NOLINT + }); + }; + this->make_handle(f, &obj, copy, del, std::forward(xs)...); + } + + template + void set_fp(Setter setter, F pf) + { + static F f = pf; + (void)f; // avoid warning on gcc + call(setter, this->get_handle_ptr(), [](auto... xs) -> migraphx_status { + return try_([&] { call_cast_arg(rank<1>{}, f, xs...); }); + }); + } + + template + void set_auto_fp(Setter setter, F f) + { + return set_fp(setter, [=](T& obj, auto out, auto... xs) { + auto_invoke(f, out, obj, auto_convert_param(rank<2>{}, xs)...); + }); + } + + struct no_out_arg + { + }; + + template {}>> + static void call_cast_arg(rank<0>, F f, X* obj, Xs... xs) + { + f(reinterpret_cast(obj), no_out_arg{}, xs...); + } + + template {}>> + static void call_cast_arg(rank<1>, F f, R result, X* obj, Xs... xs) + { + f(*reinterpret_cast(obj), result, xs...); + } + + template + void auto_invoke(F f, T* out, Ts&&... xs) + { + auto_assign(rank<2>{}, out, f(std::forward(xs)...)); + } + + template + void auto_invoke(F f, no_out_arg, Ts&&... xs) + { + f(std::forward(xs)...); + } + + template {} or std::is_enum{}>> + T auto_convert_param(rank<0>, T x) + { + return x; + } + + template + auto auto_convert_param(rank<1>, T x) -> decltype(as_handle{x}) + { + return as_handle{x}; + } + + template + auto auto_convert_param(rank<2>, T x) -> decltype(as_handle{x, borrow{}}) + { + return as_handle{x, borrow{}}; + } + + template + void auto_assign(rank<0>, T* out, U x) + { + return *out = x; + } + + template + auto auto_assign(rank<1>, T* out, U x) -> decltype(x.assign_to_handle(out)) + { + x.assign_to_handle(out); + } +}; + +// NOLINTNEXTLINE +#define MIGRAPHX_INTERFACE_LIFT(T, prefix, name) \ + this->set_auto_fp(&migraphx_##prefix##_set_##name, \ + [](T& x, auto... xs) { return x.name(xs...); }) + +template +using require_interface = + std::enable_if_t{} and not std::is_same{} and + std::is_copy_constructible{} and std::is_final{}>; + +#ifdef DOXYGEN +#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<> +#else +#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \ + handle_base +#endif +// NOLINTNEXTLINE +#define MIGRAPHX_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, ) +// NOLINTNEXTLINE +#define MIGRAPHX_CONST_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, const) + +/** + * @brief Describe shape of tensor + * @details A shape consists of a data type, lengths of multi-dimension tensor, and strides + * + */ +struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) +{ + shape() {} + + MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") + shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); } + + MIGRAPHX_HANDLE_CONSTRUCTOR(shape); + + /// Construct a scalar shape + shape(migraphx_shape_datatype_t type) + { + this->make_handle(&migraphx_shape_create_scalar, type); + } + + /// Construct a shape with its type and lengths. The strides are + /// automatically computed assumming a packed layout. + shape(migraphx_shape_datatype_t type, std::vector plengths) + { + this->make_handle(&migraphx_shape_create, type, plengths.data(), plengths.size()); + } + + shape(migraphx_shape_datatype_t type, + std::vector plengths, + std::vector pstrides) + { + this->make_handle(&migraphx_shape_create_with_strides, + type, + plengths.data(), + plengths.size(), + pstrides.data(), + pstrides.size()); + } + + std::vector lengths() const + { + const size_t* pout; + size_t pout_size; + call(&migraphx_shape_lengths, &pout, &pout_size, this->get_handle_ptr()); + return {pout, pout + pout_size}; + } + + std::vector strides() const + { + const size_t* pout; + size_t pout_size; + call(&migraphx_shape_strides, &pout, &pout_size, this->get_handle_ptr()); + return {pout, pout + pout_size}; + } + + migraphx_shape_datatype_t type() const + { + migraphx_shape_datatype_t pout; + call(&migraphx_shape_type, &pout, this->get_handle_ptr()); + return pout; + } + + size_t bytes() const + { + size_t pout; + call(&migraphx_shape_bytes, &pout, this->get_handle_ptr()); + return pout; + } + + friend bool operator==(const shape& px, const shape& py) + { + bool pout; + call(&migraphx_shape_equal, &pout, px.get_handle_ptr(), py.get_handle_ptr()); + return pout; + } + + friend bool operator!=(const shape& px, const shape& py) { return !(px == py); } +}; + +/** + * @brief Arguments to be passed to an migraphx arguments + * + * An `argument` represents a raw buffer of data with a shape. + * + */ +struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) +{ + argument() {} + + MIGRAPHX_HANDLE_CONSTRUCTOR(argument); + + MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") + argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); } + + argument(shape pshape, void* pbuffer) + { + this->make_handle(&migraphx_argument_create, pshape.get_handle_ptr(), pbuffer); + } + + shape get_shape() const + { + const_migraphx_shape_t pout; + call(&migraphx_argument_shape, &pout, this->get_handle_ptr()); + return {pout, this->share_handle()}; + } + + char* data() const + { + char* pout; + call(&migraphx_argument_buffer, &pout, this->get_handle_ptr()); + return pout; + } + + /// Generate an argument using random data + static argument generate(shape ps, size_t pseed = 0) + { + return {make(&migraphx_argument_generate, ps.get_handle_ptr(), pseed), + own{}}; + } + + friend bool operator==(const argument& px, const argument& py) + { + bool pout; + call(&migraphx_argument_equal, &pout, px.get_handle_ptr(), py.get_handle_ptr()); + return pout; + } + + friend bool operator!=(const argument& px, const argument& py) { return !(px == py); } +}; + +/// A target for compilation +struct target : MIGRAPHX_HANDLE_BASE(target) +{ + target() {} + + MIGRAPHX_HANDLE_CONSTRUCTOR(target); + + /// Construct a target from its name + target(const char* name) { this->make_handle(&migraphx_target_create, name); } +}; + +struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) +{ + program_parameter_shapes() {} + + MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes); + + size_t size() const + { + size_t pout; + call(&migraphx_program_parameter_shapes_size, &pout, this->get_handle_ptr()); + return pout; + } + + shape operator[](const char* pname) const + { + const_migraphx_shape_t pout; + call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname); + return {pout, this->share_handle()}; + } + + std::vector names() const + { + std::vector result(this->size()); + if(!result.empty()) + { + call(&migraphx_program_parameter_shapes_names, result.data(), this->get_handle_ptr()); + } + return result; + } +}; + +/// A class to construct the inputs parameters for a program +struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) +{ + MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters); + + MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") + program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); } + + program_parameters() { this->make_handle(&migraphx_program_parameters_create); } + + /// Construct the parameters from initializer_list + program_parameters(std::initializer_list> l) + { + this->make_handle(&migraphx_program_parameters_create); + for(auto&& p : l) + this->add(p.first.c_str(), p.second); + } + + /// Add a new parameter + void add(const char* pname, const argument& pargument) const + { + call(&migraphx_program_parameters_add, + this->get_handle_ptr(), + pname, + pargument.get_handle_ptr()); + } +}; + +struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base +{ + MIGRAPHX_HANDLE_CONSTRUCTOR(arguments); + + size_t size() const + { + size_t pout; + call(&migraphx_arguments_size, &pout, this->get_handle_ptr()); + return pout; + } + + argument operator[](size_t pidx) const + { + const_migraphx_argument_t pout; + call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx); + return {pout, this->share_handle()}; + } +}; + +struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base +{ + MIGRAPHX_HANDLE_CONSTRUCTOR(shapes); + + size_t size() const + { + size_t pout; + call(&migraphx_shapes_size, &pout, this->get_handle_ptr()); + return pout; + } + + shape operator[](size_t pidx) const + { + const_migraphx_shape_t pout; + call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx); + return {pout, this->share_handle()}; + } +}; + +struct operation : MIGRAPHX_HANDLE_BASE(operation) +{ + MIGRAPHX_HANDLE_CONSTRUCTOR(operation); + + template + operation(const char* name, const char* attributes = nullptr, Ts... xs) + { + this->make_handle(&migraphx_operation_create, name, attributes, xs...); + } + + std::string name() + { + std::array out_name; + call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr()); + return {out_name.data()}; + } +}; + +struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction) +{ + MIGRAPHX_HANDLE_CONSTRUCTOR(instruction); +}; + +struct instructions : MIGRAPHX_HANDLE_BASE(instructions) +{ + MIGRAPHX_HANDLE_CONSTRUCTOR(instructions); + + template + instructions(Ts... xs) + { + std::array a{xs.get_handle_ptr()...}; + this->make_handle(&migraphx_instructions_create, a.data(), a.size()); + } +}; + +struct module; + +struct modules : MIGRAPHX_HANDLE_BASE(modules) +{ + MIGRAPHX_HANDLE_CONSTRUCTOR(modules); + + template + modules(Ts... xs) + { + std::array a = {xs.get_handle_ptr()...}; + this->make_handle(&migraphx_modules_create, a.data(), a.size()); + } +}; + +struct module +{ + MIGRAPHX_DEPRECATED("Constructor without lifetime annotation is deprecated.") + module(migraphx_module* m) : mm(std::shared_ptr(), m) {} + + module(migraphx_module* m, borrow) : mm(std::shared_ptr(), m) {} + + template + module(migraphx_module* m, share b) : mm(b.alias(m)) + { + } + + void print() const { call(&migraphx_module_print, mm.get()); } + + instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args) + { + migraphx_instruction_t op_ins; + call(&migraphx_module_add_instruction, + &op_ins, + mm.get(), + op.get_handle_ptr(), + args.get_handle_ptr()); + return instruction(op_ins, own{}); + } + + instruction add_instruction(const migraphx::operation& op, + const migraphx::instructions& args, + const migraphx::modules& module_args) + { + migraphx_instruction_t op_ins; + call(&migraphx_module_add_instruction_with_mod_args, + &op_ins, + mm.get(), + op.get_handle_ptr(), + args.get_handle_ptr(), + module_args.get_handle_ptr()); + return instruction(op_ins, own{}); + } + + template + instruction add_literal(const migraphx::shape& s, T* buffer) + { + migraphx_instruction_t literal_ins; + const auto* buffer_ptr = reinterpret_cast(buffer); + call(&migraphx_module_add_literal, &literal_ins, mm.get(), s.get_handle_ptr(), buffer_ptr); + return instruction(literal_ins, own{}); + } + + instruction add_parameter(const std::string& name, shape s) + { + migraphx_instruction_t param_ins; + call( + &migraphx_module_add_parameter, ¶m_ins, mm.get(), name.c_str(), s.get_handle_ptr()); + return instruction(param_ins, own{}); + } + + instruction add_return(const migraphx::instructions& args) + { + migraphx_instruction_t ret_ins; + call(&migraphx_module_add_return, &ret_ins, mm.get(), args.get_handle_ptr()); + return instruction(ret_ins, own{}); + } + + migraphx_module_t get_handle_ptr() const { return mm.get(); } + + private: + std::shared_ptr mm; +}; + +struct context +{ + context(migraphx_context* p, borrow) : ctx(std::shared_ptr(), p) {} + + template + context(migraphx_context* p, share b) : ctx(b.alias(p)) + { + } + + void finish() const { call(&migraphx_context_finish, ctx.get()); } + + template + T get_queue() + { + void* out; + call(&migraphx_context_get_queue, &out, ctx.get()); + // TODO: check type here + return reinterpret_cast(out); + } + + private: + std::shared_ptr ctx; +}; + +struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options) +{ + compile_options() { this->make_handle(&migraphx_compile_options_create); } + + MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options); + + /// For targets with offloaded memory(such as the gpu), this will insert + /// instructions during compilation to copy the input parameters to the + /// offloaded memory and to copy the final result from the offloaded + /// memory back to main memory. + void set_offload_copy(bool value = true) + { + call(&migraphx_compile_options_set_offload_copy, this->get_handle_ptr(), value); + } + + /// Optimize math functions to use faster approximate versions. There may + /// be slight accuracy degredation when enabled. + void set_fast_math(bool value = true) + { + call(&migraphx_compile_options_set_fast_math, this->get_handle_ptr(), value); + } +}; + +/// A program represents the all computation graphs to be compiled and executed +struct program : MIGRAPHX_HANDLE_BASE(program) +{ + program() { this->make_handle(&migraphx_program_create); } + + MIGRAPHX_HANDLE_CONSTRUCTOR(program); + + /// Compile the program for a specific target to be ran on + void compile(const target& ptarget, const compile_options& poptions) const + { + call(&migraphx_program_compile, + this->get_handle_ptr(), + ptarget.get_handle_ptr(), + poptions.get_handle_ptr()); + } + + /// Compile the program for a specific target to be ran on + void compile(const target& ptarget) const + { + call(&migraphx_program_compile, + this->get_handle_ptr(), + ptarget.get_handle_ptr(), + migraphx::compile_options{}.get_handle_ptr()); + } + + /// Return the shapes for the input parameters + program_parameter_shapes get_parameter_shapes() const + { + migraphx_program_parameter_shapes_t pout; + call(&migraphx_program_get_parameter_shapes, &pout, this->get_handle_ptr()); + return program_parameter_shapes(pout, own{}); + } + + /// Get the shapes of all the outputs returned by this program + shapes get_output_shapes() const + { + migraphx_shapes_t pout; + call(&migraphx_program_get_output_shapes, &pout, this->get_handle_ptr()); + return shapes(pout, own{}); + } + + /// Run the program using the inputs passed in + arguments eval(const program_parameters& pparams) const + { + migraphx_arguments_t pout; + call(&migraphx_program_run, &pout, this->get_handle_ptr(), pparams.get_handle_ptr()); + return arguments(pout, own{}); + } + + void print() const { call(&migraphx_program_print, this->get_handle_ptr()); } + + program sort() + { + call(&migraphx_program_sort, this->get_handle_ptr()); + return *this; + } + + friend bool operator==(const program& px, const program& py) + { + bool pout; + call(&migraphx_program_equal, &pout, px.get_handle_ptr(), py.get_handle_ptr()); + return pout; + } + + module get_main_module() + { + migraphx_module_t p_modu; + call(&migraphx_program_get_main_module, &p_modu, this->get_handle_ptr()); + return module{p_modu, this->share_handle()}; + } + + context experimental_get_context() + { + migraphx_context_t ctx; + call(&migraphx_program_experimental_get_context, &ctx, this->get_handle_ptr()); + return context{ctx, this->share_handle()}; + } + + module create_module(const std::string& name) + { + migraphx_module_t p_modu; + call(&migraphx_program_create_module, &p_modu, this->get_handle_ptr(), name.data()); + return module{p_modu, this->share_handle()}; + } + + friend bool operator!=(const program& px, const program& py) { return !(px == py); } +}; + +// options for migraphx file format options +struct file_options : MIGRAPHX_HANDLE_BASE(file_options) +{ + MIGRAPHX_HANDLE_CONSTRUCTOR(file_options); + file_options() { this->make_handle(&migraphx_file_options_create); } + + // set file format + void set_file_format(const char* format) + { + call(&migraphx_file_options_set_file_format, this->get_handle_ptr(), format); + } +}; + +/// Load a saved migraphx program from a file +inline program load(const char* filename, const file_options& options) +{ + return program(make(&migraphx_load, filename, options.get_handle_ptr()), + own{}); +} + +/// Load a saved migraphx program from a file +inline program load(const char* filename) +{ + return program( + make(&migraphx_load, filename, migraphx::file_options{}.get_handle_ptr()), + own{}); +} + +/// Save a program to a file +inline void save(const program& p, const char* filename, const file_options& options) +{ + call(&migraphx_save, p.get_handle_ptr(), filename, options.get_handle_ptr()); +} + +/// Save a program to a file +inline void save(const program& p, const char* filename) +{ + call(&migraphx_save, p.get_handle_ptr(), filename, migraphx::file_options{}.get_handle_ptr()); +} + +/// Options for parsing onnx options +struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) +{ + onnx_options() { this->make_handle(&migraphx_onnx_options_create); } + + MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options); + + /// Make onnx parser treat an inputs with a certain dimensions + void set_input_parameter_shape(const std::string& name, std::vector dim) + { + call(&migraphx_onnx_options_set_input_parameter_shape, + this->get_handle_ptr(), + name.c_str(), + dim.data(), + dim.size()); + } + + /// When there is a dimension parameter, then use this default value + void set_default_dim_value(unsigned int value) + { + call(&migraphx_onnx_options_set_default_dim_value, this->get_handle_ptr(), value); + } + + /// Set default max iteration number for the loop operator + void set_default_loop_iterations(int64_t value) + { + call(&migraphx_onnx_options_set_default_loop_iterations, this->get_handle_ptr(), value); + } +}; + +/// Parse an onnx file into a migraphx program +inline program parse_onnx(const char* filename, const migraphx::onnx_options& options) +{ + return program(make(&migraphx_parse_onnx, filename, options.get_handle_ptr()), + own{}); +} + +/// Parse an onnx file into a migraphx program +inline program parse_onnx(const char* filename) +{ + migraphx::onnx_options options; + return program(make(&migraphx_parse_onnx, filename, options.get_handle_ptr()), + own{}); +} + +/// Parse a buffer of memory as an onnx file +inline program +parse_onnx_buffer(const void* data, size_t size, const migraphx::onnx_options& options) +{ + return program( + make(&migraphx_parse_onnx_buffer, data, size, options.get_handle_ptr()), + own{}); +} + +/// Parse a buffer of memory as an onnx file +inline program parse_onnx_buffer(const void* data, size_t size) +{ + migraphx::onnx_options options; + return program( + make(&migraphx_parse_onnx_buffer, data, size, options.get_handle_ptr()), + own{}); +} + +/// Parse a buffer of memory as an onnx file +inline program parse_onnx_buffer(const std::string& buffer, const migraphx::onnx_options& options) +{ + return program( + make( + &migraphx_parse_onnx_buffer, buffer.data(), buffer.size(), options.get_handle_ptr()), + own{}); +} + +/// Parse a buffer of memory as an onnx file +inline program parse_onnx_buffer(const std::string& buffer) +{ + migraphx::onnx_options options; + return program( + make( + &migraphx_parse_onnx_buffer, buffer.data(), buffer.size(), options.get_handle_ptr()), + own{}); +} + +/// Options for parsing tf options +struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options) +{ + tf_options() { this->make_handle(&migraphx_tf_options_create); } + + MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options); + + /// Make tf parser treat an inputs with a certain dimensions + void set_input_parameter_shape(const std::string& name, std::vector dim) + { + call(&migraphx_tf_options_set_input_parameter_shape, + this->get_handle_ptr(), + name.c_str(), + dim.data(), + dim.size()); + } + + /// Change data layout to NHWC (default is NCHW) + void set_nhwc(bool is_nhwc = true) + { + call(&migraphx_tf_options_set_nhwc, this->get_handle_ptr(), is_nhwc); + } + + /// When there is a dimension parameter, then use this default value + void set_default_dim_value(unsigned int value) + { + call(&migraphx_tf_options_set_default_dim_value, this->get_handle_ptr(), value); + } + + /// Set output node names to return specific outputs from graph + void set_output_names(std::vector names) + { + call(&migraphx_tf_options_set_output_names, + this->get_handle_ptr(), + names.data(), + names.size()); + } +}; + +/// Parse a tf file into a migraphx program +inline program parse_tf(const char* filename, const migraphx::tf_options& options) +{ + return program(make(&migraphx_parse_tf, filename, options.get_handle_ptr()), + own{}); +} + +/// Parse a tf file into a migraphx program +inline program parse_tf(const char* filename) +{ + migraphx::tf_options options; + return program(make(&migraphx_parse_tf, filename, options.get_handle_ptr()), + own{}); +} + +struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names) +{ + quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); } + + MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names); + + void add(const std::string& name) + { + call(&migraphx_quantize_op_names_add, this->get_handle_ptr(), name.c_str()); + } +}; + +/// Quantize program to use fp16 +inline void quantize_fp16(const program& prog, const quantize_op_names& names) +{ + call(&migraphx_quantize_fp16_with_op_names, prog.get_handle_ptr(), names.get_handle_ptr()); +} + +/// Quantize program to use fp16 +inline void quantize_fp16(const program& prog) +{ + call(&migraphx_quantize_fp16, prog.get_handle_ptr()); +} + +/// Options to be passed when quantizing for int8 +struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options) +{ + quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); } + + MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options); + + /// Add an operator that should be quantized + void add_op_name(const std::string& name) + { + call(&migraphx_quantize_int8_options_add_op_name, this->get_handle_ptr(), name.c_str()); + } + + /// Add calibrartion data to be used for quantizing + void add_calibration_data(const program_parameters& pp) + { + call(&migraphx_quantize_int8_options_add_calibration_data, + this->get_handle_ptr(), + pp.get_handle_ptr()); + } +}; + +/// Quantize program to use int8 +inline void +quantize_int8(const program& prog, const target& ptarget, const quantize_int8_options& options) +{ + call(&migraphx_quantize_int8, + prog.get_handle_ptr(), + ptarget.get_handle_ptr(), + options.get_handle_ptr()); +} + +struct experimental_custom_op_base +{ + virtual std::string name() const = 0; + virtual shape compute_shape(shapes inputs) const = 0; + virtual ~experimental_custom_op_base() = default; +}; + +struct experimental_custom_op : interface_base +{ + template + experimental_custom_op(T& obj) + { + this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str()); + MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape); + } + + void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); } +}; + +template > +void register_experimental_custom_op(T& obj) +{ + experimental_custom_op op{obj}; + op.register_op(); +} + +#ifndef DOXYGEN +} // namespace api +#endif +} // namespace migraphx + +#endif diff --git a/src/api/migraphx.py b/src/api/migraphx.py new file mode 100755 index 0000000000000000000000000000000000000000..98b123c037dc7bebaaa12639319d0a3bb97fc2e9 --- /dev/null +++ b/src/api/migraphx.py @@ -0,0 +1,419 @@ +import api + + +def bad_param_error(msg): + return 'MIGRAPHX_THROW(migraphx_status_bad_param, "{}")'.format(msg) + + +api.error_type = 'migraphx_status' +api.success_type = 'migraphx_status_success' +api.try_wrap = 'migraphx::try_' +api.bad_param_error = bad_param_error + + +@api.cwrap('migraphx::shape::type_t') +def shape_type_wrap(p): + if p.returns: + p.add_param('migraphx_shape_datatype_t *') + p.bad_param('${name} == nullptr', 'Null pointer') + p.write = ['*${name} = migraphx::to_shape_type(${result})'] + else: + p.add_param('migraphx_shape_datatype_t') + p.read = 'migraphx::to_shape_type(${name})' + + +@api.cwrap('migraphx::compile_options') +def compile_options_type_wrap(p): + if p.returns: + p.add_param('migraphx_compile_options *') + p.bad_param('${name} == nullptr', 'Null pointer') + p.write = ['*${name} = migraphx::to_compile_options(${result})'] + else: + p.add_param('migraphx_compile_options *') + p.read = '${name} == nullptr ? migraphx::compile_options{} : migraphx::to_compile_options(*${name})' + + +@api.cwrap('migraphx::file_options') +def file_options_type_wrap(p): + if p.returns: + p.add_param('migraphx_file_options *') + p.bad_param('${name} == nullptr', 'Null pointer') + p.write = ['*${name} = migraphx::to_file_options(${result})'] + else: + p.add_param('migraphx_file_options *') + p.read = '${name} == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*${name})' + + +@api.cwrap('migraphx::onnx_options') +def onnx_options_type_wrap(p): + if p.returns: + p.add_param('migraphx_onnx_options *') + p.bad_param('${name} == nullptr', 'Null pointer') + p.write = ['*${name} = migraphx::to_onnx_options(${result})'] + else: + p.add_param('migraphx_onnx_options *') + p.read = '${name} == nullptr ? migraphx::onnx_options{} : migraphx::to_onnx_options(*${name})' + + +@api.cwrap('migraphx::tf_options') +def tf_options_type_wrap(p): + if p.returns: + p.add_param('migraphx_tf_options *') + p.bad_param('${name} == nullptr', 'Null pointer') + p.write = ['*${name} = migraphx::to_tf_options(${result})'] + else: + p.add_param('migraphx_tf_options *') + p.read = '${name} == nullptr ? migraphx::tf_options{} : migraphx::to_tf_options(*${name})' + + +def auto_handle(*args, **kwargs): + def with_handle(f): + return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__, + *args, **kwargs)(f) + + return with_handle + + +@auto_handle() +def shape(h): + h.constructor( + 'create', + api.params(type='migraphx::shape::type_t', + lengths='std::vector')) + h.constructor( + 'create_with_strides', + api.params(type='migraphx::shape::type_t', + lengths='std::vector', + strides='std::vector')) + h.constructor('create_scalar', api.params(type='migraphx::shape::type_t')) + h.method('lengths', + fname='lens', + returns='const std::vector&', + const=True) + h.method('strides', returns='const std::vector&', const=True) + h.method('type', returns='migraphx::shape::type_t', const=True) + h.method('bytes', returns='size_t', const=True) + h.method('equal', + api.params(x='const migraphx::shape&'), + invoke='migraphx::equal($@)', + returns='bool', + const=True) + + +@auto_handle() +def argument(h): + h.constructor('create', + api.params(shape='const migraphx::shape&', buffer='void*')) + h.method('shape', + fname='get_shape', + cpp_name='get_shape', + returns='const migraphx::shape&', + const=True) + h.method('buffer', + fname='data', + cpp_name='data', + returns='char*', + const=True) + h.method('equal', + api.params(x='const migraphx::argument&'), + invoke='migraphx::equal($@)', + returns='bool', + const=True) + + +api.add_function('migraphx_argument_generate', + api.params(s='const migraphx::shape&', seed='size_t'), + fname='migraphx::generate_argument', + returns='migraphx::argument') + + +@auto_handle() +def target(h): + h.constructor('create', + api.params(name='const char*'), + fname='migraphx::get_target') + + +@api.handle('migraphx_program_parameter_shapes', + 'std::unordered_map') +def program_parameter_shapes(h): + h.method('size', returns='size_t') + h.method('get', + api.params(name='const char*'), + fname='at', + cpp_name='operator[]', + returns='const migraphx::shape&') + h.method('names', + invoke='migraphx::get_names(${program_parameter_shapes})', + returns='std::vector') + + +@api.handle('migraphx_program_parameters', + 'std::unordered_map') +def program_parameters(h): + h.constructor('create') + h.method('add', + api.params(name='const char*', + argument='const migraphx::argument&'), + invoke='${program_parameters}[${name}] = ${argument}') + + +@api.handle('migraphx_arguments', 'std::vector') +def arguments(h): + h.method('size', returns='size_t') + h.method('get', + api.params(idx='size_t'), + fname='at', + cpp_name='operator[]', + returns='const migraphx::argument&') + + +@api.handle('migraphx_shapes', 'std::vector') +def shapes(h): + h.method('size', returns='size_t') + h.method('get', + api.params(idx='size_t'), + fname='at', + cpp_name='operator[]', + returns='const migraphx::shape&') + + +@api.handle('migraphx_instruction', 'migraphx::instruction_ref') +def instruction(h): + pass + + +@api.handle('migraphx_instructions', 'std::vector') +def instructions(h): + h.constructor( + 'create', + api.params(ptr='const_migraphx_instruction_t*', size='size_t'), + fname='migraphx::to_obj_vector') + + +@api.handle('migraphx_modules', 'std::vector') +def modules(h): + h.constructor('create', + api.params(ptr='migraphx_module_t*', size='size_t'), + fname='migraphx::to_objptr_vector') + + +@auto_handle(ref=True) +def module(h): + h.constructor('create', api.params(name='std::string')) + h.method('print', invoke='migraphx::print_module($@)', const=True) + h.method('add_instruction', + api.params(op='migraphx::operation', + args='std::vector'), + returns='migraphx::instruction_ref') + h.method('add_instruction_with_mod_args', + api.params(op='migraphx::operation', + args='std::vector', + module_refs='std::vector'), + fname='add_instruction', + returns='migraphx::instruction_ref') + h.method('add_literal', + api.params(shape='const migraphx::shape&', buffer='const char*'), + returns='migraphx::instruction_ref') + h.method('add_parameter', + api.params(name='const char*', shape='const migraphx::shape&'), + returns='migraphx::instruction_ref') + h.method('add_return', + api.params(args='std::vector'), + returns='migraphx::instruction_ref') + + +@auto_handle() +def program(h): + h.constructor('create') + h.method('get_main_module', returns='migraphx::module*') + h.method('create_module', + api.params(name='const char*'), + returns='migraphx::module*') + h.method( + 'compile', + api.params(target='migraphx::target', + options='migraphx::compile_options')) + h.method('get_parameter_shapes', + returns='std::unordered_map') + h.method('get_output_shapes', + invoke='migraphx::get_output_shapes($@)', + returns='std::vector') + h.method('print', invoke='migraphx::print_program($@)', const=True) + h.method('sort') + h.method('run', + api.params( + params='std::unordered_map'), + invoke='migraphx::run($@)', + returns='std::vector') + h.method('equal', + api.params(x='const migraphx::program&'), + invoke='migraphx::equal($@)', + returns='bool', + const=True) + h.method('experimental_get_context', + invoke='migraphx::get_context($@)', + const=True, + returns='migraphx::context') + + +@auto_handle() +def operation(h): + h.constructor('create', + api.params(name='const char*', + attributes='const char*', + vlist='...'), + fname='migraphx::create_op') + h.method('name', returns='std::string') + + +api.add_function('migraphx_load', + api.params(name='const char*', + options='migraphx::file_options'), + fname='migraphx::load', + returns='migraphx::program') + +api.add_function('migraphx_save', + api.params(p='migraphx::program&', + name='const char*', + options='migraphx::file_options'), + fname='migraphx::save') + + +@auto_handle() +def onnx_options(h): + h.constructor('create') + h.method( + 'set_input_parameter_shape', + api.params(name='const char*', dims='std::vector'), + invoke='migraphx::set_input_parameter_shape($@)', + ) + h.method( + 'set_default_dim_value', + api.params(value='size_t'), + invoke='migraphx::set_default_dim_value($@)', + ) + h.method( + 'set_default_loop_iterations', + api.params(value='int64_t'), + invoke='migraphx::set_default_loop_iterations($@)', + ) + + +@auto_handle() +def file_options(h): + h.constructor('create') + h.method('set_file_format', + api.params(format='const char*'), + invoke='migraphx::set_file_format($@)') + + +@auto_handle() +def compile_options(h): + h.constructor('create') + h.method('set_offload_copy', + api.params(value='bool'), + invoke='migraphx::set_offload_copy($@)') + h.method('set_fast_math', + api.params(value='bool'), + invoke='migraphx::set_fast_math($@)') + + +api.add_function('migraphx_parse_onnx', + api.params(name='const char*', + options='migraphx::onnx_options'), + fname='migraphx::parse_onnx', + returns='migraphx::program') + +api.add_function('migraphx_parse_onnx_buffer', + api.params(data='const void*', + size='size_t', + options='migraphx::onnx_options'), + fname='migraphx::parse_onnx_buffer', + returns='migraphx::program') + + +@auto_handle() +def tf_options(h): + h.constructor('create') + h.method( + 'set_nhwc', + api.params(is_nhwc='bool'), + invoke='migraphx::set_nhwc($@)', + ) + h.method( + 'set_input_parameter_shape', + api.params(name='const char*', dims='std::vector'), + invoke='migraphx::set_input_parameter_shape($@)', + ) + h.method( + 'set_default_dim_value', + api.params(value='size_t'), + invoke='migraphx::set_default_dim_value($@)', + ) + h.method( + 'set_output_names', + api.params(names='std::vector'), + invoke='migraphx::set_output_names($@)', + ) + + +api.add_function('migraphx_parse_tf', + api.params(name='const char*', + options='migraphx::tf_options'), + fname='migraphx::parse_tf', + returns='migraphx::program') + + +@api.handle('migraphx_quantize_op_names', 'std::vector') +def quantize_op_names(h): + h.constructor('create') + h.method('add', api.params(name='const char*'), fname='push_back') + + +api.add_function('migraphx_quantize_fp16_with_op_names', + api.params(prog='migraphx::program&', + name='std::vector&'), + fname='migraphx::quantize_fp16_with_op_names') + +api.add_function('migraphx_quantize_fp16', + api.params(prog='migraphx::program&'), + fname='migraphx::quantize_fp16') + + +@auto_handle() +def quantize_int8_options(h): + h.constructor('create') + h.method( + 'add_op_name', + api.params(name='const char*'), + invoke='migraphx::add_op_name($@)', + ) + h.method( + 'add_calibration_data', + api.params(data='std::unordered_map'), + invoke='migraphx::add_calibration_data($@)', + ) + + +api.add_function('migraphx_quantize_int8', + api.params(prog='migraphx::program&', + target='migraphx::target', + options='migraphx::quantize_int8_options'), + fname='migraphx::quantize_int8_wrap') + + +@auto_handle(ref=True) +def context(h): + h.method('finish', const=True) + h.method('get_queue', returns='void*', fname='get_queue().unsafe_get') + + +@api.interface('migraphx_experimental_custom_op', + 'migraphx::experimental_custom_op') +def experimental_custom_op(h): + h.constructor('create', api.params(name='const char*')) + h.virtual('compute_shape', + api.params(inputs='std::vector'), + returns='migraphx::shape') + h.method('register', invoke='migraphx::register_custom_op($@)') diff --git a/src/apply_alpha_beta.cpp b/src/apply_alpha_beta.cpp new file mode 100644 index 0000000000000000000000000000000000000000..933a915bca851e1f2e2e102c17a40f6b4159bd3d --- /dev/null +++ b/src/apply_alpha_beta.cpp @@ -0,0 +1,56 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +instruction_ref insert_apply_alpha_beta(module& m, + instruction_ref pos, + const std::vector& args, + const operation& op, + const literal& alpha, + const literal& beta) +{ + auto a = args[0]; + auto b = args[1]; + auto input_type = a->get_shape().type(); + if(!float_equal(alpha.at(0), 1.0)) + { + auto alpha_literal = m.add_literal(alpha); + a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a}); + if(a->get_shape().type() != input_type) + { + a = m.insert_instruction(pos, make_op("convert", {{"target_type", input_type}}), a); + } + } + auto op_res = m.insert_instruction(pos, op, a, b); + if(args.size() == 3) + { + if(not float_equal(beta.at(0), 0.0) && args[2]->get_shape().elements() > 0) + { + auto out_lens = op_res->get_shape().lens(); + auto c = args[2]; + auto c_lens = c->get_shape().lens(); + input_type = c->get_shape().type(); + if(out_lens != c_lens) + { + c = m.insert_instruction( + pos, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]); + } + auto beta_literal = m.add_literal(beta); + auto beta_c = insert_common_op(m, pos, migraphx::make_op("mul"), {c, beta_literal}); + if(beta_c->get_shape().type() != input_type) + { + beta_c = m.insert_instruction( + pos, migraphx::make_op("convert", {{"target_type", input_type}}), beta_c); + } + return m.insert_instruction(pos, migraphx::make_op("add"), op_res, beta_c); + } + } + return op_res; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/argument.cpp b/src/argument.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7a01211d7470f2642c4e8d88ed3bc091e68dab97 --- /dev/null +++ b/src/argument.cpp @@ -0,0 +1,169 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +argument::argument(const shape& s) : m_shape(s) +{ + auto buffer = make_shared_array(s.bytes()); + assign_buffer({[=]() mutable { return buffer.get(); }}); +} + +argument::argument(shape s, std::nullptr_t) + : m_shape(std::move(s)), m_data({[] { return nullptr; }}) +{ +} + +argument::argument(const shape& s, const argument::data_t& d) : m_shape(s), m_data(d) {} + +void argument::assign_buffer(std::function d) +{ + const shape& s = m_shape; + if(s.type() != shape::tuple_type) + { + m_data = {std::move(d)}; + return; + } + // Collect all shapes + std::unordered_map shapes; + { + std::size_t i = 0; + fix([&](auto self, auto ss) { + if(ss.sub_shapes().empty()) + { + shapes[i] = ss; + i++; + } + else + { + for(auto&& child : ss.sub_shapes()) + self(child); + } + })(s); + } + // Sort by type size + std::vector order(shapes.size()); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), by(std::greater<>{}, [&](auto i) { + return shapes[i].type_size(); + })); + // Compute offsets + std::unordered_map offsets; + std::size_t offset = 0; + for(auto i : order) + { + offsets[i] = offset; + offset += shapes[i].bytes(); + } + assert(offset == s.bytes()); + + std::size_t i = 0; + m_data = fix([&](auto self, auto ss) { + data_t result; + if(ss.sub_shapes().empty()) + { + auto n = offsets[i]; + result = {[d, n]() mutable { return d() + n; }}; + i++; + return result; + } + std::vector subs; + std::transform(ss.sub_shapes().begin(), + ss.sub_shapes().end(), + std::back_inserter(subs), + [&](auto child) { return self(child); }); + result.sub = subs; + return result; + })(s); +} + +std::vector to_shapes(const std::vector& args) +{ + std::vector shapes; + std::transform(args.begin(), args.end(), std::back_inserter(shapes), [](auto&& arg) { + return arg.get_shape(); + }); + return shapes; +} + +argument::argument(const std::vector& args) + : m_shape(to_shapes(args)), m_data(data_t::from_args(args)) +{ +} + +char* argument::data() const +{ + assert(m_shape.type() != shape::tuple_type); + assert(not this->empty()); + return m_data.get(); +} + +bool argument::empty() const { return not m_data.get and m_data.sub.empty(); } + +const shape& argument::get_shape() const { return this->m_shape; } + +argument argument::reshape(const shape& s) const +{ + assert(s.element_space() <= this->get_shape().element_space()); + return {s, this->m_data}; +} + +argument::data_t argument::data_t::share() const +{ + data_t result; + if(this->get) + { + auto self = std::make_shared(*this); + result.get = [self]() mutable { return self->get(); }; + } + std::transform(sub.begin(), sub.end(), std::back_inserter(result.sub), [](const auto& d) { + return d.share(); + }); + return result; +} + +argument::data_t argument::data_t::from_args(const std::vector& args) +{ + data_t result; + std::transform(args.begin(), args.end(), std::back_inserter(result.sub), [](auto&& arg) { + return arg.m_data; + }); + return result; +} + +argument argument::copy() const +{ + argument result{this->get_shape()}; + auto* src = this->data(); + std::copy(src, src + this->get_shape().bytes(), result.data()); + return result; +} + +argument argument::share() const { return {m_shape, m_data.share()}; } + +std::vector argument::get_sub_objects() const +{ + std::vector result; + assert(m_shape.sub_shapes().size() == m_data.sub.size()); + std::transform(m_shape.sub_shapes().begin(), + m_shape.sub_shapes().end(), + m_data.sub.begin(), + std::back_inserter(result), + [](auto&& s, auto&& d) { + return argument{s, d}; + }); + return result; +} + +argument argument::element(std::size_t i) const +{ + assert(this->get_shape().sub_shapes().empty()); + auto idx = this->get_shape().index(i); + auto offset = this->get_shape().type_size() * idx; + return argument{shape{this->get_shape().type()}, this->data() + offset}; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/auto_contiguous.cpp b/src/auto_contiguous.cpp index a77dde75376ca194c1b3444863a3623e22c50a42..8816b9a58093fa14085cf504b1129fb52c3b1b6f 100644 --- a/src/auto_contiguous.cpp +++ b/src/auto_contiguous.cpp @@ -1,21 +1,49 @@ #include #include #include -#include +#include + #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -void auto_contiguous::apply(program& p) const +void auto_contiguous::apply(module& m) const { - for(auto ins : iterator_for(p)) + std::string key = "require_std_shape"; + for(auto ins : reverse_iterator_for(m)) + { + auto&& attr = ins->get_operator().attributes(); + if((attr.get(key, false))) + { + auto args = ins->inputs(); + auto new_args = args; + std::transform(args.begin(), args.end(), new_args.begin(), [&](auto in) { + if(in->name() == "contiguous") + { + return in; + } + return m.insert_instruction(ins, make_op("contiguous"), in); + }); + + if(new_args != args) + { + m.replace_instruction(ins, ins->get_operator(), new_args); + } + } + } + + auto last = std::prev(m.end()); + for(auto ins : iterator_for(m)) { + // for last instruction that is NOT a return + if(ins->outputs().empty() and ins != last) + continue; shape s = ins->get_shape(); if(not s.standard() and s.elements() != 0) { - auto c = p.insert_instruction(std::next(ins), op::contiguous{}, ins); - p.replace_instruction(ins, c); + auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins); + m.replace_instruction(ins, c); } } } diff --git a/src/common.cpp b/src/common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ac9599326e854e9b2861563206b5048560598476 --- /dev/null +++ b/src/common.cpp @@ -0,0 +1,117 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +// Example: +// s0 = (3,2,4,5) and s1 = (2,1,1) +// +// In this case we need to broadcast (:,1,1) portion of +// s1 plus broadcast the 1st dimension of s1 +// giving output_lens = (3,2,4,5) +// +// Another example: +// s0 = (3,2,1,5) and s1 = (2,7,5) +// In this case we need to broadcast the (:,:,1:,:) axis +// of s0 plus the 1st dimension of s1 giving +// output_lens = (3,2,7,5) +std::vector compute_broadcasted_lens(std::vector s0, + std::vector s1) +{ + if(s0 == s1) + return s0; + if(s0.size() > s1.size()) + s0.swap(s1); + + std::vector out_lens(s1); + auto offset = s1.size() - s0.size(); + std::transform( + s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) { + if(a != b and a != 1 and b != 1) + { + MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" + + to_string_range(s1) + "} mismatch!"); + } + return std::max(a, b); + }); + + return out_lens; +} + +std::vector compute_common_lens(const std::vector& shapes) +{ + assert(not shapes.empty()); + return transform_accumulate(shapes.begin() + 1, + shapes.end(), + shapes.front().lens(), + &compute_broadcasted_lens, + [](auto s) { return s.lens(); }); +} + +shape::type_t compute_common_type(shape::type_t t1, shape::type_t t2) +{ + if(t1 == t2) + return t1; + shape::type_t result; + shape::visit(t1, [&](auto x) { + shape::visit(t2, [&](auto y) { + // Workaround broken warning on gcc 5 + (void)x; + (void)y; + using type = std::common_type_t; + result = shape::get_type{}; + }); + }); + return result; +} + +shape::type_t compute_common_types(const std::vector& shapes) +{ + assert(not shapes.empty()); + return transform_accumulate( + shapes.begin() + 1, shapes.end(), shapes.front().type(), &compute_common_type, [&](auto s) { + return s.type(); + }); +} + +shape common_shape(const std::vector& shapes) +{ + if(shapes.empty()) + return {}; + return {compute_common_types(shapes), compute_common_lens(shapes)}; +} + +instruction_ref insert_common_op(module& m, + instruction_ref ins, + const operation& op, + std::vector inputs) +{ + auto common = common_shape(to_shapes(inputs)); + std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { + if(input->get_shape().lens() != common.lens()) + { + input = m.insert_instruction( + ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input); + } + if(input->get_shape().type() != common.type()) + { + input = m.insert_instruction( + ins, make_op("convert", {{"target_type", common.type()}}), input); + } + return input; + }); + return m.insert_instruction(ins, op, inputs); +} + +instruction_ref add_common_op(module& m, const operation& op, std::vector inputs) +{ + return insert_common_op(m, m.end(), op, std::move(inputs)); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/compile_src.cpp b/src/compile_src.cpp new file mode 100755 index 0000000000000000000000000000000000000000..685fe48f274d6e380e55e25c8a4e37f245f6e936 --- /dev/null +++ b/src/compile_src.cpp @@ -0,0 +1,57 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::vector src_compiler::compile(const std::vector& srcs) const +{ + assert(not srcs.empty()); + tmp_dir td{"compile"}; + auto params = flags; + + params += " -I."; + + auto out = output; + + for(const auto& src : srcs) + { + fs::path full_path = td.path / src.path; + fs::path parent_path = full_path.parent_path(); + fs::create_directories(parent_path); + write_buffer(full_path.string(), src.content.first, src.len()); + if(src.path.extension().string() == ".cpp") + { + params += " " + src.path.filename().string(); + if(out.empty()) + out = src.path.stem().string() + out_ext; + } + } + + params += " -o " + out; + + if(not launcher.empty()) + { + td.execute(launcher, compiler + " " + params); + } + else + { + td.execute(compiler, params); + } + + auto out_path = td.path / out; + if(not fs::exists(out_path)) + MIGRAPHX_THROW("Output file missing: " + out); + + if(process) + out_path = process(out_path); + + return read_buffer(out_path.string()); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/convert_to_json.cpp b/src/convert_to_json.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6b1b975a0d648113161bda3ef9ff38d415dad744 --- /dev/null +++ b/src/convert_to_json.cpp @@ -0,0 +1,131 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +using token = std::pair; +using lexer = std::function; + +template +auto lex_while(P p) +{ + return [=](const char* start, const char* end) { + return std::find_if(start, end, [&](char c) { return not p(c); }); + }; +} + +template +auto lex_if(P p) +{ + return [=](const char* start, const char*) { + if(p(*start)) + return start + 1; + return start; + }; +} + +std::vector tokenize(const char* start, const char* end, const std::vector& lexers) +{ + std::vector result; + while(start != end) + { + bool error = true; + for(const auto& l : lexers) + { + const auto* next = l(start, end); + if(next != start) + { + result.emplace_back(start, next); + start = next; + error = false; + break; + } + } + + if(error) + { + MIGRAPHX_THROW("TOKENIZE: no token found!"); + } + } + + return result; +} + +std::vector json_tokenize(const std::string& s) +{ + std::vector lexers; + + // Quote + lexers.push_back([](const char* start, const char* end) { + if(*start != '\"') + return start; + ++start; + while((start != end) and (*start != '\"')) + { + if(*start == '\\') + start++; + start++; + } + + return ++start; + }); + + // Line comments + lexers.push_back([](const char* start, const char* end) { + if(*start == '#') + start++; + else if((start + 1) < end and start[0] == '/' and start[1] == '/') + start += 2; + else + return start; + return std::find_if(start, end, [&](char c) { return c == '\n'; }); + }); + + // Whitespace + lexers.push_back(lex_while(&isspace)); + + // Punctation + lexers.push_back(lex_if(&ispunct)); + + // Identifier/number + lexers.push_back(lex_while([](char c) { + return (isalnum(c) != 0 or contains({'_', '.', '+'}, c)); + })); + + return tokenize(s.data(), s.data() + s.length(), lexers); +} + +std::string convert_to_json(const std::string& str) +{ + auto tokens = json_tokenize(str); + std::stringstream ss; + + for(auto& token : tokens) + { + std::string s(token.first, token.second); + if(starts_with(s, "#") or starts_with(s, "//")) + continue; + if(std::isalpha(s.front()) != 0 and + not contains({"null", "nan", "true", "false", "inf"}, s)) + { + ss << "\"" << s << "\""; + } + else + { + ss << s; + } + } + + return ss.str(); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/cpp_generator.cpp b/src/cpp_generator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..55b4fa00fb260efdac9d51d5b26a11d6eb8aeed4 --- /dev/null +++ b/src/cpp_generator.cpp @@ -0,0 +1,208 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +cpp_generator::function& +cpp_generator::function::set_body(const module& m, const cpp_generator::generate_module_callback& g) +{ + std::unordered_map names; + std::stringstream ss; + + auto return_ins = std::prev(m.end()); + + for(auto ins : iterator_for(m)) + { + ss << "// " << ins->get_operator() << " -> " << ins->get_shape() << "\n"; + if(ins->name() == "@param") + { + names[ins] = + migraphx::any_cast(ins->get_operator()).parameter; + } + else if(ins->name() == "@return") + { + assert(ins->inputs().size() == 1); + return_ins = ins->inputs().front(); + } + else + { + std::string n = "z" + std::to_string(names.size()); + names[ins] = n; + ss << "auto " << n << " = " << g(ins, names) << ";\n"; + } + } + ss << "return " << names.at(return_ins) << ";\n"; + body = ss.str(); + return *this; +} + +cpp_generator::function& cpp_generator::function::set_types(const module& m) +{ + return cpp_generator::function::set_types(m, [](auto s) { return shape::cpp_type(s.type()); }); +} +cpp_generator::function& +cpp_generator::function::set_types(const module& m, const std::function& parse) +{ + this->params.clear(); + auto pmap = m.get_parameter_shapes(); + std::map input_map(pmap.begin(), pmap.end()); + std::transform( + input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) { + return param{p.first, parse(p.second)}; + }); + auto output_shapes = m.get_output_shapes(); + assert(not output_shapes.empty()); + this->return_type = parse(output_shapes.front()); + return *this; +} + +cpp_generator::function& cpp_generator::function::set_generic_types(const module& m) +{ + this->params.clear(); + auto pmap = m.get_parameter_shapes(); + std::map input_map(pmap.begin(), pmap.end()); + std::transform( + input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) { + return param{p.first, "T" + p.first}; + }); + + std::transform(input_map.begin(), + input_map.end(), + std::back_inserter(this->tparams), + [&](auto&& p) { return "class T" + p.first; }); + this->return_type = "auto"; + return *this; +} + +struct cpp_generator_impl +{ + std::stringstream fs{}; + std::size_t function_count = 0; + std::function fmap = nullptr; + std::function fresult = nullptr; + std::unordered_map point_op_map = {}; +}; +cpp_generator::cpp_generator() : impl(std::make_unique()) {} + +cpp_generator::cpp_generator(cpp_generator&&) noexcept = default; + +cpp_generator& cpp_generator::operator=(cpp_generator rhs) +{ + std::swap(impl, rhs.impl); + return *this; +} + +cpp_generator::~cpp_generator() noexcept = default; + +void cpp_generator::fmap(const std::function& f) { impl->fmap = f; } + +void cpp_generator::fresult(const std::function& f) { impl->fresult = f; } + +void cpp_generator::add_point_op(const std::string& op_name, const std::string& code) +{ + impl->point_op_map[op_name] = code; +} + +std::string cpp_generator::generate_point_op(const operation& op, + const std::vector& args) +{ + auto v = op.to_value(); + std::string code; + if(contains(impl->point_op_map, op.name())) + { + code = impl->point_op_map.at(op.name()); + } + else + { + auto attributes = op.attributes(); + if(not attributes.contains("point_op")) + MIGRAPHX_THROW("op is missing point_op attribute: " + op.name()); + code = attributes["point_op"].to(); + } + return interpolate_string(code, [&](auto start, auto last) -> std::string { + auto key = trim({start, last}); + if(key.empty()) + MIGRAPHX_THROW("Empty parameter"); + std::string fselector = "function:"; + if(starts_with(key, fselector)) + { + auto fname = key.substr(fselector.size()); + if(impl->fmap == nullptr) + return fname; + else + return impl->fmap(fname); + } + else if(with_char(::isdigit)(key[0])) + { + auto i = std::stoul(key); + return args.at(i); + } + else if(v.contains(key)) + { + return v[key].template to(); + } + else + { + return key; + } + }); +} + +std::string cpp_generator::str() const { return impl->fs.str(); } + +cpp_generator::function cpp_generator::generate_module(const module& m) +{ + function f; + auto name = transform_string(m.name(), [](char c) { + if(with_char(::isalnum)(c) or c == '_') + return c; + return '_'; + }); + f.set_name(name).set_types(m).set_body( + m, [&](instruction_ref ins, const auto& names) -> std::string { + if(ins->name() == "@literal") + return shape::cpp_type(ins->get_shape().type()) + "(" + + ins->get_literal().to_string() + ")"; + std::vector args; + std::transform(ins->inputs().begin(), + ins->inputs().end(), + std::back_inserter(args), + [&](auto i) { return names.at(i); }); + + auto s = this->generate_point_op(ins->get_operator(), args); + if(impl->fresult) + return impl->fresult(ins->get_shape()) + '(' + s + ')'; + else + return s; + }); + return f; +} + +std::string cpp_generator::create_function(const cpp_generator::function& f) +{ + impl->function_count++; + if(not f.tparams.empty()) + impl->fs << "template<" << join_strings(f.tparams, ", ") << ">\n"; + std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name; + impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name; + char delim = '('; + for(auto&& p : f.params) + { + impl->fs << delim << p.type << " " << p.name; + delim = ','; + } + impl->fs << ") {\n" << f.body << "\n}\n"; + return name; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/dead_code_elimination.cpp b/src/dead_code_elimination.cpp index e9fa5c693520d854dab3e3c09338094d248428e5..07d15e506112569f3e7588e758d239fbb64aa0cb 100644 --- a/src/dead_code_elimination.cpp +++ b/src/dead_code_elimination.cpp @@ -4,65 +4,56 @@ #include #include #include +#include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -template -std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last) -{ - auto start_forward = start; - auto start_backwards = start; - std::size_t n = 0; - while(start_forward != last and start_backwards != last) - { - n++; - if(start_forward != r.end()) - start_forward++; - if(start_backwards != r.begin()) - start_backwards--; - } - if(start_forward == last) - return n; - else - return -n; -} +void dead_code_elimination::apply(program& p) const { p.remove_unused_modules(); } -void dead_code_elimination::apply(program& p) const +void dead_code_elimination::apply(module& m) const { - auto last = std::prev(p.end()); - for(auto ins : iterator_for(p)) + auto last = std::prev(m.end()); + for(auto ins : iterator_for(m)) { // Skip the first instruction, since we always process the previous // instruction - if(ins == p.begin()) + if(ins == m.begin()) continue; const auto i = std::prev(ins); // Skip the last instruction if(i == last) break; - // Skip instruction with empty shape as output unless its a builtin or undefined or identity + // Skip instruction with empty shape as output unless its a builtin, undefined, identity, or + // allocate if(i->get_shape().elements() == 0 and i->name().front() != '@' and - i->name() != "undefined" and i->name() != "identity") + not contains({"undefined", "identity", "allocate"}, i->name())) continue; - assert(bidistance(p, i, last) > 0); + assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last)); + std::unordered_set visited; fix([&](auto self, auto leaf) { - assert(p.has_instruction(leaf)); + if(not m.has_instruction(leaf)) + return; + if(leaf->outputs().empty()) { + // Dont visit inputs twice + if(not visited.insert(leaf).second) + return; std::unordered_set args(leaf->inputs().begin(), leaf->inputs().end()); leaf->clear_arguments(); - assert(bidistance(p, last, leaf) < 0); + assert(std::distance(m.begin(), leaf) < std::distance(m.begin(), last)); assert(leaf != ins); - p.move_instruction(leaf, p.end()); + if(leaf->name() != "@param") + m.move_instruction(leaf, m.end()); for(auto arg : args) self(arg); } })(i); } - p.remove_instructions(std::next(last), p.end()); + m.remove_instructions(std::next(last), m.end()); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/dom_info.cpp b/src/dom_info.cpp new file mode 100755 index 0000000000000000000000000000000000000000..80ac5e9830c45bd216aaf38abbc09671fb517a11 --- /dev/null +++ b/src/dom_info.cpp @@ -0,0 +1,77 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins2) +{ + if(ins1 == ins2) + return false; + auto iter = ins2idom.find(ins2); + while(iter != ins2idom.end()) + { + if(ins1 == iter->second) + return true; + assert(iter != ins2idom.find(iter->second)); + iter = ins2idom.find(iter->second); + } + return false; +} + +struct module_visitor +{ + module* mm; + module& get_nodes() const { return *mm; } + + const std::vector& get_children(instruction_ref ins) { return ins->inputs(); } +}; + +template +dominator_info compute_dominator_generic(Visitor v) +{ + dominator_info info; + std::unordered_map> instr2_doms; + for(instruction_ref ins : iterator_for(v.get_nodes())) + { + const std::vector& children = v.get_children(ins); + if(children.size() == 1) + { + info.ins2idom[ins] = children.front(); + instr2_doms[ins].insert(children.front()); + } + else if(children.size() > 1) + { + auto&& doms = instr2_doms[ins]; + + doms = instr2_doms[children.front()]; + std::for_each(children.begin() + 1, children.end(), [&](instruction_ref child) { + auto&& child_doms = instr2_doms[child]; + erase_if(doms, [&](auto x) { return not contains(child_doms, x); }); + }); + auto iter = std::find_if(doms.begin(), doms.end(), [&](auto dom1) { + return std::none_of(doms.begin(), doms.end(), [&](auto dom2) { + if(dom1 == dom2) + return false; + return info.strictly_dominate(dom1, dom2); + }); + }); + if(iter != doms.end()) + info.ins2idom[ins] = *iter; + } + instr2_doms[ins].insert(ins); + } + return info; +} + +dominator_info compute_dominator(module& m) +{ + return compute_dominator_generic(module_visitor{&m}); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/driver/CMakeLists.txt b/src/driver/CMakeLists.txt old mode 100644 new mode 100755 index c199e87750e8705bf95fd802cc0a8cd739916205..a569d89774711e9d468a3cb22916873a2d6db2ab --- a/src/driver/CMakeLists.txt +++ b/src/driver/CMakeLists.txt @@ -1,8 +1,26 @@ -add_executable(driver main.cpp verify.cpp perf.cpp) +add_executable(driver + main.cpp + verify.cpp + perf.cpp + resnet50.cpp + inceptionv3.cpp + alexnet.cpp + marker_roctx.cpp +) +set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver) +# Copy driver for backwards compatibility +add_custom_command( + TARGET driver + POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy + $ + ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver + BYPRODUCTS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver +) +set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver) rocm_clang_tidy_check(driver) -target_link_libraries(driver migraphx_cpu migraphx_onnx migraphx_tf) -if(MIGRAPHX_ENABLE_GPU) -target_link_libraries(driver migraphx_gpu) -target_compile_definitions(driver PRIVATE -DHAVE_GPU) -endif() +target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf) + +rocm_install_targets( + TARGETS driver +) diff --git a/src/driver/alexnet.cpp b/src/driver/alexnet.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e1e7a7ed5d61c3d01fe69fea3164d00bc38d0cdc --- /dev/null +++ b/src/driver/alexnet.cpp @@ -0,0 +1,183 @@ +#include +#include +#include +#include +#include "models.hpp" + +namespace migraphx { +namespace driver { +inline namespace MIGRAPHX_INLINE_NS { + +migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto m0 = + mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}}); + auto mx0 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 0)); + auto mx1 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 1)); + auto mx2 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 2)); + auto mx3 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3)); + auto mx4 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4)); + auto mx5 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5)); + auto mx6 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 6)); + auto mx7 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 7)); + auto mx8 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 8)); + auto mx9 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9)); + auto mx10 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 10)); + auto mx11 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11)); + auto mx12 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 12)); + auto mx13 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13)); + auto mx14 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 14)); + auto mx15 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 15)); + migraphx::op::convolution convolution16; + convolution16.padding = {2, 2}; + convolution16.stride = {4, 4}; + convolution16.dilation = {1, 1}; + convolution16.group = 1; + auto mx16 = mm->add_instruction(convolution16, m0, mx15); + migraphx::op::broadcast broadcast17; + broadcast17.axis = 1; + broadcast17.broadcast_lens = {batch, 64, 55, 55}; + auto mx17 = mm->add_instruction(broadcast17, mx14); + migraphx::op::add add18; + auto mx18 = mm->add_instruction(add18, mx16, mx17); + migraphx::op::relu relu19; + auto mx19 = mm->add_instruction(relu19, mx18); + migraphx::op::pooling pooling20; + pooling20.mode = migraphx::op::pooling_mode::max; + pooling20.padding = {0, 0}; + pooling20.stride = {2, 2}; + pooling20.lengths = {3, 3}; + auto mx20 = mm->add_instruction(pooling20, mx19); + migraphx::op::convolution convolution21; + convolution21.padding = {2, 2}; + convolution21.stride = {1, 1}; + convolution21.dilation = {1, 1}; + convolution21.group = 1; + auto mx21 = mm->add_instruction(convolution21, mx20, mx13); + migraphx::op::broadcast broadcast22; + broadcast22.axis = 1; + broadcast22.broadcast_lens = {batch, 192, 27, 27}; + auto mx22 = mm->add_instruction(broadcast22, mx12); + migraphx::op::add add23; + auto mx23 = mm->add_instruction(add23, mx21, mx22); + migraphx::op::relu relu24; + auto mx24 = mm->add_instruction(relu24, mx23); + migraphx::op::pooling pooling25; + pooling25.mode = migraphx::op::pooling_mode::max; + pooling25.padding = {0, 0}; + pooling25.stride = {2, 2}; + pooling25.lengths = {3, 3}; + auto mx25 = mm->add_instruction(pooling25, mx24); + migraphx::op::convolution convolution26; + convolution26.padding = {1, 1}; + convolution26.stride = {1, 1}; + convolution26.dilation = {1, 1}; + convolution26.group = 1; + auto mx26 = mm->add_instruction(convolution26, mx25, mx11); + migraphx::op::broadcast broadcast27; + broadcast27.axis = 1; + broadcast27.broadcast_lens = {batch, 384, 13, 13}; + auto mx27 = mm->add_instruction(broadcast27, mx10); + migraphx::op::add add28; + auto mx28 = mm->add_instruction(add28, mx26, mx27); + migraphx::op::relu relu29; + auto mx29 = mm->add_instruction(relu29, mx28); + migraphx::op::convolution convolution30; + convolution30.padding = {1, 1}; + convolution30.stride = {1, 1}; + convolution30.dilation = {1, 1}; + convolution30.group = 1; + auto mx30 = mm->add_instruction(convolution30, mx29, mx9); + migraphx::op::broadcast broadcast31; + broadcast31.axis = 1; + broadcast31.broadcast_lens = {batch, 256, 13, 13}; + auto mx31 = mm->add_instruction(broadcast31, mx8); + migraphx::op::add add32; + auto mx32 = mm->add_instruction(add32, mx30, mx31); + migraphx::op::relu relu33; + auto mx33 = mm->add_instruction(relu33, mx32); + migraphx::op::convolution convolution34; + convolution34.padding = {1, 1}; + convolution34.stride = {1, 1}; + convolution34.dilation = {1, 1}; + convolution34.group = 1; + auto mx34 = mm->add_instruction(convolution34, mx33, mx7); + migraphx::op::broadcast broadcast35; + broadcast35.axis = 1; + broadcast35.broadcast_lens = {batch, 256, 13, 13}; + auto mx35 = mm->add_instruction(broadcast35, mx6); + migraphx::op::add add36; + auto mx36 = mm->add_instruction(add36, mx34, mx35); + migraphx::op::relu relu37; + auto mx37 = mm->add_instruction(relu37, mx36); + migraphx::op::pooling pooling38; + pooling38.mode = migraphx::op::pooling_mode::max; + pooling38.padding = {0, 0}; + pooling38.stride = {2, 2}; + pooling38.lengths = {3, 3}; + auto mx38 = mm->add_instruction(pooling38, mx37); + migraphx::op::flatten flatten39; + flatten39.axis = 1; + auto mx39 = mm->add_instruction(flatten39, mx38); + migraphx::op::identity identity40; + auto mx40 = mm->add_instruction(identity40, mx39); + migraphx::op::transpose transpose41; + transpose41.dims = {1, 0}; + auto mx41 = mm->add_instruction(transpose41, mx5); + migraphx::op::multibroadcast multibroadcast42; + multibroadcast42.output_lens = {batch, 4096}; + auto mx42 = mm->add_instruction(multibroadcast42, mx4); + float dot43_alpha = 1; + float dot43_beta = 1; + auto mx43 = migraphx::add_apply_alpha_beta( + *mm, {mx40, mx41, mx42}, migraphx::make_op("dot"), dot43_alpha, dot43_beta); + migraphx::op::relu relu44; + auto mx44 = mm->add_instruction(relu44, mx43); + migraphx::op::identity identity45; + auto mx45 = mm->add_instruction(identity45, mx44); + migraphx::op::transpose transpose46; + transpose46.dims = {1, 0}; + auto mx46 = mm->add_instruction(transpose46, mx3); + migraphx::op::multibroadcast multibroadcast47; + multibroadcast47.output_lens = {batch, 4096}; + auto mx47 = mm->add_instruction(multibroadcast47, mx2); + float dot48_alpha = 1; + float dot48_beta = 1; + auto mx48 = migraphx::add_apply_alpha_beta( + *mm, {mx45, mx46, mx47}, migraphx::make_op("dot"), dot48_alpha, dot48_beta); + migraphx::op::relu relu49; + auto mx49 = mm->add_instruction(relu49, mx48); + migraphx::op::transpose transpose50; + transpose50.dims = {1, 0}; + auto mx50 = mm->add_instruction(transpose50, mx1); + migraphx::op::multibroadcast multibroadcast51; + multibroadcast51.output_lens = {batch, 1000}; + auto mx51 = mm->add_instruction(multibroadcast51, mx0); + float dot52_alpha = 1; + float dot52_beta = 1; + migraphx::add_apply_alpha_beta( + *mm, {mx49, mx50, mx51}, migraphx::make_op("dot"), dot52_alpha, dot52_beta); + return p; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace driver +} // namespace migraphx diff --git a/src/driver/argument_parser.hpp b/src/driver/argument_parser.hpp index d75e7daa4a11690d84cace5e1c42c568b81ca203..0c52a39db5a711440deeb043fc5c696142f4e483 100644 --- a/src/driver/argument_parser.hpp +++ b/src/driver/argument_parser.hpp @@ -17,6 +17,7 @@ #include #include #include +#include namespace migraphx { namespace driver { @@ -132,10 +133,22 @@ struct argument_parser return to_string_range(x); } + template + auto as_string_value(rank<1>, const T& x) -> decltype(to_string(x)) + { + return to_string(x); + } + + template + std::string as_string_value(rank<0>, const T&) + { + throw std::runtime_error("Can't convert to string"); + } + template {})> std::string as_string_value(const T& x) { - return to_string(x); + return as_string_value(rank<1>{}, x); } template @@ -148,10 +161,11 @@ struct argument_parser return false; }}); - argument& arg = arguments.back(); - arg.type = type_name::apply(); - arg.default_value = as_string_value(x); + argument& arg = arguments.back(); + arg.type = type_name::apply(); migraphx::each_args([&](auto f) { f(x, arg); }, fs...); + if(not arg.default_value.empty() and arg.nargs > 0) + arg.default_value = as_string_value(x); } template @@ -247,6 +261,11 @@ struct argument_parser return [=](auto&, auto& arg) { arg.metavar = metavar; }; } + MIGRAPHX_DRIVER_STATIC auto type(const std::string& type) + { + return [=](auto&, auto& arg) { arg.type = type; }; + } + template MIGRAPHX_DRIVER_STATIC auto set_value(T value) { diff --git a/src/driver/command.hpp b/src/driver/command.hpp index e9ee4df0e9207183a998fbb1a601079326126613..4a80137183ef52a81c902fbd77069d5a2744927d 100644 --- a/src/driver/command.hpp +++ b/src/driver/command.hpp @@ -17,6 +17,7 @@ inline namespace MIGRAPHX_INLINE_NS { inline auto& get_commands() { + // NOLINTNEXTLINE static std::unordered_map args)>> m; return m; } @@ -64,7 +65,7 @@ int auto_register_command() template struct command { - static int static_register; + static const int static_register; // This typedef ensures that the static member will be instantiated if // the class itself is instantiated using static_register_type = @@ -77,7 +78,7 @@ struct command #endif template -int command::static_register = auto_register_command(); // NOLINT +const int command::static_register = auto_register_command(); // NOLINT } // namespace MIGRAPHX_INLINE_NS } // namespace driver diff --git a/src/driver/inceptionv3.cpp b/src/driver/inceptionv3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c80bf53b8da5fbaaf174bedf5801a7408bfe564c --- /dev/null +++ b/src/driver/inceptionv3.cpp @@ -0,0 +1,2239 @@ +#include +#include +#include +#include +#include "models.hpp" + +namespace migraphx { +namespace driver { +inline namespace MIGRAPHX_INLINE_NS { + +migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-size) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto m0 = + mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 299, 299}}); + auto mx0 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 0)); + auto mx1 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 2048}}, 1)); + auto mx2 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 2))); + auto mx3 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 3)); + auto mx4 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 4)); + auto mx5 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 5))); + auto mx6 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 2048, 1, 1}}, 6)); + auto mx7 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 7))); + auto mx8 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 8)); + auto mx9 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 9)); + auto mx10 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 10))); + auto mx11 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {384, 384, 3, 1}}, 11)); + auto mx12 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 12))); + auto mx13 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 13)); + auto mx14 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 14)); + auto mx15 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 15))); + auto mx16 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {384, 384, 1, 3}}, 16)); + auto mx17 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 17))); + auto mx18 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 18)); + auto mx19 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 19)); + auto mx20 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 20))); + auto mx21 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {384, 448, 3, 3}}, 21)); + auto mx22 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {448}}, 22))); + auto mx23 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {448}}, 23)); + auto mx24 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {448}}, 24)); + auto mx25 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {448}}, 25))); + auto mx26 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {448, 2048, 1, 1}}, 26)); + auto mx27 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 27))); + auto mx28 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 28))); + auto mx29 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 29)); + auto mx30 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 30))); + auto mx31 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {384, 384, 3, 1}}, 31)); + auto mx32 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 32))); + auto mx33 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 33))); + auto mx34 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 34)); + auto mx35 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 35))); + auto mx36 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {384, 384, 1, 3}}, 36)); + auto mx37 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 37))); + auto mx38 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 38)); + auto mx39 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 39)); + auto mx40 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 40))); + auto mx41 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {384, 2048, 1, 1}}, 41)); + auto mx42 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {320}}, 42))); + auto mx43 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {320}}, 43)); + auto mx44 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {320}}, 44)); + auto mx45 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {320}}, 45))); + auto mx46 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {320, 2048, 1, 1}}, 46)); + auto mx47 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 47))); + auto mx48 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 48)); + auto mx49 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 49)); + auto mx50 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 50))); + auto mx51 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 1280, 1, 1}}, 51)); + auto mx52 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 52))); + auto mx53 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 53)); + auto mx54 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 54)); + auto mx55 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 55))); + auto mx56 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {384, 384, 3, 1}}, 56)); + auto mx57 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 57))); + auto mx58 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 58)); + auto mx59 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 59)); + auto mx60 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 60))); + auto mx61 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {384, 384, 1, 3}}, 61)); + auto mx62 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 62))); + auto mx63 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 63)); + auto mx64 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 64)); + auto mx65 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 65))); + auto mx66 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {384, 448, 3, 3}}, 66)); + auto mx67 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {448}}, 67))); + auto mx68 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {448}}, 68)); + auto mx69 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {448}}, 69)); + auto mx70 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {448}}, 70))); + auto mx71 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {448, 1280, 1, 1}}, 71)); + auto mx72 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 72))); + auto mx73 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 73)); + auto mx74 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 74)); + auto mx75 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 75))); + auto mx76 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {384, 384, 3, 1}}, 76)); + auto mx77 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 77))); + auto mx78 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 78)); + auto mx79 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 79)); + auto mx80 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 80))); + auto mx81 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {384, 384, 1, 3}}, 81)); + auto mx82 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 82))); + auto mx83 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 83)); + auto mx84 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 84)); + auto mx85 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 85))); + auto mx86 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {384, 1280, 1, 1}}, 86)); + auto mx87 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {320}}, 87))); + auto mx88 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {320}}, 88)); + auto mx89 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {320}}, 89)); + auto mx90 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {320}}, 90))); + auto mx91 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {320, 1280, 1, 1}}, 91)); + auto mx92 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 92))); + auto mx93 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 93)); + auto mx94 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 94)); + auto mx95 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 95))); + auto mx96 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 192, 3, 3}}, 96)); + auto mx97 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 97))); + auto mx98 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 98)); + auto mx99 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 99)); + auto mx100 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 100))); + auto mx101 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 192, 7, 1}}, 101)); + auto mx102 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 102))); + auto mx103 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 103)); + auto mx104 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 104)); + auto mx105 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 105))); + auto mx106 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 192, 1, 7}}, 106)); + auto mx107 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 107))); + auto mx108 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 108)); + auto mx109 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 109)); + auto mx110 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 110))); + auto mx111 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 768, 1, 1}}, 111)); + auto mx112 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {320}}, 112))); + auto mx113 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {320}}, 113)); + auto mx114 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {320}}, 114)); + auto mx115 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {320}}, 115))); + auto mx116 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {320, 192, 3, 3}}, 116)); + auto mx117 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 117))); + auto mx118 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 118)); + auto mx119 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 119)); + auto mx120 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 120))); + auto mx121 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 768, 1, 1}}, 121)); + auto mx134 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 134))); + auto mx135 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 135)); + auto mx136 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 136)); + auto mx137 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 137))); + auto mx138 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 768, 1, 1}}, 138)); + auto mx139 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 139))); + auto mx140 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 140)); + auto mx141 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 141)); + auto mx142 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 142))); + auto mx143 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 192, 1, 7}}, 143)); + auto mx144 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 144))); + auto mx145 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 145)); + auto mx146 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 146)); + auto mx147 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 147))); + auto mx148 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 192, 7, 1}}, 148)); + auto mx149 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 149))); + auto mx150 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 150)); + auto mx151 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 151)); + auto mx152 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 152))); + auto mx153 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 192, 1, 7}}, 153)); + auto mx154 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 154))); + auto mx155 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 155)); + auto mx156 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 156)); + auto mx157 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 157))); + auto mx158 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 192, 7, 1}}, 158)); + auto mx159 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 159))); + auto mx160 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 160)); + auto mx161 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 161)); + auto mx162 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 162))); + auto mx163 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 768, 1, 1}}, 163)); + auto mx164 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 164))); + auto mx165 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 165)); + auto mx166 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 166)); + auto mx167 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 167))); + auto mx168 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 192, 7, 1}}, 168)); + auto mx169 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 169))); + auto mx170 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 170)); + auto mx171 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 171)); + auto mx172 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 172))); + auto mx173 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 192, 1, 7}}, 173)); + auto mx174 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 174))); + auto mx175 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 175)); + auto mx176 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 176)); + auto mx177 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 177))); + auto mx178 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 768, 1, 1}}, 178)); + auto mx179 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 179))); + auto mx180 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 180)); + auto mx181 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 181)); + auto mx182 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 182))); + auto mx183 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 768, 1, 1}}, 183)); + auto mx184 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 184))); + auto mx185 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 185)); + auto mx186 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 186)); + auto mx187 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 187))); + auto mx188 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 768, 1, 1}}, 188)); + auto mx189 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 189))); + auto mx190 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 190)); + auto mx191 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 191)); + auto mx192 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 192))); + auto mx193 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 160, 1, 7}}, 193)); + auto mx194 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 194))); + auto mx195 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 195)); + auto mx196 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 196)); + auto mx197 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 197))); + auto mx198 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {160, 160, 7, 1}}, 198)); + auto mx199 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 199))); + auto mx200 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 200)); + auto mx201 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 201)); + auto mx202 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 202))); + auto mx203 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {160, 160, 1, 7}}, 203)); + auto mx204 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 204))); + auto mx205 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 205)); + auto mx206 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 206)); + auto mx207 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 207))); + auto mx208 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {160, 160, 7, 1}}, 208)); + auto mx209 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 209))); + auto mx210 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 210)); + auto mx211 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 211)); + auto mx212 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 212))); + auto mx213 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {160, 768, 1, 1}}, 213)); + auto mx214 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 214))); + auto mx215 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 215)); + auto mx216 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 216)); + auto mx217 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 217))); + auto mx218 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 160, 7, 1}}, 218)); + auto mx219 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 219))); + auto mx220 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 220)); + auto mx221 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 221)); + auto mx222 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 222))); + auto mx223 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {160, 160, 1, 7}}, 223)); + auto mx224 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 224))); + auto mx225 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 225)); + auto mx226 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 226)); + auto mx227 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 227))); + auto mx228 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {160, 768, 1, 1}}, 228)); + auto mx229 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 229))); + auto mx230 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 230)); + auto mx231 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 231)); + auto mx232 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 232))); + auto mx233 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 768, 1, 1}}, 233)); + auto mx234 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 234))); + auto mx235 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 235)); + auto mx236 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 236)); + auto mx237 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 237))); + auto mx238 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 768, 1, 1}}, 238)); + auto mx239 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 239))); + auto mx240 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 240)); + auto mx241 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 241)); + auto mx242 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 242))); + auto mx243 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 160, 1, 7}}, 243)); + auto mx244 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 244))); + auto mx245 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 245)); + auto mx246 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 246)); + auto mx247 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 247))); + auto mx248 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {160, 160, 7, 1}}, 248)); + auto mx249 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 249))); + auto mx250 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 250)); + auto mx251 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 251)); + auto mx252 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 252))); + auto mx253 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {160, 160, 1, 7}}, 253)); + auto mx254 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 254))); + auto mx255 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 255)); + auto mx256 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 256)); + auto mx257 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 257))); + auto mx258 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {160, 160, 7, 1}}, 258)); + auto mx259 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 259))); + auto mx260 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 260)); + auto mx261 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 261)); + auto mx262 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 262))); + auto mx263 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {160, 768, 1, 1}}, 263)); + auto mx264 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 264))); + auto mx265 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 265)); + auto mx266 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 266)); + auto mx267 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 267))); + auto mx268 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 160, 7, 1}}, 268)); + auto mx269 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 269))); + auto mx270 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 270)); + auto mx271 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 271)); + auto mx272 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 272))); + auto mx273 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {160, 160, 1, 7}}, 273)); + auto mx274 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 274))); + auto mx275 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 275)); + auto mx276 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 276)); + auto mx277 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {160}}, 277))); + auto mx278 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {160, 768, 1, 1}}, 278)); + auto mx279 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 279))); + auto mx280 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 280)); + auto mx281 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 281)); + auto mx282 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 282))); + auto mx283 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 768, 1, 1}}, 283)); + auto mx284 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 284))); + auto mx285 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 285)); + auto mx286 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 286)); + auto mx287 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 287))); + auto mx288 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 768, 1, 1}}, 288)); + auto mx289 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 289))); + auto mx290 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 290)); + auto mx291 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 291)); + auto mx292 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 292))); + auto mx293 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 128, 1, 7}}, 293)); + auto mx294 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 294))); + auto mx295 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 295)); + auto mx296 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 296)); + auto mx297 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 297))); + auto mx298 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {128, 128, 7, 1}}, 298)); + auto mx299 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 299))); + auto mx300 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 300)); + auto mx301 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 301)); + auto mx302 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 302))); + auto mx303 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {128, 128, 1, 7}}, 303)); + auto mx304 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 304))); + auto mx305 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 305)); + auto mx306 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 306)); + auto mx307 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 307))); + auto mx308 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {128, 128, 7, 1}}, 308)); + auto mx309 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 309))); + auto mx310 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 310)); + auto mx311 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 311)); + auto mx312 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 312))); + auto mx313 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {128, 768, 1, 1}}, 313)); + auto mx314 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 314))); + auto mx315 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 315)); + auto mx316 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 316)); + auto mx317 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 317))); + auto mx318 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 128, 7, 1}}, 318)); + auto mx319 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 319))); + auto mx320 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 320)); + auto mx321 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 321)); + auto mx322 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 322))); + auto mx323 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {128, 128, 1, 7}}, 323)); + auto mx324 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 324))); + auto mx325 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 325)); + auto mx326 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 326)); + auto mx327 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 327))); + auto mx328 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {128, 768, 1, 1}}, 328)); + auto mx329 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 329))); + auto mx330 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 330)); + auto mx331 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 331)); + auto mx332 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 332))); + auto mx333 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 768, 1, 1}}, 333)); + auto mx334 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 334))); + auto mx335 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 335)); + auto mx336 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 336)); + auto mx337 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 337))); + auto mx338 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {96, 96, 3, 3}}, 338)); + auto mx339 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 339))); + auto mx340 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 340)); + auto mx341 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 341)); + auto mx342 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 342))); + auto mx343 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {96, 64, 3, 3}}, 343)); + auto mx344 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 344))); + auto mx345 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 345)); + auto mx346 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 346)); + auto mx347 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 347))); + auto mx348 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 288, 1, 1}}, 348)); + auto mx349 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 349))); + auto mx350 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 350)); + auto mx351 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 351)); + auto mx352 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 352))); + auto mx353 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {384, 288, 3, 3}}, 353)); + auto mx354 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 354))); + auto mx355 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 355)); + auto mx356 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 356)); + auto mx357 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 357))); + auto mx358 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 288, 1, 1}}, 358)); + auto mx359 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 359))); + auto mx360 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 360)); + auto mx361 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 361)); + auto mx362 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 362))); + auto mx363 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {96, 96, 3, 3}}, 363)); + auto mx364 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 364))); + auto mx365 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 365)); + auto mx366 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 366)); + auto mx367 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 367))); + auto mx368 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {96, 64, 3, 3}}, 368)); + auto mx369 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 369))); + auto mx370 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 370)); + auto mx371 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 371)); + auto mx372 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 372))); + auto mx373 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 288, 1, 1}}, 373)); + auto mx374 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 374))); + auto mx375 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 375)); + auto mx376 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 376)); + auto mx377 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 377))); + auto mx378 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 48, 5, 5}}, 378)); + auto mx379 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {48}}, 379))); + auto mx380 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {48}}, 380)); + auto mx381 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {48}}, 381)); + auto mx382 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {48}}, 382))); + auto mx383 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {48, 288, 1, 1}}, 383)); + auto mx384 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 384))); + auto mx385 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 385)); + auto mx386 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 386)); + auto mx387 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 387))); + auto mx388 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 288, 1, 1}}, 388)); + auto mx389 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 389))); + auto mx390 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 390)); + auto mx391 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 391)); + auto mx392 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 392))); + auto mx393 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 256, 1, 1}}, 393)); + auto mx394 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 394))); + auto mx395 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 395)); + auto mx396 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 396)); + auto mx397 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 397))); + auto mx398 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {96, 96, 3, 3}}, 398)); + auto mx399 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 399))); + auto mx400 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 400)); + auto mx401 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 401)); + auto mx402 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 402))); + auto mx403 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {96, 64, 3, 3}}, 403)); + auto mx404 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 404))); + auto mx405 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 405)); + auto mx406 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 406)); + auto mx407 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 407))); + auto mx408 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 256, 1, 1}}, 408)); + auto mx409 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 409))); + auto mx410 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 410)); + auto mx411 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 411)); + auto mx412 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 412))); + auto mx413 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 48, 5, 5}}, 413)); + auto mx414 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {48}}, 414))); + auto mx415 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {48}}, 415)); + auto mx416 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {48}}, 416)); + auto mx417 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {48}}, 417))); + auto mx418 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {48, 256, 1, 1}}, 418)); + auto mx419 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 419))); + auto mx420 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 420)); + auto mx421 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 421)); + auto mx422 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 422))); + auto mx423 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 256, 1, 1}}, 423)); + auto mx424 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {32}}, 424))); + auto mx425 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {32}}, 425)); + auto mx426 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {32}}, 426)); + auto mx427 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {32}}, 427))); + auto mx428 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {32, 192, 1, 1}}, 428)); + auto mx429 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 429))); + auto mx430 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 430)); + auto mx431 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 431)); + auto mx432 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 432))); + auto mx433 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {96, 96, 3, 3}}, 433)); + auto mx434 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 434))); + auto mx435 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 435)); + auto mx436 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 436)); + auto mx437 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {96}}, 437))); + auto mx438 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {96, 64, 3, 3}}, 438)); + auto mx439 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 439))); + auto mx440 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 440)); + auto mx441 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 441)); + auto mx442 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 442))); + auto mx443 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 192, 1, 1}}, 443)); + auto mx444 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 444))); + auto mx445 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 445)); + auto mx446 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 446)); + auto mx447 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 447))); + auto mx448 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 48, 5, 5}}, 448)); + auto mx449 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {48}}, 449))); + auto mx450 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {48}}, 450)); + auto mx451 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {48}}, 451)); + auto mx452 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {48}}, 452))); + auto mx453 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {48, 192, 1, 1}}, 453)); + auto mx454 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 454))); + auto mx455 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 455)); + auto mx456 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 456)); + auto mx457 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 457))); + auto mx458 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 192, 1, 1}}, 458)); + auto mx459 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 459))); + auto mx460 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 460)); + auto mx461 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 461)); + auto mx462 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 462))); + auto mx463 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {192, 80, 3, 3}}, 463)); + auto mx464 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {80}}, 464))); + auto mx465 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {80}}, 465)); + auto mx466 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {80}}, 466)); + auto mx467 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {80}}, 467))); + auto mx468 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {80, 64, 1, 1}}, 468)); + auto mx469 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 469))); + auto mx470 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 470)); + auto mx471 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 471)); + auto mx472 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 472))); + auto mx473 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 32, 3, 3}}, 473)); + auto mx474 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {32}}, 474))); + auto mx475 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {32}}, 475)); + auto mx476 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {32}}, 476)); + auto mx477 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {32}}, 477))); + auto mx478 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {32, 32, 3, 3}}, 478)); + auto mx479 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {32}}, 479))); + auto mx480 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {32}}, 480)); + auto mx481 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {32}}, 481)); + auto mx482 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {32}}, 482))); + auto mx483 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {32, 3, 3, 3}}, 483)); + migraphx::op::convolution convolution484; + convolution484.padding = {0, 0}; + convolution484.stride = {2, 2}; + convolution484.dilation = {1, 1}; + convolution484.group = 1; + auto mx484 = mm->add_instruction(convolution484, m0, mx483); + migraphx::op::batch_norm_inference batch_norm_inference485; + batch_norm_inference485.epsilon = 0.001; + batch_norm_inference485.momentum = 0.9; + auto mx485 = mm->add_instruction(batch_norm_inference485, mx484, mx482, mx481, mx480, mx479); + migraphx::op::relu relu486; + auto mx486 = mm->add_instruction(relu486, mx485); + migraphx::op::convolution convolution487; + convolution487.padding = {0, 0}; + convolution487.stride = {1, 1}; + convolution487.dilation = {1, 1}; + convolution487.group = 1; + auto mx487 = mm->add_instruction(convolution487, mx486, mx478); + migraphx::op::batch_norm_inference batch_norm_inference488; + batch_norm_inference488.epsilon = 0.001; + batch_norm_inference488.momentum = 0.9; + auto mx488 = mm->add_instruction(batch_norm_inference488, mx487, mx477, mx476, mx475, mx474); + migraphx::op::relu relu489; + auto mx489 = mm->add_instruction(relu489, mx488); + migraphx::op::convolution convolution490; + convolution490.padding = {1, 1}; + convolution490.stride = {1, 1}; + convolution490.dilation = {1, 1}; + convolution490.group = 1; + auto mx490 = mm->add_instruction(convolution490, mx489, mx473); + migraphx::op::batch_norm_inference batch_norm_inference491; + batch_norm_inference491.epsilon = 0.001; + batch_norm_inference491.momentum = 0.9; + auto mx491 = mm->add_instruction(batch_norm_inference491, mx490, mx472, mx471, mx470, mx469); + migraphx::op::relu relu492; + auto mx492 = mm->add_instruction(relu492, mx491); + migraphx::op::pooling pooling493; + pooling493.mode = migraphx::op::pooling_mode::max; + pooling493.padding = {0, 0}; + pooling493.stride = {2, 2}; + pooling493.lengths = {3, 3}; + auto mx493 = mm->add_instruction(pooling493, mx492); + migraphx::op::convolution convolution494; + convolution494.padding = {0, 0}; + convolution494.stride = {1, 1}; + convolution494.dilation = {1, 1}; + convolution494.group = 1; + auto mx494 = mm->add_instruction(convolution494, mx493, mx468); + migraphx::op::batch_norm_inference batch_norm_inference495; + batch_norm_inference495.epsilon = 0.001; + batch_norm_inference495.momentum = 0.9; + auto mx495 = mm->add_instruction(batch_norm_inference495, mx494, mx467, mx466, mx465, mx464); + migraphx::op::relu relu496; + auto mx496 = mm->add_instruction(relu496, mx495); + migraphx::op::convolution convolution497; + convolution497.padding = {0, 0}; + convolution497.stride = {1, 1}; + convolution497.dilation = {1, 1}; + convolution497.group = 1; + auto mx497 = mm->add_instruction(convolution497, mx496, mx463); + migraphx::op::batch_norm_inference batch_norm_inference498; + batch_norm_inference498.epsilon = 0.001; + batch_norm_inference498.momentum = 0.9; + auto mx498 = mm->add_instruction(batch_norm_inference498, mx497, mx462, mx461, mx460, mx459); + migraphx::op::relu relu499; + auto mx499 = mm->add_instruction(relu499, mx498); + migraphx::op::pooling pooling500; + pooling500.mode = migraphx::op::pooling_mode::max; + pooling500.padding = {0, 0}; + pooling500.stride = {2, 2}; + pooling500.lengths = {3, 3}; + auto mx500 = mm->add_instruction(pooling500, mx499); + migraphx::op::convolution convolution501; + convolution501.padding = {0, 0}; + convolution501.stride = {1, 1}; + convolution501.dilation = {1, 1}; + convolution501.group = 1; + auto mx501 = mm->add_instruction(convolution501, mx500, mx458); + migraphx::op::batch_norm_inference batch_norm_inference502; + batch_norm_inference502.epsilon = 0.001; + batch_norm_inference502.momentum = 0.9; + auto mx502 = mm->add_instruction(batch_norm_inference502, mx501, mx457, mx456, mx455, mx454); + migraphx::op::relu relu503; + auto mx503 = mm->add_instruction(relu503, mx502); + migraphx::op::convolution convolution504; + convolution504.padding = {0, 0}; + convolution504.stride = {1, 1}; + convolution504.dilation = {1, 1}; + convolution504.group = 1; + auto mx504 = mm->add_instruction(convolution504, mx500, mx453); + migraphx::op::batch_norm_inference batch_norm_inference505; + batch_norm_inference505.epsilon = 0.001; + batch_norm_inference505.momentum = 0.9; + auto mx505 = mm->add_instruction(batch_norm_inference505, mx504, mx452, mx451, mx450, mx449); + migraphx::op::relu relu506; + auto mx506 = mm->add_instruction(relu506, mx505); + migraphx::op::convolution convolution507; + convolution507.padding = {2, 2}; + convolution507.stride = {1, 1}; + convolution507.dilation = {1, 1}; + convolution507.group = 1; + auto mx507 = mm->add_instruction(convolution507, mx506, mx448); + migraphx::op::batch_norm_inference batch_norm_inference508; + batch_norm_inference508.epsilon = 0.001; + batch_norm_inference508.momentum = 0.9; + auto mx508 = mm->add_instruction(batch_norm_inference508, mx507, mx447, mx446, mx445, mx444); + migraphx::op::relu relu509; + auto mx509 = mm->add_instruction(relu509, mx508); + migraphx::op::convolution convolution510; + convolution510.padding = {0, 0}; + convolution510.stride = {1, 1}; + convolution510.dilation = {1, 1}; + convolution510.group = 1; + auto mx510 = mm->add_instruction(convolution510, mx500, mx443); + migraphx::op::batch_norm_inference batch_norm_inference511; + batch_norm_inference511.epsilon = 0.001; + batch_norm_inference511.momentum = 0.9; + auto mx511 = mm->add_instruction(batch_norm_inference511, mx510, mx442, mx441, mx440, mx439); + migraphx::op::relu relu512; + auto mx512 = mm->add_instruction(relu512, mx511); + migraphx::op::convolution convolution513; + convolution513.padding = {1, 1}; + convolution513.stride = {1, 1}; + convolution513.dilation = {1, 1}; + convolution513.group = 1; + auto mx513 = mm->add_instruction(convolution513, mx512, mx438); + migraphx::op::batch_norm_inference batch_norm_inference514; + batch_norm_inference514.epsilon = 0.001; + batch_norm_inference514.momentum = 0.9; + auto mx514 = mm->add_instruction(batch_norm_inference514, mx513, mx437, mx436, mx435, mx434); + migraphx::op::relu relu515; + auto mx515 = mm->add_instruction(relu515, mx514); + migraphx::op::convolution convolution516; + convolution516.padding = {1, 1}; + convolution516.stride = {1, 1}; + convolution516.dilation = {1, 1}; + convolution516.group = 1; + auto mx516 = mm->add_instruction(convolution516, mx515, mx433); + migraphx::op::batch_norm_inference batch_norm_inference517; + batch_norm_inference517.epsilon = 0.001; + batch_norm_inference517.momentum = 0.9; + auto mx517 = mm->add_instruction(batch_norm_inference517, mx516, mx432, mx431, mx430, mx429); + migraphx::op::relu relu518; + auto mx518 = mm->add_instruction(relu518, mx517); + migraphx::op::pooling pooling519; + pooling519.mode = migraphx::op::pooling_mode::average; + pooling519.padding = {1, 1}; + pooling519.stride = {1, 1}; + pooling519.lengths = {3, 3}; + auto mx519 = mm->add_instruction(pooling519, mx500); + migraphx::op::convolution convolution520; + convolution520.padding = {0, 0}; + convolution520.stride = {1, 1}; + convolution520.dilation = {1, 1}; + convolution520.group = 1; + auto mx520 = mm->add_instruction(convolution520, mx519, mx428); + migraphx::op::batch_norm_inference batch_norm_inference521; + batch_norm_inference521.epsilon = 0.001; + batch_norm_inference521.momentum = 0.9; + auto mx521 = mm->add_instruction(batch_norm_inference521, mx520, mx427, mx426, mx425, mx424); + migraphx::op::relu relu522; + auto mx522 = mm->add_instruction(relu522, mx521); + migraphx::op::concat concat523; + concat523.axis = 1; + auto mx523 = mm->add_instruction(concat523, mx503, mx509, mx518, mx522); + migraphx::op::convolution convolution524; + convolution524.padding = {0, 0}; + convolution524.stride = {1, 1}; + convolution524.dilation = {1, 1}; + convolution524.group = 1; + auto mx524 = mm->add_instruction(convolution524, mx523, mx423); + migraphx::op::batch_norm_inference batch_norm_inference525; + batch_norm_inference525.epsilon = 0.001; + batch_norm_inference525.momentum = 0.9; + auto mx525 = mm->add_instruction(batch_norm_inference525, mx524, mx422, mx421, mx420, mx419); + migraphx::op::relu relu526; + auto mx526 = mm->add_instruction(relu526, mx525); + migraphx::op::convolution convolution527; + convolution527.padding = {0, 0}; + convolution527.stride = {1, 1}; + convolution527.dilation = {1, 1}; + convolution527.group = 1; + auto mx527 = mm->add_instruction(convolution527, mx523, mx418); + migraphx::op::batch_norm_inference batch_norm_inference528; + batch_norm_inference528.epsilon = 0.001; + batch_norm_inference528.momentum = 0.9; + auto mx528 = mm->add_instruction(batch_norm_inference528, mx527, mx417, mx416, mx415, mx414); + migraphx::op::relu relu529; + auto mx529 = mm->add_instruction(relu529, mx528); + migraphx::op::convolution convolution530; + convolution530.padding = {2, 2}; + convolution530.stride = {1, 1}; + convolution530.dilation = {1, 1}; + convolution530.group = 1; + auto mx530 = mm->add_instruction(convolution530, mx529, mx413); + migraphx::op::batch_norm_inference batch_norm_inference531; + batch_norm_inference531.epsilon = 0.001; + batch_norm_inference531.momentum = 0.9; + auto mx531 = mm->add_instruction(batch_norm_inference531, mx530, mx412, mx411, mx410, mx409); + migraphx::op::relu relu532; + auto mx532 = mm->add_instruction(relu532, mx531); + migraphx::op::convolution convolution533; + convolution533.padding = {0, 0}; + convolution533.stride = {1, 1}; + convolution533.dilation = {1, 1}; + convolution533.group = 1; + auto mx533 = mm->add_instruction(convolution533, mx523, mx408); + migraphx::op::batch_norm_inference batch_norm_inference534; + batch_norm_inference534.epsilon = 0.001; + batch_norm_inference534.momentum = 0.9; + auto mx534 = mm->add_instruction(batch_norm_inference534, mx533, mx407, mx406, mx405, mx404); + migraphx::op::relu relu535; + auto mx535 = mm->add_instruction(relu535, mx534); + migraphx::op::convolution convolution536; + convolution536.padding = {1, 1}; + convolution536.stride = {1, 1}; + convolution536.dilation = {1, 1}; + convolution536.group = 1; + auto mx536 = mm->add_instruction(convolution536, mx535, mx403); + migraphx::op::batch_norm_inference batch_norm_inference537; + batch_norm_inference537.epsilon = 0.001; + batch_norm_inference537.momentum = 0.9; + auto mx537 = mm->add_instruction(batch_norm_inference537, mx536, mx402, mx401, mx400, mx399); + migraphx::op::relu relu538; + auto mx538 = mm->add_instruction(relu538, mx537); + migraphx::op::convolution convolution539; + convolution539.padding = {1, 1}; + convolution539.stride = {1, 1}; + convolution539.dilation = {1, 1}; + convolution539.group = 1; + auto mx539 = mm->add_instruction(convolution539, mx538, mx398); + migraphx::op::batch_norm_inference batch_norm_inference540; + batch_norm_inference540.epsilon = 0.001; + batch_norm_inference540.momentum = 0.9; + auto mx540 = mm->add_instruction(batch_norm_inference540, mx539, mx397, mx396, mx395, mx394); + migraphx::op::relu relu541; + auto mx541 = mm->add_instruction(relu541, mx540); + migraphx::op::pooling pooling542; + pooling542.mode = migraphx::op::pooling_mode::average; + pooling542.padding = {1, 1}; + pooling542.stride = {1, 1}; + pooling542.lengths = {3, 3}; + auto mx542 = mm->add_instruction(pooling542, mx523); + migraphx::op::convolution convolution543; + convolution543.padding = {0, 0}; + convolution543.stride = {1, 1}; + convolution543.dilation = {1, 1}; + convolution543.group = 1; + auto mx543 = mm->add_instruction(convolution543, mx542, mx393); + migraphx::op::batch_norm_inference batch_norm_inference544; + batch_norm_inference544.epsilon = 0.001; + batch_norm_inference544.momentum = 0.9; + auto mx544 = mm->add_instruction(batch_norm_inference544, mx543, mx392, mx391, mx390, mx389); + migraphx::op::relu relu545; + auto mx545 = mm->add_instruction(relu545, mx544); + migraphx::op::concat concat546; + concat546.axis = 1; + auto mx546 = mm->add_instruction(concat546, mx526, mx532, mx541, mx545); + migraphx::op::convolution convolution547; + convolution547.padding = {0, 0}; + convolution547.stride = {1, 1}; + convolution547.dilation = {1, 1}; + convolution547.group = 1; + auto mx547 = mm->add_instruction(convolution547, mx546, mx388); + migraphx::op::batch_norm_inference batch_norm_inference548; + batch_norm_inference548.epsilon = 0.001; + batch_norm_inference548.momentum = 0.9; + auto mx548 = mm->add_instruction(batch_norm_inference548, mx547, mx387, mx386, mx385, mx384); + migraphx::op::relu relu549; + auto mx549 = mm->add_instruction(relu549, mx548); + migraphx::op::convolution convolution550; + convolution550.padding = {0, 0}; + convolution550.stride = {1, 1}; + convolution550.dilation = {1, 1}; + convolution550.group = 1; + auto mx550 = mm->add_instruction(convolution550, mx546, mx383); + migraphx::op::batch_norm_inference batch_norm_inference551; + batch_norm_inference551.epsilon = 0.001; + batch_norm_inference551.momentum = 0.9; + auto mx551 = mm->add_instruction(batch_norm_inference551, mx550, mx382, mx381, mx380, mx379); + migraphx::op::relu relu552; + auto mx552 = mm->add_instruction(relu552, mx551); + migraphx::op::convolution convolution553; + convolution553.padding = {2, 2}; + convolution553.stride = {1, 1}; + convolution553.dilation = {1, 1}; + convolution553.group = 1; + auto mx553 = mm->add_instruction(convolution553, mx552, mx378); + migraphx::op::batch_norm_inference batch_norm_inference554; + batch_norm_inference554.epsilon = 0.001; + batch_norm_inference554.momentum = 0.9; + auto mx554 = mm->add_instruction(batch_norm_inference554, mx553, mx377, mx376, mx375, mx374); + migraphx::op::relu relu555; + auto mx555 = mm->add_instruction(relu555, mx554); + migraphx::op::convolution convolution556; + convolution556.padding = {0, 0}; + convolution556.stride = {1, 1}; + convolution556.dilation = {1, 1}; + convolution556.group = 1; + auto mx556 = mm->add_instruction(convolution556, mx546, mx373); + migraphx::op::batch_norm_inference batch_norm_inference557; + batch_norm_inference557.epsilon = 0.001; + batch_norm_inference557.momentum = 0.9; + auto mx557 = mm->add_instruction(batch_norm_inference557, mx556, mx372, mx371, mx370, mx369); + migraphx::op::relu relu558; + auto mx558 = mm->add_instruction(relu558, mx557); + migraphx::op::convolution convolution559; + convolution559.padding = {1, 1}; + convolution559.stride = {1, 1}; + convolution559.dilation = {1, 1}; + convolution559.group = 1; + auto mx559 = mm->add_instruction(convolution559, mx558, mx368); + migraphx::op::batch_norm_inference batch_norm_inference560; + batch_norm_inference560.epsilon = 0.001; + batch_norm_inference560.momentum = 0.9; + auto mx560 = mm->add_instruction(batch_norm_inference560, mx559, mx367, mx366, mx365, mx364); + migraphx::op::relu relu561; + auto mx561 = mm->add_instruction(relu561, mx560); + migraphx::op::convolution convolution562; + convolution562.padding = {1, 1}; + convolution562.stride = {1, 1}; + convolution562.dilation = {1, 1}; + convolution562.group = 1; + auto mx562 = mm->add_instruction(convolution562, mx561, mx363); + migraphx::op::batch_norm_inference batch_norm_inference563; + batch_norm_inference563.epsilon = 0.001; + batch_norm_inference563.momentum = 0.9; + auto mx563 = mm->add_instruction(batch_norm_inference563, mx562, mx362, mx361, mx360, mx359); + migraphx::op::relu relu564; + auto mx564 = mm->add_instruction(relu564, mx563); + migraphx::op::pooling pooling565; + pooling565.mode = migraphx::op::pooling_mode::average; + pooling565.padding = {1, 1}; + pooling565.stride = {1, 1}; + pooling565.lengths = {3, 3}; + auto mx565 = mm->add_instruction(pooling565, mx546); + migraphx::op::convolution convolution566; + convolution566.padding = {0, 0}; + convolution566.stride = {1, 1}; + convolution566.dilation = {1, 1}; + convolution566.group = 1; + auto mx566 = mm->add_instruction(convolution566, mx565, mx358); + migraphx::op::batch_norm_inference batch_norm_inference567; + batch_norm_inference567.epsilon = 0.001; + batch_norm_inference567.momentum = 0.9; + auto mx567 = mm->add_instruction(batch_norm_inference567, mx566, mx357, mx356, mx355, mx354); + migraphx::op::relu relu568; + auto mx568 = mm->add_instruction(relu568, mx567); + migraphx::op::concat concat569; + concat569.axis = 1; + auto mx569 = mm->add_instruction(concat569, mx549, mx555, mx564, mx568); + migraphx::op::convolution convolution570; + convolution570.padding = {0, 0}; + convolution570.stride = {2, 2}; + convolution570.dilation = {1, 1}; + convolution570.group = 1; + auto mx570 = mm->add_instruction(convolution570, mx569, mx353); + migraphx::op::batch_norm_inference batch_norm_inference571; + batch_norm_inference571.epsilon = 0.001; + batch_norm_inference571.momentum = 0.9; + auto mx571 = mm->add_instruction(batch_norm_inference571, mx570, mx352, mx351, mx350, mx349); + migraphx::op::relu relu572; + auto mx572 = mm->add_instruction(relu572, mx571); + migraphx::op::convolution convolution573; + convolution573.padding = {0, 0}; + convolution573.stride = {1, 1}; + convolution573.dilation = {1, 1}; + convolution573.group = 1; + auto mx573 = mm->add_instruction(convolution573, mx569, mx348); + migraphx::op::batch_norm_inference batch_norm_inference574; + batch_norm_inference574.epsilon = 0.001; + batch_norm_inference574.momentum = 0.9; + auto mx574 = mm->add_instruction(batch_norm_inference574, mx573, mx347, mx346, mx345, mx344); + migraphx::op::relu relu575; + auto mx575 = mm->add_instruction(relu575, mx574); + migraphx::op::convolution convolution576; + convolution576.padding = {1, 1}; + convolution576.stride = {1, 1}; + convolution576.dilation = {1, 1}; + convolution576.group = 1; + auto mx576 = mm->add_instruction(convolution576, mx575, mx343); + migraphx::op::batch_norm_inference batch_norm_inference577; + batch_norm_inference577.epsilon = 0.001; + batch_norm_inference577.momentum = 0.9; + auto mx577 = mm->add_instruction(batch_norm_inference577, mx576, mx342, mx341, mx340, mx339); + migraphx::op::relu relu578; + auto mx578 = mm->add_instruction(relu578, mx577); + migraphx::op::convolution convolution579; + convolution579.padding = {0, 0}; + convolution579.stride = {2, 2}; + convolution579.dilation = {1, 1}; + convolution579.group = 1; + auto mx579 = mm->add_instruction(convolution579, mx578, mx338); + migraphx::op::batch_norm_inference batch_norm_inference580; + batch_norm_inference580.epsilon = 0.001; + batch_norm_inference580.momentum = 0.9; + auto mx580 = mm->add_instruction(batch_norm_inference580, mx579, mx337, mx336, mx335, mx334); + migraphx::op::relu relu581; + auto mx581 = mm->add_instruction(relu581, mx580); + migraphx::op::pooling pooling582; + pooling582.mode = migraphx::op::pooling_mode::max; + pooling582.padding = {0, 0}; + pooling582.stride = {2, 2}; + pooling582.lengths = {3, 3}; + auto mx582 = mm->add_instruction(pooling582, mx569); + migraphx::op::concat concat583; + concat583.axis = 1; + auto mx583 = mm->add_instruction(concat583, mx572, mx581, mx582); + migraphx::op::convolution convolution584; + convolution584.padding = {0, 0}; + convolution584.stride = {1, 1}; + convolution584.dilation = {1, 1}; + convolution584.group = 1; + auto mx584 = mm->add_instruction(convolution584, mx583, mx333); + migraphx::op::batch_norm_inference batch_norm_inference585; + batch_norm_inference585.epsilon = 0.001; + batch_norm_inference585.momentum = 0.9; + auto mx585 = mm->add_instruction(batch_norm_inference585, mx584, mx332, mx331, mx330, mx329); + migraphx::op::relu relu586; + auto mx586 = mm->add_instruction(relu586, mx585); + migraphx::op::convolution convolution587; + convolution587.padding = {0, 0}; + convolution587.stride = {1, 1}; + convolution587.dilation = {1, 1}; + convolution587.group = 1; + auto mx587 = mm->add_instruction(convolution587, mx583, mx328); + migraphx::op::batch_norm_inference batch_norm_inference588; + batch_norm_inference588.epsilon = 0.001; + batch_norm_inference588.momentum = 0.9; + auto mx588 = mm->add_instruction(batch_norm_inference588, mx587, mx327, mx326, mx325, mx324); + migraphx::op::relu relu589; + auto mx589 = mm->add_instruction(relu589, mx588); + migraphx::op::convolution convolution590; + convolution590.padding = {0, 3}; + convolution590.stride = {1, 1}; + convolution590.dilation = {1, 1}; + convolution590.group = 1; + auto mx590 = mm->add_instruction(convolution590, mx589, mx323); + migraphx::op::batch_norm_inference batch_norm_inference591; + batch_norm_inference591.epsilon = 0.001; + batch_norm_inference591.momentum = 0.9; + auto mx591 = mm->add_instruction(batch_norm_inference591, mx590, mx322, mx321, mx320, mx319); + migraphx::op::relu relu592; + auto mx592 = mm->add_instruction(relu592, mx591); + migraphx::op::convolution convolution593; + convolution593.padding = {3, 0}; + convolution593.stride = {1, 1}; + convolution593.dilation = {1, 1}; + convolution593.group = 1; + auto mx593 = mm->add_instruction(convolution593, mx592, mx318); + migraphx::op::batch_norm_inference batch_norm_inference594; + batch_norm_inference594.epsilon = 0.001; + batch_norm_inference594.momentum = 0.9; + auto mx594 = mm->add_instruction(batch_norm_inference594, mx593, mx317, mx316, mx315, mx314); + migraphx::op::relu relu595; + auto mx595 = mm->add_instruction(relu595, mx594); + migraphx::op::convolution convolution596; + convolution596.padding = {0, 0}; + convolution596.stride = {1, 1}; + convolution596.dilation = {1, 1}; + convolution596.group = 1; + auto mx596 = mm->add_instruction(convolution596, mx583, mx313); + migraphx::op::batch_norm_inference batch_norm_inference597; + batch_norm_inference597.epsilon = 0.001; + batch_norm_inference597.momentum = 0.9; + auto mx597 = mm->add_instruction(batch_norm_inference597, mx596, mx312, mx311, mx310, mx309); + migraphx::op::relu relu598; + auto mx598 = mm->add_instruction(relu598, mx597); + migraphx::op::convolution convolution599; + convolution599.padding = {3, 0}; + convolution599.stride = {1, 1}; + convolution599.dilation = {1, 1}; + convolution599.group = 1; + auto mx599 = mm->add_instruction(convolution599, mx598, mx308); + migraphx::op::batch_norm_inference batch_norm_inference600; + batch_norm_inference600.epsilon = 0.001; + batch_norm_inference600.momentum = 0.9; + auto mx600 = mm->add_instruction(batch_norm_inference600, mx599, mx307, mx306, mx305, mx304); + migraphx::op::relu relu601; + auto mx601 = mm->add_instruction(relu601, mx600); + migraphx::op::convolution convolution602; + convolution602.padding = {0, 3}; + convolution602.stride = {1, 1}; + convolution602.dilation = {1, 1}; + convolution602.group = 1; + auto mx602 = mm->add_instruction(convolution602, mx601, mx303); + migraphx::op::batch_norm_inference batch_norm_inference603; + batch_norm_inference603.epsilon = 0.001; + batch_norm_inference603.momentum = 0.9; + auto mx603 = mm->add_instruction(batch_norm_inference603, mx602, mx302, mx301, mx300, mx299); + migraphx::op::relu relu604; + auto mx604 = mm->add_instruction(relu604, mx603); + migraphx::op::convolution convolution605; + convolution605.padding = {3, 0}; + convolution605.stride = {1, 1}; + convolution605.dilation = {1, 1}; + convolution605.group = 1; + auto mx605 = mm->add_instruction(convolution605, mx604, mx298); + migraphx::op::batch_norm_inference batch_norm_inference606; + batch_norm_inference606.epsilon = 0.001; + batch_norm_inference606.momentum = 0.9; + auto mx606 = mm->add_instruction(batch_norm_inference606, mx605, mx297, mx296, mx295, mx294); + migraphx::op::relu relu607; + auto mx607 = mm->add_instruction(relu607, mx606); + migraphx::op::convolution convolution608; + convolution608.padding = {0, 3}; + convolution608.stride = {1, 1}; + convolution608.dilation = {1, 1}; + convolution608.group = 1; + auto mx608 = mm->add_instruction(convolution608, mx607, mx293); + migraphx::op::batch_norm_inference batch_norm_inference609; + batch_norm_inference609.epsilon = 0.001; + batch_norm_inference609.momentum = 0.9; + auto mx609 = mm->add_instruction(batch_norm_inference609, mx608, mx292, mx291, mx290, mx289); + migraphx::op::relu relu610; + auto mx610 = mm->add_instruction(relu610, mx609); + migraphx::op::pooling pooling611; + pooling611.mode = migraphx::op::pooling_mode::average; + pooling611.padding = {1, 1}; + pooling611.stride = {1, 1}; + pooling611.lengths = {3, 3}; + auto mx611 = mm->add_instruction(pooling611, mx583); + migraphx::op::convolution convolution612; + convolution612.padding = {0, 0}; + convolution612.stride = {1, 1}; + convolution612.dilation = {1, 1}; + convolution612.group = 1; + auto mx612 = mm->add_instruction(convolution612, mx611, mx288); + migraphx::op::batch_norm_inference batch_norm_inference613; + batch_norm_inference613.epsilon = 0.001; + batch_norm_inference613.momentum = 0.9; + auto mx613 = mm->add_instruction(batch_norm_inference613, mx612, mx287, mx286, mx285, mx284); + migraphx::op::relu relu614; + auto mx614 = mm->add_instruction(relu614, mx613); + migraphx::op::concat concat615; + concat615.axis = 1; + auto mx615 = mm->add_instruction(concat615, mx586, mx595, mx610, mx614); + migraphx::op::convolution convolution616; + convolution616.padding = {0, 0}; + convolution616.stride = {1, 1}; + convolution616.dilation = {1, 1}; + convolution616.group = 1; + auto mx616 = mm->add_instruction(convolution616, mx615, mx283); + migraphx::op::batch_norm_inference batch_norm_inference617; + batch_norm_inference617.epsilon = 0.001; + batch_norm_inference617.momentum = 0.9; + auto mx617 = mm->add_instruction(batch_norm_inference617, mx616, mx282, mx281, mx280, mx279); + migraphx::op::relu relu618; + auto mx618 = mm->add_instruction(relu618, mx617); + migraphx::op::convolution convolution619; + convolution619.padding = {0, 0}; + convolution619.stride = {1, 1}; + convolution619.dilation = {1, 1}; + convolution619.group = 1; + auto mx619 = mm->add_instruction(convolution619, mx615, mx278); + migraphx::op::batch_norm_inference batch_norm_inference620; + batch_norm_inference620.epsilon = 0.001; + batch_norm_inference620.momentum = 0.9; + auto mx620 = mm->add_instruction(batch_norm_inference620, mx619, mx277, mx276, mx275, mx274); + migraphx::op::relu relu621; + auto mx621 = mm->add_instruction(relu621, mx620); + migraphx::op::convolution convolution622; + convolution622.padding = {0, 3}; + convolution622.stride = {1, 1}; + convolution622.dilation = {1, 1}; + convolution622.group = 1; + auto mx622 = mm->add_instruction(convolution622, mx621, mx273); + migraphx::op::batch_norm_inference batch_norm_inference623; + batch_norm_inference623.epsilon = 0.001; + batch_norm_inference623.momentum = 0.9; + auto mx623 = mm->add_instruction(batch_norm_inference623, mx622, mx272, mx271, mx270, mx269); + migraphx::op::relu relu624; + auto mx624 = mm->add_instruction(relu624, mx623); + migraphx::op::convolution convolution625; + convolution625.padding = {3, 0}; + convolution625.stride = {1, 1}; + convolution625.dilation = {1, 1}; + convolution625.group = 1; + auto mx625 = mm->add_instruction(convolution625, mx624, mx268); + migraphx::op::batch_norm_inference batch_norm_inference626; + batch_norm_inference626.epsilon = 0.001; + batch_norm_inference626.momentum = 0.9; + auto mx626 = mm->add_instruction(batch_norm_inference626, mx625, mx267, mx266, mx265, mx264); + migraphx::op::relu relu627; + auto mx627 = mm->add_instruction(relu627, mx626); + migraphx::op::convolution convolution628; + convolution628.padding = {0, 0}; + convolution628.stride = {1, 1}; + convolution628.dilation = {1, 1}; + convolution628.group = 1; + auto mx628 = mm->add_instruction(convolution628, mx615, mx263); + migraphx::op::batch_norm_inference batch_norm_inference629; + batch_norm_inference629.epsilon = 0.001; + batch_norm_inference629.momentum = 0.9; + auto mx629 = mm->add_instruction(batch_norm_inference629, mx628, mx262, mx261, mx260, mx259); + migraphx::op::relu relu630; + auto mx630 = mm->add_instruction(relu630, mx629); + migraphx::op::convolution convolution631; + convolution631.padding = {3, 0}; + convolution631.stride = {1, 1}; + convolution631.dilation = {1, 1}; + convolution631.group = 1; + auto mx631 = mm->add_instruction(convolution631, mx630, mx258); + migraphx::op::batch_norm_inference batch_norm_inference632; + batch_norm_inference632.epsilon = 0.001; + batch_norm_inference632.momentum = 0.9; + auto mx632 = mm->add_instruction(batch_norm_inference632, mx631, mx257, mx256, mx255, mx254); + migraphx::op::relu relu633; + auto mx633 = mm->add_instruction(relu633, mx632); + migraphx::op::convolution convolution634; + convolution634.padding = {0, 3}; + convolution634.stride = {1, 1}; + convolution634.dilation = {1, 1}; + convolution634.group = 1; + auto mx634 = mm->add_instruction(convolution634, mx633, mx253); + migraphx::op::batch_norm_inference batch_norm_inference635; + batch_norm_inference635.epsilon = 0.001; + batch_norm_inference635.momentum = 0.9; + auto mx635 = mm->add_instruction(batch_norm_inference635, mx634, mx252, mx251, mx250, mx249); + migraphx::op::relu relu636; + auto mx636 = mm->add_instruction(relu636, mx635); + migraphx::op::convolution convolution637; + convolution637.padding = {3, 0}; + convolution637.stride = {1, 1}; + convolution637.dilation = {1, 1}; + convolution637.group = 1; + auto mx637 = mm->add_instruction(convolution637, mx636, mx248); + migraphx::op::batch_norm_inference batch_norm_inference638; + batch_norm_inference638.epsilon = 0.001; + batch_norm_inference638.momentum = 0.9; + auto mx638 = mm->add_instruction(batch_norm_inference638, mx637, mx247, mx246, mx245, mx244); + migraphx::op::relu relu639; + auto mx639 = mm->add_instruction(relu639, mx638); + migraphx::op::convolution convolution640; + convolution640.padding = {0, 3}; + convolution640.stride = {1, 1}; + convolution640.dilation = {1, 1}; + convolution640.group = 1; + auto mx640 = mm->add_instruction(convolution640, mx639, mx243); + migraphx::op::batch_norm_inference batch_norm_inference641; + batch_norm_inference641.epsilon = 0.001; + batch_norm_inference641.momentum = 0.9; + auto mx641 = mm->add_instruction(batch_norm_inference641, mx640, mx242, mx241, mx240, mx239); + migraphx::op::relu relu642; + auto mx642 = mm->add_instruction(relu642, mx641); + migraphx::op::pooling pooling643; + pooling643.mode = migraphx::op::pooling_mode::average; + pooling643.padding = {1, 1}; + pooling643.stride = {1, 1}; + pooling643.lengths = {3, 3}; + auto mx643 = mm->add_instruction(pooling643, mx615); + migraphx::op::convolution convolution644; + convolution644.padding = {0, 0}; + convolution644.stride = {1, 1}; + convolution644.dilation = {1, 1}; + convolution644.group = 1; + auto mx644 = mm->add_instruction(convolution644, mx643, mx238); + migraphx::op::batch_norm_inference batch_norm_inference645; + batch_norm_inference645.epsilon = 0.001; + batch_norm_inference645.momentum = 0.9; + auto mx645 = mm->add_instruction(batch_norm_inference645, mx644, mx237, mx236, mx235, mx234); + migraphx::op::relu relu646; + auto mx646 = mm->add_instruction(relu646, mx645); + migraphx::op::concat concat647; + concat647.axis = 1; + auto mx647 = mm->add_instruction(concat647, mx618, mx627, mx642, mx646); + migraphx::op::convolution convolution648; + convolution648.padding = {0, 0}; + convolution648.stride = {1, 1}; + convolution648.dilation = {1, 1}; + convolution648.group = 1; + auto mx648 = mm->add_instruction(convolution648, mx647, mx233); + migraphx::op::batch_norm_inference batch_norm_inference649; + batch_norm_inference649.epsilon = 0.001; + batch_norm_inference649.momentum = 0.9; + auto mx649 = mm->add_instruction(batch_norm_inference649, mx648, mx232, mx231, mx230, mx229); + migraphx::op::relu relu650; + auto mx650 = mm->add_instruction(relu650, mx649); + migraphx::op::convolution convolution651; + convolution651.padding = {0, 0}; + convolution651.stride = {1, 1}; + convolution651.dilation = {1, 1}; + convolution651.group = 1; + auto mx651 = mm->add_instruction(convolution651, mx647, mx228); + migraphx::op::batch_norm_inference batch_norm_inference652; + batch_norm_inference652.epsilon = 0.001; + batch_norm_inference652.momentum = 0.9; + auto mx652 = mm->add_instruction(batch_norm_inference652, mx651, mx227, mx226, mx225, mx224); + migraphx::op::relu relu653; + auto mx653 = mm->add_instruction(relu653, mx652); + migraphx::op::convolution convolution654; + convolution654.padding = {0, 3}; + convolution654.stride = {1, 1}; + convolution654.dilation = {1, 1}; + convolution654.group = 1; + auto mx654 = mm->add_instruction(convolution654, mx653, mx223); + migraphx::op::batch_norm_inference batch_norm_inference655; + batch_norm_inference655.epsilon = 0.001; + batch_norm_inference655.momentum = 0.9; + auto mx655 = mm->add_instruction(batch_norm_inference655, mx654, mx222, mx221, mx220, mx219); + migraphx::op::relu relu656; + auto mx656 = mm->add_instruction(relu656, mx655); + migraphx::op::convolution convolution657; + convolution657.padding = {3, 0}; + convolution657.stride = {1, 1}; + convolution657.dilation = {1, 1}; + convolution657.group = 1; + auto mx657 = mm->add_instruction(convolution657, mx656, mx218); + migraphx::op::batch_norm_inference batch_norm_inference658; + batch_norm_inference658.epsilon = 0.001; + batch_norm_inference658.momentum = 0.9; + auto mx658 = mm->add_instruction(batch_norm_inference658, mx657, mx217, mx216, mx215, mx214); + migraphx::op::relu relu659; + auto mx659 = mm->add_instruction(relu659, mx658); + migraphx::op::convolution convolution660; + convolution660.padding = {0, 0}; + convolution660.stride = {1, 1}; + convolution660.dilation = {1, 1}; + convolution660.group = 1; + auto mx660 = mm->add_instruction(convolution660, mx647, mx213); + migraphx::op::batch_norm_inference batch_norm_inference661; + batch_norm_inference661.epsilon = 0.001; + batch_norm_inference661.momentum = 0.9; + auto mx661 = mm->add_instruction(batch_norm_inference661, mx660, mx212, mx211, mx210, mx209); + migraphx::op::relu relu662; + auto mx662 = mm->add_instruction(relu662, mx661); + migraphx::op::convolution convolution663; + convolution663.padding = {3, 0}; + convolution663.stride = {1, 1}; + convolution663.dilation = {1, 1}; + convolution663.group = 1; + auto mx663 = mm->add_instruction(convolution663, mx662, mx208); + migraphx::op::batch_norm_inference batch_norm_inference664; + batch_norm_inference664.epsilon = 0.001; + batch_norm_inference664.momentum = 0.9; + auto mx664 = mm->add_instruction(batch_norm_inference664, mx663, mx207, mx206, mx205, mx204); + migraphx::op::relu relu665; + auto mx665 = mm->add_instruction(relu665, mx664); + migraphx::op::convolution convolution666; + convolution666.padding = {0, 3}; + convolution666.stride = {1, 1}; + convolution666.dilation = {1, 1}; + convolution666.group = 1; + auto mx666 = mm->add_instruction(convolution666, mx665, mx203); + migraphx::op::batch_norm_inference batch_norm_inference667; + batch_norm_inference667.epsilon = 0.001; + batch_norm_inference667.momentum = 0.9; + auto mx667 = mm->add_instruction(batch_norm_inference667, mx666, mx202, mx201, mx200, mx199); + migraphx::op::relu relu668; + auto mx668 = mm->add_instruction(relu668, mx667); + migraphx::op::convolution convolution669; + convolution669.padding = {3, 0}; + convolution669.stride = {1, 1}; + convolution669.dilation = {1, 1}; + convolution669.group = 1; + auto mx669 = mm->add_instruction(convolution669, mx668, mx198); + migraphx::op::batch_norm_inference batch_norm_inference670; + batch_norm_inference670.epsilon = 0.001; + batch_norm_inference670.momentum = 0.9; + auto mx670 = mm->add_instruction(batch_norm_inference670, mx669, mx197, mx196, mx195, mx194); + migraphx::op::relu relu671; + auto mx671 = mm->add_instruction(relu671, mx670); + migraphx::op::convolution convolution672; + convolution672.padding = {0, 3}; + convolution672.stride = {1, 1}; + convolution672.dilation = {1, 1}; + convolution672.group = 1; + auto mx672 = mm->add_instruction(convolution672, mx671, mx193); + migraphx::op::batch_norm_inference batch_norm_inference673; + batch_norm_inference673.epsilon = 0.001; + batch_norm_inference673.momentum = 0.9; + auto mx673 = mm->add_instruction(batch_norm_inference673, mx672, mx192, mx191, mx190, mx189); + migraphx::op::relu relu674; + auto mx674 = mm->add_instruction(relu674, mx673); + migraphx::op::pooling pooling675; + pooling675.mode = migraphx::op::pooling_mode::average; + pooling675.padding = {1, 1}; + pooling675.stride = {1, 1}; + pooling675.lengths = {3, 3}; + auto mx675 = mm->add_instruction(pooling675, mx647); + migraphx::op::convolution convolution676; + convolution676.padding = {0, 0}; + convolution676.stride = {1, 1}; + convolution676.dilation = {1, 1}; + convolution676.group = 1; + auto mx676 = mm->add_instruction(convolution676, mx675, mx188); + migraphx::op::batch_norm_inference batch_norm_inference677; + batch_norm_inference677.epsilon = 0.001; + batch_norm_inference677.momentum = 0.9; + auto mx677 = mm->add_instruction(batch_norm_inference677, mx676, mx187, mx186, mx185, mx184); + migraphx::op::relu relu678; + auto mx678 = mm->add_instruction(relu678, mx677); + migraphx::op::concat concat679; + concat679.axis = 1; + auto mx679 = mm->add_instruction(concat679, mx650, mx659, mx674, mx678); + migraphx::op::convolution convolution680; + convolution680.padding = {0, 0}; + convolution680.stride = {1, 1}; + convolution680.dilation = {1, 1}; + convolution680.group = 1; + auto mx680 = mm->add_instruction(convolution680, mx679, mx183); + migraphx::op::batch_norm_inference batch_norm_inference681; + batch_norm_inference681.epsilon = 0.001; + batch_norm_inference681.momentum = 0.9; + auto mx681 = mm->add_instruction(batch_norm_inference681, mx680, mx182, mx181, mx180, mx179); + migraphx::op::relu relu682; + auto mx682 = mm->add_instruction(relu682, mx681); + migraphx::op::convolution convolution683; + convolution683.padding = {0, 0}; + convolution683.stride = {1, 1}; + convolution683.dilation = {1, 1}; + convolution683.group = 1; + auto mx683 = mm->add_instruction(convolution683, mx679, mx178); + migraphx::op::batch_norm_inference batch_norm_inference684; + batch_norm_inference684.epsilon = 0.001; + batch_norm_inference684.momentum = 0.9; + auto mx684 = mm->add_instruction(batch_norm_inference684, mx683, mx177, mx176, mx175, mx174); + migraphx::op::relu relu685; + auto mx685 = mm->add_instruction(relu685, mx684); + migraphx::op::convolution convolution686; + convolution686.padding = {0, 3}; + convolution686.stride = {1, 1}; + convolution686.dilation = {1, 1}; + convolution686.group = 1; + auto mx686 = mm->add_instruction(convolution686, mx685, mx173); + migraphx::op::batch_norm_inference batch_norm_inference687; + batch_norm_inference687.epsilon = 0.001; + batch_norm_inference687.momentum = 0.9; + auto mx687 = mm->add_instruction(batch_norm_inference687, mx686, mx172, mx171, mx170, mx169); + migraphx::op::relu relu688; + auto mx688 = mm->add_instruction(relu688, mx687); + migraphx::op::convolution convolution689; + convolution689.padding = {3, 0}; + convolution689.stride = {1, 1}; + convolution689.dilation = {1, 1}; + convolution689.group = 1; + auto mx689 = mm->add_instruction(convolution689, mx688, mx168); + migraphx::op::batch_norm_inference batch_norm_inference690; + batch_norm_inference690.epsilon = 0.001; + batch_norm_inference690.momentum = 0.9; + auto mx690 = mm->add_instruction(batch_norm_inference690, mx689, mx167, mx166, mx165, mx164); + migraphx::op::relu relu691; + auto mx691 = mm->add_instruction(relu691, mx690); + migraphx::op::convolution convolution692; + convolution692.padding = {0, 0}; + convolution692.stride = {1, 1}; + convolution692.dilation = {1, 1}; + convolution692.group = 1; + auto mx692 = mm->add_instruction(convolution692, mx679, mx163); + migraphx::op::batch_norm_inference batch_norm_inference693; + batch_norm_inference693.epsilon = 0.001; + batch_norm_inference693.momentum = 0.9; + auto mx693 = mm->add_instruction(batch_norm_inference693, mx692, mx162, mx161, mx160, mx159); + migraphx::op::relu relu694; + auto mx694 = mm->add_instruction(relu694, mx693); + migraphx::op::convolution convolution695; + convolution695.padding = {3, 0}; + convolution695.stride = {1, 1}; + convolution695.dilation = {1, 1}; + convolution695.group = 1; + auto mx695 = mm->add_instruction(convolution695, mx694, mx158); + migraphx::op::batch_norm_inference batch_norm_inference696; + batch_norm_inference696.epsilon = 0.001; + batch_norm_inference696.momentum = 0.9; + auto mx696 = mm->add_instruction(batch_norm_inference696, mx695, mx157, mx156, mx155, mx154); + migraphx::op::relu relu697; + auto mx697 = mm->add_instruction(relu697, mx696); + migraphx::op::convolution convolution698; + convolution698.padding = {0, 3}; + convolution698.stride = {1, 1}; + convolution698.dilation = {1, 1}; + convolution698.group = 1; + auto mx698 = mm->add_instruction(convolution698, mx697, mx153); + migraphx::op::batch_norm_inference batch_norm_inference699; + batch_norm_inference699.epsilon = 0.001; + batch_norm_inference699.momentum = 0.9; + auto mx699 = mm->add_instruction(batch_norm_inference699, mx698, mx152, mx151, mx150, mx149); + migraphx::op::relu relu700; + auto mx700 = mm->add_instruction(relu700, mx699); + migraphx::op::convolution convolution701; + convolution701.padding = {3, 0}; + convolution701.stride = {1, 1}; + convolution701.dilation = {1, 1}; + convolution701.group = 1; + auto mx701 = mm->add_instruction(convolution701, mx700, mx148); + migraphx::op::batch_norm_inference batch_norm_inference702; + batch_norm_inference702.epsilon = 0.001; + batch_norm_inference702.momentum = 0.9; + auto mx702 = mm->add_instruction(batch_norm_inference702, mx701, mx147, mx146, mx145, mx144); + migraphx::op::relu relu703; + auto mx703 = mm->add_instruction(relu703, mx702); + migraphx::op::convolution convolution704; + convolution704.padding = {0, 3}; + convolution704.stride = {1, 1}; + convolution704.dilation = {1, 1}; + convolution704.group = 1; + auto mx704 = mm->add_instruction(convolution704, mx703, mx143); + migraphx::op::batch_norm_inference batch_norm_inference705; + batch_norm_inference705.epsilon = 0.001; + batch_norm_inference705.momentum = 0.9; + auto mx705 = mm->add_instruction(batch_norm_inference705, mx704, mx142, mx141, mx140, mx139); + migraphx::op::relu relu706; + auto mx706 = mm->add_instruction(relu706, mx705); + migraphx::op::pooling pooling707; + pooling707.mode = migraphx::op::pooling_mode::average; + pooling707.padding = {1, 1}; + pooling707.stride = {1, 1}; + pooling707.lengths = {3, 3}; + auto mx707 = mm->add_instruction(pooling707, mx679); + migraphx::op::convolution convolution708; + convolution708.padding = {0, 0}; + convolution708.stride = {1, 1}; + convolution708.dilation = {1, 1}; + convolution708.group = 1; + auto mx708 = mm->add_instruction(convolution708, mx707, mx138); + migraphx::op::batch_norm_inference batch_norm_inference709; + batch_norm_inference709.epsilon = 0.001; + batch_norm_inference709.momentum = 0.9; + auto mx709 = mm->add_instruction(batch_norm_inference709, mx708, mx137, mx136, mx135, mx134); + migraphx::op::relu relu710; + auto mx710 = mm->add_instruction(relu710, mx709); + migraphx::op::concat concat711; + concat711.axis = 1; + auto mx711 = mm->add_instruction(concat711, mx682, mx691, mx706, mx710); + migraphx::op::convolution convolution712; + convolution712.padding = {0, 0}; + convolution712.stride = {1, 1}; + convolution712.dilation = {1, 1}; + convolution712.group = 1; + auto mx712 = mm->add_instruction(convolution712, mx711, mx121); + migraphx::op::batch_norm_inference batch_norm_inference713; + batch_norm_inference713.epsilon = 0.001; + batch_norm_inference713.momentum = 0.9; + auto mx713 = mm->add_instruction(batch_norm_inference713, mx712, mx120, mx119, mx118, mx117); + migraphx::op::relu relu714; + auto mx714 = mm->add_instruction(relu714, mx713); + migraphx::op::convolution convolution715; + convolution715.padding = {0, 0}; + convolution715.stride = {2, 2}; + convolution715.dilation = {1, 1}; + convolution715.group = 1; + auto mx715 = mm->add_instruction(convolution715, mx714, mx116); + migraphx::op::batch_norm_inference batch_norm_inference716; + batch_norm_inference716.epsilon = 0.001; + batch_norm_inference716.momentum = 0.9; + auto mx716 = mm->add_instruction(batch_norm_inference716, mx715, mx115, mx114, mx113, mx112); + migraphx::op::relu relu717; + auto mx717 = mm->add_instruction(relu717, mx716); + migraphx::op::convolution convolution718; + convolution718.padding = {0, 0}; + convolution718.stride = {1, 1}; + convolution718.dilation = {1, 1}; + convolution718.group = 1; + auto mx718 = mm->add_instruction(convolution718, mx711, mx111); + migraphx::op::batch_norm_inference batch_norm_inference719; + batch_norm_inference719.epsilon = 0.001; + batch_norm_inference719.momentum = 0.9; + auto mx719 = mm->add_instruction(batch_norm_inference719, mx718, mx110, mx109, mx108, mx107); + migraphx::op::relu relu720; + auto mx720 = mm->add_instruction(relu720, mx719); + migraphx::op::convolution convolution721; + convolution721.padding = {0, 3}; + convolution721.stride = {1, 1}; + convolution721.dilation = {1, 1}; + convolution721.group = 1; + auto mx721 = mm->add_instruction(convolution721, mx720, mx106); + migraphx::op::batch_norm_inference batch_norm_inference722; + batch_norm_inference722.epsilon = 0.001; + batch_norm_inference722.momentum = 0.9; + auto mx722 = mm->add_instruction(batch_norm_inference722, mx721, mx105, mx104, mx103, mx102); + migraphx::op::relu relu723; + auto mx723 = mm->add_instruction(relu723, mx722); + migraphx::op::convolution convolution724; + convolution724.padding = {3, 0}; + convolution724.stride = {1, 1}; + convolution724.dilation = {1, 1}; + convolution724.group = 1; + auto mx724 = mm->add_instruction(convolution724, mx723, mx101); + migraphx::op::batch_norm_inference batch_norm_inference725; + batch_norm_inference725.epsilon = 0.001; + batch_norm_inference725.momentum = 0.9; + auto mx725 = mm->add_instruction(batch_norm_inference725, mx724, mx100, mx99, mx98, mx97); + migraphx::op::relu relu726; + auto mx726 = mm->add_instruction(relu726, mx725); + migraphx::op::convolution convolution727; + convolution727.padding = {0, 0}; + convolution727.stride = {2, 2}; + convolution727.dilation = {1, 1}; + convolution727.group = 1; + auto mx727 = mm->add_instruction(convolution727, mx726, mx96); + migraphx::op::batch_norm_inference batch_norm_inference728; + batch_norm_inference728.epsilon = 0.001; + batch_norm_inference728.momentum = 0.9; + auto mx728 = mm->add_instruction(batch_norm_inference728, mx727, mx95, mx94, mx93, mx92); + migraphx::op::relu relu729; + auto mx729 = mm->add_instruction(relu729, mx728); + migraphx::op::pooling pooling730; + pooling730.mode = migraphx::op::pooling_mode::max; + pooling730.padding = {0, 0}; + pooling730.stride = {2, 2}; + pooling730.lengths = {3, 3}; + auto mx730 = mm->add_instruction(pooling730, mx711); + migraphx::op::concat concat731; + concat731.axis = 1; + auto mx731 = mm->add_instruction(concat731, mx717, mx729, mx730); + migraphx::op::convolution convolution732; + convolution732.padding = {0, 0}; + convolution732.stride = {1, 1}; + convolution732.dilation = {1, 1}; + convolution732.group = 1; + auto mx732 = mm->add_instruction(convolution732, mx731, mx91); + migraphx::op::batch_norm_inference batch_norm_inference733; + batch_norm_inference733.epsilon = 0.001; + batch_norm_inference733.momentum = 0.9; + auto mx733 = mm->add_instruction(batch_norm_inference733, mx732, mx90, mx89, mx88, mx87); + migraphx::op::relu relu734; + auto mx734 = mm->add_instruction(relu734, mx733); + migraphx::op::convolution convolution735; + convolution735.padding = {0, 0}; + convolution735.stride = {1, 1}; + convolution735.dilation = {1, 1}; + convolution735.group = 1; + auto mx735 = mm->add_instruction(convolution735, mx731, mx86); + migraphx::op::batch_norm_inference batch_norm_inference736; + batch_norm_inference736.epsilon = 0.001; + batch_norm_inference736.momentum = 0.9; + auto mx736 = mm->add_instruction(batch_norm_inference736, mx735, mx85, mx84, mx83, mx82); + migraphx::op::relu relu737; + auto mx737 = mm->add_instruction(relu737, mx736); + migraphx::op::convolution convolution738; + convolution738.padding = {0, 1}; + convolution738.stride = {1, 1}; + convolution738.dilation = {1, 1}; + convolution738.group = 1; + auto mx738 = mm->add_instruction(convolution738, mx737, mx81); + migraphx::op::batch_norm_inference batch_norm_inference739; + batch_norm_inference739.epsilon = 0.001; + batch_norm_inference739.momentum = 0.9; + auto mx739 = mm->add_instruction(batch_norm_inference739, mx738, mx80, mx79, mx78, mx77); + migraphx::op::relu relu740; + auto mx740 = mm->add_instruction(relu740, mx739); + migraphx::op::convolution convolution741; + convolution741.padding = {1, 0}; + convolution741.stride = {1, 1}; + convolution741.dilation = {1, 1}; + convolution741.group = 1; + auto mx741 = mm->add_instruction(convolution741, mx737, mx76); + migraphx::op::batch_norm_inference batch_norm_inference742; + batch_norm_inference742.epsilon = 0.001; + batch_norm_inference742.momentum = 0.9; + auto mx742 = mm->add_instruction(batch_norm_inference742, mx741, mx75, mx74, mx73, mx72); + migraphx::op::relu relu743; + auto mx743 = mm->add_instruction(relu743, mx742); + migraphx::op::concat concat744; + concat744.axis = 1; + auto mx744 = mm->add_instruction(concat744, mx740, mx743); + migraphx::op::convolution convolution745; + convolution745.padding = {0, 0}; + convolution745.stride = {1, 1}; + convolution745.dilation = {1, 1}; + convolution745.group = 1; + auto mx745 = mm->add_instruction(convolution745, mx731, mx71); + migraphx::op::batch_norm_inference batch_norm_inference746; + batch_norm_inference746.epsilon = 0.001; + batch_norm_inference746.momentum = 0.9; + auto mx746 = mm->add_instruction(batch_norm_inference746, mx745, mx70, mx69, mx68, mx67); + migraphx::op::relu relu747; + auto mx747 = mm->add_instruction(relu747, mx746); + migraphx::op::convolution convolution748; + convolution748.padding = {1, 1}; + convolution748.stride = {1, 1}; + convolution748.dilation = {1, 1}; + convolution748.group = 1; + auto mx748 = mm->add_instruction(convolution748, mx747, mx66); + migraphx::op::batch_norm_inference batch_norm_inference749; + batch_norm_inference749.epsilon = 0.001; + batch_norm_inference749.momentum = 0.9; + auto mx749 = mm->add_instruction(batch_norm_inference749, mx748, mx65, mx64, mx63, mx62); + migraphx::op::relu relu750; + auto mx750 = mm->add_instruction(relu750, mx749); + migraphx::op::convolution convolution751; + convolution751.padding = {0, 1}; + convolution751.stride = {1, 1}; + convolution751.dilation = {1, 1}; + convolution751.group = 1; + auto mx751 = mm->add_instruction(convolution751, mx750, mx61); + migraphx::op::batch_norm_inference batch_norm_inference752; + batch_norm_inference752.epsilon = 0.001; + batch_norm_inference752.momentum = 0.9; + auto mx752 = mm->add_instruction(batch_norm_inference752, mx751, mx60, mx59, mx58, mx57); + migraphx::op::relu relu753; + auto mx753 = mm->add_instruction(relu753, mx752); + migraphx::op::convolution convolution754; + convolution754.padding = {1, 0}; + convolution754.stride = {1, 1}; + convolution754.dilation = {1, 1}; + convolution754.group = 1; + auto mx754 = mm->add_instruction(convolution754, mx750, mx56); + migraphx::op::batch_norm_inference batch_norm_inference755; + batch_norm_inference755.epsilon = 0.001; + batch_norm_inference755.momentum = 0.9; + auto mx755 = mm->add_instruction(batch_norm_inference755, mx754, mx55, mx54, mx53, mx52); + migraphx::op::relu relu756; + auto mx756 = mm->add_instruction(relu756, mx755); + migraphx::op::concat concat757; + concat757.axis = 1; + auto mx757 = mm->add_instruction(concat757, mx753, mx756); + migraphx::op::pooling pooling758; + pooling758.mode = migraphx::op::pooling_mode::average; + pooling758.padding = {1, 1}; + pooling758.stride = {1, 1}; + pooling758.lengths = {3, 3}; + auto mx758 = mm->add_instruction(pooling758, mx731); + migraphx::op::convolution convolution759; + convolution759.padding = {0, 0}; + convolution759.stride = {1, 1}; + convolution759.dilation = {1, 1}; + convolution759.group = 1; + auto mx759 = mm->add_instruction(convolution759, mx758, mx51); + migraphx::op::batch_norm_inference batch_norm_inference760; + batch_norm_inference760.epsilon = 0.001; + batch_norm_inference760.momentum = 0.9; + auto mx760 = mm->add_instruction(batch_norm_inference760, mx759, mx50, mx49, mx48, mx47); + migraphx::op::relu relu761; + auto mx761 = mm->add_instruction(relu761, mx760); + migraphx::op::concat concat762; + concat762.axis = 1; + auto mx762 = mm->add_instruction(concat762, mx734, mx744, mx757, mx761); + migraphx::op::convolution convolution763; + convolution763.padding = {0, 0}; + convolution763.stride = {1, 1}; + convolution763.dilation = {1, 1}; + convolution763.group = 1; + auto mx763 = mm->add_instruction(convolution763, mx762, mx46); + migraphx::op::batch_norm_inference batch_norm_inference764; + batch_norm_inference764.epsilon = 0.001; + batch_norm_inference764.momentum = 0.9; + auto mx764 = mm->add_instruction(batch_norm_inference764, mx763, mx45, mx44, mx43, mx42); + migraphx::op::relu relu765; + auto mx765 = mm->add_instruction(relu765, mx764); + migraphx::op::convolution convolution766; + convolution766.padding = {0, 0}; + convolution766.stride = {1, 1}; + convolution766.dilation = {1, 1}; + convolution766.group = 1; + auto mx766 = mm->add_instruction(convolution766, mx762, mx41); + migraphx::op::batch_norm_inference batch_norm_inference767; + batch_norm_inference767.epsilon = 0.001; + batch_norm_inference767.momentum = 0.9; + auto mx767 = mm->add_instruction(batch_norm_inference767, mx766, mx40, mx39, mx38, mx37); + migraphx::op::relu relu768; + auto mx768 = mm->add_instruction(relu768, mx767); + migraphx::op::convolution convolution769; + convolution769.padding = {0, 1}; + convolution769.stride = {1, 1}; + convolution769.dilation = {1, 1}; + convolution769.group = 1; + auto mx769 = mm->add_instruction(convolution769, mx768, mx36); + migraphx::op::batch_norm_inference batch_norm_inference770; + batch_norm_inference770.epsilon = 0.001; + batch_norm_inference770.momentum = 0.9; + auto mx770 = mm->add_instruction(batch_norm_inference770, mx769, mx35, mx34, mx33, mx32); + migraphx::op::relu relu771; + auto mx771 = mm->add_instruction(relu771, mx770); + migraphx::op::convolution convolution772; + convolution772.padding = {1, 0}; + convolution772.stride = {1, 1}; + convolution772.dilation = {1, 1}; + convolution772.group = 1; + auto mx772 = mm->add_instruction(convolution772, mx768, mx31); + migraphx::op::batch_norm_inference batch_norm_inference773; + batch_norm_inference773.epsilon = 0.001; + batch_norm_inference773.momentum = 0.9; + auto mx773 = mm->add_instruction(batch_norm_inference773, mx772, mx30, mx29, mx28, mx27); + migraphx::op::relu relu774; + auto mx774 = mm->add_instruction(relu774, mx773); + migraphx::op::concat concat775; + concat775.axis = 1; + auto mx775 = mm->add_instruction(concat775, mx771, mx774); + migraphx::op::convolution convolution776; + convolution776.padding = {0, 0}; + convolution776.stride = {1, 1}; + convolution776.dilation = {1, 1}; + convolution776.group = 1; + auto mx776 = mm->add_instruction(convolution776, mx762, mx26); + migraphx::op::batch_norm_inference batch_norm_inference777; + batch_norm_inference777.epsilon = 0.001; + batch_norm_inference777.momentum = 0.9; + auto mx777 = mm->add_instruction(batch_norm_inference777, mx776, mx25, mx24, mx23, mx22); + migraphx::op::relu relu778; + auto mx778 = mm->add_instruction(relu778, mx777); + migraphx::op::convolution convolution779; + convolution779.padding = {1, 1}; + convolution779.stride = {1, 1}; + convolution779.dilation = {1, 1}; + convolution779.group = 1; + auto mx779 = mm->add_instruction(convolution779, mx778, mx21); + migraphx::op::batch_norm_inference batch_norm_inference780; + batch_norm_inference780.epsilon = 0.001; + batch_norm_inference780.momentum = 0.9; + auto mx780 = mm->add_instruction(batch_norm_inference780, mx779, mx20, mx19, mx18, mx17); + migraphx::op::relu relu781; + auto mx781 = mm->add_instruction(relu781, mx780); + migraphx::op::convolution convolution782; + convolution782.padding = {0, 1}; + convolution782.stride = {1, 1}; + convolution782.dilation = {1, 1}; + convolution782.group = 1; + auto mx782 = mm->add_instruction(convolution782, mx781, mx16); + migraphx::op::batch_norm_inference batch_norm_inference783; + batch_norm_inference783.epsilon = 0.001; + batch_norm_inference783.momentum = 0.9; + auto mx783 = mm->add_instruction(batch_norm_inference783, mx782, mx15, mx14, mx13, mx12); + migraphx::op::relu relu784; + auto mx784 = mm->add_instruction(relu784, mx783); + migraphx::op::convolution convolution785; + convolution785.padding = {1, 0}; + convolution785.stride = {1, 1}; + convolution785.dilation = {1, 1}; + convolution785.group = 1; + auto mx785 = mm->add_instruction(convolution785, mx781, mx11); + migraphx::op::batch_norm_inference batch_norm_inference786; + batch_norm_inference786.epsilon = 0.001; + batch_norm_inference786.momentum = 0.9; + auto mx786 = mm->add_instruction(batch_norm_inference786, mx785, mx10, mx9, mx8, mx7); + migraphx::op::relu relu787; + auto mx787 = mm->add_instruction(relu787, mx786); + migraphx::op::concat concat788; + concat788.axis = 1; + auto mx788 = mm->add_instruction(concat788, mx784, mx787); + migraphx::op::pooling pooling789; + pooling789.mode = migraphx::op::pooling_mode::average; + pooling789.padding = {1, 1}; + pooling789.stride = {1, 1}; + pooling789.lengths = {3, 3}; + auto mx789 = mm->add_instruction(pooling789, mx762); + migraphx::op::convolution convolution790; + convolution790.padding = {0, 0}; + convolution790.stride = {1, 1}; + convolution790.dilation = {1, 1}; + convolution790.group = 1; + auto mx790 = mm->add_instruction(convolution790, mx789, mx6); + migraphx::op::batch_norm_inference batch_norm_inference791; + batch_norm_inference791.epsilon = 0.001; + batch_norm_inference791.momentum = 0.9; + auto mx791 = mm->add_instruction(batch_norm_inference791, mx790, mx5, mx4, mx3, mx2); + migraphx::op::relu relu792; + auto mx792 = mm->add_instruction(relu792, mx791); + migraphx::op::concat concat793; + concat793.axis = 1; + auto mx793 = mm->add_instruction(concat793, mx765, mx775, mx788, mx792); + migraphx::op::pooling pooling794; + pooling794.mode = migraphx::op::pooling_mode::average; + pooling794.padding = {0, 0}; + pooling794.stride = {8, 8}; + pooling794.lengths = {8, 8}; + auto mx794 = mm->add_instruction(pooling794, mx793); + migraphx::op::identity identity795; + auto mx795 = mm->add_instruction(identity795, mx794); + migraphx::op::flatten flatten796; + flatten796.axis = 1; + auto mx796 = mm->add_instruction(flatten796, mx795); + migraphx::op::transpose transpose797; + transpose797.dims = {1, 0}; + auto mx797 = mm->add_instruction(transpose797, mx1); + migraphx::op::multibroadcast multibroadcast798; + multibroadcast798.output_lens = {batch, 1000}; + auto mx798 = mm->add_instruction(multibroadcast798, mx0); + float dot799_alpha = 1; + float dot799_beta = 1; + migraphx::add_apply_alpha_beta( + *mm, {mx796, mx797, mx798}, migraphx::make_op("dot"), dot799_alpha, dot799_beta); + + return p; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace driver +} // namespace migraphx diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 67b178f1cb2a5df28b43aabeef6e93a23b1a5148..325cffac0d0542d150ebe465bba8edbff6a76fc8 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -1,11 +1,17 @@ +#include "verify.hpp" #include "argument_parser.hpp" #include "command.hpp" -#include "verify.hpp" +#include "precision.hpp" #include "perf.hpp" +#include "models.hpp" +#include "marker_roctx.hpp" #include #include #include +#include +#include +#include #include #include @@ -14,9 +20,11 @@ #include #include #include +#include #include #include #include +#include #include @@ -26,45 +34,162 @@ inline namespace MIGRAPHX_INLINE_NS { struct loader { + std::string model; std::string file; std::string file_type; - bool is_nhwc = true; - unsigned trim = 0; - bool optimize = false; + unsigned batch = 1; + bool is_nhwc = true; + unsigned trim = 0; + bool optimize = false; + bool skip_unknown_operators = false; + bool brief = false; + std::string output_type; + std::string output; + std::vector param_dims; + std::vector output_names; void parse(argument_parser& ap) { ap(file, {}, ap.metavar("")); + ap(model, {"--model"}, ap.help("Load model"), ap.type("resnet50|inceptionv3|alexnet")); ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx")); ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf")); + ap(file_type, {"--migraphx"}, ap.help("Load as MIGraphX"), ap.set_value("migraphx")); + ap(file_type, {"--migraphx-json"}, ap.help("Load as MIGraphX JSON"), ap.set_value("json")); + ap(batch, {"--batch"}, ap.help("Set batch size for model")); ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true)); + ap(skip_unknown_operators, + {"--skip-unknown-operators"}, + ap.help("Skip unknown operators when parsing and continue to parse."), + ap.set_value(true)); ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false)); ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end")); + ap(param_dims, + {"--input-dim"}, + ap.help("Dim of a parameter (format: \"@name d1 d2 dn\")"), + ap.append(), + ap.nargs(2)); + + ap(output_names, + {"--output-names"}, + ap.help("Names of node output (format: \"name_1 name_2 name_n\")"), + ap.append(), + ap.nargs(2)); ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true)); + ap(output_type, + {"--graphviz", "-g"}, + ap.help("Print out a graphviz representation."), + ap.set_value("graphviz")); + ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true)); + ap(output_type, + {"--cpp"}, + ap.help("Print out the program as cpp program."), + ap.set_value("cpp")); + ap(output_type, {"--json"}, ap.help("Print out program as json."), ap.set_value("json")); + ap(output_type, + {"--text"}, + ap.help("Print out program in text format."), + ap.set_value("text")); + ap(output_type, + {"--binary"}, + ap.help("Print out program in binary format."), + ap.set_value("binary")); + ap(output, {"--output", "-o"}, ap.help("Output to file.")); + } + + static auto parse_param_dims(const std::vector& param_dims_info) + { + std::unordered_map> map_input_dims; + std::string name = ""; + for(auto&& x : param_dims_info) + { + if(x[0] == '@') + { + name = x.substr(1); + } + else + { + map_input_dims[name].push_back(value_parser::apply(x)); + } + } + + return map_input_dims; + } + + static auto parse_output_names(const std::vector& output_names_info) + { + std::vector output_node_names; + std::transform(output_names_info.begin(), + output_names_info.end(), + std::back_inserter(output_node_names), + [&](auto x) { return value_parser::apply(x); }); + + return output_node_names; } program load() { program p; - if(file_type.empty()) + if(model.empty()) { - if(ends_with(file, ".onnx")) - file_type = "onnx"; - else if(ends_with(file, ".pb")) - file_type = "tf"; + auto map_input_dims = parse_param_dims(param_dims); + auto output_node_names = parse_output_names(output_names); + if(file_type.empty()) + { + if(ends_with(file, ".onnx")) + file_type = "onnx"; + else if(ends_with(file, ".pb")) + file_type = "tf"; + else if(ends_with(file, ".json")) + file_type = "json"; + else + file_type = "migraphx"; + } + std::cout << "Reading: " << file << std::endl; + if(file_type == "onnx") + { + onnx_options options; + options.default_dim_value = batch; + options.skip_unknown_operators = skip_unknown_operators; + options.print_program_on_error = true; + options.map_input_dims = map_input_dims; + p = parse_onnx(file, options); + } + else if(file_type == "tf") + { + p = parse_tf(file, tf_options{is_nhwc, batch, map_input_dims, output_node_names}); + } + else if(file_type == "json") + { + file_options options; + options.format = "json"; + p = migraphx::load(file, options); + } + else if(file_type == "migraphx") + { + p = migraphx::load(file); + } + } + else + { + if(model == "resnet50") + p = resnet50(batch); + else if(model == "inceptionv3") + p = inceptionv3(batch); + else if(model == "alexnet") + p = alexnet(batch); + else + MIGRAPHX_THROW("Unknown model: " + model); } - std::cout << "Reading: " << file << std::endl; - if(file_type == "onnx") - p = parse_onnx(file); - else if(file_type == "tf") - p = parse_tf(file, is_nhwc); if(trim > 0) { - auto last = std::prev(p.end(), trim); - p.remove_instructions(last, p.end()); + auto* mm = p.get_main_module(); + auto last = std::prev(mm->end(), trim); + mm->remove_instructions(last, mm->end()); } if(optimize) - migraphx::run_passes(p, + { + migraphx::run_passes(*p.get_main_module(), { migraphx::rewrite_batchnorm{}, migraphx::eliminate_identity{}, @@ -78,59 +203,140 @@ struct loader migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}, }); + } return p; } + + static void write(std::ostream& os, const std::vector& buffer) + { + os.write(buffer.data(), buffer.size()); + } + + void save(const program& p) const + { + auto* os = &std::cout; + std::ofstream fs; + if(not output.empty()) + { + fs.open(output); + os = &fs; + } + + std::string type = output_type; + if(type.empty()) + { + if(output.empty()) + type = "text"; + else + type = "binary"; + } + + if(type == "cpp") + p.print_cpp(*os); + else if(type == "graphviz") + p.print_graph(*os, brief); + else if(type == "text") + *os << p << std::endl; + else if(type == "json") + *os << to_json_string(p.to_value()) << std::endl; + else if(type == "binary") + write(*os, save_buffer(p)); + } +}; + +struct program_params +{ + std::vector fill0{}; + std::vector fill1{}; + void parse(argument_parser& ap) + { + ap(fill0, {"--fill0"}, ap.help("Fill parameter with 0s"), ap.append(), ap.nargs(2)); + ap(fill1, {"--fill1"}, ap.help("Fill parameter with 1s"), ap.append(), ap.nargs(2)); + } + + auto generate(const program& p, const target& t, bool offload) + { + parameter_map m; + for(auto&& s : fill0) + m[s] = fill_argument(p.get_parameter_shape(s), 0); + for(auto&& s : fill1) + m[s] = fill_argument(p.get_parameter_shape(s), 1); + fill_param_map(m, p, t, offload); + return m; + } +}; + +struct compiler_target +{ +#ifdef HAVE_GPU + std::string target_name = "gpu"; +#else + std::string target_name = "cpu"; +#endif + + void parse(argument_parser& ap) + { + ap(target_name, {"--gpu"}, ap.help("Compile on the gpu"), ap.set_value("gpu")); + ap(target_name, {"--cpu"}, ap.help("Compile on the cpu"), ap.set_value("cpu")); + ap(target_name, + {"--ref"}, + ap.help("Compile on the reference implementation"), + ap.set_value("ref")); + } + + target get_target() const { return make_target(target_name); } }; struct compiler { - static const int q_fp16 = 1; - static const int q_int8 = 2; loader l; - bool gpu = true; - bool offload_copy = false; - int quantize = 0; + program_params parameters; + compiler_target ct; + bool offload_copy = false; + bool fast_math = true; + precision quantize = precision::fp32; + std::vector fill0; std::vector fill1; void parse(argument_parser& ap) { l.parse(ap); - ap(gpu, {"--gpu"}, ap.help("Compile on the gpu"), ap.set_value(true)); - ap(gpu, {"--cpu"}, ap.help("Compile on the cpu"), ap.set_value(false)); + parameters.parse(ap); + ct.parse(ap); ap(offload_copy, {"--enable-offload-copy"}, ap.help("Enable implicit offload copying"), + ap.set_value(true)); + ap(fast_math, + {"--disable-fast-math"}, + ap.help("Disable fast math optimization"), ap.set_value(false)); - ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(q_fp16)); - ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(q_int8)); - ap(fill1, {"--fill1"}, ap.help("Fill parameter with 1s"), ap.append()); + ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16)); + ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(precision::int8)); } - auto params(const program& p, bool use_gpu = true) - { - bool gpu_flag = use_gpu && gpu && !offload_copy; - program::parameter_map m; - for(auto&& s : fill1) - m[s] = fill_argument(p.get_parameter_shape(s), 1); - fill_param_map(m, p, gpu_flag); - return m; - } + auto params(const program& p) { return parameters.generate(p, ct.get_target(), offload_copy); } program compile() { auto p = l.load(); - auto t = get_target(gpu); - if(quantize == q_fp16) + // Dont compile if its already been compiled + if(p.is_compiled()) + return p; + auto t = ct.get_target(); + if(quantize == precision::fp16) { quantize_fp16(p); } - else if(quantize == q_int8) + else if(quantize == precision::int8) { - quantize_int8(p, t, {params(p, false)}); + quantize_int8(p, t, {params(p)}); } compile_options options; options.offload_copy = offload_copy; + options.fast_math = fast_math; p.compile(t, options); + l.save(p); return p; } }; @@ -138,36 +344,12 @@ struct compiler struct read : command { loader l; - bool graphviz = false; - bool brief = false; - std::string output; - void parse(argument_parser& ap) - { - l.parse(ap); - ap(graphviz, - {"--graphviz", "-g"}, - ap.help("Print out a graphviz representation."), - ap.set_value(true)); - ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true)); - ap(output, {"--output", "-o"}, ap.help("Output to file.")); - } + void parse(argument_parser& ap) { l.parse(ap); } void run() { auto p = l.load(); - - auto* os = &std::cout; - std::ofstream fs; - if(not output.empty()) - { - fs.open(output); - os = &fs; - } - - if(graphviz) - p.print_graph(*os, brief); - else - *os << p << std::endl; + l.save(p); } }; @@ -187,40 +369,73 @@ struct params : command struct verify : command { loader l; + program_params parameters; + compiler_target ct; double tolerance = 80; bool per_instruction = false; bool reduce = false; + bool offload_copy = false; + bool fast_math = true; + precision quantize = precision::fp32; void parse(argument_parser& ap) { l.parse(ap); + parameters.parse(ap); + ct.parse(ap); + ap(offload_copy, + {"--enable-offload-copy"}, + ap.help("Enable implicit offload copying"), + ap.set_value(true)); + ap(fast_math, + {"--disable-fast-math"}, + ap.help("Disable fast math optimization"), + ap.set_value(false)); ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors")); ap(per_instruction, {"-i", "--per-instruction"}, ap.help("Verify each instruction"), ap.set_value(true)); ap(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true)); + ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16)); } void run() { auto p = l.load(); + l.save(p); std::cout << p << std::endl; + compile_options options; + options.offload_copy = offload_copy; + options.fast_math = fast_math; + auto t = ct.get_target(); + auto m = parameters.generate(p, t, true); + if(per_instruction) { - verify_instructions(p, tolerance); + verify_instructions(p, t, options, quantize, tolerance); } else if(reduce) { - verify_reduced_program(p, tolerance); + verify_reduced_program(p, t, options, quantize, m, tolerance); } else { - verify_program(l.file, p, tolerance); + verify_program(l.file, p, t, options, quantize, m, tolerance); } } }; +struct version : command +{ + void parse(const argument_parser&) {} + void run() const + { + std::cout << "MIGraphX Version: " << MIGRAPHX_VERSION_MAJOR << "." << MIGRAPHX_VERSION_MINOR + << std::endl; + } +}; + struct compile : command { compiler c; @@ -229,8 +444,7 @@ struct compile : command void run() { std::cout << "Compiling ... " << std::endl; - auto p = c.compile(); - std::cout << p << std::endl; + c.compile(); } }; @@ -267,7 +481,72 @@ struct perf : command std::cout << "Allocating params ... " << std::endl; auto m = c.params(p); std::cout << "Running performance report ... " << std::endl; - p.perf_report(std::cout, n, m); + p.perf_report(std::cout, n, m, c.l.batch); + } +}; + +struct roctx : command +{ + compiler c; + void parse(argument_parser& ap) { c.parse(ap); } + + void run() + { + std::cout << "Compiling ... " << std::endl; + auto p = c.compile(); + std::cout << "Allocating params ... " << std::endl; + auto m = c.params(p); + std::cout << "rocTX:\tLoading rocTX library..." << std::endl; + auto rtx = create_marker_roctx(); + p.mark(m, std::move(rtx)); + } +}; + +struct op : command +{ + bool show_ops = false; + std::string op_name{}; + void parse(argument_parser& ap) + { + ap(op_name, {}, ap.metavar("")); + ap(show_ops, + {"--list", "-l"}, + ap.help("List all the operators of MIGraphX"), + ap.set_value(true)); + } + void run() const + { + if(show_ops) + { + for(const auto& name : get_operators()) + std::cout << name << std::endl; + } + else + { + auto op = load_op(op_name); + std::cout << op_name << ": " << std::endl; + std::cout << to_pretty_json_string(op.to_value()) << std::endl; + } + } +}; + +struct onnx : command +{ + bool show_ops = false; + void parse(argument_parser& ap) + { + ap(show_ops, + {"--list", "-l"}, + ap.help("List all onnx operators supported by MIGraphX"), + ap.set_value(true)); + } + void run() const + { + if(show_ops) + { + for(const auto& name : get_onnx_operators()) + std::cout << name << std::endl; + } } }; @@ -283,7 +562,13 @@ struct main_command } void parse(argument_parser& ap) { + std::string version_str = "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) + + "." + std::to_string(MIGRAPHX_VERSION_MINOR); ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help())); + ap(nullptr, + {"-v", "--version"}, + ap.help("Show MIGraphX version"), + ap.show_help(version_str)); } void run() {} @@ -297,8 +582,13 @@ using namespace migraphx::driver; // NOLINT int main(int argc, const char* argv[]) { std::vector args(argv + 1, argv + argc); + + // no argument, print the help infomration by default if(args.empty()) - return 0; + { + args.push_back("-h"); + } + auto&& m = get_commands(); auto cmd = args.front(); if(m.count(cmd) > 0) @@ -309,5 +599,6 @@ int main(int argc, const char* argv[]) { run_command(args); } + return 0; } diff --git a/src/driver/marker_roctx.cpp b/src/driver/marker_roctx.cpp new file mode 100644 index 0000000000000000000000000000000000000000..549e9be7f0028ba5086b94435495a13eb39b1406 --- /dev/null +++ b/src/driver/marker_roctx.cpp @@ -0,0 +1,49 @@ +#include "marker_roctx.hpp" + +#include +#include +#include + +namespace migraphx { +namespace driver { +inline namespace MIGRAPHX_INLINE_NS { + +class marker_roctx +{ + std::function sym_roctx_mark; + std::function sym_roctx_range_start; + std::function sym_roctx_range_stop; + + std::function sym_roctx_range_push; + std::function sym_roctx_range_pop; + + uint64_t range_id = 0; + + public: + marker_roctx() + { + dynamic_loader lib = migraphx::dynamic_loader{"libroctx64.so"}; + sym_roctx_mark = lib.get_function("roctxMarkA"); + sym_roctx_range_start = lib.get_function("roctxRangeStartA"); + sym_roctx_range_stop = lib.get_function("roctxRangeStop"); + + sym_roctx_range_push = lib.get_function("roctxRangePushA"); + sym_roctx_range_pop = lib.get_function("roctxRangePop"); + + sym_roctx_mark("rocTX marker created."); + } + + void mark_start(instruction_ref ins_ref) + { + std::string text = "Marker start: " + ins_ref->name(); + sym_roctx_range_push(text.c_str()); + } + void mark_stop(instruction_ref) { sym_roctx_range_pop(); } + void mark_start(const program&) { range_id = sym_roctx_range_start("0"); } + void mark_stop(const program&) { sym_roctx_range_stop(range_id); } +}; + +marker create_marker_roctx() { return marker_roctx(); } +} // namespace MIGRAPHX_INLINE_NS +} // namespace driver +} // namespace migraphx diff --git a/src/driver/marker_roctx.hpp b/src/driver/marker_roctx.hpp new file mode 100755 index 0000000000000000000000000000000000000000..64b4d672a96c900ac1597ea47fd3557b413b4e83 --- /dev/null +++ b/src/driver/marker_roctx.hpp @@ -0,0 +1,16 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_MARKER_ROCTX_HPP +#define MIGRAPHX_GUARD_RTGLIB_MARKER_ROCTX_HPP + +#include + +namespace migraphx { +namespace driver { +inline namespace MIGRAPHX_INLINE_NS { + +marker create_marker_roctx(); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace driver +} // namespace migraphx + +#endif diff --git a/src/driver/models.hpp b/src/driver/models.hpp new file mode 100644 index 0000000000000000000000000000000000000000..43d9d0643bc837c826388a63b851d8b7cbc21a10 --- /dev/null +++ b/src/driver/models.hpp @@ -0,0 +1,14 @@ + +#include + +namespace migraphx { +namespace driver { +inline namespace MIGRAPHX_INLINE_NS { + +migraphx::program resnet50(unsigned batch); +migraphx::program inceptionv3(unsigned batch); +migraphx::program alexnet(unsigned batch); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace driver +} // namespace migraphx diff --git a/src/driver/perf.cpp b/src/driver/perf.cpp index 958e48198da5a106727ffc385deaa49d5dfe48f2..a1ba98622f698cd1d57658981c051b146ef29e8e 100644 --- a/src/driver/perf.cpp +++ b/src/driver/perf.cpp @@ -1,9 +1,8 @@ #include "perf.hpp" -#include #include +#include #ifdef HAVE_GPU -#include #include #endif @@ -11,13 +10,32 @@ namespace migraphx { namespace driver { inline namespace MIGRAPHX_INLINE_NS { -program::parameter_map fill_param_map(program::parameter_map& m, const program& p, bool gpu) +template +auto get_hash(const T& x) +{ + return std::hash{}(x); +} + +parameter_map fill_param_map(parameter_map& m, const program& p, const target& t, bool offload) +{ + for(auto&& x : p.get_parameter_shapes()) + { + argument& arg = m[x.first]; + if(arg.empty()) + arg = generate_argument(x.second, get_hash(x.first)); + if(not offload) + arg = t.copy_to(arg); + } + return m; +} + +parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu) { for(auto&& x : p.get_parameter_shapes()) { argument& arg = m[x.first]; if(arg.empty()) - arg = generate_argument(x.second); + arg = generate_argument(x.second, get_hash(x.first)); #ifdef HAVE_GPU if(gpu) arg = gpu::to_gpu(arg); @@ -28,55 +46,47 @@ program::parameter_map fill_param_map(program::parameter_map& m, const program& return m; } -program::parameter_map create_param_map(const program& p, bool gpu) +parameter_map create_param_map(const program& p, const target& t, bool offload) { - program::parameter_map m; + parameter_map m; for(auto&& x : p.get_parameter_shapes()) { -#ifdef HAVE_GPU - if(gpu) - m[x.first] = gpu::to_gpu(generate_argument(x.second)); + auto arg = generate_argument(x.second, get_hash(x.first)); + if(offload) + m[x.first] = arg; else -#else - (void)gpu; -#endif - m[x.first] = generate_argument(x.second); + m[x.first] = t.copy_to(arg); } return m; } -target get_target(bool gpu) +parameter_map create_param_map(const program& p, bool gpu) { - if(gpu) + parameter_map m; + for(auto&& x : p.get_parameter_shapes()) { #ifdef HAVE_GPU - return gpu::target{}; + if(gpu) + m[x.first] = gpu::to_gpu(generate_argument(x.second, get_hash(x.first))); + else #else - MIGRAPHX_THROW("Gpu not supported."); + (void)gpu; #endif + m[x.first] = generate_argument(x.second, get_hash(x.first)); } - else - { - return cpu::target{}; - } + return m; } -void compile_program(program& p, bool gpu) +target get_target(bool gpu) { if(gpu) - { -#ifdef HAVE_GPU - p.compile(gpu::target{}); -#else - MIGRAPHX_THROW("Gpu not supported."); -#endif - } + return make_target("gpu"); else - { - p.compile(cpu::target{}); - } + return make_target("cpu"); } -} // namespace MIGRAPHX_INLINE_NS +void compile_program(program& p, bool gpu) { p.compile(get_target(gpu)); } + +} // namespace MIGRAPHX_INLINE_NS } // namespace driver } // namespace migraphx diff --git a/src/driver/perf.hpp b/src/driver/perf.hpp index ffff7def652acec4970e7a6c736f2f63fb55946d..9e96596a9fbee4d0186b26c7ab5c188e84bc73f6 100644 --- a/src/driver/perf.hpp +++ b/src/driver/perf.hpp @@ -7,8 +7,12 @@ namespace migraphx { namespace driver { inline namespace MIGRAPHX_INLINE_NS { -program::parameter_map fill_param_map(program::parameter_map& m, const program& p, bool gpu); -program::parameter_map create_param_map(const program& p, bool gpu = true); +parameter_map +fill_param_map(parameter_map& m, const program& p, const target& t, bool offload = false); +parameter_map create_param_map(const program& p, const target& t, bool offload = false); + +parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu); +parameter_map create_param_map(const program& p, bool gpu = true); target get_target(bool gpu); void compile_program(program& p, bool gpu = true); diff --git a/src/driver/precision.hpp b/src/driver/precision.hpp new file mode 100644 index 0000000000000000000000000000000000000000..95f637f0700bd8308cb2fd7882d7e7c68f4d5989 --- /dev/null +++ b/src/driver/precision.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_PRECISION_HPP +#define MIGRAPHX_GUARD_RTGLIB_PRECISION_HPP + +namespace migraphx { +namespace driver { +inline namespace MIGRAPHX_INLINE_NS { + +enum class precision +{ + fp32, + fp16, + int8 +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace driver +} // namespace migraphx + +#endif diff --git a/src/driver/resnet50.cpp b/src/driver/resnet50.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c25b2c831e713a6ac5fa7e6a476ef063a49912bb --- /dev/null +++ b/src/driver/resnet50.cpp @@ -0,0 +1,1241 @@ +#include +#include +#include +#include +#include "models.hpp" + +namespace migraphx { +namespace driver { +inline namespace MIGRAPHX_INLINE_NS { + +migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto m0 = + mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}}); + auto mx0 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 0)); + auto mx1 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 2048}}, 1)); + auto mx2 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 2))); + auto mx3 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 3)); + auto mx4 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 4)); + auto mx5 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 5))); + auto mx6 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {2048, 512, 1, 1}}, 6)); + auto mx7 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 7))); + auto mx8 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 8)); + auto mx9 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 9)); + auto mx10 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 10))); + auto mx11 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {512, 512, 3, 3}}, 11)); + auto mx12 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 12))); + auto mx13 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 13)); + auto mx14 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 14)); + auto mx15 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 15))); + auto mx16 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {512, 2048, 1, 1}}, 16)); + auto mx17 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 17))); + auto mx18 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 18)); + auto mx19 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 19)); + auto mx20 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 20))); + auto mx21 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {2048, 512, 1, 1}}, 21)); + auto mx22 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 22))); + auto mx23 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 23)); + auto mx24 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 24)); + auto mx25 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 25))); + auto mx26 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {512, 512, 3, 3}}, 26)); + auto mx27 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 27))); + auto mx28 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 28)); + auto mx29 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 29)); + auto mx30 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 30))); + auto mx31 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {512, 2048, 1, 1}}, 31)); + auto mx32 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 32))); + auto mx33 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 33)); + auto mx34 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 34)); + auto mx35 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 35))); + auto mx36 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {2048, 1024, 1, 1}}, 36)); + auto mx37 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 37))); + auto mx38 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 38)); + auto mx39 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 39)); + auto mx40 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {2048}}, 40))); + auto mx41 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {2048, 512, 1, 1}}, 41)); + auto mx42 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 42))); + auto mx43 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 43)); + auto mx44 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 44)); + auto mx45 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 45))); + auto mx46 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {512, 512, 3, 3}}, 46)); + auto mx47 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 47))); + auto mx48 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 48)); + auto mx49 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 49)); + auto mx50 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 50))); + auto mx51 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {512, 1024, 1, 1}}, 51)); + auto mx52 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 52))); + auto mx53 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 53)); + auto mx54 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 54)); + auto mx55 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 55)); + auto mx56 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {1024, 256, 1, 1}}, 56)); + auto mx57 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 57))); + auto mx58 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 58)); + auto mx59 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 59)); + auto mx60 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 60))); + auto mx61 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 61)); + auto mx62 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 62))); + auto mx63 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 63)); + auto mx64 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 64)); + auto mx65 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 65))); + auto mx66 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 1024, 1, 1}}, 66)); + auto mx67 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 67))); + auto mx68 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 68)); + auto mx69 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 69)); + auto mx70 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 70)); + auto mx71 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {1024, 256, 1, 1}}, 71)); + auto mx72 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 72))); + auto mx73 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 73)); + auto mx74 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 74)); + auto mx75 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 75))); + auto mx76 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 76)); + auto mx77 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 77))); + auto mx78 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 78)); + auto mx79 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 79)); + auto mx80 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 80))); + auto mx81 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 1024, 1, 1}}, 81)); + auto mx82 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 82))); + auto mx83 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 83)); + auto mx84 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 84)); + auto mx85 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 85)); + auto mx86 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {1024, 256, 1, 1}}, 86)); + auto mx87 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 87))); + auto mx88 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 88)); + auto mx89 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 89)); + auto mx90 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 90))); + auto mx91 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 91)); + auto mx92 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 92))); + auto mx93 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 93)); + auto mx94 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 94)); + auto mx95 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 95))); + auto mx96 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 1024, 1, 1}}, 96)); + auto mx97 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 97))); + auto mx98 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 98)); + auto mx99 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 99)); + auto mx100 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 100)); + auto mx101 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {1024, 256, 1, 1}}, 101)); + auto mx102 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 102))); + auto mx103 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 103)); + auto mx104 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 104)); + auto mx105 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 105))); + auto mx106 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 106)); + auto mx107 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 107))); + auto mx108 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 108)); + auto mx109 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 109)); + auto mx110 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 110))); + auto mx111 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 1024, 1, 1}}, 111)); + auto mx112 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 112))); + auto mx113 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 113)); + auto mx114 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 114)); + auto mx115 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 115)); + auto mx116 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {1024, 256, 1, 1}}, 116)); + auto mx117 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 117))); + auto mx118 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 118)); + auto mx119 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 119)); + auto mx120 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 120))); + auto mx121 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 121)); + auto mx122 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 122))); + auto mx123 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 123)); + auto mx124 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 124)); + auto mx125 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 125))); + auto mx126 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 1024, 1, 1}}, 126)); + auto mx127 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 127))); + auto mx128 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 128)); + auto mx129 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 129)); + auto mx130 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 130)); + auto mx131 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {1024, 512, 1, 1}}, 131)); + auto mx132 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 132))); + auto mx133 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 133)); + auto mx134 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 134)); + auto mx135 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1024}}, 135)); + auto mx136 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {1024, 256, 1, 1}}, 136)); + auto mx137 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 137))); + auto mx138 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 138)); + auto mx139 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 139)); + auto mx140 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 140))); + auto mx141 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 141)); + auto mx142 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 142))); + auto mx143 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 143)); + auto mx144 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 144)); + auto mx145 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 145))); + auto mx146 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 512, 1, 1}}, 146)); + auto mx147 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 147))); + auto mx148 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 148)); + auto mx149 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 149)); + auto mx150 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 150)); + auto mx151 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {512, 128, 1, 1}}, 151)); + auto mx152 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 152))); + auto mx153 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 153)); + auto mx154 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 154)); + auto mx155 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 155))); + auto mx156 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {128, 128, 3, 3}}, 156)); + auto mx157 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 157))); + auto mx158 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 158)); + auto mx159 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 159)); + auto mx160 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 160))); + auto mx161 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {128, 512, 1, 1}}, 161)); + auto mx162 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 162))); + auto mx163 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 163)); + auto mx164 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 164)); + auto mx165 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 165)); + auto mx166 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {512, 128, 1, 1}}, 166)); + auto mx167 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 167))); + auto mx168 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 168)); + auto mx169 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 169)); + auto mx170 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 170))); + auto mx171 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {128, 128, 3, 3}}, 171)); + auto mx172 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 172))); + auto mx173 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 173)); + auto mx174 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 174)); + auto mx175 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 175))); + auto mx176 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {128, 512, 1, 1}}, 176)); + auto mx177 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 177))); + auto mx178 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 178)); + auto mx179 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 179)); + auto mx180 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 180)); + auto mx181 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {512, 128, 1, 1}}, 181)); + auto mx182 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 182))); + auto mx183 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 183)); + auto mx184 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 184)); + auto mx185 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 185))); + auto mx186 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {128, 128, 3, 3}}, 186)); + auto mx187 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 187))); + auto mx188 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 188)); + auto mx189 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 189)); + auto mx190 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 190))); + auto mx191 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {128, 512, 1, 1}}, 191)); + auto mx192 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 192))); + auto mx193 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 193)); + auto mx194 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 194)); + auto mx195 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 195)); + auto mx196 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {512, 256, 1, 1}}, 196)); + auto mx197 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 197))); + auto mx198 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 198)); + auto mx199 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 199)); + auto mx200 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {512}}, 200)); + auto mx201 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {512, 128, 1, 1}}, 201)); + auto mx202 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 202))); + auto mx203 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 203)); + auto mx204 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 204)); + auto mx205 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 205))); + auto mx206 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {128, 128, 3, 3}}, 206)); + auto mx207 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 207))); + auto mx208 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 208)); + auto mx209 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 209)); + auto mx210 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {128}}, 210))); + auto mx211 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {128, 256, 1, 1}}, 211)); + auto mx212 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 212))); + auto mx213 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 213)); + auto mx214 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 214)); + auto mx215 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 215)); + auto mx216 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 64, 1, 1}}, 216)); + auto mx217 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 217))); + auto mx218 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 218)); + auto mx219 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 219)); + auto mx220 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 220))); + auto mx221 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 64, 3, 3}}, 221)); + auto mx222 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 222))); + auto mx223 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 223)); + auto mx224 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 224)); + auto mx225 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 225))); + auto mx226 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 256, 1, 1}}, 226)); + auto mx227 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 227))); + auto mx228 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 228)); + auto mx229 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 229)); + auto mx230 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 230)); + auto mx231 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 64, 1, 1}}, 231)); + auto mx232 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 232))); + auto mx233 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 233)); + auto mx234 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 234)); + auto mx235 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 235))); + auto mx236 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 64, 3, 3}}, 236)); + auto mx237 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 237))); + auto mx238 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 238)); + auto mx239 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 239)); + auto mx240 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 240))); + auto mx241 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 256, 1, 1}}, 241)); + auto mx242 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 242))); + auto mx243 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 243)); + auto mx244 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 244)); + auto mx245 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 245)); + auto mx246 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 64, 1, 1}}, 246)); + auto mx247 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 247))); + auto mx248 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 248)); + auto mx249 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 249)); + auto mx250 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 250)); + auto mx251 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {256, 64, 1, 1}}, 251)); + auto mx252 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 252))); + auto mx253 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 253)); + auto mx254 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 254)); + auto mx255 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 255))); + auto mx256 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 64, 3, 3}}, 256)); + auto mx257 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 257))); + auto mx258 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 258)); + auto mx259 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 259)); + auto mx260 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 260))); + auto mx261 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 64, 1, 1}}, 261)); + auto mx262 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 262))); + auto mx263 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 263)); + auto mx264 = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 264)); + auto mx265 = mm->add_literal(migraphx::abs( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 265))); + auto mx266 = mm->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 3, 7, 7}}, 266)); + migraphx::op::convolution convolution267; + convolution267.padding = {3, 3}; + convolution267.stride = {2, 2}; + convolution267.dilation = {1, 1}; + convolution267.group = 1; + auto mx267 = mm->add_instruction(convolution267, m0, mx266); + migraphx::op::batch_norm_inference batch_norm_inference268; + batch_norm_inference268.epsilon = 1e-05; + batch_norm_inference268.momentum = 0.9; + auto mx268 = mm->add_instruction(batch_norm_inference268, mx267, mx265, mx264, mx263, mx262); + migraphx::op::relu relu269; + auto mx269 = mm->add_instruction(relu269, mx268); + migraphx::op::pooling pooling270; + pooling270.mode = migraphx::op::pooling_mode::max; + pooling270.padding = {1, 1}; + pooling270.stride = {2, 2}; + pooling270.lengths = {3, 3}; + auto mx270 = mm->add_instruction(pooling270, mx269); + migraphx::op::convolution convolution271; + convolution271.padding = {0, 0}; + convolution271.stride = {1, 1}; + convolution271.dilation = {1, 1}; + convolution271.group = 1; + auto mx271 = mm->add_instruction(convolution271, mx270, mx261); + migraphx::op::batch_norm_inference batch_norm_inference272; + batch_norm_inference272.epsilon = 1e-05; + batch_norm_inference272.momentum = 0.9; + auto mx272 = mm->add_instruction(batch_norm_inference272, mx271, mx260, mx259, mx258, mx257); + migraphx::op::relu relu273; + auto mx273 = mm->add_instruction(relu273, mx272); + migraphx::op::convolution convolution274; + convolution274.padding = {1, 1}; + convolution274.stride = {1, 1}; + convolution274.dilation = {1, 1}; + convolution274.group = 1; + auto mx274 = mm->add_instruction(convolution274, mx273, mx256); + migraphx::op::batch_norm_inference batch_norm_inference275; + batch_norm_inference275.epsilon = 1e-05; + batch_norm_inference275.momentum = 0.9; + auto mx275 = mm->add_instruction(batch_norm_inference275, mx274, mx255, mx254, mx253, mx252); + migraphx::op::relu relu276; + auto mx276 = mm->add_instruction(relu276, mx275); + migraphx::op::convolution convolution277; + convolution277.padding = {0, 0}; + convolution277.stride = {1, 1}; + convolution277.dilation = {1, 1}; + convolution277.group = 1; + auto mx277 = mm->add_instruction(convolution277, mx276, mx251); + migraphx::op::batch_norm_inference batch_norm_inference278; + batch_norm_inference278.epsilon = 1e-05; + batch_norm_inference278.momentum = 0.9; + auto mx278 = mm->add_instruction(batch_norm_inference278, mx277, mx250, mx249, mx248, mx247); + migraphx::op::convolution convolution279; + convolution279.padding = {0, 0}; + convolution279.stride = {1, 1}; + convolution279.dilation = {1, 1}; + convolution279.group = 1; + auto mx279 = mm->add_instruction(convolution279, mx270, mx246); + migraphx::op::batch_norm_inference batch_norm_inference280; + batch_norm_inference280.epsilon = 1e-05; + batch_norm_inference280.momentum = 0.9; + auto mx280 = mm->add_instruction(batch_norm_inference280, mx279, mx245, mx244, mx243, mx242); + migraphx::op::add add281; + auto mx281 = mm->add_instruction(add281, mx278, mx280); + migraphx::op::relu relu282; + auto mx282 = mm->add_instruction(relu282, mx281); + migraphx::op::convolution convolution283; + convolution283.padding = {0, 0}; + convolution283.stride = {1, 1}; + convolution283.dilation = {1, 1}; + convolution283.group = 1; + auto mx283 = mm->add_instruction(convolution283, mx282, mx241); + migraphx::op::batch_norm_inference batch_norm_inference284; + batch_norm_inference284.epsilon = 1e-05; + batch_norm_inference284.momentum = 0.9; + auto mx284 = mm->add_instruction(batch_norm_inference284, mx283, mx240, mx239, mx238, mx237); + migraphx::op::relu relu285; + auto mx285 = mm->add_instruction(relu285, mx284); + migraphx::op::convolution convolution286; + convolution286.padding = {1, 1}; + convolution286.stride = {1, 1}; + convolution286.dilation = {1, 1}; + convolution286.group = 1; + auto mx286 = mm->add_instruction(convolution286, mx285, mx236); + migraphx::op::batch_norm_inference batch_norm_inference287; + batch_norm_inference287.epsilon = 1e-05; + batch_norm_inference287.momentum = 0.9; + auto mx287 = mm->add_instruction(batch_norm_inference287, mx286, mx235, mx234, mx233, mx232); + migraphx::op::relu relu288; + auto mx288 = mm->add_instruction(relu288, mx287); + migraphx::op::convolution convolution289; + convolution289.padding = {0, 0}; + convolution289.stride = {1, 1}; + convolution289.dilation = {1, 1}; + convolution289.group = 1; + auto mx289 = mm->add_instruction(convolution289, mx288, mx231); + migraphx::op::batch_norm_inference batch_norm_inference290; + batch_norm_inference290.epsilon = 1e-05; + batch_norm_inference290.momentum = 0.9; + auto mx290 = mm->add_instruction(batch_norm_inference290, mx289, mx230, mx229, mx228, mx227); + migraphx::op::add add291; + auto mx291 = mm->add_instruction(add291, mx290, mx282); + migraphx::op::relu relu292; + auto mx292 = mm->add_instruction(relu292, mx291); + migraphx::op::convolution convolution293; + convolution293.padding = {0, 0}; + convolution293.stride = {1, 1}; + convolution293.dilation = {1, 1}; + convolution293.group = 1; + auto mx293 = mm->add_instruction(convolution293, mx292, mx226); + migraphx::op::batch_norm_inference batch_norm_inference294; + batch_norm_inference294.epsilon = 1e-05; + batch_norm_inference294.momentum = 0.9; + auto mx294 = mm->add_instruction(batch_norm_inference294, mx293, mx225, mx224, mx223, mx222); + migraphx::op::relu relu295; + auto mx295 = mm->add_instruction(relu295, mx294); + migraphx::op::convolution convolution296; + convolution296.padding = {1, 1}; + convolution296.stride = {1, 1}; + convolution296.dilation = {1, 1}; + convolution296.group = 1; + auto mx296 = mm->add_instruction(convolution296, mx295, mx221); + migraphx::op::batch_norm_inference batch_norm_inference297; + batch_norm_inference297.epsilon = 1e-05; + batch_norm_inference297.momentum = 0.9; + auto mx297 = mm->add_instruction(batch_norm_inference297, mx296, mx220, mx219, mx218, mx217); + migraphx::op::relu relu298; + auto mx298 = mm->add_instruction(relu298, mx297); + migraphx::op::convolution convolution299; + convolution299.padding = {0, 0}; + convolution299.stride = {1, 1}; + convolution299.dilation = {1, 1}; + convolution299.group = 1; + auto mx299 = mm->add_instruction(convolution299, mx298, mx216); + migraphx::op::batch_norm_inference batch_norm_inference300; + batch_norm_inference300.epsilon = 1e-05; + batch_norm_inference300.momentum = 0.9; + auto mx300 = mm->add_instruction(batch_norm_inference300, mx299, mx215, mx214, mx213, mx212); + migraphx::op::add add301; + auto mx301 = mm->add_instruction(add301, mx300, mx292); + migraphx::op::relu relu302; + auto mx302 = mm->add_instruction(relu302, mx301); + migraphx::op::convolution convolution303; + convolution303.padding = {0, 0}; + convolution303.stride = {1, 1}; + convolution303.dilation = {1, 1}; + convolution303.group = 1; + auto mx303 = mm->add_instruction(convolution303, mx302, mx211); + migraphx::op::batch_norm_inference batch_norm_inference304; + batch_norm_inference304.epsilon = 1e-05; + batch_norm_inference304.momentum = 0.9; + auto mx304 = mm->add_instruction(batch_norm_inference304, mx303, mx210, mx209, mx208, mx207); + migraphx::op::relu relu305; + auto mx305 = mm->add_instruction(relu305, mx304); + migraphx::op::convolution convolution306; + convolution306.padding = {1, 1}; + convolution306.stride = {2, 2}; + convolution306.dilation = {1, 1}; + convolution306.group = 1; + auto mx306 = mm->add_instruction(convolution306, mx305, mx206); + migraphx::op::batch_norm_inference batch_norm_inference307; + batch_norm_inference307.epsilon = 1e-05; + batch_norm_inference307.momentum = 0.9; + auto mx307 = mm->add_instruction(batch_norm_inference307, mx306, mx205, mx204, mx203, mx202); + migraphx::op::relu relu308; + auto mx308 = mm->add_instruction(relu308, mx307); + migraphx::op::convolution convolution309; + convolution309.padding = {0, 0}; + convolution309.stride = {1, 1}; + convolution309.dilation = {1, 1}; + convolution309.group = 1; + auto mx309 = mm->add_instruction(convolution309, mx308, mx201); + migraphx::op::batch_norm_inference batch_norm_inference310; + batch_norm_inference310.epsilon = 1e-05; + batch_norm_inference310.momentum = 0.9; + auto mx310 = mm->add_instruction(batch_norm_inference310, mx309, mx200, mx199, mx198, mx197); + migraphx::op::convolution convolution311; + convolution311.padding = {0, 0}; + convolution311.stride = {2, 2}; + convolution311.dilation = {1, 1}; + convolution311.group = 1; + auto mx311 = mm->add_instruction(convolution311, mx302, mx196); + migraphx::op::batch_norm_inference batch_norm_inference312; + batch_norm_inference312.epsilon = 1e-05; + batch_norm_inference312.momentum = 0.9; + auto mx312 = mm->add_instruction(batch_norm_inference312, mx311, mx195, mx194, mx193, mx192); + migraphx::op::add add313; + auto mx313 = mm->add_instruction(add313, mx310, mx312); + migraphx::op::relu relu314; + auto mx314 = mm->add_instruction(relu314, mx313); + migraphx::op::convolution convolution315; + convolution315.padding = {0, 0}; + convolution315.stride = {1, 1}; + convolution315.dilation = {1, 1}; + convolution315.group = 1; + auto mx315 = mm->add_instruction(convolution315, mx314, mx191); + migraphx::op::batch_norm_inference batch_norm_inference316; + batch_norm_inference316.epsilon = 1e-05; + batch_norm_inference316.momentum = 0.9; + auto mx316 = mm->add_instruction(batch_norm_inference316, mx315, mx190, mx189, mx188, mx187); + migraphx::op::relu relu317; + auto mx317 = mm->add_instruction(relu317, mx316); + migraphx::op::convolution convolution318; + convolution318.padding = {1, 1}; + convolution318.stride = {1, 1}; + convolution318.dilation = {1, 1}; + convolution318.group = 1; + auto mx318 = mm->add_instruction(convolution318, mx317, mx186); + migraphx::op::batch_norm_inference batch_norm_inference319; + batch_norm_inference319.epsilon = 1e-05; + batch_norm_inference319.momentum = 0.9; + auto mx319 = mm->add_instruction(batch_norm_inference319, mx318, mx185, mx184, mx183, mx182); + migraphx::op::relu relu320; + auto mx320 = mm->add_instruction(relu320, mx319); + migraphx::op::convolution convolution321; + convolution321.padding = {0, 0}; + convolution321.stride = {1, 1}; + convolution321.dilation = {1, 1}; + convolution321.group = 1; + auto mx321 = mm->add_instruction(convolution321, mx320, mx181); + migraphx::op::batch_norm_inference batch_norm_inference322; + batch_norm_inference322.epsilon = 1e-05; + batch_norm_inference322.momentum = 0.9; + auto mx322 = mm->add_instruction(batch_norm_inference322, mx321, mx180, mx179, mx178, mx177); + migraphx::op::add add323; + auto mx323 = mm->add_instruction(add323, mx322, mx314); + migraphx::op::relu relu324; + auto mx324 = mm->add_instruction(relu324, mx323); + migraphx::op::convolution convolution325; + convolution325.padding = {0, 0}; + convolution325.stride = {1, 1}; + convolution325.dilation = {1, 1}; + convolution325.group = 1; + auto mx325 = mm->add_instruction(convolution325, mx324, mx176); + migraphx::op::batch_norm_inference batch_norm_inference326; + batch_norm_inference326.epsilon = 1e-05; + batch_norm_inference326.momentum = 0.9; + auto mx326 = mm->add_instruction(batch_norm_inference326, mx325, mx175, mx174, mx173, mx172); + migraphx::op::relu relu327; + auto mx327 = mm->add_instruction(relu327, mx326); + migraphx::op::convolution convolution328; + convolution328.padding = {1, 1}; + convolution328.stride = {1, 1}; + convolution328.dilation = {1, 1}; + convolution328.group = 1; + auto mx328 = mm->add_instruction(convolution328, mx327, mx171); + migraphx::op::batch_norm_inference batch_norm_inference329; + batch_norm_inference329.epsilon = 1e-05; + batch_norm_inference329.momentum = 0.9; + auto mx329 = mm->add_instruction(batch_norm_inference329, mx328, mx170, mx169, mx168, mx167); + migraphx::op::relu relu330; + auto mx330 = mm->add_instruction(relu330, mx329); + migraphx::op::convolution convolution331; + convolution331.padding = {0, 0}; + convolution331.stride = {1, 1}; + convolution331.dilation = {1, 1}; + convolution331.group = 1; + auto mx331 = mm->add_instruction(convolution331, mx330, mx166); + migraphx::op::batch_norm_inference batch_norm_inference332; + batch_norm_inference332.epsilon = 1e-05; + batch_norm_inference332.momentum = 0.9; + auto mx332 = mm->add_instruction(batch_norm_inference332, mx331, mx165, mx164, mx163, mx162); + migraphx::op::add add333; + auto mx333 = mm->add_instruction(add333, mx332, mx324); + migraphx::op::relu relu334; + auto mx334 = mm->add_instruction(relu334, mx333); + migraphx::op::convolution convolution335; + convolution335.padding = {0, 0}; + convolution335.stride = {1, 1}; + convolution335.dilation = {1, 1}; + convolution335.group = 1; + auto mx335 = mm->add_instruction(convolution335, mx334, mx161); + migraphx::op::batch_norm_inference batch_norm_inference336; + batch_norm_inference336.epsilon = 1e-05; + batch_norm_inference336.momentum = 0.9; + auto mx336 = mm->add_instruction(batch_norm_inference336, mx335, mx160, mx159, mx158, mx157); + migraphx::op::relu relu337; + auto mx337 = mm->add_instruction(relu337, mx336); + migraphx::op::convolution convolution338; + convolution338.padding = {1, 1}; + convolution338.stride = {1, 1}; + convolution338.dilation = {1, 1}; + convolution338.group = 1; + auto mx338 = mm->add_instruction(convolution338, mx337, mx156); + migraphx::op::batch_norm_inference batch_norm_inference339; + batch_norm_inference339.epsilon = 1e-05; + batch_norm_inference339.momentum = 0.9; + auto mx339 = mm->add_instruction(batch_norm_inference339, mx338, mx155, mx154, mx153, mx152); + migraphx::op::relu relu340; + auto mx340 = mm->add_instruction(relu340, mx339); + migraphx::op::convolution convolution341; + convolution341.padding = {0, 0}; + convolution341.stride = {1, 1}; + convolution341.dilation = {1, 1}; + convolution341.group = 1; + auto mx341 = mm->add_instruction(convolution341, mx340, mx151); + migraphx::op::batch_norm_inference batch_norm_inference342; + batch_norm_inference342.epsilon = 1e-05; + batch_norm_inference342.momentum = 0.9; + auto mx342 = mm->add_instruction(batch_norm_inference342, mx341, mx150, mx149, mx148, mx147); + migraphx::op::add add343; + auto mx343 = mm->add_instruction(add343, mx342, mx334); + migraphx::op::relu relu344; + auto mx344 = mm->add_instruction(relu344, mx343); + migraphx::op::convolution convolution345; + convolution345.padding = {0, 0}; + convolution345.stride = {1, 1}; + convolution345.dilation = {1, 1}; + convolution345.group = 1; + auto mx345 = mm->add_instruction(convolution345, mx344, mx146); + migraphx::op::batch_norm_inference batch_norm_inference346; + batch_norm_inference346.epsilon = 1e-05; + batch_norm_inference346.momentum = 0.9; + auto mx346 = mm->add_instruction(batch_norm_inference346, mx345, mx145, mx144, mx143, mx142); + migraphx::op::relu relu347; + auto mx347 = mm->add_instruction(relu347, mx346); + migraphx::op::convolution convolution348; + convolution348.padding = {1, 1}; + convolution348.stride = {2, 2}; + convolution348.dilation = {1, 1}; + convolution348.group = 1; + auto mx348 = mm->add_instruction(convolution348, mx347, mx141); + migraphx::op::batch_norm_inference batch_norm_inference349; + batch_norm_inference349.epsilon = 1e-05; + batch_norm_inference349.momentum = 0.9; + auto mx349 = mm->add_instruction(batch_norm_inference349, mx348, mx140, mx139, mx138, mx137); + migraphx::op::relu relu350; + auto mx350 = mm->add_instruction(relu350, mx349); + migraphx::op::convolution convolution351; + convolution351.padding = {0, 0}; + convolution351.stride = {1, 1}; + convolution351.dilation = {1, 1}; + convolution351.group = 1; + auto mx351 = mm->add_instruction(convolution351, mx350, mx136); + migraphx::op::batch_norm_inference batch_norm_inference352; + batch_norm_inference352.epsilon = 1e-05; + batch_norm_inference352.momentum = 0.9; + auto mx352 = mm->add_instruction(batch_norm_inference352, mx351, mx135, mx134, mx133, mx132); + migraphx::op::convolution convolution353; + convolution353.padding = {0, 0}; + convolution353.stride = {2, 2}; + convolution353.dilation = {1, 1}; + convolution353.group = 1; + auto mx353 = mm->add_instruction(convolution353, mx344, mx131); + migraphx::op::batch_norm_inference batch_norm_inference354; + batch_norm_inference354.epsilon = 1e-05; + batch_norm_inference354.momentum = 0.9; + auto mx354 = mm->add_instruction(batch_norm_inference354, mx353, mx130, mx129, mx128, mx127); + migraphx::op::add add355; + auto mx355 = mm->add_instruction(add355, mx352, mx354); + migraphx::op::relu relu356; + auto mx356 = mm->add_instruction(relu356, mx355); + migraphx::op::convolution convolution357; + convolution357.padding = {0, 0}; + convolution357.stride = {1, 1}; + convolution357.dilation = {1, 1}; + convolution357.group = 1; + auto mx357 = mm->add_instruction(convolution357, mx356, mx126); + migraphx::op::batch_norm_inference batch_norm_inference358; + batch_norm_inference358.epsilon = 1e-05; + batch_norm_inference358.momentum = 0.9; + auto mx358 = mm->add_instruction(batch_norm_inference358, mx357, mx125, mx124, mx123, mx122); + migraphx::op::relu relu359; + auto mx359 = mm->add_instruction(relu359, mx358); + migraphx::op::convolution convolution360; + convolution360.padding = {1, 1}; + convolution360.stride = {1, 1}; + convolution360.dilation = {1, 1}; + convolution360.group = 1; + auto mx360 = mm->add_instruction(convolution360, mx359, mx121); + migraphx::op::batch_norm_inference batch_norm_inference361; + batch_norm_inference361.epsilon = 1e-05; + batch_norm_inference361.momentum = 0.9; + auto mx361 = mm->add_instruction(batch_norm_inference361, mx360, mx120, mx119, mx118, mx117); + migraphx::op::relu relu362; + auto mx362 = mm->add_instruction(relu362, mx361); + migraphx::op::convolution convolution363; + convolution363.padding = {0, 0}; + convolution363.stride = {1, 1}; + convolution363.dilation = {1, 1}; + convolution363.group = 1; + auto mx363 = mm->add_instruction(convolution363, mx362, mx116); + migraphx::op::batch_norm_inference batch_norm_inference364; + batch_norm_inference364.epsilon = 1e-05; + batch_norm_inference364.momentum = 0.9; + auto mx364 = mm->add_instruction(batch_norm_inference364, mx363, mx115, mx114, mx113, mx112); + migraphx::op::add add365; + auto mx365 = mm->add_instruction(add365, mx364, mx356); + migraphx::op::relu relu366; + auto mx366 = mm->add_instruction(relu366, mx365); + migraphx::op::convolution convolution367; + convolution367.padding = {0, 0}; + convolution367.stride = {1, 1}; + convolution367.dilation = {1, 1}; + convolution367.group = 1; + auto mx367 = mm->add_instruction(convolution367, mx366, mx111); + migraphx::op::batch_norm_inference batch_norm_inference368; + batch_norm_inference368.epsilon = 1e-05; + batch_norm_inference368.momentum = 0.9; + auto mx368 = mm->add_instruction(batch_norm_inference368, mx367, mx110, mx109, mx108, mx107); + migraphx::op::relu relu369; + auto mx369 = mm->add_instruction(relu369, mx368); + migraphx::op::convolution convolution370; + convolution370.padding = {1, 1}; + convolution370.stride = {1, 1}; + convolution370.dilation = {1, 1}; + convolution370.group = 1; + auto mx370 = mm->add_instruction(convolution370, mx369, mx106); + migraphx::op::batch_norm_inference batch_norm_inference371; + batch_norm_inference371.epsilon = 1e-05; + batch_norm_inference371.momentum = 0.9; + auto mx371 = mm->add_instruction(batch_norm_inference371, mx370, mx105, mx104, mx103, mx102); + migraphx::op::relu relu372; + auto mx372 = mm->add_instruction(relu372, mx371); + migraphx::op::convolution convolution373; + convolution373.padding = {0, 0}; + convolution373.stride = {1, 1}; + convolution373.dilation = {1, 1}; + convolution373.group = 1; + auto mx373 = mm->add_instruction(convolution373, mx372, mx101); + migraphx::op::batch_norm_inference batch_norm_inference374; + batch_norm_inference374.epsilon = 1e-05; + batch_norm_inference374.momentum = 0.9; + auto mx374 = mm->add_instruction(batch_norm_inference374, mx373, mx100, mx99, mx98, mx97); + migraphx::op::add add375; + auto mx375 = mm->add_instruction(add375, mx374, mx366); + migraphx::op::relu relu376; + auto mx376 = mm->add_instruction(relu376, mx375); + migraphx::op::convolution convolution377; + convolution377.padding = {0, 0}; + convolution377.stride = {1, 1}; + convolution377.dilation = {1, 1}; + convolution377.group = 1; + auto mx377 = mm->add_instruction(convolution377, mx376, mx96); + migraphx::op::batch_norm_inference batch_norm_inference378; + batch_norm_inference378.epsilon = 1e-05; + batch_norm_inference378.momentum = 0.9; + auto mx378 = mm->add_instruction(batch_norm_inference378, mx377, mx95, mx94, mx93, mx92); + migraphx::op::relu relu379; + auto mx379 = mm->add_instruction(relu379, mx378); + migraphx::op::convolution convolution380; + convolution380.padding = {1, 1}; + convolution380.stride = {1, 1}; + convolution380.dilation = {1, 1}; + convolution380.group = 1; + auto mx380 = mm->add_instruction(convolution380, mx379, mx91); + migraphx::op::batch_norm_inference batch_norm_inference381; + batch_norm_inference381.epsilon = 1e-05; + batch_norm_inference381.momentum = 0.9; + auto mx381 = mm->add_instruction(batch_norm_inference381, mx380, mx90, mx89, mx88, mx87); + migraphx::op::relu relu382; + auto mx382 = mm->add_instruction(relu382, mx381); + migraphx::op::convolution convolution383; + convolution383.padding = {0, 0}; + convolution383.stride = {1, 1}; + convolution383.dilation = {1, 1}; + convolution383.group = 1; + auto mx383 = mm->add_instruction(convolution383, mx382, mx86); + migraphx::op::batch_norm_inference batch_norm_inference384; + batch_norm_inference384.epsilon = 1e-05; + batch_norm_inference384.momentum = 0.9; + auto mx384 = mm->add_instruction(batch_norm_inference384, mx383, mx85, mx84, mx83, mx82); + migraphx::op::add add385; + auto mx385 = mm->add_instruction(add385, mx384, mx376); + migraphx::op::relu relu386; + auto mx386 = mm->add_instruction(relu386, mx385); + migraphx::op::convolution convolution387; + convolution387.padding = {0, 0}; + convolution387.stride = {1, 1}; + convolution387.dilation = {1, 1}; + convolution387.group = 1; + auto mx387 = mm->add_instruction(convolution387, mx386, mx81); + migraphx::op::batch_norm_inference batch_norm_inference388; + batch_norm_inference388.epsilon = 1e-05; + batch_norm_inference388.momentum = 0.9; + auto mx388 = mm->add_instruction(batch_norm_inference388, mx387, mx80, mx79, mx78, mx77); + migraphx::op::relu relu389; + auto mx389 = mm->add_instruction(relu389, mx388); + migraphx::op::convolution convolution390; + convolution390.padding = {1, 1}; + convolution390.stride = {1, 1}; + convolution390.dilation = {1, 1}; + convolution390.group = 1; + auto mx390 = mm->add_instruction(convolution390, mx389, mx76); + migraphx::op::batch_norm_inference batch_norm_inference391; + batch_norm_inference391.epsilon = 1e-05; + batch_norm_inference391.momentum = 0.9; + auto mx391 = mm->add_instruction(batch_norm_inference391, mx390, mx75, mx74, mx73, mx72); + migraphx::op::relu relu392; + auto mx392 = mm->add_instruction(relu392, mx391); + migraphx::op::convolution convolution393; + convolution393.padding = {0, 0}; + convolution393.stride = {1, 1}; + convolution393.dilation = {1, 1}; + convolution393.group = 1; + auto mx393 = mm->add_instruction(convolution393, mx392, mx71); + migraphx::op::batch_norm_inference batch_norm_inference394; + batch_norm_inference394.epsilon = 1e-05; + batch_norm_inference394.momentum = 0.9; + auto mx394 = mm->add_instruction(batch_norm_inference394, mx393, mx70, mx69, mx68, mx67); + migraphx::op::add add395; + auto mx395 = mm->add_instruction(add395, mx394, mx386); + migraphx::op::relu relu396; + auto mx396 = mm->add_instruction(relu396, mx395); + migraphx::op::convolution convolution397; + convolution397.padding = {0, 0}; + convolution397.stride = {1, 1}; + convolution397.dilation = {1, 1}; + convolution397.group = 1; + auto mx397 = mm->add_instruction(convolution397, mx396, mx66); + migraphx::op::batch_norm_inference batch_norm_inference398; + batch_norm_inference398.epsilon = 1e-05; + batch_norm_inference398.momentum = 0.9; + auto mx398 = mm->add_instruction(batch_norm_inference398, mx397, mx65, mx64, mx63, mx62); + migraphx::op::relu relu399; + auto mx399 = mm->add_instruction(relu399, mx398); + migraphx::op::convolution convolution400; + convolution400.padding = {1, 1}; + convolution400.stride = {1, 1}; + convolution400.dilation = {1, 1}; + convolution400.group = 1; + auto mx400 = mm->add_instruction(convolution400, mx399, mx61); + migraphx::op::batch_norm_inference batch_norm_inference401; + batch_norm_inference401.epsilon = 1e-05; + batch_norm_inference401.momentum = 0.9; + auto mx401 = mm->add_instruction(batch_norm_inference401, mx400, mx60, mx59, mx58, mx57); + migraphx::op::relu relu402; + auto mx402 = mm->add_instruction(relu402, mx401); + migraphx::op::convolution convolution403; + convolution403.padding = {0, 0}; + convolution403.stride = {1, 1}; + convolution403.dilation = {1, 1}; + convolution403.group = 1; + auto mx403 = mm->add_instruction(convolution403, mx402, mx56); + migraphx::op::batch_norm_inference batch_norm_inference404; + batch_norm_inference404.epsilon = 1e-05; + batch_norm_inference404.momentum = 0.9; + auto mx404 = mm->add_instruction(batch_norm_inference404, mx403, mx55, mx54, mx53, mx52); + migraphx::op::add add405; + auto mx405 = mm->add_instruction(add405, mx404, mx396); + migraphx::op::relu relu406; + auto mx406 = mm->add_instruction(relu406, mx405); + migraphx::op::convolution convolution407; + convolution407.padding = {0, 0}; + convolution407.stride = {1, 1}; + convolution407.dilation = {1, 1}; + convolution407.group = 1; + auto mx407 = mm->add_instruction(convolution407, mx406, mx51); + migraphx::op::batch_norm_inference batch_norm_inference408; + batch_norm_inference408.epsilon = 1e-05; + batch_norm_inference408.momentum = 0.9; + auto mx408 = mm->add_instruction(batch_norm_inference408, mx407, mx50, mx49, mx48, mx47); + migraphx::op::relu relu409; + auto mx409 = mm->add_instruction(relu409, mx408); + migraphx::op::convolution convolution410; + convolution410.padding = {1, 1}; + convolution410.stride = {2, 2}; + convolution410.dilation = {1, 1}; + convolution410.group = 1; + auto mx410 = mm->add_instruction(convolution410, mx409, mx46); + migraphx::op::batch_norm_inference batch_norm_inference411; + batch_norm_inference411.epsilon = 1e-05; + batch_norm_inference411.momentum = 0.9; + auto mx411 = mm->add_instruction(batch_norm_inference411, mx410, mx45, mx44, mx43, mx42); + migraphx::op::relu relu412; + auto mx412 = mm->add_instruction(relu412, mx411); + migraphx::op::convolution convolution413; + convolution413.padding = {0, 0}; + convolution413.stride = {1, 1}; + convolution413.dilation = {1, 1}; + convolution413.group = 1; + auto mx413 = mm->add_instruction(convolution413, mx412, mx41); + migraphx::op::batch_norm_inference batch_norm_inference414; + batch_norm_inference414.epsilon = 1e-05; + batch_norm_inference414.momentum = 0.9; + auto mx414 = mm->add_instruction(batch_norm_inference414, mx413, mx40, mx39, mx38, mx37); + migraphx::op::convolution convolution415; + convolution415.padding = {0, 0}; + convolution415.stride = {2, 2}; + convolution415.dilation = {1, 1}; + convolution415.group = 1; + auto mx415 = mm->add_instruction(convolution415, mx406, mx36); + migraphx::op::batch_norm_inference batch_norm_inference416; + batch_norm_inference416.epsilon = 1e-05; + batch_norm_inference416.momentum = 0.9; + auto mx416 = mm->add_instruction(batch_norm_inference416, mx415, mx35, mx34, mx33, mx32); + migraphx::op::add add417; + auto mx417 = mm->add_instruction(add417, mx414, mx416); + migraphx::op::relu relu418; + auto mx418 = mm->add_instruction(relu418, mx417); + migraphx::op::convolution convolution419; + convolution419.padding = {0, 0}; + convolution419.stride = {1, 1}; + convolution419.dilation = {1, 1}; + convolution419.group = 1; + auto mx419 = mm->add_instruction(convolution419, mx418, mx31); + migraphx::op::batch_norm_inference batch_norm_inference420; + batch_norm_inference420.epsilon = 1e-05; + batch_norm_inference420.momentum = 0.9; + auto mx420 = mm->add_instruction(batch_norm_inference420, mx419, mx30, mx29, mx28, mx27); + migraphx::op::relu relu421; + auto mx421 = mm->add_instruction(relu421, mx420); + migraphx::op::convolution convolution422; + convolution422.padding = {1, 1}; + convolution422.stride = {1, 1}; + convolution422.dilation = {1, 1}; + convolution422.group = 1; + auto mx422 = mm->add_instruction(convolution422, mx421, mx26); + migraphx::op::batch_norm_inference batch_norm_inference423; + batch_norm_inference423.epsilon = 1e-05; + batch_norm_inference423.momentum = 0.9; + auto mx423 = mm->add_instruction(batch_norm_inference423, mx422, mx25, mx24, mx23, mx22); + migraphx::op::relu relu424; + auto mx424 = mm->add_instruction(relu424, mx423); + migraphx::op::convolution convolution425; + convolution425.padding = {0, 0}; + convolution425.stride = {1, 1}; + convolution425.dilation = {1, 1}; + convolution425.group = 1; + auto mx425 = mm->add_instruction(convolution425, mx424, mx21); + migraphx::op::batch_norm_inference batch_norm_inference426; + batch_norm_inference426.epsilon = 1e-05; + batch_norm_inference426.momentum = 0.9; + auto mx426 = mm->add_instruction(batch_norm_inference426, mx425, mx20, mx19, mx18, mx17); + migraphx::op::add add427; + auto mx427 = mm->add_instruction(add427, mx426, mx418); + migraphx::op::relu relu428; + auto mx428 = mm->add_instruction(relu428, mx427); + migraphx::op::convolution convolution429; + convolution429.padding = {0, 0}; + convolution429.stride = {1, 1}; + convolution429.dilation = {1, 1}; + convolution429.group = 1; + auto mx429 = mm->add_instruction(convolution429, mx428, mx16); + migraphx::op::batch_norm_inference batch_norm_inference430; + batch_norm_inference430.epsilon = 1e-05; + batch_norm_inference430.momentum = 0.9; + auto mx430 = mm->add_instruction(batch_norm_inference430, mx429, mx15, mx14, mx13, mx12); + migraphx::op::relu relu431; + auto mx431 = mm->add_instruction(relu431, mx430); + migraphx::op::convolution convolution432; + convolution432.padding = {1, 1}; + convolution432.stride = {1, 1}; + convolution432.dilation = {1, 1}; + convolution432.group = 1; + auto mx432 = mm->add_instruction(convolution432, mx431, mx11); + migraphx::op::batch_norm_inference batch_norm_inference433; + batch_norm_inference433.epsilon = 1e-05; + batch_norm_inference433.momentum = 0.9; + auto mx433 = mm->add_instruction(batch_norm_inference433, mx432, mx10, mx9, mx8, mx7); + migraphx::op::relu relu434; + auto mx434 = mm->add_instruction(relu434, mx433); + migraphx::op::convolution convolution435; + convolution435.padding = {0, 0}; + convolution435.stride = {1, 1}; + convolution435.dilation = {1, 1}; + convolution435.group = 1; + auto mx435 = mm->add_instruction(convolution435, mx434, mx6); + migraphx::op::batch_norm_inference batch_norm_inference436; + batch_norm_inference436.epsilon = 1e-05; + batch_norm_inference436.momentum = 0.9; + auto mx436 = mm->add_instruction(batch_norm_inference436, mx435, mx5, mx4, mx3, mx2); + migraphx::op::add add437; + auto mx437 = mm->add_instruction(add437, mx436, mx428); + migraphx::op::relu relu438; + auto mx438 = mm->add_instruction(relu438, mx437); + migraphx::op::pooling pooling439; + pooling439.mode = migraphx::op::pooling_mode::average; + pooling439.padding = {0, 0}; + pooling439.stride = {1, 1}; + pooling439.lengths = {7, 7}; + auto mx439 = mm->add_instruction(pooling439, mx438); + migraphx::op::flatten flatten440; + flatten440.axis = 1; + auto mx440 = mm->add_instruction(flatten440, mx439); + migraphx::op::transpose transpose441; + transpose441.dims = {1, 0}; + auto mx441 = mm->add_instruction(transpose441, mx1); + migraphx::op::multibroadcast multibroadcast442; + multibroadcast442.output_lens = {batch, 1000}; + auto mx442 = mm->add_instruction(multibroadcast442, mx0); + float dot443_alpha = 1; + float dot443_beta = 1; + migraphx::add_apply_alpha_beta( + *mm, {mx440, mx441, mx442}, migraphx::make_op("dot"), dot443_alpha, dot443_beta); + return p; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace driver +} // namespace migraphx diff --git a/src/driver/verify.cpp b/src/driver/verify.cpp index a6fd1f6d6f658d117e91fc7a95eda09f0167a835..0665661d3245441d2a3dea2fecee22e4dece89d2 100644 --- a/src/driver/verify.cpp +++ b/src/driver/verify.cpp @@ -1,69 +1,78 @@ #include "verify.hpp" +#include "perf.hpp" -#include +#include #include #include #include - -#ifdef HAVE_GPU -#include -#include -#endif +#include +#include namespace migraphx { namespace driver { inline namespace MIGRAPHX_INLINE_NS { -template -auto get_hash(const T& x) -{ - return std::hash{}(x); -} - -argument run_cpu(program p) +std::vector run_ref(program p, const parameter_map& inputs) { - p.compile(cpu::target{}); - program::parameter_map m; - for(auto&& x : p.get_parameter_shapes()) - { - m[x.first] = generate_argument(x.second, get_hash(x.first)); - } - auto out = p.eval(m); + p.compile(ref::target{}); + auto out = p.eval(inputs); std::cout << p << std::endl; return out; } -argument run_gpu(program p) +std::vector run_target(program p, + const target& t, + const compile_options& options, + precision quantize, + const parameter_map& inputs) { -#ifdef HAVE_GPU - p.compile(gpu::target{}); + if(quantize == precision::fp16) + { + quantize_fp16(p); + } + p.compile(t, options); - program::parameter_map m; + parameter_map m; for(auto&& x : p.get_parameter_shapes()) { - m[x.first] = gpu::to_gpu(generate_argument(x.second, get_hash(x.first))); + auto arg = inputs.count(x.first) == 0 ? generate_argument(x.second) : inputs.at(x.first); + m[x.first] = options.offload_copy ? arg : t.copy_to(arg); } - auto out = gpu::from_gpu(p.eval(m)); + auto gpu_out = p.eval(m); + std::vector output(gpu_out.size()); std::cout << p << std::endl; - return gpu::from_gpu(out); -#else - (void)p; - MIGRAPHX_THROW("Gpu unsupported!"); -#endif + std::transform(gpu_out.begin(), gpu_out.end(), output.begin(), [&](auto& argu) { + return options.offload_copy ? argu : t.copy_from(argu); + }); + return output; } -void verify_program(const std::string& name, const program& p, double tolerance) +void verify_program(const std::string& name, + const program& p, + const target& t, + compile_options options, + precision quantize, + const parameter_map& inputs, + double tolerance) { - auto x = run_cpu(p); - auto y = run_gpu(p); - verify_args(name, x, y, tolerance); - // std::cout << "cpu: " << x << std::endl; - // std::cout << "gpu: " << y << std::endl; + auto x = run_ref(p, inputs); + auto y = run_target(p, t, options, quantize, inputs); + + std::size_t output_num = x.size(); + for(std::size_t i = 0; i < output_num; ++i) + { + verify_args(name, x[i], y[i], tolerance); + } } -void verify_instructions(const program& prog, double tolerance) +void verify_instructions(const program& prog, + const target& t, + compile_options options, + precision quantize, + double tolerance) { - for(auto&& ins : prog) + const auto* mm_prog = prog.get_main_module(); + for(auto&& ins : (*mm_prog)) { if(ins.name().front() == '@') continue; @@ -73,21 +82,26 @@ void verify_instructions(const program& prog, double tolerance) continue; if(ins.name() == "reshape") continue; + if(ins.name() == "undefined") + continue; program p; + auto* mm_p = p.get_main_module(); std::vector inputs; for(auto&& arg : ins.inputs()) { if(arg->name() == "@literal") - inputs.push_back(p.add_literal(arg->get_literal())); + inputs.push_back(mm_p->add_literal(arg->get_literal())); else - inputs.push_back(p.add_parameter(std::to_string(inputs.size()), arg->get_shape())); + inputs.push_back( + mm_p->add_parameter(std::to_string(inputs.size()), arg->get_shape())); } - p.add_instruction(ins.get_operator(), inputs); + mm_p->add_instruction(ins.get_operator(), inputs); try { std::cout << "Verify: " << ins.name() << std::endl; std::cout << p << std::endl; - verify_program(ins.name(), p, tolerance); + verify_program( + ins.name(), p, t, options, quantize, create_param_map(p, false), tolerance); } catch(...) { @@ -97,21 +111,34 @@ void verify_instructions(const program& prog, double tolerance) } } -void verify_reduced(program p, int n, double tolerance) +void verify_reduced(program p, + int n, + const target& t, + compile_options options, + precision quantize, + const parameter_map& inputs, + double tolerance) { - auto last = std::prev(p.end(), n + 1); - p.remove_instructions(last, p.end()); + auto* mm = p.get_main_module(); + auto last = std::prev(mm->end(), n + 1); + mm->remove_instructions(last, mm->end()); std::cout << "Verify: " << std::endl; std::cout << p << std::endl; - verify_program(std::to_string(n), p, tolerance); + verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance); } -void verify_reduced_program(const program& p, double tolerance) +void verify_reduced_program(const program& p, + const target& t, + compile_options options, + precision quantize, + const parameter_map& inputs, + double tolerance) { - auto n = std::distance(p.begin(), p.end()); + const auto* mm = p.get_main_module(); + auto n = std::distance(mm->begin(), mm->end()); for(std::size_t i = 0; i < n; i++) { - verify_reduced(p, i, tolerance); + verify_reduced(p, i, t, options, quantize, inputs, tolerance); } } diff --git a/src/driver/verify.hpp b/src/driver/verify.hpp index 8b420e137e49474a37b07f456410f7b2072aad33..9f010d59e89e1261f44c925cbe173fbad9b32615 100644 --- a/src/driver/verify.hpp +++ b/src/driver/verify.hpp @@ -1,17 +1,31 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_HPP #define MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_HPP +#include "precision.hpp" #include namespace migraphx { namespace driver { inline namespace MIGRAPHX_INLINE_NS { -argument run_cpu(program p); -argument run_gpu(program p); -void verify_program(const std::string& name, const program& p, double tolerance = 100); -void verify_instructions(const program& prog, double tolerance = 80); -void verify_reduced_program(const program& p, double tolerance = 80); +void verify_program(const std::string& name, + const program& p, + const target& t, + compile_options options = compile_options{}, + precision quantize = precision::fp32, + const parameter_map& inputs = {}, + double tolerance = 100); +void verify_instructions(const program& prog, + const target& t, + compile_options options = compile_options{}, + precision quantize = precision::fp32, + double tolerance = 80); +void verify_reduced_program(const program& p, + const target& t, + compile_options options = compile_options{}, + precision quantize = precision::fp32, + const parameter_map& inputs = {}, + double tolerance = 80); } // namespace MIGRAPHX_INLINE_NS } // namespace driver diff --git a/src/dynamic_loader.cpp b/src/dynamic_loader.cpp new file mode 100755 index 0000000000000000000000000000000000000000..ab0cbbecb44863054431585cca414674de24ab6b --- /dev/null +++ b/src/dynamic_loader.cpp @@ -0,0 +1,56 @@ +#include +#include +#include +#include +#include + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct dynamic_loader_impl +{ + dynamic_loader_impl() = default; + dynamic_loader_impl(const fs::path& p, std::shared_ptr t = nullptr) + : handle(dlopen(p.string().c_str(), RTLD_LAZY), &dlclose), temp(std::move(t)) + { + } + + static std::shared_ptr from_buffer(const char* image, std::size_t size) + { + auto t = std::make_shared("dloader"); + auto f = t->path / "libtmp.so"; + write_buffer(f.string(), image, size); + return std::make_shared(f, t); + } + + std::shared_ptr handle = nullptr; + std::shared_ptr temp = nullptr; +}; + +dynamic_loader::dynamic_loader(const fs::path& p) : impl(std::make_shared(p)) +{ +} + +dynamic_loader::dynamic_loader(const char* image, std::size_t size) + : impl(dynamic_loader_impl::from_buffer(image, size)) +{ +} + +dynamic_loader::dynamic_loader(const std::vector& buffer) + : impl(dynamic_loader_impl::from_buffer(buffer.data(), buffer.size())) +{ +} + +std::shared_ptr dynamic_loader::get_symbol(const std::string& name) const +{ + dlerror(); + void* symbol = dlsym(impl->handle.get(), name.c_str()); + if(symbol == nullptr) + MIGRAPHX_THROW("Symbol not found: " + name); + return {impl, symbol}; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/eliminate_allocation.cpp b/src/eliminate_allocation.cpp old mode 100644 new mode 100755 index bfd82ae16abd16e40efcf2f27f9b49fbc5e4256c..ffd52116e78a2d0e12410567156bfefed336e314 --- a/src/eliminate_allocation.cpp +++ b/src/eliminate_allocation.cpp @@ -1,22 +1,25 @@ #include #include #include -#include #include #include #include +#include + +#include + #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -void eliminate_allocation::apply(program& p) const +void eliminate_allocation::apply(module& m) const { assert(alignment > 0); std::size_t n = 0; std::vector> allocs; - for(auto ins : iterator_for(p)) + for(auto ins : iterator_for(m)) { if(ins->name() != allocation_op) continue; @@ -27,13 +30,14 @@ void eliminate_allocation::apply(program& p) const } if(n > 0) { - auto mem = p.add_parameter("memory", shape{shape::int8_type, {n}}); + auto mem = m.add_parameter("memory", shape{shape::int8_type, {n}}); for(auto&& pp : allocs) { auto ins = pp.first; auto s = ins->get_shape(); auto offset = pp.second; - p.replace_instruction(ins, op::load{s, offset}, mem); + m.replace_instruction( + ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem); } } } diff --git a/src/eliminate_common_subexpression.cpp b/src/eliminate_common_subexpression.cpp index de7497d2d05cf899c4358f3d83ae10324ca04fae..03308019daef3d95c29519d601bf10f33164448f 100644 --- a/src/eliminate_common_subexpression.cpp +++ b/src/eliminate_common_subexpression.cpp @@ -11,33 +11,43 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { template -void cse_range(program& p, Range&& r) +void cse_range(module& m, Range&& r) { std::unordered_multimap instructions; + std::unordered_set processed_ins; for(auto ins : r) { // Skip dead instructions if(ins->outputs().empty()) continue; + // Find instruction with the same name auto found_instructions = range(instructions.equal_range(ins->name())); for(const auto& pp : found_instructions) { auto eq = pp.second; + if(contains(processed_ins, eq)) + continue; if(*eq != *ins) continue; - p.replace_instruction(ins, eq); - auto outputs = eq->outputs(); + m.replace_instruction(ins, eq); + processed_ins.emplace(ins); + std::vector outputs; + std::copy_if(eq->outputs().begin(), + eq->outputs().end(), + std::back_inserter(outputs), + [&](auto x) { return m.has_instruction(x); }); + std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) { return std::distance(eq, x) < std::distance(eq, y); }); - cse_range(p, outputs); + cse_range(m, outputs); } instructions.emplace(ins->name(), ins); } } -void eliminate_common_subexpression::apply(program& p) const { cse_range(p, iterator_for(p)); } +void eliminate_common_subexpression::apply(module& m) const { cse_range(m, iterator_for(m)); } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/eliminate_concat.cpp b/src/eliminate_concat.cpp index b3290a83d8df9026196dc94511e2eaf8b1a0428a..9c37d416fcd7dcd45bee4cb69ce95b018c570ad1 100644 --- a/src/eliminate_concat.cpp +++ b/src/eliminate_concat.cpp @@ -6,13 +6,16 @@ #include #include #include +#include + #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -void eliminate_concat::apply(program& p) const +void eliminate_concat::apply(module& m) const { - for(auto ins : iterator_for(p)) + for(auto ins : iterator_for(m)) { // Look for the concat operator if(ins->name() != concat_opt.name()) @@ -31,12 +34,11 @@ void eliminate_concat::apply(program& p) const // axis OR the sizes to the left of this axis are all equal to 1 // Since we've already checked that the non-axis dimensions are identical // we only need to check the first input - auto lens = ins->inputs().front()->get_shape().lens(); - auto concat_op = concat_opt.get_concat(ins->get_operator()); - std::size_t axis_index = - (concat_op.axis < 0) ? (concat_op.axis + lens.size()) : concat_op.axis; + auto lens = ins->inputs().front()->get_shape().lens(); + auto concat_op = concat_opt.get_concat(ins->get_operator()); + std::size_t axis_index = tune_axis(lens.size(), concat_op.axis, concat_op.name()); if(axis_index == 0 || - std::all_of(lens.begin(), lens.begin() + concat_op.axis, [](auto x) { return x == 1; })) + std::all_of(lens.begin(), lens.begin() + axis_index, [](auto x) { return x == 1; })) { // Last input should be an allocation auto last = ins->inputs().back(); @@ -58,24 +60,26 @@ void eliminate_concat::apply(program& p) const // Need to sort the allocations, so that we know where to // insert the "super"-allocation - std::sort( - allocations.begin(), allocations.end(), [&](instruction_ref x, instruction_ref y) { - return std::distance(p.begin(), x) < std::distance(p.begin(), y); - }); + auto sorted_allocations = allocations; + std::sort(sorted_allocations.begin(), + sorted_allocations.end(), + [&](instruction_ref x, instruction_ref y) { + return std::distance(m.begin(), x) < std::distance(m.begin(), y); + }); // Move "super" allocation to the front - auto first = allocations.front(); - auto super = p.move_instruction(last, first); + auto first = sorted_allocations.front(); + auto super = m.move_instruction(last, first); // Replace each allocation with a load std::size_t offset = 0; for(auto alloc : allocations) { op::load op{alloc->get_shape(), offset}; - p.replace_instruction(alloc, op, {super}); + m.replace_instruction(alloc, op, {super}); offset += alloc->get_shape().bytes(); } std::vector args = {super}; std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args)); - p.replace_instruction(ins, migraphx::op::identity{}, args); + m.replace_instruction(ins, migraphx::make_op("identity"), args); } } } diff --git a/src/eliminate_contiguous.cpp b/src/eliminate_contiguous.cpp index 1a2fe0f81de25374d0adfefdf6549c92ee3af0f0..217103e54894b31b846a5cd04016bc7ba42581d4 100644 --- a/src/eliminate_contiguous.cpp +++ b/src/eliminate_contiguous.cpp @@ -6,16 +6,19 @@ #include #include #include +#include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -static bool try_compute_shape(instruction_ref ins, const std::vector& inputs) +static bool try_compute_shape(instruction_ref ins, + const std::vector& inputs, + const std::vector& mods) { try { - shape new_shape = ins->get_operator().compute_shape(inputs); + shape new_shape = ins->get_operator().compute_shape(inputs, mods); // If the output shape is a standard shape, no need to try its output if(new_shape.standard()) { @@ -45,7 +48,7 @@ static bool try_compute_shape(instruction_ref ins, const std::vector& inp return (arg == ins) ? new_shape : arg->get_shape(); }); - if(!try_compute_shape(output, input_shapes)) + if(!try_compute_shape(output, input_shapes, mods)) { return false; } @@ -59,42 +62,60 @@ static bool try_compute_shape(instruction_ref ins, const std::vector& inp return true; } -static bool try_compute_shape(instruction_ref ins, const std::vector& args) +static bool try_compute_shape(instruction_ref ins, + const std::vector& args, + const std::vector& mods) { auto inputs = to_shapes(args); - return try_compute_shape(ins, inputs); + return try_compute_shape(ins, inputs, mods); } -void eliminate_contiguous::apply(program& p) const +void eliminate_contiguous::apply(module& m) const { - for(auto ins : iterator_for(p)) + std::vector const_instruction; + + for(auto ins : iterator_for(m)) { + // return instruction should have inputs with standard shape + if(ins->name() == "@return") + continue; + // Make a copy so we can modify it while we iterate - auto args = ins->inputs(); + auto args = ins->inputs(); + auto new_args = args; + auto mod_args = ins->module_inputs(); + for(auto arg : ins->inputs()) { - // TODO: Pass in names for the operator in the constructor instead - // of using ends_with - if(ends_with(arg->name(), "contiguous")) + if(arg->name() == op_name) { - auto new_args = args; - auto prev = arg->inputs().front(); + auto prev = arg->inputs().front(); replace(new_args, arg, prev); - if(try_compute_shape(ins, new_args)) + if(try_compute_shape(ins, new_args, mod_args)) { instruction::replace_argument(ins, arg, prev); } else if(prev->can_eval()) { - auto c = op::contiguous{}; - auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()}); - - auto l = p.add_literal(r.get_shape(), r.data()); - p.replace_instruction(arg, l); + const_instruction.push_back(arg); } } } } + + // Perform evaluations in parallel + std::vector literals(const_instruction.size()); + par_for(const_instruction.size(), 1, [&](const auto i) { + auto c = op::contiguous{}; + auto prev = const_instruction[i]->inputs().front(); + literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()}); + }); + + for(size_t i = 0; i < const_instruction.size(); i++) + { + auto l = m.add_literal(literals[i].get_shape(), literals[i].data()); + m.replace_instruction(const_instruction[i], l); + } } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/eliminate_data_type.cpp b/src/eliminate_data_type.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e51fc2eb9c368b1e9990fb809b6685ba2aa870e0 --- /dev/null +++ b/src/eliminate_data_type.cpp @@ -0,0 +1,50 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +void eliminate_data_type::apply(module& m) const +{ + static const std::vector skip_op_names = {"convert", + "get_tuple_elem", + "if", + "loop", + "roialign", + "scatternd_add", + "scatternd_mul", + "scatternd_none"}; + for(auto ins : iterator_for(m)) + { + if(ins->name()[0] == '@') + continue; + if(contains(skip_op_names, ins->name())) + continue; + auto inputs = ins->inputs(); + std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) { + if(types.count(i->get_shape().type()) == 0) + return i; + return m.insert_instruction(ins, make_op("convert", {{"target_type", target_type}}), i); + }); + if(inputs == ins->inputs()) + continue; + auto op = ins->get_operator(); + auto attributes = op.attributes(); + if(attributes.contains("general_data_type")) + { + op = make_op(attributes["general_data_type"].to(), op.to_value()); + } + auto old_type = ins->get_shape().type(); + auto out = m.insert_instruction(ins, op, inputs); + auto convert = + m.insert_instruction(ins, make_op("convert", {{"target_type", old_type}}), out); + m.replace_instruction(ins, convert); + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/eliminate_identity.cpp b/src/eliminate_identity.cpp index 5795d59a2bafd9861d38b229f2cfff3d5c48fcfc..de339e36b439d91790e829280d296206a53af59f 100644 --- a/src/eliminate_identity.cpp +++ b/src/eliminate_identity.cpp @@ -8,21 +8,21 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -void eliminate_identity::apply(program& p) const +void eliminate_identity::apply(module& m) const { - auto last = std::prev(p.end()); - for(auto ins : iterator_for(p)) + auto last = std::prev(m.end()); + for(auto ins : iterator_for(m)) { // Skip the first instruction, since we always process the previous // instruction - if(ins == p.begin()) + if(ins == m.begin()) continue; const auto i = std::prev(ins); if(i->name() == "identity") { - p.replace_instruction(i, i->inputs().front()); - p.move_instruction(i, p.end()); + m.replace_instruction(i, i->inputs().front()); + m.move_instruction(i, m.end()); } if(ins == last) { @@ -31,7 +31,7 @@ void eliminate_identity::apply(program& p) const const instruction_ref& identity_input = ins->inputs().front(); if(identity_input->outputs().size() == 1) { - p.move_instruction(identity_input, i); + m.move_instruction(identity_input, i); // since this is the last instruction, removing it only // requires changing "last" and calling remove below last = std::prev(last); @@ -40,7 +40,7 @@ void eliminate_identity::apply(program& p) const break; } } - p.remove_instructions(std::next(last), p.end()); + m.remove_instructions(std::next(last), m.end()); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/eliminate_pad.cpp b/src/eliminate_pad.cpp index 524e7434d4f22872d6694ba8f4aa08cc12e6affd..04a0845243c64406eaf67b350fdf615ae9eb3fca 100644 --- a/src/eliminate_pad.cpp +++ b/src/eliminate_pad.cpp @@ -5,51 +5,86 @@ #include #include #include +#include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -void eliminate_pad::apply(program& p) const +static void update_op(const instruction_ref& input, const instruction_ref& ins, module& m) { - for(auto ins : iterator_for(p)) - { - const std::string& op_name = ins->name(); - if(op_name != "convolution" and op_name != "im2col" and op_name != "pooling") - continue; - auto input = ins->inputs().front(); - if(input->name() != "pad") - continue; - if(op_name == "convolution") - update_op(op::convolution{}, input, ins, p); - else if(op_name == "im2col") - update_op(op::im2col{}, input, ins, p); - else if(op_name == "pooling") - update_op(op::pooling{}, input, ins, p); - } + auto pad_op = any_cast(input->get_operator()); + + auto kdims = input->get_shape().lens().size() - 2; + auto kdims_it = pad_op.pads.begin() + 2; + + std::vector pads_l(kdims_it, kdims_it + kdims); + std::vector pads_r(kdims_it + kdims + 2, pad_op.pads.end()); + + auto op = ins->get_operator(); + std::vector padding(kdims * 2, 0); + + std::transform( + pads_l.begin(), pads_l.end(), padding.begin(), padding.begin(), std::plus()); + std::transform(pads_r.begin(), + pads_r.end(), + padding.begin() + kdims, + padding.begin() + kdims, + std::plus()); + + op.from_value({{"padding", padding}}); + + std::vector new_inputs{ins->inputs()}; + new_inputs.front() = input->inputs().front(); + + m.replace_instruction(ins, op, new_inputs); } -template -void eliminate_pad::update_op(T, - const instruction_ref& input, - const instruction_ref& ins, - program& p) const +static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m) { - auto pad_op = any_cast(input->get_operator()); - if(!pad_op.symmetric()) + auto op = any_cast(ins->get_operator()); + if(op.mode == op::pooling_mode::average) + { return; + } + auto pad_op = any_cast(input->get_operator()); - std::vector pads = pad_op.pads; - std::array new_pads{static_cast(pads[2]), static_cast(pads[3])}; + auto kdims = input->get_shape().lens().size() - 2; + auto kdims_it = pad_op.pads.begin() + 2; - T op = any_cast(ins->get_operator()); - op.padding = new_pads; + std::vector pads_l(kdims_it, kdims_it + kdims); + std::vector pads_r(kdims_it + kdims + 2, pad_op.pads.end()); + + std::transform( + pads_l.begin(), pads_l.end(), op.padding.begin(), op.padding.begin(), std::plus()); + std::transform(pads_r.begin(), + pads_r.end(), + op.padding.begin() + kdims, + op.padding.begin() + kdims, + std::plus()); std::vector new_inputs{ins->inputs()}; new_inputs.front() = input->inputs().front(); - p.replace_instruction(ins, op, new_inputs); + m.replace_instruction(ins, op, new_inputs); +} + +void eliminate_pad::apply(module& m) const +{ + for(auto ins : iterator_for(m)) + { + const std::string& op_name = ins->name(); + if(op_name != "convolution" and op_name != "im2col" and op_name != "pooling") + continue; + auto input = ins->inputs().front(); + if(input->name() != "pad") + continue; + if(op_name == "convolution" or op_name == "im2col") + update_op(input, ins, m); + else if(op_name == "pooling") + update_pooling(input, ins, m); + } } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/env.cpp b/src/env.cpp index 204f92b67789231ae2a46135406ab37e55281132..22c29fe342b0b1321328b4c367f010468181501b 100644 --- a/src/env.cpp +++ b/src/env.cpp @@ -29,9 +29,17 @@ std::size_t value_of(const char* name, std::size_t fallback) return std::stoul(e.front()); } +std::string string_value_of(const char* name, std::string fallback) +{ + auto e = env(name); + if(e.empty()) + return fallback; + return e.front(); +} + std::vector env(const char* name) { - auto p = std::getenv(name); + auto* p = std::getenv(name); if(p == nullptr) return {}; else diff --git a/src/file_buffer.cpp b/src/file_buffer.cpp new file mode 100755 index 0000000000000000000000000000000000000000..abefa46f003030d552e798d3d0311a5e81b51db7 --- /dev/null +++ b/src/file_buffer.cpp @@ -0,0 +1,45 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +T generic_read_file(const std::string& filename) +{ + std::ifstream is(filename, std::ios::binary | std::ios::ate); + std::streamsize size = is.tellg(); + if(size < 1) + MIGRAPHX_THROW("Invalid size for: " + filename); + is.seekg(0, std::ios::beg); + + T buffer(size, 0); + if(!is.read(&buffer[0], size)) + MIGRAPHX_THROW("Error reading file: " + filename); + return buffer; +} + +std::vector read_buffer(const std::string& filename) +{ + return generic_read_file>(filename); +} + +std::string read_string(const std::string& filename) +{ + return generic_read_file(filename); +} + +void write_buffer(const std::string& filename, const char* buffer, std::size_t size) +{ + std::ofstream os(filename); + os.write(buffer, size); +} +void write_buffer(const std::string& filename, const std::vector& buffer) +{ + write_buffer(filename, buffer.data(), buffer.size()); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp new file mode 100644 index 0000000000000000000000000000000000000000..799a5c5b3d27087ed3d0e5352df271af42504d18 --- /dev/null +++ b/src/fuse_pointwise.cpp @@ -0,0 +1,166 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +static literal get_scalar(instruction_ref ins) +{ + if(ins->name() == "contiguous") + return get_scalar(ins->inputs().front()); + const auto& s = ins->get_shape(); + if(not(s.elements() == 1 or s.scalar())) + return {}; + if(not ins->can_eval()) + return {}; + auto e = ins->eval(); + literal r{}; + e.visit_at([&](auto x) { r = literal{x}; }); + return r; +} + +static void create_pointwise_modules(module_pass_manager& mpm) +{ + std::size_t n = 0; + for(auto ins : iterator_for(mpm.get_module())) + { + if(not ins->get_operator().attributes().get("pointwise", false)) + continue; + assert(ins->get_operator().attributes().contains("point_op")); + auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++)); + pm->set_bypass(); + + std::unordered_map param_map; + std::vector pointwise_inputs; + std::size_t i = 0; + for(auto input : ins->inputs()) + { + if(contains(param_map, input)) + continue; + auto scalar = get_scalar(input); + if(scalar.empty()) + { + pointwise_inputs.push_back(input); + param_map[input] = + pm->add_parameter("x" + std::to_string(i), shape{input->get_shape().type()}); + i++; + } + else + { + param_map[input] = pm->add_literal(scalar); + } + } + + std::vector inputs; + std::transform(ins->inputs().begin(), + ins->inputs().end(), + std::back_inserter(inputs), + [&](auto input) { return param_map[input]; }); + auto r = pm->add_instruction(ins->get_operator(), inputs); + pm->add_return({r}); + + mpm.get_module().replace_instruction(ins, make_op("pointwise"), pointwise_inputs, {pm}); + } +} + +static std::vector append_pointwise_module(instruction_ref ins, + instruction_ref output) +{ + assert(contains(output->inputs(), ins)); + module_ref pm = ins->module_inputs().at(0); + module_ref xm = output->module_inputs().at(0); + + auto last = std::prev(pm->end()); + assert(last->name() == "@return"); + assert(last->inputs().size() == 1); + + assert(pm->get_parameter_names().size() == ins->inputs().size()); + assert(xm->get_parameter_names().size() == output->inputs().size()); + + std::vector inputs = ins->inputs(); + std::unordered_map map_ins; + std::unordered_map input_map; + // Copy inputs to input_map + for(auto i : range(inputs.size())) + { + auto input = inputs[i]; + auto param = pm->get_parameter("x" + std::to_string(i)); + assert(param != pm->end()); + input_map[input] = param; + } + // Add the new parameter and additional inputs + for(auto i : range(output->inputs().size())) + { + auto input = output->inputs()[i]; + auto param = xm->get_parameter("x" + std::to_string(i)); + assert(param != xm->end()); + if(input == ins) + { + map_ins[param] = last->inputs().front(); + input_map[input] = map_ins[param]; + } + // Avoid duplicate paramter inputs + else if(contains(input_map, input)) + { + map_ins[param] = input_map[input]; + } + else + { + map_ins[param] = + pm->add_parameter("x" + std::to_string(inputs.size()), {input->get_shape().type()}); + inputs.push_back(input); + input_map[input] = map_ins[param]; + } + } + pm->replace_return(pm->insert_module_instructions(last, xm, map_ins)); + return inputs; +} + +static bool find_pointwise_modules(module& m) +{ + bool changed = false; + auto last = std::prev(m.end()); + for(auto ins : iterator_for(m)) + { + if(ins->name() != "pointwise") + continue; + if(ins->outputs().empty() and ins != last) + continue; + auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) { + return i->name() == "pointwise" and i->outputs().size() == 1; + }); + if(it == ins->inputs().end()) + continue; + auto input = *it; + + auto new_inputs = append_pointwise_module(input, ins); + m.replace_instruction(input, input->get_operator(), new_inputs, input->module_inputs()); + m.replace_instruction(ins, input); + m.move_instruction(input, ins); + + changed = true; + } + return changed; +} + +void fuse_pointwise::apply(module_pass_manager& mpm) const +{ + create_pointwise_modules(mpm); + mpm.run_pass(dead_code_elimination{}); + for(int i = 0; i < 8; i++) + { + if(not find_pointwise_modules(mpm.get_module())) + break; + mpm.run_pass(dead_code_elimination{}); + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/generate.cpp b/src/generate.cpp index d8aa8960f17315708ac2bbf31cf932c45503e6af..aa875b66185a9d98a8f4bf3990f58ed9ca90df2e 100644 --- a/src/generate.cpp +++ b/src/generate.cpp @@ -6,22 +6,59 @@ inline namespace MIGRAPHX_INLINE_NS { argument fill_argument(shape s, unsigned long value) { argument result; - s.visit_type([&](auto as) { - using type = typename decltype(as)::type; - auto v = fill_tensor_data(s, value); - result = {s, v}; - }); + if(s.type() == shape::tuple_type) + { + std::vector sub_args; + const auto& sub_ss = s.sub_shapes(); + std::transform(sub_ss.begin(), sub_ss.end(), std::back_inserter(sub_args), [&](auto ss) { + return fill_argument(ss, value); + }); + + result = argument(sub_args); + } + else + { + s.visit_type([&](auto as) { + using type = typename decltype(as)::type; + auto v = fill_tensor_data(s, value); + result = {s, v}; + }); + } return result; } argument generate_argument(shape s, unsigned long seed) { argument result; - s.visit_type([&](auto as) { - using type = typename decltype(as)::type; - auto v = generate_tensor_data(s, seed); - result = {s, v}; - }); + if(s.type() == shape::tuple_type) + { + const auto& sub_ss = s.sub_shapes(); + std::vector sub_args; + std::transform(sub_ss.begin(), sub_ss.end(), std::back_inserter(sub_args), [&](auto ss) { + return generate_argument(ss, seed); + }); + + result = argument(sub_args); + } + else + { + s.visit_type([&](auto as) { + // we use char type to store bool type internally, so bool_type + // needs special processing to generate data + if(s.type() == shape::bool_type) + { + auto v = generate_tensor_data(s, seed); + result = {s, v}; + } + else + { + using type = typename decltype(as)::type; + auto v = generate_tensor_data(s, seed); + result = {s, v}; + } + }); + } + return result; } diff --git a/src/targets/gpu/include/migraphx/gpu/adjust_allocation.hpp b/src/include/migraphx/adjust_allocation.hpp old mode 100644 new mode 100755 similarity index 59% rename from src/targets/gpu/include/migraphx/gpu/adjust_allocation.hpp rename to src/include/migraphx/adjust_allocation.hpp index dfefc348529472330b69394df0ce7df9186251ad..5c09a9534d543efc01f94c025b6628096bd6f05e --- a/src/targets/gpu/include/migraphx/gpu/adjust_allocation.hpp +++ b/src/include/migraphx/adjust_allocation.hpp @@ -1,22 +1,21 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_ADJUST_ALLOCATION_HPP #define MIGRAPHX_GUARD_RTGLIB_ADJUST_ALLOCATION_HPP -#include #include -#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -namespace gpu { +struct module; struct adjust_allocation { - std::string name() const { return "gpu::adjust_allocation"; } - void apply(program& p) const; + allocation_model model; + std::string name() const { return "adjust_allocation"; } + void apply(module& m) const; }; -} // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/algorithm.hpp b/src/include/migraphx/algorithm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..72daa611da57c5d16941c85283f185ac881537cc --- /dev/null +++ b/src/include/migraphx/algorithm.hpp @@ -0,0 +1,57 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_ALGORITHM_HPP +#define MIGRAPHX_GUARD_RTGLIB_ALGORITHM_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +void transform_if(Iterator start, Iterator last, Output out, Predicate pred, F f) +{ + while(start != last) + { + if(pred(*start)) + { + *out = f(*start); + ++out; + } + ++start; + } +} + +template +T transform_accumulate(Iterator first, Iterator last, T init, BinaryOp binop, UnaryOp unaryop) +{ + return std::inner_product( + first, last, first, init, binop, [&](auto&& x, auto&&) { return unaryop(x); }); +} + +template +void group_by(Iterator start, Iterator last, Output out, Predicate pred) +{ + while(start != last) + { + auto it = std::partition(start, last, [&](auto&& x) { return pred(x, *start); }); + out(start, it); + start = it; + } +} + +template +void group_unique(Iterator start, Iterator last, Output out, Predicate pred) +{ + while(start != last) + { + auto it = std::find_if(start, last, [&](auto&& x) { return not pred(*start, x); }); + out(start, it); + start = it; + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/allocation_model.hpp b/src/include/migraphx/allocation_model.hpp new file mode 100644 index 0000000000000000000000000000000000000000..af0b562c2761e8d9b6b66679338f34e408282944 --- /dev/null +++ b/src/include/migraphx/allocation_model.hpp @@ -0,0 +1,291 @@ +#ifndef MIGRAPHX_GUARD_ALLOCATION_MODEL_HPP +#define MIGRAPHX_GUARD_ALLOCATION_MODEL_HPP + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +#ifdef DOXYGEN + +/// An interface for target-dependent allocation +struct allocation_model +{ + /// A name of the target-dependent allocate operator + std::string name() const; + /// A name of the target-dependent copy operator + std::string copy() const; + /// Create an allocation operator for the given shape + operation allocate(const shape& s) const; + /// Create a preallocated operator for the given shape + operation preallocate(const shape& s, const std::string& id) const; + /// Check if outputs are to be inserted + bool needs_out_params() const; +}; + +#else + +#ifdef TYPE_ERASED_DECLARATION + +// Type-erased interface for: +struct allocation_model +{ + // + std::string name() const; + // + std::string copy() const; + // + operation allocate(const shape& s) const; + // + operation preallocate(const shape& s, std::string id) const; + // + bool needs_out_params() const; +}; + +#else + +struct allocation_model +{ + // Constructors + allocation_model() = default; + + template + allocation_model(PrivateDetailTypeErasedT value) + : private_detail_te_handle_mem_var( + std::make_shared::type>>( + std::forward(value))) + { + } + + // Assignment + template + allocation_model& operator=(PrivateDetailTypeErasedT value) + { + using std::swap; + auto* derived = this->any_cast(); + if(derived and private_detail_te_handle_mem_var.unique()) + { + *derived = std::forward(value); + } + else + { + allocation_model rhs(value); + swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var); + } + return *this; + } + + // Cast + template + PrivateDetailTypeErasedT* any_cast() + { + return this->type_id() == typeid(PrivateDetailTypeErasedT) + ? std::addressof(static_cast::type>&>( + private_detail_te_get_handle()) + .private_detail_te_value) + : nullptr; + } + + template + const typename std::remove_cv::type* any_cast() const + { + return this->type_id() == typeid(PrivateDetailTypeErasedT) + ? std::addressof(static_cast::type>&>( + private_detail_te_get_handle()) + .private_detail_te_value) + : nullptr; + } + + const std::type_info& type_id() const + { + if(private_detail_te_handle_empty()) + return typeid(std::nullptr_t); + else + return private_detail_te_get_handle().type(); + } + + std::string name() const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().name(); + } + + std::string copy() const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().copy(); + } + + operation allocate(const shape& s) const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().allocate(s); + } + + operation preallocate(const shape& s, std::string id) const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().preallocate(s, std::move(id)); + } + + bool needs_out_params() const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().needs_out_params(); + } + + friend bool is_shared(const allocation_model& private_detail_x, + const allocation_model& private_detail_y) + { + return private_detail_x.private_detail_te_handle_mem_var == + private_detail_y.private_detail_te_handle_mem_var; + } + + private: + struct private_detail_te_handle_base_type + { + virtual ~private_detail_te_handle_base_type() {} + virtual std::shared_ptr clone() const = 0; + virtual const std::type_info& type() const = 0; + + virtual std::string name() const = 0; + virtual std::string copy() const = 0; + virtual operation allocate(const shape& s) const = 0; + virtual operation preallocate(const shape& s, std::string id) const = 0; + virtual bool needs_out_params() const = 0; + }; + + template + struct private_detail_te_handle_type : private_detail_te_handle_base_type + { + template + private_detail_te_handle_type( + PrivateDetailTypeErasedT value, + typename std::enable_if::value>::type* = + nullptr) + : private_detail_te_value(value) + { + } + + template + private_detail_te_handle_type( + PrivateDetailTypeErasedT value, + typename std::enable_if::value, + int>::type* = nullptr) noexcept + : private_detail_te_value(std::move(value)) + { + } + + std::shared_ptr clone() const override + { + return std::make_shared(private_detail_te_value); + } + + const std::type_info& type() const override { return typeid(private_detail_te_value); } + + std::string name() const override { return private_detail_te_value.name(); } + + std::string copy() const override { return private_detail_te_value.copy(); } + + operation allocate(const shape& s) const override + { + + return private_detail_te_value.allocate(s); + } + + operation preallocate(const shape& s, std::string id) const override + { + + return private_detail_te_value.preallocate(s, std::move(id)); + } + + bool needs_out_params() const override + { + + return private_detail_te_value.needs_out_params(); + } + + PrivateDetailTypeErasedT private_detail_te_value; + }; + + template + struct private_detail_te_handle_type> + : private_detail_te_handle_type + { + private_detail_te_handle_type(std::reference_wrapper ref) + : private_detail_te_handle_type(ref.get()) + { + } + }; + + bool private_detail_te_handle_empty() const + { + return private_detail_te_handle_mem_var == nullptr; + } + + const private_detail_te_handle_base_type& private_detail_te_get_handle() const + { + assert(private_detail_te_handle_mem_var != nullptr); + return *private_detail_te_handle_mem_var; + } + + private_detail_te_handle_base_type& private_detail_te_get_handle() + { + assert(private_detail_te_handle_mem_var != nullptr); + if(!private_detail_te_handle_mem_var.unique()) + private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); + return *private_detail_te_handle_mem_var; + } + + std::shared_ptr private_detail_te_handle_mem_var; +}; + +template +inline const ValueType* any_cast(const allocation_model* x) +{ + return x->any_cast(); +} + +template +inline ValueType* any_cast(allocation_model* x) +{ + return x->any_cast(); +} + +template +inline ValueType& any_cast(allocation_model& x) +{ + auto* y = x.any_cast::type>(); + if(y == nullptr) + throw std::bad_cast(); + return *y; +} + +template +inline const ValueType& any_cast(const allocation_model& x) +{ + const auto* y = x.any_cast::type>(); + if(y == nullptr) + throw std::bad_cast(); + return *y; +} +#endif + +#endif + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/analyze_streams.hpp b/src/include/migraphx/analyze_streams.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cd0b245209478f657ef52fe1f009d8d1ab3a4d8b --- /dev/null +++ b/src/include/migraphx/analyze_streams.hpp @@ -0,0 +1,24 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_ANALYZE_STREAMS_HPP +#define MIGRAPHX_GUARD_RTGLIB_ANALYZE_STREAMS_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +struct stream_race +{ + instruction_ref ins; + // The instruction that should before + instruction_ref before; +}; + +std::vector analyze_streams(const module& m, const stream_model& strmm); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/any_ptr.hpp b/src/include/migraphx/any_ptr.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ce3177b7c4b4651573a7a25ebfb24fe9b3062f89 --- /dev/null +++ b/src/include/migraphx/any_ptr.hpp @@ -0,0 +1,61 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct any_ptr +{ + any_ptr() = default; + template + any_ptr(T* p) : ptr(p), ti(typeid(T*)), name(get_name()) + { + } + + any_ptr(void* p, std::string_view pname) : ptr(p), name(pname) {} + + void* get(std::string_view n) const + { + if(name != n) + MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} + + " != " + std::string{n}); + return ptr; + } + + template + T get() const + { + static_assert(std::is_pointer{}, "Must be a pointer"); + assert(ptr != nullptr); + if(ti and std::type_index{typeid(T)} != *ti) + MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} + " != " + get_name()); + else if(name != get_name()) + MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} + " != " + get_name()); + return reinterpret_cast(ptr); + } + void* unsafe_get() const { return ptr; } + + private: + void* ptr = nullptr; + optional ti = nullopt; + std::string_view name = ""; + + template + static const std::string& get_name() + { + return get_type_name>>(); + } +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP diff --git a/src/include/migraphx/apply_alpha_beta.hpp b/src/include/migraphx/apply_alpha_beta.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b48a80a6a222254252cecc305a6e82e7c339cdb --- /dev/null +++ b/src/include/migraphx/apply_alpha_beta.hpp @@ -0,0 +1,43 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_APPLY_ALPHA_BETA_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_APPLY_ALPHA_BETA_HPP + +#include "migraphx/make_op.hpp" +#include "migraphx/normalize_attributes.hpp" +#include "migraphx/operation.hpp" +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +instruction_ref insert_apply_alpha_beta(module& m, + instruction_ref pos, + const std::vector& args, + const operation& op, + const literal& alpha, + const literal& beta); + +template +instruction_ref insert_apply_alpha_beta(module& m, + instruction_ref pos, + const std::vector& args, + const operation& op, + T alpha = 1, + T beta = 0) +{ + return insert_apply_alpha_beta(m, pos, args, op, literal{T{alpha}}, literal{T{beta}}); +} + +template +instruction_ref add_apply_alpha_beta(module& m, + const std::vector& args, + const operation& op, + T alpha = 1, + T beta = 0) +{ + return insert_apply_alpha_beta(m, m.end(), args, op, alpha, beta); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_APPLY_ALPHA_BETA_HPP diff --git a/src/include/migraphx/argument.hpp b/src/include/migraphx/argument.hpp old mode 100644 new mode 100755 index f591d72110ce4dacdd3f73c8fe04bdc5ae4b743c..f89a46143140a4aad4a7d5f1a9b407847e5c48ce --- a/src/include/migraphx/argument.hpp +++ b/src/include/migraphx/argument.hpp @@ -4,9 +4,11 @@ #include #include #include +#include #include #include +// clang-format off namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -19,61 +21,74 @@ inline namespace MIGRAPHX_INLINE_NS { */ struct argument : raw_data { - argument() {} + argument() = default; - argument(const shape& s) : m_shape(s) - { - std::vector buffer(s.bytes()); - // TODO: Move vector - data = [=]() mutable { return buffer.data(); }; - } + argument(const shape& s); template ()())>{})> argument(shape s, F d) - : data([f = std::move(d)]() mutable { return reinterpret_cast(f()); }), - m_shape(std::move(s)) + : m_shape(std::move(s)) + { + assign_buffer([f = std::move(d)]() mutable { return reinterpret_cast(f()); }); } template argument(shape s, T* d) - : data([d] { return reinterpret_cast(d); }), m_shape(std::move(s)) + : m_shape(std::move(s)) { + assign_buffer([d] { return reinterpret_cast(d); }); } template argument(shape s, std::shared_ptr d) - : data([d] { return reinterpret_cast(d.get()); }), m_shape(std::move(s)) + : m_shape(std::move(s)) { + assign_buffer([d] { return reinterpret_cast(d.get()); }); } - argument(shape s, std::nullptr_t) : data([] { return nullptr; }), m_shape(std::move(s)) {} + argument(shape s, std::nullptr_t); + + argument(const std::vector& args); /// Provides a raw pointer to the data - std::function data = nullptr; + char* data() const; /// Whether data is available - bool empty() const { return not data; } + bool empty() const; - const shape& get_shape() const { return this->m_shape; } + const shape& get_shape() const; - argument reshape(const shape& s) const - { - argument self = *this; - return {s, [=]() mutable { return self.data(); }}; - } + argument reshape(const shape& s) const; + + argument copy() const; /// Make copy of the argument that is always sharing the data - argument share() const - { - auto self = std::make_shared(*this); - return {m_shape, [self]() mutable { return self->data(); }}; - } + argument share() const; + + std::vector get_sub_objects() const; + + /// Return the ith element + argument element(std::size_t i) const; private: + void assign_buffer(std::function d); + struct data_t + { + std::function get = nullptr; + std::vector sub = {}; + data_t share() const; + static data_t from_args(const std::vector& args); + }; + argument(const shape& s, const data_t& d); shape m_shape; + data_t m_data{}; }; +void migraphx_to_value(value& v, const argument& a); +void migraphx_from_value(const value& v, argument& a); + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx +// clang-format on #endif diff --git a/src/include/migraphx/assert.hpp b/src/include/migraphx/assert.hpp new file mode 100644 index 0000000000000000000000000000000000000000..51e4cadec64eff5eaa18fe469d630fe368b71d39 --- /dev/null +++ b/src/include/migraphx/assert.hpp @@ -0,0 +1,38 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_ASSERT_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_ASSERT_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +auto abort_on_throw(F f) -> decltype(f()) +{ + try + { + return f(); + } + catch(const std::exception& e) + { + std::cerr << e.what() << std::endl; + std::abort(); + } + catch(...) + { + std::cerr << "Unknown exception" << std::endl; + std::abort(); + } +} +#ifdef NDEBUG +#define MIGRAPHX_ASSERT_NO_THROW(...) __VA_ARGS__ +#else +#define MIGRAPHX_ASSERT_NO_THROW(...) \ + migraphx::abort_on_throw([&]() -> decltype(__VA_ARGS__) { return __VA_ARGS__; }) +#endif + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_ASSERT_HPP diff --git a/src/include/migraphx/auto_contiguous.hpp b/src/include/migraphx/auto_contiguous.hpp index 065626a05af86e537b78e16cab91b43b86d3da36..8576010b90896d12fdf411e3cc622c9bdd6a722d 100644 --- a/src/include/migraphx/auto_contiguous.hpp +++ b/src/include/migraphx/auto_contiguous.hpp @@ -8,12 +8,12 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; struct auto_contiguous { std::string name() const { return "auto_contiguous"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/auto_register.hpp b/src/include/migraphx/auto_register.hpp new file mode 100644 index 0000000000000000000000000000000000000000..67a455aa6289ce83ada14d616789162c6076a91f --- /dev/null +++ b/src/include/migraphx/auto_register.hpp @@ -0,0 +1,50 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_AUTO_REGISTER_HPP +#define MIGRAPHX_GUARD_RTGLIB_AUTO_REGISTER_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +int auto_register_action() +{ + Action::template apply(); + return 0; +} + +template +struct auto_register +{ + const static int static_register; + // This typedef ensures that the static member will be instantiated if + // the class itself is instantiated + using static_register_type = + std::integral_constant; +}; + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wglobal-constructors" +#endif + +template +const int auto_register::static_register = auto_register_action(); // NOLINT + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#define MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x) migraphx_auto_register_##x +#define MIGRAPHX_AUTO_REGISTER_NAME(x) MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x) +// NOLINTNEXTLINE +#define MIGRAPHX_AUTO_REGISTER(...) \ + void MIGRAPHX_AUTO_REGISTER_NAME(__LINE__)(migraphx::auto_register<__VA_ARGS__> x = \ + migraphx::auto_register<__VA_ARGS__>{}) \ + __attribute__((unused)); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/builtin.hpp b/src/include/migraphx/builtin.hpp old mode 100644 new mode 100755 index 585ce1630635107f56d8f43106819009ce57d8d1..cc1d2b22825d6af1cc44291a56aa0e75986ac961 --- a/src/include/migraphx/builtin.hpp +++ b/src/include/migraphx/builtin.hpp @@ -43,6 +43,7 @@ struct outline struct param { std::string parameter; + uint32_t order = 0; template static auto reflect(Self& self, F f) @@ -63,6 +64,16 @@ struct param } }; +struct returns +{ + std::string name() const { return "@return"; } + shape compute_shape(const std::vector&) const { return {}; } + argument compute(context&, const shape&, const std::vector&) const + { + MIGRAPHX_THROW("builtin"); + } +}; + } // namespace builtin } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/check_context.hpp b/src/include/migraphx/check_context.hpp index ab36dd500896fc4fd7f1a2408d076595d9343d9d..77743fd3d94e78add5b15400574bad6d11c2184c 100644 --- a/src/include/migraphx/check_context.hpp +++ b/src/include/migraphx/check_context.hpp @@ -3,6 +3,7 @@ #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -10,9 +11,9 @@ inline namespace MIGRAPHX_INLINE_NS { template struct check_context { - struct op + struct op : auto_register_op { - std::string name() const { return "check_context"; } + std::string name() const { return "check_context::" + get_type_name(); } shape compute_shape(const std::vector&) const { return {}; } argument compute(context& ctx, const shape&, const std::vector&) const { @@ -32,7 +33,7 @@ struct check_context }; std::string name() const { return "check_context"; } - void apply(program& p) const { p.insert_instruction(p.begin(), op{}); } + void apply(module& m) const { m.insert_instruction(m.begin(), op{}); } }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/check_shapes.hpp b/src/include/migraphx/check_shapes.hpp old mode 100644 new mode 100755 index fb7dd5eeac2708992b80c388bdf90d892f08965e..350d523256b8b55488d775b5a0bef019990b922e --- a/src/include/migraphx/check_shapes.hpp +++ b/src/include/migraphx/check_shapes.hpp @@ -25,8 +25,6 @@ struct check_shapes { } - check_shapes(const std::vector& s) : begin(s.data()), end(s.data() + s.size()) {} - template check_shapes(const std::vector& s, const Op& op) : begin(s.data()), end(s.data() + s.size()), name(op.name()) @@ -59,6 +57,13 @@ struct check_shapes return *this; } + const check_shapes& nelements(std::size_t n) const + { + if(!this->all_of([&](const shape& s) { return s.elements() == n; })) + MIGRAPHX_THROW(prefix() + "Shapes must have only " + std::to_string(n) + " elements"); + return *this; + } + const check_shapes& only_dims(std::size_t n) const { assert(begin != nullptr); @@ -71,6 +76,32 @@ struct check_shapes return *this; } + const check_shapes& max_ndims(std::size_t n) const + { + assert(begin != nullptr); + assert(end != nullptr); + if(begin != end) + { + if(begin->lens().size() > n) + MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) + + " dimensions"); + } + return *this; + } + + const check_shapes& min_ndims(std::size_t n) const + { + assert(begin != nullptr); + assert(end != nullptr); + if(begin != end) + { + if(begin->lens().size() < n) + MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) + + " dimensions"); + } + return *this; + } + const check_shapes& same_shape() const { if(!this->same([](const shape& s) { return s; })) @@ -120,6 +151,20 @@ struct check_shapes return *this; } + const check_shapes& packed_or_broadcasted() const + { + if(!this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); })) + MIGRAPHX_THROW(prefix() + "Shapes are not packed nor broadcasted"); + return *this; + } + + const check_shapes& tuple_type() const + { + if(!this->all_of([](const shape& s) { return s.type() == shape::tuple_type; })) + MIGRAPHX_THROW(prefix() + "Shapes are not tuple!"); + return *this; + } + const check_shapes& not_transposed() const { if(!this->all_of([](const shape& s) { return not s.transposed(); })) @@ -141,6 +186,13 @@ struct check_shapes return *this; } + const check_shapes& batch_not_transposed() const + { + if(!this->all_of([&](const shape& s) { return batch_not_transposed_strides(s.strides()); })) + MIGRAPHX_THROW(prefix() + "Batch size is transposed"); + return *this; + } + template bool same(F f) const { @@ -162,7 +214,7 @@ struct check_shapes return std::all_of(begin, end, p); } - const shape* get(long i) + const shape* get(long i) const { if(i >= size()) MIGRAPHX_THROW(prefix() + "Accessing shape out of bounds"); @@ -173,9 +225,31 @@ struct check_shapes return begin + i; } - check_shapes slice(long start) { return {get(start), end, name}; } + check_shapes slice(long start) const { return {get(start), end, name}; } + + check_shapes slice(long start, long last) const { return {get(start), get(last), name}; } - check_shapes slice(long start, long last) { return {get(start), get(last), name}; } + private: + static bool batch_not_transposed_strides(const std::vector& strides) + { + if(strides.size() <= 2) + return true; + auto dim_0 = strides.size() - 2; + auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]); + std::vector batch(strides.begin(), strides.begin() + dim_0); + if(std::all_of(batch.begin(), batch.end(), [&](auto i) { return (i < matrix_size); })) + { + return false; + } + + if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) { + return (i < j or i < matrix_size or j < matrix_size); + }) != batch.end()) + { + return false; + } + return true; + } }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/clamp.hpp b/src/include/migraphx/clamp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..19ba71bd08ba3caf31e91fe430c2544e5991d5c9 --- /dev/null +++ b/src/include/migraphx/clamp.hpp @@ -0,0 +1,25 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_CLAMP_HPP +#define MIGRAPHX_GUARD_RTGLIB_CLAMP_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +U pad_clamp(T x) +{ + if(float_equal(x, std::numeric_limits::lowest())) + return std::numeric_limits::lowest(); + if(float_equal(x, std::numeric_limits::max())) + return std::numeric_limits::max(); + return (x < std::numeric_limits::lowest()) + ? std::numeric_limits::lowest() + : (std::numeric_limits::max() < x) ? std::numeric_limits::max() : U(x); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/cloneable.hpp b/src/include/migraphx/cloneable.hpp new file mode 100755 index 0000000000000000000000000000000000000000..a8d59194d279cdcfc05fb0a3a801585eb1371073 --- /dev/null +++ b/src/include/migraphx/cloneable.hpp @@ -0,0 +1,49 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_CLONEABLE_HPP +#define MIGRAPHX_GUARD_RTGLIB_CLONEABLE_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +struct cloneable +{ + friend Base; + + virtual std::shared_ptr clone() = 0; + + template + struct derive : Base + { + friend Derived; + + std::shared_ptr clone() override + { + return std::make_shared(static_cast(*this)); + } + template + derive(Args&&... args) : Base(std::forward(args)...) + { + } + }; + + struct share : Base, std::enable_shared_from_this + { + std::shared_ptr clone() override { return this->shared_from_this(); } + template + share(Args&&... args) : Base(std::forward(args)...) + { + } + }; + cloneable() = default; + cloneable(const cloneable&) = default; + cloneable& operator=(const cloneable&) = default; + virtual ~cloneable() {} +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/common.hpp b/src/include/migraphx/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e4a387b15116fdc55a95c32a937ed0f9df2bba71 --- /dev/null +++ b/src/include/migraphx/common.hpp @@ -0,0 +1,26 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; +struct operation; + +std::vector compute_broadcasted_lens(std::vector s0, + std::vector s1); +shape common_shape(const std::vector& shapes); + +instruction_ref insert_common_op(module& m, + instruction_ref ins, + const operation& op, + std::vector inputs); +instruction_ref add_common_op(module& m, const operation& op, std::vector inputs); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_HPP diff --git a/src/include/migraphx/compile_options.hpp b/src/include/migraphx/compile_options.hpp index 78221833f08162871a579e731f3f81c0113ecbc4..f76438c2667dcce5687883fe94cd4d219f66af73 100644 --- a/src/include/migraphx/compile_options.hpp +++ b/src/include/migraphx/compile_options.hpp @@ -10,6 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { struct compile_options { bool offload_copy = false; + bool fast_math = true; tracer trace{}; }; diff --git a/src/include/migraphx/compile_src.hpp b/src/include/migraphx/compile_src.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b0ad8d5127e7266d2deb02f87a486001e1596489 --- /dev/null +++ b/src/include/migraphx/compile_src.hpp @@ -0,0 +1,34 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMPILE_SRC_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_COMPILE_SRC_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct src_file +{ + fs::path path; + std::pair content; + std::size_t len() const { return content.second - content.first; } +}; + +struct src_compiler +{ + std::string compiler = "c++"; + std::string flags = ""; + std::string output = ""; + std::string launcher = ""; + std::string out_ext = ".o"; + std::function process = nullptr; + std::vector compile(const std::vector& srcs) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_COMPILE_SRC_HPP diff --git a/src/include/migraphx/concat_opt.hpp b/src/include/migraphx/concat_opt.hpp index e1ad1cfbaeb6d8292fc436f6f707c603327ec2d3..fe93cbc10c41fc37e389d0b6e20c3703fc195b65 100644 --- a/src/include/migraphx/concat_opt.hpp +++ b/src/include/migraphx/concat_opt.hpp @@ -15,8 +15,6 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; - #ifdef DOXYGEN /// An interface for target-dependent optimization for the concat instruction @@ -32,17 +30,20 @@ struct concat_optimization #else -/* - * Type-erased interface for: - * - * struct concat_optimization - * { - * std::string name() const; - * std::string allocate() const; - * op::concat get_concat(const operation& op) const; - * }; - * - */ +#ifdef TYPE_ERASED_DECLARATION + +// Type-erased interface for: +struct concat_optimization +{ + // + std::string name() const; + // + std::string allocate() const; + // + op::concat get_concat(const operation& op) const; +}; + +#else struct concat_optimization { @@ -62,11 +63,17 @@ struct concat_optimization template concat_optimization& operator=(PrivateDetailTypeErasedT value) { - if(private_detail_te_handle_mem_var.unique()) - *private_detail_te_handle_mem_var = std::forward(value); - else if(!private_detail_te_handle_mem_var) - private_detail_te_handle_mem_var = std::make_shared( - std::forward(value)); + using std::swap; + auto* derived = this->any_cast(); + if(derived and private_detail_te_handle_mem_var.unique()) + { + *derived = std::forward(value); + } + else + { + concat_optimization rhs(value); + swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var); + } return *this; } @@ -74,7 +81,7 @@ struct concat_optimization template PrivateDetailTypeErasedT* any_cast() { - return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) + return this->type_id() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type>&>( private_detail_te_get_handle()) @@ -85,7 +92,7 @@ struct concat_optimization template const typename std::remove_cv::type* any_cast() const { - return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) + return this->type_id() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type>&>( private_detail_te_get_handle()) @@ -240,6 +247,7 @@ inline const ValueType& any_cast(const concat_optimization& x) throw std::bad_cast(); return *y; } +#endif #endif diff --git a/src/include/migraphx/config.hpp b/src/include/migraphx/config.hpp old mode 100644 new mode 100755 index 6230a6ef74300f993d3d71fcda9064830af211c5..974c302476c19002fa9b7aeae7c88e3f0f649e3e --- a/src/include/migraphx/config.hpp +++ b/src/include/migraphx/config.hpp @@ -7,6 +7,16 @@ namespace migraphx { #define MIGRAPHX_INLINE_NS version_1 #endif +#ifdef DOXYGEN +#define MIGRAPHX_INLINE_NS internal +#endif + +#ifdef MIGRAPHX_USE_CLANG_TIDY +#define MIGRAPHX_TIDY_CONST const +#else +#define MIGRAPHX_TIDY_CONST +#endif + } // namespace migraphx #endif diff --git a/src/include/migraphx/context.hpp b/src/include/migraphx/context.hpp index 5c0cab49f9f75f071759d43e50ec20446189c998..b0ab8dbe0b1581e70e6e682bf6904a25163858d7 100644 --- a/src/include/migraphx/context.hpp +++ b/src/include/migraphx/context.hpp @@ -8,6 +8,8 @@ #include #include #include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -25,15 +27,39 @@ struct context #else -/* - * Type-erased interface for: - * - * struct context - * { - * void finish() const; - * }; - * - */ +template +value to_value_context(const T&) +{ + return value{}; +} + +template +void from_value_context(T&, const value&) +{ +} + +template +any_ptr get_queue_context(T&) +{ + return {}; +} + +#ifdef TYPE_ERASED_DECLARATION + +// Type-erased interface for: +struct context +{ + // (optional) + value to_value() const; + // (optional) + void from_value(const value& v); + // (optional) + any_ptr get_queue(); + // + void finish() const; +}; + +#else struct context { @@ -53,11 +79,17 @@ struct context template context& operator=(PrivateDetailTypeErasedT value) { - if(private_detail_te_handle_mem_var.unique()) - *private_detail_te_handle_mem_var = std::forward(value); - else if(!private_detail_te_handle_mem_var) - private_detail_te_handle_mem_var = std::make_shared( - std::forward(value)); + using std::swap; + auto* derived = this->any_cast(); + if(derived and private_detail_te_handle_mem_var.unique()) + { + *derived = std::forward(value); + } + else + { + context rhs(value); + swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var); + } return *this; } @@ -65,7 +97,7 @@ struct context template PrivateDetailTypeErasedT* any_cast() { - return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) + return this->type_id() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type>&>( private_detail_te_get_handle()) @@ -76,7 +108,7 @@ struct context template const typename std::remove_cv::type* any_cast() const { - return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) + return this->type_id() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type>&>( private_detail_te_get_handle()) @@ -92,6 +124,24 @@ struct context return private_detail_te_get_handle().type(); } + value to_value() const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().to_value(); + } + + void from_value(const value& v) + { + assert((*this).private_detail_te_handle_mem_var); + (*this).private_detail_te_get_handle().from_value(v); + } + + any_ptr get_queue() + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().get_queue(); + } + void finish() const { assert((*this).private_detail_te_handle_mem_var); @@ -111,9 +161,53 @@ struct context virtual std::shared_ptr clone() const = 0; virtual const std::type_info& type() const = 0; - virtual void finish() const = 0; + virtual value to_value() const = 0; + virtual void from_value(const value& v) = 0; + virtual any_ptr get_queue() = 0; + virtual void finish() const = 0; }; + template + static auto private_detail_te_default_to_value(char, T&& private_detail_te_self) + -> decltype(private_detail_te_self.to_value()) + { + return private_detail_te_self.to_value(); + } + + template + static value private_detail_te_default_to_value(float, T&& private_detail_te_self) + { + return to_value_context(private_detail_te_self); + } + + template + static auto + private_detail_te_default_from_value(char, T&& private_detail_te_self, const value& v) + -> decltype(private_detail_te_self.from_value(v)) + { + private_detail_te_self.from_value(v); + } + + template + static void + private_detail_te_default_from_value(float, T&& private_detail_te_self, const value& v) + { + from_value_context(private_detail_te_self, v); + } + + template + static auto private_detail_te_default_get_queue(char, T&& private_detail_te_self) + -> decltype(private_detail_te_self.get_queue()) + { + return private_detail_te_self.get_queue(); + } + + template + static any_ptr private_detail_te_default_get_queue(float, T&& private_detail_te_self) + { + return get_queue_context(private_detail_te_self); + } + template struct private_detail_te_handle_type : private_detail_te_handle_base_type { @@ -142,6 +236,24 @@ struct context const std::type_info& type() const override { return typeid(private_detail_te_value); } + value to_value() const override + { + + return private_detail_te_default_to_value(char(0), private_detail_te_value); + } + + void from_value(const value& v) override + { + + private_detail_te_default_from_value(char(0), private_detail_te_value, v); + } + + any_ptr get_queue() override + { + + return private_detail_te_default_get_queue(char(0), private_detail_te_value); + } + void finish() const override { private_detail_te_value.finish(); } PrivateDetailTypeErasedT private_detail_te_value; @@ -208,6 +320,10 @@ inline const ValueType& any_cast(const context& x) throw std::bad_cast(); return *y; } +#endif + +inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); } +inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); } #endif diff --git a/src/include/migraphx/convert_to_json.hpp b/src/include/migraphx/convert_to_json.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9cbd9d34f14bec740edd8a11d69d1cb35492122d --- /dev/null +++ b/src/include/migraphx/convert_to_json.hpp @@ -0,0 +1,15 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_CONVERT_TO_JSON_HPP +#define MIGRAPHX_GUARD_MIGRAPHLIB_CONVERT_TO_JSON_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::string convert_to_json(const std::string& str); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/cpp_generator.hpp b/src/include/migraphx/cpp_generator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8d8f965bba7177b60ea65c5fff399579d7730741 --- /dev/null +++ b/src/include/migraphx/cpp_generator.hpp @@ -0,0 +1,91 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_CPP_GENERATOR_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_CPP_GENERATOR_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct operation; +struct module; +struct shape; + +struct cpp_generator_impl; + +struct cpp_generator +{ + using generate_module_callback = std::function&)>; + struct param + { + std::string name; + std::string type; + }; + + struct function + { + std::vector params = {}; + std::string body = ""; + std::string return_type = "void"; + std::string name = ""; + std::vector attributes = {}; + std::vector tparams = {}; + function& set_body(const module& m, const generate_module_callback& g); + function& set_body(const std::string& s) + { + body = s; + return *this; + } + function& set_name(const std::string& s) + { + name = s; + return *this; + } + function& set_attributes(std::vector attrs) + { + attributes = std::move(attrs); + return *this; + } + function& set_types(const module& m); + function& set_types(const module& m, const std::function& parse); + function& set_generic_types(const module& m); + }; + + cpp_generator(); + + // move constructor + cpp_generator(cpp_generator&&) noexcept; + + // copy assignment operator + cpp_generator& operator=(cpp_generator rhs); + + ~cpp_generator() noexcept; + + void fmap(const std::function& f); + + void fresult(const std::function& f); + + void add_point_op(const std::string& op_name, const std::string& code); + + std::string generate_point_op(const operation& op, const std::vector& args); + + std::string str() const; + + function generate_module(const module& m, const generate_module_callback& g); + + function generate_module(const module& m); + + std::string create_function(const function& f); + + private: + std::unique_ptr impl; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_CPP_GENERATOR_HPP diff --git a/src/include/migraphx/dead_code_elimination.hpp b/src/include/migraphx/dead_code_elimination.hpp index 2413e202e3a41771bc22de47ccd82290205cc64a..95c1a868870b1497fe334634360eea4a6ccc3a59 100644 --- a/src/include/migraphx/dead_code_elimination.hpp +++ b/src/include/migraphx/dead_code_elimination.hpp @@ -8,6 +8,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +struct module; struct program; /** @@ -16,6 +17,7 @@ struct program; struct dead_code_elimination { std::string name() const { return "dead_code_elimination"; } + void apply(module& m) const; void apply(program& p) const; }; diff --git a/src/include/migraphx/dom_info.hpp b/src/include/migraphx/dom_info.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0c920da5600836a6c3b3eebd59aa52091f62c16f --- /dev/null +++ b/src/include/migraphx/dom_info.hpp @@ -0,0 +1,27 @@ + +#ifndef MIGRAPHX_GUARD_RTGLIB_DOM_INFO_HPP +#define MIGRAPHX_GUARD_RTGLIB_DOM_INFO_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +struct dominator_info +{ + bool strictly_dominate(instruction_ref ins1, instruction_ref ins2); + + std::unordered_map ins2idom; +}; + +dominator_info compute_dominator(module& m); +// dominator_info compute_dominator_naive(const module& m); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/dynamic_loader.hpp b/src/include/migraphx/dynamic_loader.hpp new file mode 100755 index 0000000000000000000000000000000000000000..939a3712b331a67e9e08b61f1c00aef341544a09 --- /dev/null +++ b/src/include/migraphx/dynamic_loader.hpp @@ -0,0 +1,43 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_DYNAMIC_LOADER_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_DYNAMIC_LOADER_HPP + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct dynamic_loader_impl; + +struct dynamic_loader +{ + dynamic_loader() = default; + + dynamic_loader(const fs::path& p); + + dynamic_loader(const char* image, std::size_t size); + + dynamic_loader(const std::vector& buffer); + + std::shared_ptr get_symbol(const std::string& name) const; + + template + std::function get_function(const std::string& name) const + { + auto s = get_symbol(name); + return [=](auto&&... xs) -> decltype(auto) { + auto f = reinterpret_cast>(s.get()); + return f(std::forward(xs)...); + }; + } + + private: + std::shared_ptr impl; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_DYNAMIC_LOADER_HPP diff --git a/src/include/migraphx/eliminate_allocation.hpp b/src/include/migraphx/eliminate_allocation.hpp index b541c85f0dfc228d6a1fd8ca5b24162c8b6e5f78..c9cb6573aac205a12536d4655e7b252469e71548 100644 --- a/src/include/migraphx/eliminate_allocation.hpp +++ b/src/include/migraphx/eliminate_allocation.hpp @@ -8,7 +8,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; /** * Remove memory allocations. This will create a parameter which is the max of all memory used in @@ -19,7 +19,7 @@ struct eliminate_allocation std::string allocation_op{}; std::size_t alignment = 32; std::string name() const { return "eliminate_allocation"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/eliminate_common_subexpression.hpp b/src/include/migraphx/eliminate_common_subexpression.hpp index 0f44ebeb9492fd9f38297ce021392d3de82a4354..26c96f98e92dc29a0474d2f306b0dbdc7124a6b5 100644 --- a/src/include/migraphx/eliminate_common_subexpression.hpp +++ b/src/include/migraphx/eliminate_common_subexpression.hpp @@ -8,7 +8,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; /** * Remove identical instructions. @@ -16,7 +16,7 @@ struct program; struct eliminate_common_subexpression { std::string name() const { return "eliminate_common_subexpression"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/eliminate_concat.hpp b/src/include/migraphx/eliminate_concat.hpp index 5f0029db8b3fd6398e0c53c2bf4be4440e7c4250..b9cbefa7a5c60abe7587f214d39ff3b88b1d72c9 100644 --- a/src/include/migraphx/eliminate_concat.hpp +++ b/src/include/migraphx/eliminate_concat.hpp @@ -9,7 +9,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; /** * Remove concat operators by having each operator can write to different chunk of memory. @@ -18,7 +18,7 @@ struct eliminate_concat { concat_optimization concat_opt; std::string name() const { return "eliminate_concat"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/eliminate_contiguous.hpp b/src/include/migraphx/eliminate_contiguous.hpp old mode 100644 new mode 100755 index b809ccceb6bb4fc16fce3afc23bf3a13f61435a8..95b44f0f3d14b6dacad994515eb54f11a62498b7 --- a/src/include/migraphx/eliminate_contiguous.hpp +++ b/src/include/migraphx/eliminate_contiguous.hpp @@ -8,15 +8,16 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; /** * Remove contiguous instructions by checking if the operator can use non-standard shapes. */ struct eliminate_contiguous { + std::string op_name; std::string name() const { return "eliminate_contiguous"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/eliminate_data_type.hpp b/src/include/migraphx/eliminate_data_type.hpp new file mode 100755 index 0000000000000000000000000000000000000000..8ebf88393f38ec05c9c28ebe8d1462e537520ff3 --- /dev/null +++ b/src/include/migraphx/eliminate_data_type.hpp @@ -0,0 +1,29 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_DATA_TYPE_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_DATA_TYPE_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +/** + * Remove data types. This will instert convert operators so the data type + * is not used by any operator. + */ +struct eliminate_data_type +{ + std::set types; + shape::type_t target_type; + std::string name() const { return "eliminate_data_type"; } + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/eliminate_identity.hpp b/src/include/migraphx/eliminate_identity.hpp index a8267832c4707979aeaceeaef69b9fcc527d1cb5..a2e23c6bbddbd8f140db8674dece3d8fd98c5b5e 100644 --- a/src/include/migraphx/eliminate_identity.hpp +++ b/src/include/migraphx/eliminate_identity.hpp @@ -8,7 +8,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; /** * Remove identity instructions. Currently when used as the last pass, it will @@ -18,7 +18,7 @@ struct program; struct eliminate_identity { std::string name() const { return "eliminate_identity"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/eliminate_pad.hpp b/src/include/migraphx/eliminate_pad.hpp index 427f272921606a5510497d117cba53db28d65915..316adf0c61f07c3c716499f2a84c81842d5fc9c7 100644 --- a/src/include/migraphx/eliminate_pad.hpp +++ b/src/include/migraphx/eliminate_pad.hpp @@ -10,7 +10,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; /** * Remove pads if they can be written as an @@ -19,9 +19,8 @@ struct program; struct eliminate_pad { std::string name() const { return "eliminate_pad"; } - void apply(program& p) const; - template - void update_op(T, const instruction_ref& input, const instruction_ref& ins, program& p) const; + + void apply(module& m) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/env.hpp b/src/include/migraphx/env.hpp index 8504ea48bf8cf33ad8c4bd9f190541d45e0534e3..94a3da30485ad559fa7267063f15996661ce8cac 100644 --- a/src/include/migraphx/env.hpp +++ b/src/include/migraphx/env.hpp @@ -21,6 +21,8 @@ std::vector env(const char* name); std::size_t value_of(const char* name, std::size_t fallback = 0); +std::string string_value_of(const char* name, std::string fallback = ""); + template bool enabled(T) { @@ -42,6 +44,13 @@ std::size_t value_of(T, std::size_t fallback = 0) return result; } +template +std::string string_value_of(T, std::string fallback = "") +{ + static const std::string result = string_value_of(T::value(), fallback); + return result; +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/erase.hpp b/src/include/migraphx/erase.hpp index 46b8742a276f45eb32776597b931cabec5270128..cbef1e4ce363baa8ace3dcd6c3c2428170597a84 100644 --- a/src/include/migraphx/erase.hpp +++ b/src/include/migraphx/erase.hpp @@ -25,12 +25,19 @@ auto erase(R&& r, const T& value) * * @param r The container to erase elements from * @param pred Predicate function that selects which elements should be erased. - * @return Returns iterator to erased element */ template -auto erase_if(R&& r, P&& pred) +void erase_if(R&& r, P&& pred) { - return r.erase(std::remove_if(r.begin(), r.end(), pred), r.end()); + auto first = r.begin(); + auto last = r.end(); + while(first != last) + { + if(pred(*first)) + first = r.erase(first); + else + first++; + } } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/errors.hpp b/src/include/migraphx/errors.hpp old mode 100644 new mode 100755 index 61dcb81f0095505ab041fc25efde09d7a81dbd02..c85717db79b923df9426b87e8073f4f10450d3e6 --- a/src/include/migraphx/errors.hpp +++ b/src/include/migraphx/errors.hpp @@ -12,7 +12,10 @@ inline namespace MIGRAPHX_INLINE_NS { /// Represents exceptions that can be thrown by migraphxlib struct exception : std::runtime_error { - exception(const std::string& msg = "") : std::runtime_error(msg) {} + unsigned int error; + exception(unsigned int e = 0, const std::string& msg = "") : std::runtime_error(msg), error(e) + { + } }; /** @@ -24,7 +27,13 @@ struct exception : std::runtime_error */ inline exception make_exception(const std::string& context, const std::string& message = "") { - return {context + ": " + message}; + return {0, context + ": " + message}; +} + +inline exception +make_exception(const std::string& context, unsigned int e, const std::string& message = "") +{ + return {e, context + ": " + message}; } /** @@ -35,16 +44,18 @@ inline exception make_exception(const std::string& context, const std::string& m * * @return A string that represents the file location */ -inline std::string make_source_context(const std::string& file, int line) +inline std::string make_source_context(const std::string& file, int line, const std::string& fname) { - return file + ":" + std::to_string(line); + return file + ":" + std::to_string(line) + ": " + fname; } +// NOLINTNEXTLINE +#define MIGRAPHX_MAKE_SOURCE_CTX() migraphx::make_source_context(__FILE__, __LINE__, __func__) + /** * @brief Throw an exception with context information */ -#define MIGRAPHX_THROW(...) \ - throw migraphx::make_exception(migraphx::make_source_context(__FILE__, __LINE__), __VA_ARGS__) +#define MIGRAPHX_THROW(...) throw migraphx::make_exception(MIGRAPHX_MAKE_SOURCE_CTX(), __VA_ARGS__) } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/file_buffer.hpp b/src/include/migraphx/file_buffer.hpp new file mode 100755 index 0000000000000000000000000000000000000000..74002214b102facbec88914b8cb8321ca4463022 --- /dev/null +++ b/src/include/migraphx/file_buffer.hpp @@ -0,0 +1,20 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_FILE_BUFFER_HPP +#define MIGRAPHX_GUARD_RTGLIB_FILE_BUFFER_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::vector read_buffer(const std::string& filename); +std::string read_string(const std::string& filename); + +void write_buffer(const std::string& filename, const char* buffer, std::size_t size); +void write_buffer(const std::string& filename, const std::vector& buffer); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/filesystem.hpp b/src/include/migraphx/filesystem.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1cbeca1c45221ff0c8294fe01be69d42a11ff040 --- /dev/null +++ b/src/include/migraphx/filesystem.hpp @@ -0,0 +1,45 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_FILESYSTEM_HPP +#define MIGRAPHX_GUARD_RTGLIB_FILESYSTEM_HPP + +#include + +#if defined(CPPCHECK) +#define MIGRAPHX_HAS_FILESYSTEM 1 +#define MIGRAPHX_HAS_FILESYSTEM_TS 1 +#elif defined(__has_include) +#if __has_include() && __cplusplus >= 201703L +#define MIGRAPHX_HAS_FILESYSTEM 1 +#else +#define MIGRAPHX_HAS_FILESYSTEM 0 +#endif +#if __has_include() && __cplusplus >= 201103L +#define MIGRAPHX_HAS_FILESYSTEM_TS 1 +#else +#define MIGRAPHX_HAS_FILESYSTEM_TS 0 +#endif +#else +#define MIGRAPHX_HAS_FILESYSTEM 0 +#define MIGRAPHX_HAS_FILESYSTEM_TS 0 +#endif + +#if MIGRAPHX_HAS_FILESYSTEM +#include +#elif MIGRAPHX_HAS_FILESYSTEM_TS +#include +#else +#error "No filesystem include available" +#endif + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +#if MIGRAPHX_HAS_FILESYSTEM +namespace fs = ::std::filesystem; +#elif MIGRAPHX_HAS_FILESYSTEM_TS +namespace fs = ::std::experimental::filesystem; +#endif + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/float_equal.hpp b/src/include/migraphx/float_equal.hpp index 1653380bb9dcaa171c574ab73829fa165559c32e..7753f0aaa28e30c478e4147c33a2db532fe67fce 100644 --- a/src/include/migraphx/float_equal.hpp +++ b/src/include/migraphx/float_equal.hpp @@ -10,6 +10,7 @@ #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -19,7 +20,7 @@ using common_type = typename std::common_type::type; struct float_equal_fn { - template {})> + template {})> static bool apply(T x, T y) { return std::isfinite(x) and std::isfinite(y) and @@ -27,7 +28,7 @@ struct float_equal_fn std::nextafter(x, std::numeric_limits::max()) >= y; } - template {})> + template {})> static bool apply(T x, T y) { return x == y; diff --git a/src/include/migraphx/functional.hpp b/src/include/migraphx/functional.hpp old mode 100644 new mode 100755 index dbf7323d65184e45440c5c345ee58e60884a8482..2815372d842f70c0d321007b0d5fceb5365a2c2d --- a/src/include/migraphx/functional.hpp +++ b/src/include/migraphx/functional.hpp @@ -125,16 +125,10 @@ auto fix(F f) return fix(f); } -template -auto pack(Ts... xs) -{ - return [=](auto f) { return f(xs...); }; -} - template auto fold_impl(F&&, T&& x) { - return x; + return std::forward(x); } template @@ -149,6 +143,22 @@ auto fold(F f) return [=](auto&&... xs) { return fold_impl(f, std::forward(xs)...); }; } +template +auto pack(Ts... xs) +{ + return [=](auto f) { return f(xs...); }; +} + +inline auto pack_join() { return pack(); } + +template +auto pack_join(Ps... ps) +{ + return fold([](auto p1, auto p2) { + return p1([=](auto... xs) { return p2([=](auto... ys) { return pack(xs..., ys...); }); }); + })(ps...); +} + template auto by(F f, Proj proj) { @@ -216,6 +226,11 @@ struct id } }; +template +void nop(Ts&&...) +{ +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/fuse_pointwise.hpp b/src/include/migraphx/fuse_pointwise.hpp new file mode 100755 index 0000000000000000000000000000000000000000..9071cdeac2da5d320285dcd3188cd71792dcd4f7 --- /dev/null +++ b/src/include/migraphx/fuse_pointwise.hpp @@ -0,0 +1,20 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module_pass_manager; + +struct fuse_pointwise +{ + std::string name() const { return "fuse_pointwise"; } + void apply(module_pass_manager& mpm) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP diff --git a/src/include/migraphx/gemm.hpp b/src/include/migraphx/gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ab881ad3b91b2bef5a69d434dd6201a952b49e4f --- /dev/null +++ b/src/include/migraphx/gemm.hpp @@ -0,0 +1,41 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_HPP +#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +void gemm(tensor_view cmat, tensor_view amat, tensor_view bmat, F alpha, F beta) +{ + std::size_t n_dims = cmat.get_shape().lens().size(); + std::size_t dim_0 = n_dims - 2; + std::size_t dim_1 = n_dims - 1; + auto k = amat.get_shape().lens()[dim_1]; + + assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]); + assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]); + assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]); + auto cs = cmat.get_shape(); + + par_for(cs.elements(), [&](auto i) { + auto c_idx = cs.multi(i); + auto a_idx = c_idx; + auto b_idx = c_idx; + double s = 0.0; + dfor(k)([&](auto kk) { + a_idx[dim_1] = b_idx[dim_0] = kk; + s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end()); + }); + cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta; + }); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/generate.hpp b/src/include/migraphx/generate.hpp old mode 100644 new mode 100755 index 2cabc12173e439eb2153fd2178bcaf422811ff29..9ef429e645788455906c0afcd257a0c66575347d --- a/src/include/migraphx/generate.hpp +++ b/src/include/migraphx/generate.hpp @@ -25,18 +25,26 @@ constexpr T normalize(unsigned long z) template {} and not is_floating_point{})> constexpr T normalize(unsigned long z) { - const auto max = std::numeric_limits::max() / 64; + const auto max = 1UL << (sizeof(T) * 5); const auto half_max = max / 2; return half_max - (z % max); } -template {} and std::is_integral{})> +template {} and std::is_integral{} and + not std::is_same{})> constexpr T normalize(unsigned long z) { - const auto max = std::numeric_limits::max() / 64; + const auto max = 1UL << (sizeof(T) * 5); return z % max; } +template {})> +constexpr bool normalize(unsigned long z) +{ + return static_cast(z % 2); +} + template struct xorshf96_generator { @@ -80,16 +88,16 @@ struct xorshift_generator template auto generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0) { - auto result = make_shared_array(s.elements()); - std::generate(result.get(), result.get() + s.elements(), xorshf96_generator{seed}); + auto result = make_shared_array(s.element_space()); + std::generate(result.get(), result.get() + s.element_space(), xorshf96_generator{seed}); return result; } template auto fill_tensor_data(const migraphx::shape& s, unsigned long value = 0) { - auto result = make_shared_array(s.elements()); - std::generate(result.get(), result.get() + s.elements(), [=] { return value; }); + auto result = make_shared_array(s.element_space()); + std::generate(result.get(), result.get() + s.element_space(), [=] { return value; }); return result; } diff --git a/src/include/migraphx/half.hpp b/src/include/migraphx/half.hpp index ceeaffd2a7b7ed7850e152ec230c393bb8fc67ee..d8c20c3a59f8910f0b405fe7f97f242177a95571 100644 --- a/src/include/migraphx/half.hpp +++ b/src/include/migraphx/half.hpp @@ -23,11 +23,13 @@ struct deduce using type = T; }; +#ifdef HAS_HALF_V1 template <> struct deduce { using type = half; }; +#endif } // namespace detail template @@ -36,4 +38,24 @@ using deduce = typename detail::deduce::type; } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx +namespace std { + +template +struct common_type : std::common_type +{ +}; + +template +struct common_type : std::common_type +{ +}; + +template <> +struct common_type +{ + using type = migraphx::half; +}; + +} // namespace std + #endif diff --git a/src/include/migraphx/inline_module.hpp b/src/include/migraphx/inline_module.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5ead29c9e6ba8b93efac49f936ff6d11b874d408 --- /dev/null +++ b/src/include/migraphx/inline_module.hpp @@ -0,0 +1,22 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_INLINE_MODULE_HPP +#define MIGRAPHX_GUARD_RTGLIB_INLINE_MODULE_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +struct inline_module +{ + std::string name() const { return "inline_module"; } + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/insert_pad.hpp b/src/include/migraphx/insert_pad.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a0877463eb0061014958d8c851428a0f8fdb9e0a --- /dev/null +++ b/src/include/migraphx/insert_pad.hpp @@ -0,0 +1,28 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_INSERT_PAD_HPP +#define MIGRAPHX_GUARD_RTGLIB_INSERT_PAD_HPP + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +/** + * insert pads if attribute of padding is asymmetrical + */ +struct insert_pad +{ + std::string name() const { return "insert_pad"; } + + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/instruction.hpp b/src/include/migraphx/instruction.hpp index 2201c73136cf62e5b48cce2f2e485b0812b4423f..dfcaed7c83046a78790b0abf96b370d9b49ceefa 100644 --- a/src/include/migraphx/instruction.hpp +++ b/src/include/migraphx/instruction.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -14,7 +15,11 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { shape compute_shape(const operation& op, const std::vector& args); +shape compute_shape(const operation& op, + const std::vector& args, + const std::vector& mods); std::vector to_shapes(const std::vector& args); +std::vector try_compute_shape(const operation& op, const std::vector& inputs); struct instruction { @@ -22,6 +27,11 @@ struct instruction instruction(operation o, shape r, std::vector args); + instruction(operation o, + shape r, + std::vector args, + std::vector modules); + instruction(literal l); void replace(operation o); @@ -32,7 +42,7 @@ struct instruction friend bool operator==(const instruction& i, instruction_ref ref); - bool valid(instruction_ref start) const; + bool valid(instruction_ref start, bool check_order = false) const; bool valid() const; @@ -45,6 +55,8 @@ struct instruction const std::vector& inputs() const; + const std::vector& module_inputs() const; + const std::vector& outputs() const; friend bool operator==(const instruction& x, const instruction& y); @@ -65,13 +77,25 @@ struct instruction migraphx::erase(output, ins); } + static void replace_refs(instruction_ref ins, + const std::unordered_map& map_insts, + const std::unordered_map& map_mods); + static void backreference(instruction_ref ref); static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins); + static void replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod); + static void replace(instruction_ref ins, operation o, const shape& r, std::vector args); + static void replace(instruction_ref ins, + operation o, + const shape& r, + std::vector args, + std::vector module_args); + bool can_eval() const; argument eval(bool check_eval = true) const; @@ -80,39 +104,52 @@ struct instruction static instruction_ref get_output_alias(instruction_ref ins, bool shallow = false); + void set_normalized(bool value = true); + bool is_normalized() const; + + bool need_normalization() const; + + operation normalized_operator() const; + + void debug_print() const; + + static void print(std::ostream& os, + instruction_ref ins, + const std::unordered_map& names); + private: // internal void replace(operation o, const shape& r, std::vector args); + // internal + void replace(operation o, + const shape& r, + std::vector args, + std::vector mdl_args); + // internal void replace(std::vector args); + // internal + void replace(std::vector args, std::vector mdl_args); + // internal void replace_argument(instruction_ref old, instruction_ref new_ins); + // internal + void replace_mod_argument(module_ref old, module_ref new_mod); + void replace(const shape& r); operation op; - shape result; + shape result{}; std::vector output; std::vector arguments; + std::vector module_args; literal lit; + bool normalized = false; }; } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx -namespace std { -template <> -struct hash -{ - using argument_type = migraphx::instruction_ref; - using result_type = std::size_t; - result_type operator()(const argument_type& x) const noexcept - { - return std::hash{}(&*x); - } -}; - -} // namespace std - #endif diff --git a/src/include/migraphx/instruction_ref.hpp b/src/include/migraphx/instruction_ref.hpp old mode 100644 new mode 100755 index 4f99cf3f990d457aaf91e653f6471bb8cb4e014c..b07286cbb0722ee32f26cd1622e5006fc1f62f97 --- a/src/include/migraphx/instruction_ref.hpp +++ b/src/include/migraphx/instruction_ref.hpp @@ -11,7 +11,35 @@ inline namespace MIGRAPHX_INLINE_NS { struct instruction; using instruction_ref = std::list::iterator; +migraphx::instruction* as_address(const instruction_ref& ins) noexcept; + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx +namespace std { +template <> +struct hash +{ + using argument_type = migraphx::instruction_ref; + using result_type = std::size_t; + result_type operator()(const migraphx::instruction_ref& x) const noexcept + { + return std::hash{}(migraphx::as_address(x)); + } +}; + +template <> +struct equal_to +{ + using argument_type = migraphx::instruction_ref; + using result_type = bool; + result_type operator()(const migraphx::instruction_ref& x, + const migraphx::instruction_ref& y) const noexcept + { + return migraphx::as_address(x) == migraphx::as_address(y); + } +}; + +} // namespace std + #endif diff --git a/src/include/migraphx/iota_iterator.hpp b/src/include/migraphx/iota_iterator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..15a4ae47084d3b24c0de776e92caf5101de44e41 --- /dev/null +++ b/src/include/migraphx/iota_iterator.hpp @@ -0,0 +1,140 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_IOTA_ITERATOR_HPP +#define MIGRAPHX_GUARD_RTGLIB_IOTA_ITERATOR_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +struct basic_iota_iterator +{ + Iterator index; + F f; + + using difference_type = std::ptrdiff_t; + using reference = decltype(f(std::declval())); + using value_type = typename std::remove_reference::type; + using pointer = typename std::add_pointer::type; + using iterator_category = std::random_access_iterator_tag; + + basic_iota_iterator& operator+=(int n) + { + index += n; + return *this; + } + + basic_iota_iterator& operator-=(int n) + { + index -= n; + return *this; + } + + basic_iota_iterator& operator++() + { + index++; + return *this; + } + + basic_iota_iterator& operator--() + { + index--; + return *this; + } + + basic_iota_iterator operator++(int) // NOLINT + { + basic_iota_iterator it = *this; + index++; + return it; + } + + basic_iota_iterator operator--(int) // NOLINT + { + basic_iota_iterator it = *this; + index--; + return it; + } + // TODO: operator-> + reference operator*() const { return f(index); } +}; + +template +inline basic_iota_iterator make_basic_iota_iterator(T x, F f) +{ + return basic_iota_iterator{x, f}; +} + +template +inline basic_iota_iterator operator+(basic_iota_iterator x, + std::ptrdiff_t y) +{ + return x += y; +} + +template +inline basic_iota_iterator operator+(std::ptrdiff_t x, + basic_iota_iterator y) +{ + return y + x; +} + +template +inline std::ptrdiff_t operator-(basic_iota_iterator x, + basic_iota_iterator y) +{ + return x.index - y.index; +} + +template +inline basic_iota_iterator operator-(basic_iota_iterator x, + std::ptrdiff_t y) +{ + return x -= y; +} + +template +inline bool operator==(basic_iota_iterator x, basic_iota_iterator y) +{ + return x.index == y.index; +} + +template +inline bool operator!=(basic_iota_iterator x, basic_iota_iterator y) +{ + return x.index != y.index; +} + +template +inline bool operator<(basic_iota_iterator x, basic_iota_iterator y) +{ + return x.index < y.index; +} + +template +inline bool operator>(basic_iota_iterator x, basic_iota_iterator y) +{ + return x.index > y.index; +} + +template +inline bool operator>=(basic_iota_iterator x, basic_iota_iterator y) +{ + return x.index >= y.index; +} + +template +inline bool operator<=(basic_iota_iterator x, basic_iota_iterator y) +{ + return x.index <= y.index; +} + +using iota_iterator = basic_iota_iterator; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/iterator.hpp b/src/include/migraphx/iterator.hpp new file mode 100755 index 0000000000000000000000000000000000000000..ed032b29b4bd58d88552f8ea1ebe00f59524d5ac --- /dev/null +++ b/src/include/migraphx/iterator.hpp @@ -0,0 +1,30 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +auto is_end(rank<2>, Iterator it, EndIterator) -> decltype(!it._M_dereferenceable()) +{ + return !it._M_dereferenceable(); +} + +template +auto is_end(rank<1>, Iterator it, EndIterator last) +{ + return it == last; +} + +template +bool is_end(Iterator it, EndIterator last) +{ + return is_end(rank<2>{}, it, last); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_ITERATOR_HPP diff --git a/src/include/migraphx/iterator_for.hpp b/src/include/migraphx/iterator_for.hpp old mode 100644 new mode 100755 index ca2a2123084f48cb60c85e5c98ed93c52662cc24..5f557eb566a39a0fc0c50d64c613eb624d5bc735 --- a/src/include/migraphx/iterator_for.hpp +++ b/src/include/migraphx/iterator_for.hpp @@ -59,18 +59,24 @@ struct iterator_for_range struct iterator { + using difference_type = std::ptrdiff_t; + using reference = decltype(std::declval()); + using value_type = std::remove_reference_t; + using pointer = std::add_pointer_t; + using iterator_category = std::input_iterator_tag; base_iterator i; auto operator*() const { return Selector::deref(i); } base_iterator operator++() { return ++i; } + bool operator==(const iterator& rhs) const { return i == rhs.i; } bool operator!=(const iterator& rhs) const { return i != rhs.i; } }; - iterator begin() + iterator begin() const { assert(base != nullptr); return {Selector::begin(base)}; } - iterator end() + iterator end() const { assert(base != nullptr); return {Selector::end(base)}; diff --git a/src/include/migraphx/json.hpp b/src/include/migraphx/json.hpp new file mode 100644 index 0000000000000000000000000000000000000000..173e38c8180f7ac5c55074a7bbfc40b08c738ca7 --- /dev/null +++ b/src/include/migraphx/json.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_JSON_HPP +#define MIGRAPHX_GUARD_RTGLIB_JSON_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::string to_pretty_json_string(const value& val, std::size_t indent = 4); +std::string to_json_string(const value& val); +value from_json_string(const std::string& str); +value from_json_string(const char* str, std::size_t size); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/lifetime.hpp b/src/include/migraphx/lifetime.hpp new file mode 100755 index 0000000000000000000000000000000000000000..7a21e5302ebed9308e413b36525c0ca8e4d644de --- /dev/null +++ b/src/include/migraphx/lifetime.hpp @@ -0,0 +1,18 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +enum class lifetime +{ + local, + global, + borrow +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_LIFETIME_HPP diff --git a/src/include/migraphx/literal.hpp b/src/include/migraphx/literal.hpp index d86cd0e35fa6b2f8c40f48cfe1b148c5db94c6f0..038615bb3a6eba586d342c8ba523c29f1c2bac95 100644 --- a/src/include/migraphx/literal.hpp +++ b/src/include/migraphx/literal.hpp @@ -52,7 +52,8 @@ struct literal : raw_data fill(start, end); } - literal(const shape& s, const char* x) : buffer(make_shared_array(s.bytes())), m_shape(s) + template + literal(const shape& s, T* x) : buffer(make_shared_array(s.bytes())), m_shape(s) { std::copy(x, x + s.bytes(), buffer.get()); } @@ -65,11 +66,13 @@ struct literal : raw_data const shape& get_shape() const { return this->m_shape; } + std::vector get_sub_objects() const { return {}; } + /// Convert the data to an argument argument get_argument() const { - std::vector b(buffer.get(), buffer.get() + m_shape.bytes()); - return {m_shape, [b]() mutable { return b.data(); }}; + auto b = make_shared_array(buffer.get(), buffer.get() + m_shape.bytes()); + return {m_shape, [b]() { return b.get(); }}; } private: @@ -90,7 +93,7 @@ struct literal : raw_data m_shape.visit_type([&](auto as) { auto output = make_view(m_shape, as.from(buffer.get())); shape_for_each(output.get_shape(), [&](const auto& idx) { - output(idx.begin(), idx.end()) = *it; + output(idx.begin(), idx.end()) = *it; // NOLINT(bugprone-signed-char-misuse) it++; }); }); @@ -125,6 +128,9 @@ literal transform(literal l1, literal l2, F f) return result; } +void migraphx_to_value(value& v, const literal& l); +void migraphx_from_value(const value& v, literal& l); + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/load_save.hpp b/src/include/migraphx/load_save.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f6db67d6e725061bd0f16a06c6b50795ac2af686 --- /dev/null +++ b/src/include/migraphx/load_save.hpp @@ -0,0 +1,29 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_LOAD_SAVE_HPP +#define MIGRAPHX_GUARD_RTGLIB_LOAD_SAVE_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct file_options +{ + std::string format = "msgpack"; +}; + +program load(const std::string& filename, const file_options& options = file_options{}); +program load_buffer(const std::vector& buffer, const file_options& options = file_options{}); +program +load_buffer(const char* buffer, std::size_t size, const file_options& options = file_options{}); + +void save(const program& p, + const std::string& filename, + const file_options& options = file_options{}); +std::vector save_buffer(const program& p, const file_options& options = file_options{}); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/make_op.hpp b/src/include/migraphx/make_op.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d736c9fc19c2bdd15448afdb744686ea0729a8a7 --- /dev/null +++ b/src/include/migraphx/make_op.hpp @@ -0,0 +1,29 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_MAKE_OP_HPP +#define MIGRAPHX_GUARD_RTGLIB_MAKE_OP_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +operation make_op(const std::string& name); +operation make_op(const std::string& name, + const std::initializer_list>& v); +operation make_op_from_value(const std::string& name, const value& v); + +// A template overload is added for migraphx::value so the initializer_list +// cannot be passed in directly. This is to enforce at compile-time that all +// initializer_list are key-value pairs, whereas migraphx::value allows other +// types of initializer_list such as for arrays. +template +operation make_op(const std::string& name, const Value& v) +{ + return make_op_from_value(name, v); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/make_shared_array.hpp b/src/include/migraphx/make_shared_array.hpp old mode 100644 new mode 100755 index f0a2d574c70a4dce76450df86fd24cf60b7055cb..09bfa2c58b055ec96006a6fb2e6d96f27d4603be --- a/src/include/migraphx/make_shared_array.hpp +++ b/src/include/migraphx/make_shared_array.hpp @@ -10,7 +10,15 @@ inline namespace MIGRAPHX_INLINE_NS { template std::shared_ptr make_shared_array(size_t size) { - return std::shared_ptr(new T[size], std::default_delete()); // NOLINT + return std::shared_ptr(new T[size](), std::default_delete()); // NOLINT +} + +template +std::shared_ptr make_shared_array(Iterator start, Iterator last) +{ + auto result = make_shared_array(std::distance(start, last)); + std::copy(start, last, result.get()); + return result; } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/manage_ptr.hpp b/src/include/migraphx/manage_ptr.hpp index 99e7972e627d593e16275bedda3d524ae15c7444..8231e8a199f389d943a465c6687d4bb38926191f 100644 --- a/src/include/migraphx/manage_ptr.hpp +++ b/src/include/migraphx/manage_ptr.hpp @@ -16,7 +16,7 @@ struct manage_deleter { if(x != nullptr) { - f(x); + (void)f(x); } } }; diff --git a/src/include/migraphx/marker.hpp b/src/include/migraphx/marker.hpp new file mode 100755 index 0000000000000000000000000000000000000000..5e3075fc6514cf97c096344932c6467b7601ad73 --- /dev/null +++ b/src/include/migraphx/marker.hpp @@ -0,0 +1,257 @@ +#ifndef MIGRAPHX_GUARD_MARKER_HPP +#define MIGRAPHX_GUARD_MARKER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +#ifdef DOXYGEN + +/// Marker is an interface to general marking functions, such as rocTX markers. + +#else + +#ifdef TYPE_ERASED_DECLARATION + +// Type-erased interface for: +struct marker +{ + // + void mark_start(instruction_ref ins_ref); + // + void mark_start(const program& prog); + // + void mark_stop(instruction_ref ins); + // + void mark_stop(const program& prog); +}; + +#else + +struct marker +{ + // Constructors + marker() = default; + + template + marker(PrivateDetailTypeErasedT value) + : private_detail_te_handle_mem_var( + std::make_shared::type>>( + std::forward(value))) + { + } + + // Assignment + template + marker& operator=(PrivateDetailTypeErasedT value) + { + using std::swap; + auto* derived = this->any_cast(); + if(derived and private_detail_te_handle_mem_var.unique()) + { + *derived = std::forward(value); + } + else + { + marker rhs(value); + swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var); + } + return *this; + } + + // Cast + template + PrivateDetailTypeErasedT* any_cast() + { + return this->type_id() == typeid(PrivateDetailTypeErasedT) + ? std::addressof(static_cast::type>&>( + private_detail_te_get_handle()) + .private_detail_te_value) + : nullptr; + } + + template + const typename std::remove_cv::type* any_cast() const + { + return this->type_id() == typeid(PrivateDetailTypeErasedT) + ? std::addressof(static_cast::type>&>( + private_detail_te_get_handle()) + .private_detail_te_value) + : nullptr; + } + + const std::type_info& type_id() const + { + if(private_detail_te_handle_empty()) + return typeid(std::nullptr_t); + else + return private_detail_te_get_handle().type(); + } + + void mark_start(instruction_ref ins_ref) + { + assert((*this).private_detail_te_handle_mem_var); + (*this).private_detail_te_get_handle().mark_start(ins_ref); + } + + void mark_start(const program& prog) + { + assert((*this).private_detail_te_handle_mem_var); + (*this).private_detail_te_get_handle().mark_start(prog); + } + + void mark_stop(instruction_ref ins) + { + assert((*this).private_detail_te_handle_mem_var); + (*this).private_detail_te_get_handle().mark_stop(ins); + } + + void mark_stop(const program& prog) + { + assert((*this).private_detail_te_handle_mem_var); + (*this).private_detail_te_get_handle().mark_stop(prog); + } + + friend bool is_shared(const marker& private_detail_x, const marker& private_detail_y) + { + return private_detail_x.private_detail_te_handle_mem_var == + private_detail_y.private_detail_te_handle_mem_var; + } + + private: + struct private_detail_te_handle_base_type + { + virtual ~private_detail_te_handle_base_type() {} + virtual std::shared_ptr clone() const = 0; + virtual const std::type_info& type() const = 0; + + virtual void mark_start(instruction_ref ins_ref) = 0; + virtual void mark_start(const program& prog) = 0; + virtual void mark_stop(instruction_ref ins) = 0; + virtual void mark_stop(const program& prog) = 0; + }; + + template + struct private_detail_te_handle_type : private_detail_te_handle_base_type + { + template + private_detail_te_handle_type( + PrivateDetailTypeErasedT value, + typename std::enable_if::value>::type* = + nullptr) + : private_detail_te_value(value) + { + } + + template + private_detail_te_handle_type( + PrivateDetailTypeErasedT value, + typename std::enable_if::value, + int>::type* = nullptr) noexcept + : private_detail_te_value(std::move(value)) + { + } + + std::shared_ptr clone() const override + { + return std::make_shared(private_detail_te_value); + } + + const std::type_info& type() const override { return typeid(private_detail_te_value); } + + void mark_start(instruction_ref ins_ref) override + { + + private_detail_te_value.mark_start(ins_ref); + } + + void mark_start(const program& prog) override { private_detail_te_value.mark_start(prog); } + + void mark_stop(instruction_ref ins) override { private_detail_te_value.mark_stop(ins); } + + void mark_stop(const program& prog) override { private_detail_te_value.mark_stop(prog); } + + PrivateDetailTypeErasedT private_detail_te_value; + }; + + template + struct private_detail_te_handle_type> + : private_detail_te_handle_type + { + private_detail_te_handle_type(std::reference_wrapper ref) + : private_detail_te_handle_type(ref.get()) + { + } + }; + + bool private_detail_te_handle_empty() const + { + return private_detail_te_handle_mem_var == nullptr; + } + + const private_detail_te_handle_base_type& private_detail_te_get_handle() const + { + assert(private_detail_te_handle_mem_var != nullptr); + return *private_detail_te_handle_mem_var; + } + + private_detail_te_handle_base_type& private_detail_te_get_handle() + { + assert(private_detail_te_handle_mem_var != nullptr); + if(!private_detail_te_handle_mem_var.unique()) + private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); + return *private_detail_te_handle_mem_var; + } + + std::shared_ptr private_detail_te_handle_mem_var; +}; + +template +inline const ValueType* any_cast(const marker* x) +{ + return x->any_cast(); +} + +template +inline ValueType* any_cast(marker* x) +{ + return x->any_cast(); +} + +template +inline ValueType& any_cast(marker& x) +{ + auto* y = x.any_cast::type>(); + if(y == nullptr) + throw std::bad_cast(); + return *y; +} + +template +inline const ValueType& any_cast(const marker& x) +{ + const auto* y = x.any_cast::type>(); + if(y == nullptr) + throw std::bad_cast(); + return *y; +} +#endif + +#endif + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/match/gelu_erf.hpp b/src/include/migraphx/match/gelu_erf.hpp new file mode 100755 index 0000000000000000000000000000000000000000..681fe3648ba74c0ab24ec296d1af76688b27d382 --- /dev/null +++ b/src/include/migraphx/match/gelu_erf.hpp @@ -0,0 +1,50 @@ +#ifndef MIGRAPHX_GUARD_MATCH_GELU_ERF_HPP +#define MIGRAPHX_GUARD_MATCH_GELU_ERF_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace match { + +namespace detail { +template +struct gelu_erf_matcher +{ + F f; + auto erf_fn() const + { + return f("erf")( + used_once(), + arg(0)(used_once(), + f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"), + has_value(M_SQRT1_2, 1e-3))))); + } + + auto add_erf() const + { + return f("add")(used_once(), either_arg(0, 1)(erf_fn(), has_value(1.0f))); + } + + auto one_half() const { return has_value(0.5f); } + + auto matcher() const { return unordered_tree(f("mul"), one_half(), add_erf(), any()); } +}; +} // namespace detail + +template +auto gelu_erf(F f) +{ + return detail::gelu_erf_matcher{f}.matcher(); +} + +inline auto gelu_erf() +{ + return gelu_erf([](auto x) { return name(x); }); +} + +} // namespace match +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MATCH_GELU_ERF_HPP diff --git a/src/include/migraphx/match/gelu_tanh.hpp b/src/include/migraphx/match/gelu_tanh.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8da91019d6427d1929ab3daf597b05e66ce23ccc --- /dev/null +++ b/src/include/migraphx/match/gelu_tanh.hpp @@ -0,0 +1,51 @@ +#ifndef MIGRAPHX_GUARD_MATCH_GELU_TANH_HPP +#define MIGRAPHX_GUARD_MATCH_GELU_TANH_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace match { + +namespace detail { +template +struct gelu_tanh_matcher +{ + F f; + auto pow_fn() const { return f("pow")(used_once(), arg(1)(has_value(3.0f))); } + + auto tanh_fn() const + { + return f("tanh")( + used_once(), + arg(0)(f("mul")(either_arg(0, 1)(has_value(sqrt(M_2_PI), 1e-3), + f("add")(any_arg(0, 1)(f("mul")(either_arg(0, 1)( + has_value(0.044715f), pow_fn())))))))); + } + + auto matcher() const + { + return f("mul")(used_once(), + either_arg(0, 1)(any().bind("x"), + f("add")(any_arg(0, 1)(f("mul")( + either_arg(0, 1)(has_value(0.5f), tanh_fn())))))); + } +}; +} // namespace detail + +template +auto gelu_tanh(F f) +{ + return detail::gelu_tanh_matcher{f}.matcher(); +} + +inline auto gelu_tanh() +{ + return gelu_tanh([](auto x) { return name(x); }); +} + +} // namespace match +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MATCH_GELU_TANH_HPP diff --git a/src/include/migraphx/match/layernorm.hpp b/src/include/migraphx/match/layernorm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..03828a251d1632de62210b6a29c7c82a61df0228 --- /dev/null +++ b/src/include/migraphx/match/layernorm.hpp @@ -0,0 +1,53 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_MATCH_LAYERNORM_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_MATCH_LAYERNORM_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace match { + +namespace detail { +template +struct layernorm_matcher +{ + F f; + auto x_minus_mean() const + { + return f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(f("reduce_mean")))); + } + + auto variance() const + { + return f("reduce_mean")(arg(0)(f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))))); + } + + auto layernorm_onnx() const + { + return f("div")(arg(0)(x_minus_mean()), + + arg(1)(skip_broadcasts(f("sqrt")( + arg(0)(f("add")(either_arg(0, 1)(variance(), has_value(1e-12f)))))))); + } + + auto matcher() const { return layernorm_onnx(); } +}; +} // namespace detail + +template +auto layernorm(F f) +{ + return detail::layernorm_matcher{f}.matcher(); +} + +inline auto layernorm() +{ + return layernorm([](auto x) { return name(x); }); +} + +} // namespace match +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index 7b8450b1d975db5bc8b7dd946bd1440cf75a2ecc..1417b48b001b1516f0878bb5d9d0e2b91fc482dd 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -4,8 +4,10 @@ #include #include #include -#include +#include +#include #include +#include #include #include #include @@ -17,18 +19,51 @@ namespace match { struct matcher_context { - matcher_context(instruction_ref i) : last(i) {} + matcher_context(module& m) : mod(&m) {} std::unordered_map instructions; - instruction_ref not_found() const { return last; } template bool matched(M m, instruction_ref ins) { - return m.match(*this, ins) != this->not_found(); + return has_value(m.match(*this, ins)); + } + + template + bool matched(M m, optional ins) + { + if(ins) + return has_value(m.match(*this, *ins)); + return false; + } + + template + auto lazy_match(M m, I ins) + { + return [=] { return this->matched(m, ins); }; + } + + bool has_instruction(instruction_ref ins) const + { + if(mod == nullptr) + return true; + return mod->has_instruction(ins); + } + bool has_instruction(optional ins) const + { + if(ins) + return this->has_instruction(*ins); + return false; + } + + bool is_last(instruction_ref ins) const + { + assert(mod->begin() != mod->end()); + assert(this->has_instruction(ins)); + return ins == std::prev(mod->end()); } private: - instruction_ref last; + module* mod = nullptr; }; /// Convert a predicate function into a matcher @@ -37,12 +72,11 @@ struct predicate_matcher { P p; - instruction_ref match(matcher_context& ctx, instruction_ref ins) const + optional match(const matcher_context&, instruction_ref ins) const { - assert(ins != ctx.not_found()); if(p(ins)) - return ins; - return ctx.not_found(); + return optional{ins}; + return nullopt; } }; @@ -52,11 +86,7 @@ struct function_matcher { F f; - instruction_ref match(matcher_context& ctx, instruction_ref ins) const - { - assert(ins != ctx.not_found()); - return f(ctx, ins); - } + auto match(matcher_context& ctx, instruction_ref ins) const { return f(ctx, ins); } }; /// Convert a function into a matcher @@ -71,10 +101,15 @@ template auto bind_match(M m, std::string name) { return make_function_matcher( - [ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) { + [=, name = std::move(name)](matcher_context& ctx, + instruction_ref ins) -> optional { auto result = m.match(ctx, ins); - if(result != ctx.not_found()) + if(result) + { + if(not ctx.has_instruction(ins)) + return nullopt; ctx.instructions[name] = ins; + } return result; }); } @@ -87,10 +122,7 @@ struct bindable_matcher auto bind(std::string name) const { return bind_match(m, std::move(name)); } - instruction_ref match(matcher_context& ctx, instruction_ref ins) const - { - return m.match(ctx, ins); - } + auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); } }; /// Create a bindable matcher @@ -118,9 +150,25 @@ using bool_list = std::initializer_list; struct id_matcher { - instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; } + auto match(matcher_context&, instruction_ref ins) const + { + return optional{ins}; + } }; +// Forward declare class and constructors +template +struct basic_matcher; + +template +basic_matcher make_basic_matcher(M m); + +template +basic_matcher> make_basic_fun_matcher(F f); + +template +basic_matcher> make_basic_pred_matcher(P p); + /// The basic matcher provides the all_of composability of the matcher template struct basic_matcher @@ -132,26 +180,23 @@ struct basic_matcher { // Copy m because we cant capture `this` by value auto mm = m; - return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { + return make_basic_fun_matcher([=](matcher_context& ctx, + instruction_ref ins) -> optional { auto result = mm.match(ctx, ins); - if(result != ctx.not_found()) + if(result) { - bool matches = fold([&](auto x, auto y) { - return x and y.match(ctx, result) != ctx.not_found(); - })(true, ms...); + bool matches = + fold([&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...); if(matches) return result; } - return ctx.not_found(); + return nullopt; }); } auto bind(std::string name) const { return bind_match(m, std::move(name)); } - instruction_ref match(matcher_context& ctx, instruction_ref ins) const - { - return m.match(ctx, ins); - } + auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); } }; /// Create a basic matcher from a matcher @@ -175,14 +220,25 @@ basic_matcher> make_basic_pred_matcher(P p) return {{p}}; } +/// Create a typed-erased matcher +using any_matcher_base = basic_matcher< + function_matcher(matcher_context&, instruction_ref)>>>; +struct any_matcher : any_matcher_base +{ + template + any_matcher(M mm) : any_matcher_base({[=](auto& ctx, auto ins) { return mm.match(ctx, ins); }}) + { + } +}; + /// This macro takes care of the boilerplate for defining a matcher #define MIGRAPHX_BASIC_MATCHER(name, ...) \ struct name##_m \ { \ - instruction_ref match(__VA_ARGS__) const; \ + optional match(__VA_ARGS__) const; \ }; \ const constexpr auto name = migraphx::match::basic_matcher{{}}; \ - inline instruction_ref name##_m::match(__VA_ARGS__) const + inline optional name##_m::match(__VA_ARGS__) const /// This macro takes care of the boilerplate for defining a predicate matcher #define MIGRAPHX_PRED_MATCHER(name, ...) \ @@ -196,57 +252,139 @@ basic_matcher> make_basic_pred_matcher(P p) struct matcher_result { - std::unordered_map instructions; + struct instruction_container + { + instruction_container() = default; + instruction_container(std::unordered_map x) + : ins_map(std::move(x)) + { + } + + instruction_ref operator[](const std::string& name) const + { + auto it = ins_map.find(name); + if(it == ins_map.end()) + MIGRAPHX_THROW("Accessing name that wasn't bound in matcher: " + name); + return it->second; + } + + auto find(const std::string& name) const { return ins_map.find(name); } + + auto begin() const { return ins_map.cbegin(); } + + auto end() const { return ins_map.cend(); } + + bool has_instructions_in(const module& mod) const + { + return std::all_of(ins_map.begin(), ins_map.end(), [&](auto&& p) { + return mod.has_instruction(p.second); + }); + } + + private: + std::unordered_map ins_map; + }; + instruction_container instructions; instruction_ref result; }; /// Match a single instruction template -matcher_result match_instruction(program& p, instruction_ref ins, M&& m) +matcher_result match_instruction(module& mod, instruction_ref ins, M&& m) { - assert(ins != p.end()); + assert(ins != mod.end()); + assert(mod.has_instruction(ins)); + matcher_context ctx{mod}; matcher_result result; - matcher_context ctx{p.end()}; - result.result = m.match(ctx, ins); - result.instructions = ctx.instructions; + if(m.match(ctx, ins)) + { + result.result = ins; + result.instructions = ctx.instructions; + assert(result.instructions.has_instructions_in(mod)); + } + else + { + result.result = mod.end(); + } + return result; +} + +/// Find first instance of a matching instruction in a module +template +match::matcher_result find_match(module& modl, M&& m) +{ + match::matcher_result result; + for(auto ins : iterator_for(modl)) + { + result = match::match_instruction(modl, ins, m); + if(result.result != modl.end()) + return result; + } return result; } -/// Find matches for an instruction in the program +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES) + +/// Find matches for an instruction in the module template -void find_matches(program& p, instruction_ref ins, Ms&&... ms) +void find_matches(module& mod, instruction_ref ins, Ms&&... ms) { - bool match = false; +#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 + const +#endif + bool trace = enabled(MIGRAPHX_TRACE_MATCHES{}); + bool match = false; each_args( [&](auto&& m) { if(match) return; - auto r = match_instruction(p, ins, m.matcher()); - if(r.result == p.end()) + auto r = match_instruction(mod, ins, m.matcher()); + if(r.result == mod.end()) return; - m.apply(p, r); + if(trace) + { + std::cout << "Matched by " << get_type_name(m) << std::endl; + mod.debug_print(ins); + } + m.apply(mod, r); match = true; }, ms...); } -/// Find matches in a program +/// Find matches in a module template -void find_matches(program& p, Ms&&... ms) +void find_matches(module& mod, Ms&&... ms) { - for(auto ins : iterator_for(p)) + for(auto ins : iterator_for(mod)) { - find_matches(p, ins, ms...); + find_matches(mod, ins, ms...); } } +template +struct find_generic_match +{ + M m; + F f; + M matcher() const { return m; } + + void apply(module& mod, const matcher_result& mr) const { f(mod, mr); } +}; + +template +find_generic_match make_match_finder(M m, F f) +{ + return {m, f}; +} + template struct find_skip { M m; M matcher() const { return m; } - void apply(program&, const matcher_result&) const {} + void apply(module&, const matcher_result&) const {} }; template @@ -258,18 +396,18 @@ find_skip make_find_skip(M m) struct lazy_and { template - bool operator()(F f, G g) const + auto operator()(F f, G g) const { - return f() and g(); + return [=] { return f() and g(); }; } }; struct lazy_or { template - bool operator()(F f, G g) const + auto operator()(F f, G g) const { - return f() or g(); + return [=] { return f() or g(); }; } }; @@ -281,7 +419,7 @@ struct match_fold_f { Op op; auto matched = [&](auto m) { return [=, &ctx] { return ctx.matched(m, ins); }; }; - return fold([&](auto x, auto y) { return op(always(x), matched(y)); })(Start, ms...); + return fold(op)(always(Start), matched(ms)...)(); } template @@ -293,12 +431,13 @@ struct match_fold_f template auto operator()(Ts... ms) const { - return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { - bool matches = match_fold_f::fold_matchers(ctx, ins, ms...); - if(matches == Matches) - return ins; - return ctx.not_found(); - }); + return make_bf_matcher( + [=](matcher_context& ctx, instruction_ref ins) -> optional { + bool matches = match_fold_f::fold_matchers(ctx, ins, ms...); + if(matches == Matches) + return {ins}; + return nullopt; + }); } template @@ -307,17 +446,18 @@ struct match_fold_f return [=](auto... ms) { // Workaround ICE on gcc by packing matchers into an object auto mpack = pack(ms...); - return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) { - Op op; - bool matches = Start; - select(start, [&](auto ins) { - auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); }; - matches = op(always(matches), fm); + return make_bf_matcher( + [=](matcher_context& ctx, instruction_ref start) -> optional { + Op op; + bool matches = Start; + select(start, [&](auto ins) { + auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); }; + matches = op(always(matches), fm)(); + }); + if(matches == Matches) + return {start}; + return nullopt; }); - if(matches == Matches) - return start; - return ctx.not_found(); - }); }; } }; @@ -374,64 +514,46 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins) ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; }); } -MIGRAPHX_BASIC_MATCHER(output, const matcher_context& ctx, instruction_ref ins) +MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins) { if(ins->outputs().size() == 1) - return ins->outputs().front(); - return ctx.not_found(); + return {ins->outputs().front()}; + return nullopt; } MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins) { if(ins->outputs().size() == 1) - return ins; - if(ins->outputs().empty() and std::next(ins) == ctx.not_found()) - return ins; - return ctx.not_found(); -} - -inline auto used_once_recursive(std::size_t depth) -{ - return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) { - // Used once - if(start->outputs().size() == 1) - return start; - // Unused - if(start->outputs().empty()) - { - if(std::next(start) == ctx.not_found()) - return start; - else - return ctx.not_found(); - } - // Check for dead instructions - auto is_dead = fix([&](auto self, auto ins, auto n) { - if(n == 0) - return false; - if(ins->get_shape().elements() == 0) - return false; - if(ins->outputs().empty() and std::next(ins) != ctx.not_found()) - return true; - return std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) { - return self(i, n - 1); - }); - }); - auto dead = std::count_if(start->outputs().begin(), start->outputs().end(), [&](auto i) { - return is_dead(i, depth); - }); - if(dead + 1 == start->outputs().size()) - return start; - return ctx.not_found(); - }); + return {ins}; + if(ins->outputs().empty() and ctx.is_last(ins)) + return {ins}; + return nullopt; } MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); } MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins) { - if(ins->outputs().empty() and ins != std::prev(ctx.not_found())) - return ins; - return ctx.not_found(); + if(ins->outputs().empty() and not ctx.is_last(ins)) + return {ins}; + return nullopt; +} + +template +auto skip(Ms... ms) +{ + auto m = any_of(ms...); + return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) { + return fix>( + [&](auto self, auto ins) -> optional { + if(ins->inputs().size() == 1 and ctx.matched(m, ins)) + { + auto next = ins->inputs().front(); + return self(next); + } + return ins; + })(start); + }); } template @@ -439,32 +561,51 @@ auto skip_output(Ms... ms) { auto m = any_of(ms...); return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) { - return fix([&](auto self, auto ins) { - if(ins->outputs().size() == 1) - { - auto next = ins->outputs().front(); - if(ctx.matched(m, next)) + return fix>( + [&](auto self, auto ins) -> optional { + if(ins->outputs().size() == 1) { - auto skipped_next = self(next); - if(skipped_next != ctx.not_found()) - return skipped_next; + auto next = ins->outputs().front(); + if(ctx.matched(m, next)) + { + auto skipped_next = self(next); + if(skipped_next) + return skipped_next; + } + return next; } - return next; - } - return ctx.not_found(); - })(start); + return nullopt; + })(start); }); } +inline auto var(std::string s) +{ + return make_basic_fun_matcher( + [=, s = std::move(s)](const matcher_context& ctx, + instruction_ref) -> optional { + auto it = ctx.instructions.find(s); + if(it == ctx.instructions.end()) + return nullopt; + return it->second; + }); +} + inline auto name(std::string s) { return make_basic_pred_matcher( - [ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; }); + [=, s = std::move(s)](instruction_ref ins) { return ins->name() == s; }); +} + +inline auto name_contains(const std::string& name) +{ + return make_basic_pred_matcher( + [=](instruction_ref ins) { return contains(ins->get_operator().name(), name); }); } inline auto name(std::unordered_set names) { - return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) { + return make_basic_pred_matcher([=, names = std::move(names)](instruction_ref ins) { return names.count(ins->name()) > 0; }); } @@ -482,11 +623,12 @@ inline auto nargs(std::size_t n) inline auto arg(std::size_t i) { - return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref ins) { - if(i < ins->inputs().size()) - return ins->inputs()[i]; - return ctx.not_found(); - }); + return make_basic_fun_matcher( + [=](const matcher_context&, instruction_ref ins) -> optional { + if(i < ins->inputs().size()) + return ins->inputs()[i]; + return nullopt; + }); } // Workaround for bugs in clang @@ -518,15 +660,86 @@ inline auto either_arg(std::size_t i, std::size_t j) }; } +inline auto any_arg(std::size_t i, std::size_t j) +{ + return [=](auto m) { return match::any_of(arg(i)(m), arg(j)(m)); }; +} + +template +std::size_t tree_leafs_impl(matcher_context& ctx, + std::array& leafs, + M m, + instruction_ref ins) +{ + std::size_t idx = 0; + fix([&](auto self, auto i) { + if(idx == leafs.size()) + return; + if(ctx.matched(m, i) and i->inputs().size() >= 2) + { + self(i->inputs()[0]); + self(i->inputs()[1]); + return; + } + leafs[idx] = i; + idx++; + })(ins); + return idx; +} + +template +auto tree(M main_op, Ms... ms) +{ + return make_basic_fun_matcher( + [=](matcher_context& ctx, instruction_ref ins) -> optional { + // Flatten leaf nodes + std::array leafs; + std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins); + if(idx != leafs.size()) + return nullopt; + // Use explicit captures to workaround ICE on gcc + // Capture by value to workaround compile error on gcc 9 + bool found = sequence_c([ms..., &ctx, &leafs](auto... is) { + return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)(); + }); + if(not found) + return nullopt; + return ins; + }); +} + +template +auto unordered_tree(M main_op, Ms... ms) +{ + return make_basic_fun_matcher( + [=](matcher_context& ctx, instruction_ref ins) -> optional { + // Flatten leaf nodes + std::array leafs; + std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins); + if(idx != leafs.size()) + return nullopt; + // Use explicit captures to workaround ICE on gcc + bool found = sequence_c([ms..., &ctx, &leafs](auto... is) { + return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) { + return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...); + })(ms...)(); + }); + if(not found) + return nullopt; + return ins; + }); +} + template auto same_shape(M m) { - return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { - auto i = m.match(ctx, ins); - if(i != ctx.not_found() and i->get_shape() == ins->get_shape()) - return ins; - return ctx.not_found(); - }); + return make_basic_fun_matcher( + [=](matcher_context& ctx, instruction_ref ins) -> optional { + auto i = m.match(ctx, ins); + if(i and (*i)->get_shape() == ins->get_shape()) + return ins; + return nullopt; + }); } template @@ -535,6 +748,50 @@ auto same_shape(Ms... ms) return all_of(same_shape(ms)...); } +template +auto skip_broadcasts(Ms... ms) +{ + return skip(name("broadcast", "multibroadcast", "contiguous"))(ms...); +} + +template +auto skip_broadcasts_converts(Ms... ms) +{ + return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...); +} + +template +inline auto has_value(T x, float tolerance = 1e-6) +{ + return skip_broadcasts_converts(make_basic_pred_matcher([=](instruction_ref ins) { + if(ins->name() != "@literal") + return false; + auto l = ins->get_literal(); + if(l.empty()) + return false; + bool b = false; + l.visit([&](auto v) { + if(std::all_of( + v.begin(), v.end(), [&](auto val) { return std::fabs(val - x) < tolerance; })) + b = true; + }); + return b; + })); +} + +inline auto has_attribute(const std::string& name) +{ + return make_basic_pred_matcher( + [=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); }); +} + +template +auto pointwise(Ms... ms) +{ + return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)), + ms...); +} + } // namespace match } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/memory_coloring.hpp b/src/include/migraphx/memory_coloring.hpp index efa4e1782330c5a5e17c4a759419f0eab047429d..41c8d3abb7e3382a7884df978773e59e761154fd 100644 --- a/src/include/migraphx/memory_coloring.hpp +++ b/src/include/migraphx/memory_coloring.hpp @@ -7,7 +7,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; /** * Remove memory allocations. It uses graph coloring to find memory allocations that can be reused. @@ -17,7 +17,7 @@ struct memory_coloring std::string allocation_op{}; bool verify = false; std::string name() const { return "memory coloring"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cdbb60c088cc949b4c3460cc00216a4e36a18f5a --- /dev/null +++ b/src/include/migraphx/module.hpp @@ -0,0 +1,184 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_MODULE_HPP +#define MIGRAPHX_GUARD_MIGRAPHLIB_MODULE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +const operation& get_operation(instruction_ref ins); + +struct module_impl; + +using parameter_map = std::unordered_map; +using ins_dep_map = std::unordered_map>; + +/** + * @brief Stores the instruction stream + */ +struct module +{ + module(const std::string& name = ""); + + // move constructor + module(module&&) noexcept; + + // copy constructor + module(const module&); + + // copy assignment operator + module& operator=(module); + + ~module() noexcept; + + std::string name() const; + + bool bypass() const; + void set_bypass(bool b = true); + + template {}...)> + instruction_ref add_instruction(operation op, Ts... args) + { + return add_instruction(op, {args...}); + } + + instruction_ref add_instruction(const operation& op, std::vector args); + + instruction_ref add_instruction(const operation& op, + std::vector args, + std::vector module_args); + + template {}...)> + instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args) + { + return insert_instruction(ins, op, {args...}); + } + instruction_ref + insert_instruction(instruction_ref ins, const operation& op, std::vector args); + + instruction_ref insert_instruction(instruction_ref ins, + const operation& op, + std::vector args, + std::vector module_args); + + template {}...)> + instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args) + { + return replace_instruction(ins, op, {args...}); + } + instruction_ref replace_instruction(instruction_ref ins, + const operation& op, + std::vector args) MIGRAPHX_TIDY_CONST; + + instruction_ref replace_instruction(instruction_ref ins, + const operation& op, + std::vector args, + std::vector module_args) MIGRAPHX_TIDY_CONST; + + instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep); + + instruction_ref remove_instruction(instruction_ref ins); + instruction_ref remove_instructions(instruction_ref first, instruction_ref last); + + instruction_ref move_instruction(instruction_ref src, instruction_ref dst); + instruction_ref move_instructions(instruction_ref src, instruction_ref dst); + + std::vector + insert_module_instructions(instruction_ref ins, + module_ref m, + std::unordered_map map_ins = {}); + + template + instruction_ref add_literal(Ts&&... xs) + { + return add_literal(literal{std::forward(xs)...}); + } + + instruction_ref add_literal(literal l); + + instruction_ref add_outline(const shape& s); + + instruction_ref add_parameter(std::string name, shape s); + + instruction_ref add_return(std::vector args); + + instruction_ref replace_return(std::vector args); + + std::vector get_parameter_names() const; + + shape get_parameter_shape(std::string name) const; + + instruction_ref get_parameter(std::string name) const; + + std::unordered_map get_parameter_shapes() const; + + bool has_instruction(instruction_ref ins) const; + + std::size_t size() const; + instruction_ref begin() const; + instruction_ref end() const; + + std::vector get_output_shapes() const; + + instruction_ref validate() const; + instruction_ref find_dangling_reference() const; + + void finalize(context& ctx); + + void debug_print() const; + void debug_print(instruction_ref ins) const; + void debug_print(instruction_ref ins, + std::unordered_map& names) const; + void debug_print(const std::vector& inss) const; + + std::unordered_map print( + const std::function&)>& print_func, + std::unordered_map names) const; + void print(const std::function&)>& + print_func) const; + + void print_graph(std::ostream& os, bool brief = false) const; + + void print_cpp(std::ostream& os) const; + std::unordered_map + print_cpp(std::ostream& os, std::unordered_map names) const; + + void annotate(std::ostream& os, std::function a) const; + + std::vector get_sub_modules() const; + module& sort(); + ins_dep_map calc_implicit_deps() const; + + friend std::ostream& operator<<(std::ostream& os, const module& m); + friend bool operator==(const module& x, const module& y); + friend bool operator!=(const module& x, const module& y) { return !(x == y); } + + private: + void assign(const module& m); + void calc_implicit_deps(const module& smod, + const module& pmod, + instruction_ref ins, + ins_dep_map& deps) const; + + std::unique_ptr impl; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/module_ref.hpp b/src/include/migraphx/module_ref.hpp new file mode 100644 index 0000000000000000000000000000000000000000..134f72b1043e8b21092d3c5fc643152f84d8f28f --- /dev/null +++ b/src/include/migraphx/module_ref.hpp @@ -0,0 +1,17 @@ +#ifndef MIGRAPHX_GUARD_MODULE_REF_HPP +#define MIGRAPHX_GUARD_MODULE_REF_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; +using module_ref = module*; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/msgpack.hpp b/src/include/migraphx/msgpack.hpp new file mode 100644 index 0000000000000000000000000000000000000000..962230fda98664b7b5a36a2903b2d64b5d32d5c7 --- /dev/null +++ b/src/include/migraphx/msgpack.hpp @@ -0,0 +1,17 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_MSGPACK_HPP +#define MIGRAPHX_GUARD_RTGLIB_MSGPACK_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::vector to_msgpack(const value& v); +value from_msgpack(const std::vector& buffer); +value from_msgpack(const char* buffer, std::size_t size); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/normalize_attributes.hpp b/src/include/migraphx/normalize_attributes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ba1366e3879eef2c84c38287e245310233fec40d --- /dev/null +++ b/src/include/migraphx/normalize_attributes.hpp @@ -0,0 +1,27 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_NORMALIZE_ATTRIBUTES_HPP +#define MIGRAPHX_GUARD_RTGLIB_NORMALIZE_ATTRIBUTES_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct operation; + +template +struct select_dependent_type +{ + using type = T; +}; +template +using dependent_type = typename select_dependent_type::type; + +bool normalize_attributes(operation& op, const std::vector& lens); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/normalize_ops.hpp b/src/include/migraphx/normalize_ops.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a953eb47a6373eb66823964461ca7a4de40d9e72 --- /dev/null +++ b/src/include/migraphx/normalize_ops.hpp @@ -0,0 +1,28 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_NORMALIZE_OPS_HPP +#define MIGRAPHX_GUARD_RTGLIB_NORMALIZE_OPS_HPP + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +/** + * Process negative axis attributes of ops + */ + +struct normalize_ops +{ + std::string name() const { return "normalize_ops"; } + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/onnx.hpp b/src/include/migraphx/onnx.hpp index dc17482bcc1753b9b5e1c698ed63fd4b96acca49..44e0c4858e767d883dd2caa6d5d58bbc2c69088a 100644 --- a/src/include/migraphx/onnx.hpp +++ b/src/include/migraphx/onnx.hpp @@ -7,8 +7,31 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +/// struct to pass in onnx options to parser +struct onnx_options +{ + /// default batch size to use (if not specified in onnx file) + std::size_t default_dim_value = 1; + /// Explicitly specify the dims of an input + std::unordered_map> map_input_dims = {}; + /// Continue parsing onnx file if an unknown operator is found + bool skip_unknown_operators = false; + /// Print program if an error occurs + bool print_program_on_error = false; + /// Max iter num for the loop operator + int64_t max_loop_iterations = 10; +}; + /// Create a program from an onnx file -program parse_onnx(const std::string& name); +program parse_onnx(const std::string& name, const onnx_options& = onnx_options{}); + +/// Create a program from an onnx buffer +program parse_onnx_buffer(const std::string& buffer, const onnx_options& options); + +/// Create a program from an onnx buffer +program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options); + +std::vector get_onnx_operators(); } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/op/abnormal_ops.hpp b/src/include/migraphx/op/abnormal_ops.hpp deleted file mode 100644 index 9187249378b8f8027f7cf04e751d5aba59ff7420..0000000000000000000000000000000000000000 --- a/src/include/migraphx/op/abnormal_ops.hpp +++ /dev/null @@ -1,67 +0,0 @@ -#ifndef MIGRAPHX_GUARD_OPERATORS_ABNORMAL_OPS_HPP -#define MIGRAPHX_GUARD_OPERATORS_ABNORMAL_OPS_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { -namespace op { - -struct not_computable -{ - argument compute(const shape&, const std::vector&) const - { - MIGRAPHX_THROW("not computable"); - } -}; - -struct undefined -{ - std::string name() const { return "undefined"; } - shape compute_shape(const std::vector& inputs) const - { - check_shapes{inputs, *this}.has(0); - return {}; - } - - argument compute(const shape&, const std::vector&) const { return {{}, nullptr}; } -}; - -struct unknown -{ - std::string op; - template - static auto reflect(Self& self, F f) - { - return pack(f(self.op, "op")); - } - std::string name() const { return "unknown:" + op; } - shape compute_shape(std::vector input) const - { - if(input.empty()) - return {}; - else - return input.front(); - } - - friend std::ostream& operator<<(std::ostream& os, const unknown& x) - { - os << x.name(); - return os; - } -}; - -} // namespace op -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx - -#endif diff --git a/src/include/migraphx/op/abs.hpp b/src/include/migraphx/op/abs.hpp index 510403aef1429130cabbf469386cb049596311d6..034aa2ab4c390e034cee201c492a0415f7470186 100644 --- a/src/include/migraphx/op/abs.hpp +++ b/src/include/migraphx/op/abs.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include diff --git a/src/include/migraphx/op/acos.hpp b/src/include/migraphx/op/acos.hpp index 1e1ca47161a004e737b6baf34c2acd56ddc3efd3..3ce959f63769de9cb2c7c68eda1fc62de058117a 100644 --- a/src/include/migraphx/op/acos.hpp +++ b/src/include/migraphx/op/acos.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include diff --git a/src/include/migraphx/op/acosh.hpp b/src/include/migraphx/op/acosh.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9ff8ea59ed6cfbedb6ea34de19d98fc39d083b14 --- /dev/null +++ b/src/include/migraphx/op/acosh.hpp @@ -0,0 +1,22 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_ACOSH_HPP +#define MIGRAPHX_GUARD_OPERATORS_ACOSH_HPP + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct acosh : unary +{ + auto apply() const + { + return [](auto x) { return std::acosh(x); }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/add.hpp b/src/include/migraphx/op/add.hpp old mode 100644 new mode 100755 index 53e910a52a62630fa53450edba35eb095a94e410..a4dae809b29dfd3166cf21ba8177191413c70dc1 --- a/src/include/migraphx/op/add.hpp +++ b/src/include/migraphx/op/add.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -19,6 +18,13 @@ namespace op { struct add : binary { + value attributes() const + { + auto a = base_attributes(); + a["commutative"] = true; + return a; + } + std::string point_function() const { return "+"; } auto apply() const { return [](auto x, auto y) { return x + y; }; diff --git a/src/include/migraphx/op/allocate.hpp b/src/include/migraphx/op/allocate.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fe7741a4226aa4e9ace7a8b21055d34f3c1c6f07 --- /dev/null +++ b/src/include/migraphx/op/allocate.hpp @@ -0,0 +1,43 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_ALLOCATE_HPP +#define MIGRAPHX_GUARD_OPERATORS_ALLOCATE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct allocate +{ + shape s{}; + template + static auto reflect(Self& self, F f) + { + return pack(f(self.s, "shape")); + } + std::string name() const { return "allocate"; } + shape compute_shape(const std::vector& inputs) const + { + migraphx::check_shapes{inputs, *this}.has(0); + return s; + } + argument compute(const shape& output_shape, const std::vector&) const + { + return {output_shape}; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/argmax.hpp b/src/include/migraphx/op/argmax.hpp index d99d495f7c9ab38e9c21a1e1889aa3a9c0cf398e..d6e3349026dc24f36e638309245d482aa6b5142c 100644 --- a/src/include/migraphx/op/argmax.hpp +++ b/src/include/migraphx/op/argmax.hpp @@ -1,10 +1,14 @@ #ifndef MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP #define MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP -#include #include -#include +#include +#include +#include #include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -20,17 +24,19 @@ struct argmax return pack(f(self.axis, "axis")); } + value attributes() const + { + value normalize; + normalize["axis"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + std::string name() const { return "argmax"; } - shape compute_shape(std::vector inputs) const + shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(1).standard(); - auto lens = inputs[0].lens(); - int64_t n_dim = static_cast(lens.size()); - if(axis >= n_dim || axis < 0) - { - MIGRAPHX_THROW("ARGMAX: axis is out of range."); - } + check_shapes{inputs, *this}.has(1); + auto lens = inputs[0].lens(); lens[axis] = 1; diff --git a/src/include/migraphx/op/argmin.hpp b/src/include/migraphx/op/argmin.hpp index 0b19665ba4cdf2ae77eb9c0e547c92c1e31f4983..c17322bf5f0d1880c91cae498f6719a9dd1217e6 100644 --- a/src/include/migraphx/op/argmin.hpp +++ b/src/include/migraphx/op/argmin.hpp @@ -1,10 +1,14 @@ #ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP #define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP -#include #include -#include +#include +#include +#include #include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -20,17 +24,19 @@ struct argmin return pack(f(self.axis, "axis")); } + value attributes() const + { + value normalize; + normalize["axis"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + std::string name() const { return "argmin"; } - shape compute_shape(std::vector inputs) const + shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(1).standard(); - auto lens = inputs[0].lens(); - int64_t n_dim = static_cast(lens.size()); - if(axis >= n_dim || axis < 0) - { - MIGRAPHX_THROW("ARGMIN: axis is out of range."); - } + check_shapes{inputs, *this}.has(1); + auto lens = inputs[0].lens(); lens[axis] = 1; diff --git a/src/include/migraphx/op/as_shape.hpp b/src/include/migraphx/op/as_shape.hpp old mode 100644 new mode 100755 index d02e18cf742b35ba878766e3bdc0722d45d02817..57c1f0ce2ee9dd18dd249937bf41786a5b22e601 --- a/src/include/migraphx/op/as_shape.hpp +++ b/src/include/migraphx/op/as_shape.hpp @@ -2,13 +2,13 @@ #define MIGRAPHX_GUARD_OPERATORS_AS_SHAPE_HPP #include -#include #include #include #include #include #include #include +#include #include #include @@ -34,7 +34,7 @@ struct as_shape } argument compute(shape output_shape, std::vector args) const { - return {std::move(output_shape), std::move(args.front().data)}; + return args.front().reshape(output_shape); } std::ptrdiff_t output_alias(const std::vector&) const { return 0; } }; diff --git a/src/include/migraphx/op/asin.hpp b/src/include/migraphx/op/asin.hpp index 8a16c2cae705661e37e8aa498a1e524a59dd0a27..29395891e6600312b4960344f8f64cead02cbca8 100644 --- a/src/include/migraphx/op/asin.hpp +++ b/src/include/migraphx/op/asin.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include diff --git a/src/include/migraphx/op/asinh.hpp b/src/include/migraphx/op/asinh.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4679c4094f169610ec7cf4826c02bd1d9b3127d6 --- /dev/null +++ b/src/include/migraphx/op/asinh.hpp @@ -0,0 +1,22 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_ASINH_HPP +#define MIGRAPHX_GUARD_OPERATORS_ASINH_HPP + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct asinh : unary +{ + auto apply() const + { + return [](auto x) { return std::asinh(x); }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/atan.hpp b/src/include/migraphx/op/atan.hpp index 68617186f495481e58bd7d61263ca71076e2f0a4..1864159ada016c8a287cd503125a8653622cce6a 100644 --- a/src/include/migraphx/op/atan.hpp +++ b/src/include/migraphx/op/atan.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include diff --git a/src/include/migraphx/op/atanh.hpp b/src/include/migraphx/op/atanh.hpp new file mode 100644 index 0000000000000000000000000000000000000000..210c5dfbeffccaacb5c98a9ab2cd339deb19404c --- /dev/null +++ b/src/include/migraphx/op/atanh.hpp @@ -0,0 +1,22 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_ATANH_HPP +#define MIGRAPHX_GUARD_OPERATORS_ATANH_HPP + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct atanh : unary +{ + auto apply() const + { + return [](auto x) { return std::atanh(x); }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/batch_norm.hpp b/src/include/migraphx/op/batch_norm_inference.hpp similarity index 87% rename from src/include/migraphx/op/batch_norm.hpp rename to src/include/migraphx/op/batch_norm_inference.hpp index 0921fba1ec10278d1af5541b765c75a4c74dca5f..427c59164cfa1637bd329fe5f9136b4d230e6525 100644 --- a/src/include/migraphx/op/batch_norm.hpp +++ b/src/include/migraphx/op/batch_norm_inference.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -42,9 +41,8 @@ struct batch_norm_inference shape compute_shape(std::vector inputs) const { check_shapes{inputs, *this}.has(5); - check_shapes{inputs.data(), inputs.data() + 1, *this}.only_dims(4); - check_shapes{inputs.data() + 1, inputs.data() + inputs.size(), *this}.same_shape().elements( - inputs.front().lens()[1]); + check_shapes{inputs.data(), inputs.data() + 1, *this}.same_ndims(); + check_shapes{inputs.data() + 1, inputs.data() + inputs.size(), *this}.same_shape(); return inputs.front(); } }; diff --git a/src/include/migraphx/op/binary.hpp b/src/include/migraphx/op/binary.hpp old mode 100644 new mode 100755 index 18f485566dc4b1ba783960057efe739c962c280b..422a5888074bd8f9fd4bb67c019bf7c31ceb6375 --- a/src/include/migraphx/op/binary.hpp +++ b/src/include/migraphx/op/binary.hpp @@ -2,6 +2,11 @@ #define MIGRAPHX_GUARD_OPERATORS_BINARY_HPP #include +#include +#include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -10,15 +15,45 @@ namespace op { template struct binary : op_name { + std::string point_function() const { return this->name(); } + std::string point_op() const + { + const auto& self = static_cast(*this); + auto pf = self.point_function(); + if(pf.empty()) + return {}; + if(with_char(::ispunct)(pf.front())) + { + return "${0} " + pf + " ${1}"; + } + else + { + return "${function:" + pf + "}(${0}, ${1})"; + } + } + value base_attributes() const + { + const auto& self = static_cast(*this); + return {{"pointwise", true}, {"point_op", self.point_op()}}; + } + value attributes() const { return base_attributes(); } shape compute_shape(std::vector inputs) const { - check_shapes{inputs}.has(2).same_type().same_dims(); + check_shapes{inputs, static_cast(*this)}.has(2).same_type().same_dims(); auto s0 = inputs.at(0); auto s1 = inputs.at(1); if(s0 == s1 and s0.packed()) { return s0; } + else if(s0.packed() != s1.packed()) + { + return s0.packed() ? s0 : s1; + } + else if(s0.broadcasted() != s1.broadcasted()) + { + return s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens()); + } else { return {s0.type(), s0.lens()}; @@ -28,32 +63,13 @@ struct binary : op_name argument compute(const shape& output_shape, std::vector args) const { argument result{output_shape}; - auto s1 = args[0].get_shape(); - auto s2 = args[1].get_shape(); - if(s1 == s2 and s1.packed()) - { - shape std_shape{s1.type(), s1.lens()}; - argument std_result{std_shape, result.data()}; - argument std_arg0{std_shape, args[0].data()}; - argument std_arg1{std_shape, args[1].data()}; - visit_all(std_result, std_arg0, std_arg1)([&](auto output, auto input1, auto input2) { - std::transform(input1.begin(), - input1.end(), - input2.begin(), - output.begin(), - static_cast(*this).apply()); - }); - } - else - { - visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { - shape_for_each(output.get_shape(), [&](const auto& idx) { - output(idx.begin(), idx.end()) = static_cast(*this).apply()( - input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end())); - }); - }); - } - + visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { + std::transform(input1.begin(), + input1.end(), + input2.begin(), + output.begin(), + static_cast(*this).apply()); + }); return result; } }; diff --git a/src/include/migraphx/op/broadcast.hpp b/src/include/migraphx/op/broadcast.hpp old mode 100644 new mode 100755 index 83d6f971fd9a31c805e20f943fbd56a9417f1ef4..fe0d8e2df0d3b7b7c6831306c229bd5d14f07e72 --- a/src/include/migraphx/op/broadcast.hpp +++ b/src/include/migraphx/op/broadcast.hpp @@ -2,13 +2,11 @@ #define MIGRAPHX_GUARD_OPERATORS_BROADCAST_HPP #include -#include #include -#include -#include -#include -#include +#include +#include #include +#include #include #include @@ -32,36 +30,42 @@ struct broadcast template static auto reflect(Self& self, F f) { - return pack(f(self.axis, "axis"), f(self.broadcast_lens, "dims")); + return pack(f(self.axis, "axis"), f(self.broadcast_lens, "out_lens")); } std::string name() const { return "broadcast"; } shape compute_shape(std::vector inputs) const { - auto t = inputs.at(0).type(); auto input = inputs.at(0); + auto t = input.type(); std::vector bcast_strides(broadcast_lens.size(), 0); + // the broacast op is deprecated now, so not handling the negative + // value of axis anymore + if(axis >= broadcast_lens.size()) + { + MIGRAPHX_THROW("BROADCAST : axis is out of range"); + } - if(std::all_of( - broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; })) + if(broadcast_lens.size() - axis < input.lens().size()) { - if(axis != 0) - MIGRAPHX_THROW("BROADCAST: when broadcasting tensor of size 1, axis should be 0"); - return {t, broadcast_lens, std::move(bcast_strides)}; + MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims"); } - else + + if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis)) { - assert(broadcast_lens.size() - axis >= input.lens().size()); - if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis)) - MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match"); - std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis); - return {t, broadcast_lens, std::move(bcast_strides)}; + MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match"); } + std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis); + + shape output{t, broadcast_lens, std::move(bcast_strides)}; + if(output.elements() < input.elements()) + MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size"); + return output; } argument compute(shape output_shape, std::vector args) const { - return {std::move(output_shape), std::move(args.at(0).data)}; + return args[0].reshape(output_shape); } std::ptrdiff_t output_alias(const std::vector&) const { return 0; } }; diff --git a/src/include/migraphx/op/capture.hpp b/src/include/migraphx/op/capture.hpp index 4665c037485183e78204188fb3c819cbbbcf03db..a450ec1d5bb5435425eaf96db94c6e1a7c901f28 100644 --- a/src/include/migraphx/op/capture.hpp +++ b/src/include/migraphx/op/capture.hpp @@ -2,13 +2,13 @@ #define MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP #include -#include #include #include #include #include #include #include +#include #include #include @@ -30,7 +30,9 @@ struct capture shape compute_shape(std::vector inputs) const { return inputs.front(); } - argument compute(const shape&, std::vector args) const + // the context argument is added to prevent the op from be eliminated by + // constant propagation + argument compute(context&, const shape&, const std::vector& args) const { if(f) { @@ -43,6 +45,8 @@ struct capture return args.front(); } + + std::ptrdiff_t output_alias(const std::vector&) const { return 0; } }; } // namespace op diff --git a/src/include/migraphx/op/clip.hpp b/src/include/migraphx/op/clip.hpp index 3bb7e6c1997438bc60b60f506e0276594888f29c..3c05f1df95fe3b5999b956a5cd12b82a82bbf71f 100644 --- a/src/include/migraphx/op/clip.hpp +++ b/src/include/migraphx/op/clip.hpp @@ -3,12 +3,11 @@ #include #include -#include #include #include #include #include -#include +#include #include #include #include @@ -18,29 +17,31 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { -struct clip : unary +struct clip { - float max_val = std::numeric_limits::max(); - float min_val = std::numeric_limits::min(); + std::string name() const { return "clip"; } - clip() {} - - clip(float max, float min) : max_val(max), min_val(min) {} + value attributes() const + { + return {{"pointwise", true}, + {"point_op", "${function:min}(${function:max}(${1}, ${0}), ${2})"}}; + } - auto apply() const + shape compute_shape(std::vector inputs) const { - auto max = max_val; - auto min = min_val; - return [max, min](auto x) { - using type = decltype(x); - return std::min(std::max(type(min), x), type(max)); - }; + check_shapes{inputs, *this}.has(3).same_type().same_dims(); + return inputs.front(); } - template - static auto reflect(Self& self, F f) + argument compute(const shape& output_shape, std::vector args) const { - return pack(f(self.max_val, "max"), f(self.min_val, "min")); + argument result{output_shape}; + visit_all(result, args[0], args[1], args[2])([&](auto output, auto x, auto min, auto max) { + par_for(output_shape.elements(), + [&](auto i) { output[i] = std::min(std::max(min[i], x[i]), max[i]); }); + }); + + return result; } }; diff --git a/src/include/migraphx/op/common.hpp b/src/include/migraphx/op/common.hpp index 49c7fb46f2b4b15acb5ac18d40385c1256271ff4..8d9cfc9d89aed36c4e7635b31a9e04fb3dabc404 100644 --- a/src/include/migraphx/op/common.hpp +++ b/src/include/migraphx/op/common.hpp @@ -1,15 +1,9 @@ #ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP #define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP -#include -#include -#include -#include -#include -#include -#include +#include +#include #include -#include #include namespace migraphx { @@ -23,6 +17,15 @@ enum padding_mode_t valid }; +// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling. +// Used in pooling and roialign operators. +enum class pooling_mode +{ + average, + max, + lpnorm +}; + // indicate rnn computation direction enum class rnn_direction { @@ -31,6 +34,7 @@ enum class rnn_direction bidirectional, }; +std::ostream& operator<<(std::ostream& os, pooling_mode v); std::ostream& operator<<(std::ostream& os, rnn_direction v); } // namespace op diff --git a/src/include/migraphx/op/concat.hpp b/src/include/migraphx/op/concat.hpp index 94ca5d972dff9c803d5986a41b752a0399e32b63..0556b0a6debc21ce619527807d06acbe4285ceeb 100644 --- a/src/include/migraphx/op/concat.hpp +++ b/src/include/migraphx/op/concat.hpp @@ -2,15 +2,18 @@ #define MIGRAPHX_GUARD_OPERATORS_CONCAT_HPP #include -#include #include #include #include #include #include #include +#include +#include +#include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -26,23 +29,29 @@ struct concat return pack(f(self.axis, "axis")); } + value attributes() const + { + value normalize; + normalize["axis"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + std::string name() const { return "concat"; } std::vector compute_offsets(const shape& output_shape, const std::vector& args) const { - auto n_dims = args[0].get_shape().lens().size(); - std::size_t axis_index = (axis < 0) ? axis + n_dims : axis; + auto n_dims = args[0].get_shape().lens().size(); std::vector offsets; std::vector offset(n_dims, 0); - offset[axis_index] = 0; + offset[axis] = 0; for(const auto& arg : args) { offsets.push_back(output_shape.index(offset)); - offset[axis_index] += arg.get_shape().lens()[axis_index]; + offset[axis] += arg.get_shape().lens()[axis]; } return offsets; } - shape compute_shape(std::vector inputs) const + shape normalize_compute_shape(std::vector inputs) const { if(inputs.empty()) { @@ -51,10 +60,9 @@ struct concat const auto& first_shape_lens = inputs.front().lens(); const auto& type = inputs.front().type(); - std::size_t axis_index = (axis < 0) ? (first_shape_lens.size() + axis) : axis; for(std::size_t l = 0; l < first_shape_lens.size(); l++) { - if(l != axis_index) + if(l != axis) { if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) { return s.lens()[l] == first_shape_lens[l]; @@ -68,12 +76,12 @@ struct concat for(const auto& input : inputs) { const auto& lens = input.lens(); - new_dim_axis += lens[axis_index]; + new_dim_axis += lens[axis]; } std::vector new_lens; std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens)); - new_lens[axis_index] = new_dim_axis; - return {type, new_lens}; + new_lens[axis] = new_dim_axis; + return shape::from_permutation(type, new_lens, find_permutation(inputs)); } argument compute(const shape& output_shape, std::vector args) const { @@ -81,17 +89,12 @@ struct concat std::vector coffsets = compute_offsets(output_shape, args); for(std::size_t l = 0; l < args.size(); l++) { - auto argl = args[l]; - std::size_t nelements = argl.get_shape().elements(); + auto argl = args[l]; visit_all(result, argl)([&](auto output, auto input) { auto slice_shape = shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()}; auto slice = make_view(slice_shape, output.data() + coffsets[l]); - // cppcheck-suppress useStlAlgorithm - for(std::size_t i = 0; i < nelements; i++) - { - slice[i] = input[i]; - } + std::copy(input.begin(), input.end(), slice.begin()); }); } return result; diff --git a/src/include/migraphx/op/contiguous.hpp b/src/include/migraphx/op/contiguous.hpp old mode 100644 new mode 100755 index 234fd2f702f2d568ff0752cc189de35c5c6f424e..f2f2e01c6698ffeadaeb6758884ea0dcadc6cb91 --- a/src/include/migraphx/op/contiguous.hpp +++ b/src/include/migraphx/op/contiguous.hpp @@ -2,7 +2,6 @@ #define MIGRAPHX_GUARD_OPERATORS_CONTIGUOUS_HPP #include -#include #include #include #include @@ -28,6 +27,8 @@ struct contiguous shape compute_shape(std::vector inputs) const { check_shapes{inputs, *this}.has(1); + if(inputs.front().standard()) + return inputs.front(); auto lens = inputs.at(0).lens(); auto t = inputs.at(0).type(); return {t, lens}; @@ -43,6 +44,11 @@ struct contiguous }); return result; } + + auto apply() const + { + return [](auto x) { return x; }; + } }; } // namespace op diff --git a/src/include/migraphx/op/convert.hpp b/src/include/migraphx/op/convert.hpp index 9cbbaa622f28e462ce3baffa528523660b55fa98..4fbc025d3dfe9778fa915be996467a9722886515 100644 --- a/src/include/migraphx/op/convert.hpp +++ b/src/include/migraphx/op/convert.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -33,9 +32,19 @@ struct convert : unary return {target_type, inputs.at(0).lens(), inputs.at(0).strides()}; } + std::string point_op() const + { + return "${function:convert}<" + shape::cpp_type(target_type) + ">(${0})"; + } + auto apply() const { - return [](auto x) { return x; }; + auto type = target_type; + return [type](auto x) { + auto y = x; + shape::visit(type, [&](auto as) { y = std::min(std::max(as(x), as.min()), as.max()); }); + return y; + }; } convert(shape::type_t t) : target_type{t} {} diff --git a/src/include/migraphx/op/convolution.hpp b/src/include/migraphx/op/convolution.hpp old mode 100644 new mode 100755 index b35117d4f425a42789325be30f925ac76dbbbbb2..7ac93d3619165595c836000c81c8f9b671fb398f --- a/src/include/migraphx/op/convolution.hpp +++ b/src/include/migraphx/op/convolution.hpp @@ -3,13 +3,14 @@ #include #include -#include #include #include #include #include #include #include +#include +#include #include #include @@ -19,12 +20,12 @@ namespace op { struct convolution { - std::array padding = {{0, 0}}; - std::array stride = {{1, 1}}; - std::array dilation = {{1, 1}}; + std::vector padding = {0, 0}; + std::vector stride = {1, 1}; + std::vector dilation = {1, 1}; - padding_mode_t padding_mode = default_; int group = 1; + padding_mode_t padding_mode = default_; template static auto reflect(Self& self, F f) @@ -32,36 +33,68 @@ struct convolution return pack(f(self.padding, "padding"), f(self.stride, "stride"), f(self.dilation, "dilation"), - f(self.padding_mode, "padding_mode"), - f(self.group, "group")); + f(self.group, "group"), + f(self.padding_mode, "padding_mode")); } std::string name() const { return "convolution"; } - shape compute_shape(std::vector inputs) const + + void check_attribute_size() const + { + if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and + stride.size() == dilation.size())) + { + MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes"); + } + } + + value attributes() const { return {{"normalize_padding", "padding"}}; } + + shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4); + check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3); + check_attribute_size(); + // dim num of input and attribute should match + auto input_size = inputs[0].lens().size(); + auto padding_size = padding.size(); + if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2)) + { + MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!"); + } const shape& input = inputs.at(0); const shape& weights = inputs.at(1); - auto t = input.type(); - - return {t, - { - input.lens()[0], - weights.lens()[0], - std::size_t(std::max( - 1, - (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) + - 2 * padding[0]) / - stride[0] + - 1)), - std::size_t(std::max( - 1, - (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) + - 2 * padding[1]) / - stride[1] + - 1)), - }}; + size_t kdims = input_size - 2; + if(kdims != this->kdims()) + { + MIGRAPHX_THROW("convolution: input k-dims does not match attribute size"); + } + + if(input.lens().at(1) != (weights.lens().at(1) * group)) + MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers"); + + std::vector output_lens{input.lens()[0], weights.lens()[0]}; + + for(size_t i = 0; i < kdims; i++) + { + auto padding_factor = 2 * padding[i]; + if(padding_size == 2 * kdims) + padding_factor = padding[i] + padding[i + kdims]; + output_lens.push_back(std::size_t(std::max( + 1, + (input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + + padding_factor) / + stride[i] + + 1))); + } + + return inputs[0].with_lens(output_lens); + } + + size_t kdims() const + { + check_attribute_size(); + return stride.size(); } }; diff --git a/src/include/migraphx/op/cos.hpp b/src/include/migraphx/op/cos.hpp index 7d63fa86d66a6214ccda1f86597ded2021a69307..3b2d16dcb6e0a05bb19a7d0b7c7a502b0d4b3d98 100644 --- a/src/include/migraphx/op/cos.hpp +++ b/src/include/migraphx/op/cos.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include diff --git a/src/include/migraphx/op/cosh.hpp b/src/include/migraphx/op/cosh.hpp index 9a8fa8c73c330269911e7b5af3855b01c7590b81..4681fca3ac51dde19a93a69003005fed1b09816f 100644 --- a/src/include/migraphx/op/cosh.hpp +++ b/src/include/migraphx/op/cosh.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include diff --git a/src/include/migraphx/op/deconvolution.hpp b/src/include/migraphx/op/deconvolution.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b263e903a09e8cc68397b94357a11db14c5033a0 --- /dev/null +++ b/src/include/migraphx/op/deconvolution.hpp @@ -0,0 +1,158 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_DECONVOLUTION_HPP +#define MIGRAPHX_GUARD_OPERATORS_DECONVOLUTION_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct deconvolution +{ + std::vector padding = {0, 0}; + std::vector stride = {1, 1}; + std::vector dilation = {1, 1}; + + padding_mode_t padding_mode = default_; + int group = 1; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.padding, "padding"), + f(self.stride, "stride"), + f(self.dilation, "dilation"), + f(self.padding_mode, "padding_mode"), + f(self.group, "group")); + } + + std::string name() const { return "deconvolution"; } + + void check_attribute_size() const + { + if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and + stride.size() == dilation.size())) + { + MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes"); + } + } + + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3); + + const shape& input = inputs.at(0); + const shape& weights = inputs.at(1); + size_t kdims = input.lens().size() - 2; + if(kdims != this->kdims()) + { + MIGRAPHX_THROW("deconvolution: input k-dims does not match attribute size"); + } + + std::vector output_lens{input.lens()[0], weights.lens()[1]}; + + for(size_t i = 0; i < kdims; i++) + { + output_lens.push_back(std::size_t(std::max( + 1, + stride[i] * (input.lens()[i + 2] - 1) + + ((weights.lens()[i + 2] - 1) * dilation[i] + 1) - 2 * padding[i]))); + } + return inputs[0].with_lens(output_lens); + } + + argument compute(shape output_shape, std::vector args) const + { + argument result{output_shape}; + auto kdims = this->kdims(); + visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) { + using type = typename decltype(output)::value_type; + + std::fill(output.begin(), output.end(), type{0}); + + auto in_lens = input.get_shape().lens(); + auto in_n = in_lens[0]; + auto in_c = in_lens[1]; + + auto wei = weights.get_shape().lens(); + auto wei_n = wei[0]; + auto wei_c = wei[1]; + + auto out_lens = output_shape.lens(); + + std::vector win_size{in_c}; + std::copy(in_lens.begin() + 2, in_lens.end(), std::back_inserter(win_size)); + std::copy(wei.begin() + 2, wei.end(), std::back_inserter(win_size)); + shape win_shape{output_shape.type(), win_size}; + + par_dfor(in_n, wei_c)([&](int o, int k) { + shape_for_each(win_shape, [&](auto idx_win) { + const int w = idx_win[0]; + + auto input_dims_start = idx_win.begin() + 1; + auto wei_dims_start = idx_win.begin() + kdims + 1; + + std::vector win_start; + for(std::size_t n = 0; n < kdims; ++n) + { + win_start.push_back(std::ptrdiff_t(*(input_dims_start + n) * stride[n]) - + std::ptrdiff_t(padding[n])); + } + + const int group_id = w / (wei_n / group); + const int in_ch = group_id * wei_c + k; + + std::vector idx_out{o, in_ch}; + + for(size_t n = 0; n < kdims; n++) + { + idx_out.push_back(win_start[n] + *(wei_dims_start + n) * dilation[n]); + } + + std::vector idx_wei{w, k}; + std::copy(wei_dims_start, idx_win.end(), std::back_inserter(idx_wei)); + + std::vector idx_in{o, w}; + std::copy(input_dims_start, wei_dims_start, std::back_inserter(idx_in)); + + if(std::all_of( + idx_out.begin() + 2, idx_out.end(), [&](auto ii) { return ii >= 0; }) and + std::equal(idx_out.begin() + 2, + idx_out.end(), + out_lens.begin() + 2, + out_lens.end(), + std::less{})) + { + output(idx_out.begin(), idx_out.end()) += + input(idx_in.begin(), idx_in.end()) * + weights(idx_wei.begin(), idx_wei.end()); + } + }); + }); + }); + return result; + } + + size_t kdims() const + { + check_attribute_size(); + return stride.size(); + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/dequantizelinear.hpp b/src/include/migraphx/op/dequantizelinear.hpp new file mode 100644 index 0000000000000000000000000000000000000000..44d56b3f05a7c0d7cc168663df60486daf9ae54d --- /dev/null +++ b/src/include/migraphx/op/dequantizelinear.hpp @@ -0,0 +1,62 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP +#define MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct dequantizelinear +{ + std::string name() const { return "dequantizelinear"; } + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.same_dims(); + return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()}; + } + + argument compute(const shape& output_shape, std::vector args) const + { + auto x = args.at(0); + auto x_scale = args.at(1); + std::vector zeros(output_shape.bytes(), 0); + argument x_zero_point{{x.get_shape().type(), output_shape.lens()}, zeros.data()}; + if(args.size() == 3) + { + x_zero_point = args.at(2); + } + + argument result{output_shape}; + visit_all(x, x_zero_point)([&](auto input, auto zero_pts) { + visit_all(result, x_scale)([&](auto output, auto scales) { + par_for(output_shape.elements(), [&](auto i) { + output[i] = static_cast(static_cast(input[i]) - + static_cast(zero_pts[i])) * + scales[i]; + }); + }); + }); + + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/div.hpp b/src/include/migraphx/op/div.hpp old mode 100644 new mode 100755 index faae6353925dbb58e869b7f9af3ca38ba35d3d70..46f4707055160acf449f695cdf28670bfa4a09f1 --- a/src/include/migraphx/op/div.hpp +++ b/src/include/migraphx/op/div.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -19,6 +18,7 @@ namespace op { struct div : binary
{ + std::string point_function() const { return "/"; } auto apply() const { return [](auto x, auto y) { return x / y; }; diff --git a/src/include/migraphx/op/dot.hpp b/src/include/migraphx/op/dot.hpp index c5a70412ae7f4c8270afc854becbe06384c3ad38..af338bd5f9e7693eb6067f20a5cefff6c24576f2 100644 --- a/src/include/migraphx/op/dot.hpp +++ b/src/include/migraphx/op/dot.hpp @@ -2,13 +2,13 @@ #define MIGRAPHX_GUARD_OPERATORS_DOT_HPP #include -#include #include #include #include #include #include #include +#include #include #include @@ -18,19 +18,10 @@ namespace op { struct dot { - float alpha = 1.0; - float beta = 1.0; - - template - static auto reflect(Self& self, F f) - { - return pack(f(self.alpha, "alpha"), f(self.beta, "beta")); - } - std::string name() const { return "dot"; } shape compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.same_type(); + check_shapes{inputs, *this}.same_type().has(2); const shape& a = inputs.at(0); const shape& b = inputs.at(1); auto t = a.type(); @@ -58,15 +49,16 @@ struct dot auto out_lens = a.lens(); out_lens[dim_1] = b.lens()[dim_1]; - if(inputs.size() == 3 && out_lens != inputs.at(2).lens()) - { - MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" + - to_string_range(inputs.at(2).lens()) + - "}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}"); - } - return {t, out_lens}; } + + argument compute(shape output_shape, std::vector args) const + { + argument result = argument{output_shape}; + visit_all(result, args[0], args[1])( + [&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); }); + return result; + } }; } // namespace op diff --git a/src/include/migraphx/op/elu.hpp b/src/include/migraphx/op/elu.hpp old mode 100644 new mode 100755 index bbf73d8a70a1500a595ff2aa5589e7f6a82fbeff..0f33682daf7102af02783d547f9570e0df904be2 --- a/src/include/migraphx/op/elu.hpp +++ b/src/include/migraphx/op/elu.hpp @@ -2,7 +2,6 @@ #define MIGRAPHX_GUARD_OPERATORS_ELU_HPP #include -#include #include #include #include @@ -19,7 +18,7 @@ namespace op { struct elu { std::string name() const { return "elu"; } - float alpha; + float alpha = 1; shape compute_shape(std::vector inputs) const { check_shapes{inputs, *this}.has(1); diff --git a/src/include/migraphx/op/equal.hpp b/src/include/migraphx/op/equal.hpp new file mode 100755 index 0000000000000000000000000000000000000000..25a49954a191a796b1c97d41cfd1ccd492a43c0a --- /dev/null +++ b/src/include/migraphx/op/equal.hpp @@ -0,0 +1,33 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_EQUAL_HPP +#define MIGRAPHX_GUARD_OPERATORS_EQUAL_HPP + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct equal : binary +{ + value attributes() const + { + auto a = base_attributes(); + a["commutative"] = true; + return a; + } + std::string point_function() const { return "=="; } + auto apply() const + { + return [](auto x, auto y) { return float_equal(x, y); }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/exp.hpp b/src/include/migraphx/op/exp.hpp index dc1afae7b29899af72c63734e490828b6de56db6..e2ebb94031d98e07b55dde792b4f08eeac20bf86 100644 --- a/src/include/migraphx/op/exp.hpp +++ b/src/include/migraphx/op/exp.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include diff --git a/src/include/migraphx/op/flatten.hpp b/src/include/migraphx/op/flatten.hpp old mode 100644 new mode 100755 index 16f135255c67daff90f103a5ca74c0b8bb3b0c0b..53daad5ef2dc50bb2ed63a6bac8966f43aa937cc --- a/src/include/migraphx/op/flatten.hpp +++ b/src/include/migraphx/op/flatten.hpp @@ -2,13 +2,15 @@ #define MIGRAPHX_GUARD_OPERATORS_FLATTEN_HPP #include -#include #include #include #include #include #include #include +#include +#include +#include #include #include @@ -18,7 +20,7 @@ namespace op { struct flatten { - uint64_t axis = 0; + int64_t axis = 1; template static auto reflect(Self& self, F f) @@ -26,16 +28,19 @@ struct flatten return pack(f(self.axis, "axis")); } + value attributes() const + { + value normalize; + normalize["axis"] = + value::array{normalize_attribute::include_min, normalize_attribute::include_max}; + return {{"normalize_axes", normalize}}; + } + std::string name() const { return "flatten"; } - shape compute_shape(std::vector inputs) const + shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs}.has(1); + check_shapes{inputs, *this}.has(1).standard(); auto&& lens = inputs.front().lens(); - - if(axis > lens.size()) - { - MIGRAPHX_THROW("axis for flatten must be less than tensor rank"); - } auto x = std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{}); auto y = @@ -44,7 +49,7 @@ struct flatten } argument compute(shape output_shape, std::vector args) const { - return {std::move(output_shape), std::move(args.front().data)}; + return args[0].reshape(output_shape); } std::ptrdiff_t output_alias(const std::vector&) const { return 0; } }; diff --git a/src/include/migraphx/op/gather.hpp b/src/include/migraphx/op/gather.hpp index 998ca190ea8a3ee029f26b03ef79eea9a6647aae..3e979d445bad5a53d084778a40f47578c5567022 100644 --- a/src/include/migraphx/op/gather.hpp +++ b/src/include/migraphx/op/gather.hpp @@ -2,13 +2,14 @@ #define MIGRAPHX_GUARD_OPERATORS_GATHER_HPP #include -#include #include #include #include #include #include #include +#include +#include #include #include @@ -18,7 +19,7 @@ namespace op { struct gather { - int axis = 0; + int64_t axis = 0; template static auto reflect(Self& self, F f) @@ -26,27 +27,25 @@ struct gather return pack(f(self.axis, "axis")); } + value attributes() const + { + value normalize; + normalize["axis"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + std::string name() const { return "gather"; } - shape compute_shape(std::vector inputs) const + shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(2).standard(); + check_shapes{inputs, *this}.has(2); auto lens = inputs[0].lens(); - int n_dim = static_cast(lens.size()); - if(axis >= n_dim || axis < -n_dim) - { - MIGRAPHX_THROW("Gather: axis is out of range."); - } - - // negative axis means counting dimensions from back - int axis_index = (axis < 0) ? (n_dim + axis) : axis; - auto type = inputs[0].type(); - lens.erase(lens.begin() + axis_index); + lens.erase(lens.begin() + axis); if(!inputs[1].scalar()) { auto ind_lens = inputs[1].lens(); - lens.insert(lens.begin() + axis_index, ind_lens.begin(), ind_lens.end()); + lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end()); } // for scalar output @@ -62,10 +61,8 @@ struct gather { argument result{output_shape}; // negative axis means counting dimensions from back - auto lens = args[0].get_shape().lens(); - int axis_index = (axis < 0) ? static_cast(lens.size() + axis) : axis; - - std::size_t axis_dim_size = lens[axis_index]; + auto lens = args[0].get_shape().lens(); + std::size_t axis_dim_size = lens[axis]; // max dimension in axis visit_all(result, args[0])([&](auto output, auto data) { args[1].visit([&](auto indices) { @@ -73,18 +70,18 @@ struct gather { auto in_index = indices.front(); in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; - output[0] = data[indices.front()]; + output[0] = data[in_index]; } else { - auto out_lens = data.get_shape().lens(); - out_lens[axis_index] = indices.get_shape().elements(); + auto out_lens = data.get_shape().lens(); + out_lens[axis] = indices.get_shape().elements(); migraphx::shape out_comp_shape{data.get_shape().type(), out_lens}; shape_for_each(out_comp_shape, [&](const auto& out_idx) { - auto data_idx = out_idx; - auto in_index = indices[data_idx[axis_index]]; - in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; - data_idx[axis_index] = in_index; + auto data_idx = out_idx; + auto in_index = indices[data_idx[axis]]; + in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; + data_idx[axis] = in_index; output[out_comp_shape.index(out_idx.begin(), out_idx.end())] = data(data_idx.begin(), data_idx.end()); }); diff --git a/src/include/migraphx/op/gathernd.hpp b/src/include/migraphx/op/gathernd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2b954787836785f720ff226a2cb9a66a78666ab8 --- /dev/null +++ b/src/include/migraphx/op/gathernd.hpp @@ -0,0 +1,131 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP +#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct gathernd +{ + int batch_dims = 0; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.batch_dims, "batch_dims")); + } + + std::string name() const { return "gathernd"; } + + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(2); + auto r = inputs.front().lens().size(); + auto q = inputs.back().lens().size(); + auto k = inputs.back().lens().back(); + if(k > r - batch_dims) + { + MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) + + " cannot be used to access data of rank " + + std::to_string(r - batch_dims)); + } + auto indices_lens_iter = inputs.back().lens().begin(); + auto output_lens_size = q + r - k - batch_dims - 1; + std::vector output_lens(output_lens_size); + std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin()); + if(k < r - batch_dims) + { + auto data_lens = inputs.front().lens(); + std::copy( + data_lens.begin() + batch_dims + k, data_lens.end(), output_lens.begin() + q - 1); + } + shape output_shape{inputs.front().type(), output_lens}; + return output_shape; + } + + argument compute(const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + visit_all(result, args[0])([&](auto output, auto data) { + args[1].visit([&](auto indices) { + auto indices_shape = indices.get_shape(); + auto indices_shape_lens = indices_shape.lens(); + auto data_shape = data.get_shape(); + auto data_shape_lens = data_shape.lens(); + auto k = indices_shape.lens().back(); + const auto num_slice_dims = k; + std::size_t num_slices = std::accumulate(indices_shape_lens.begin(), + indices_shape_lens.end() - 1, + 1, + std::multiplies()); + std::size_t slice_size = std::accumulate(data_shape_lens.begin() + k + batch_dims, + data_shape_lens.end(), + 1, + std::multiplies()); + std::size_t num_batches = std::accumulate(data_shape_lens.begin(), + data_shape_lens.begin() + batch_dims, + 1, + std::multiplies()); + std::size_t data_batch_stride = + std::accumulate(data_shape_lens.begin() + batch_dims, + data_shape_lens.end(), + 1, + std::multiplies()); + auto num_slices_per_batch = num_slices / num_batches; + + std::vector sizes_from_slice_dims(num_slice_dims); + { + auto running_product = slice_size; + for(std::size_t i = 0; i < num_slice_dims; ++i) + { + sizes_from_slice_dims[num_slice_dims - 1 - i] = running_product; + running_product *= data_shape_lens[batch_dims + num_slice_dims - 1 - i]; + } + } + + std::vector input_slice_offsets(num_slices); + par_for(num_slices, [&](const auto i) { + std::size_t batch_idx = i / num_slices_per_batch; + + auto slice_indices = indices.begin() + (i * num_slice_dims); + std::size_t relative_slice_offset = 0; + for(size_t dim_idx = 0; dim_idx < num_slice_dims; ++dim_idx) + { + int64_t index = *(slice_indices + dim_idx); + const std::size_t input_dim_idx = batch_dims + dim_idx; + const auto input_dim = data_shape_lens[input_dim_idx]; + if(index < -static_cast(input_dim) or + index >= static_cast(input_dim)) + MIGRAPHX_THROW("GatherND: index " + std::to_string(index) + + " is out of bounds for dim of len " + + std::to_string(input_dim)); + if(index < 0) + index += input_dim; + + relative_slice_offset += index * sizes_from_slice_dims[dim_idx]; + } + + input_slice_offsets[i] = + (batch_idx * data_batch_stride) + relative_slice_offset; + }); + + par_for(num_slices * slice_size, [&](const auto i) { + auto slice_offset = input_slice_offsets[i / slice_size]; + output[i] = data[slice_offset + i % slice_size]; + }); + }); + }); + + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/get_tuple_elem.hpp b/src/include/migraphx/op/get_tuple_elem.hpp new file mode 100644 index 0000000000000000000000000000000000000000..31565c0032588628dd4e27aa5b9b22582cb35473 --- /dev/null +++ b/src/include/migraphx/op/get_tuple_elem.hpp @@ -0,0 +1,56 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP +#define MIGRAPHX_GUARD_OPERATORS_GET_TUPLE_ELEM_HPP + +#include "migraphx/errors.hpp" +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct get_tuple_elem +{ + std::size_t index = 0; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.index, "index")); + } + + std::string name() const { return "get_tuple_elem"; } + + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(1).tuple_type(); + const auto& sub_shapes = inputs.at(0).sub_shapes(); + if(index >= sub_shapes.size()) + { + MIGRAPHX_THROW("GET_TUPLE_ELEM: index " + std::to_string(index) + " is out of range " + + std::to_string(sub_shapes.size())); + } + + return sub_shapes.at(index); + } + + argument compute(const shape&, std::vector args) const + { + assert(args.size() == 1); + auto vec_args = args.at(0).get_sub_objects(); + assert(index < vec_args.size()); + return vec_args.at(index); + } + + std::ptrdiff_t output_alias(const std::vector&) const { return 0; } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/greater.hpp b/src/include/migraphx/op/greater.hpp new file mode 100755 index 0000000000000000000000000000000000000000..c0a55bdf34ffe5d50534d5d69eecced62abbfed1 --- /dev/null +++ b/src/include/migraphx/op/greater.hpp @@ -0,0 +1,26 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_GREATER_HPP +#define MIGRAPHX_GUARD_OPERATORS_GREATER_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct greater : binary +{ + std::string point_function() const { return ">"; } + auto apply() const + { + return [](auto x, auto y) { return x > y; }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/gru.hpp b/src/include/migraphx/op/gru.hpp index c27dcdd751ebc29c912d896888823f5f9802317f..c0fc5ea5f2efb5fdfac05fe2b26b828af7ce587d 100644 --- a/src/include/migraphx/op/gru.hpp +++ b/src/include/migraphx/op/gru.hpp @@ -3,9 +3,9 @@ #include #include +#include #include #include -#include #include #include #include diff --git a/src/include/migraphx/op/identity.hpp b/src/include/migraphx/op/identity.hpp old mode 100644 new mode 100755 index 85b014d9a3906d01c1a2fc3c9bca12807c7ecc7a..816bdd2f81832168559337497d85689f4cab1e9f --- a/src/include/migraphx/op/identity.hpp +++ b/src/include/migraphx/op/identity.hpp @@ -2,7 +2,6 @@ #define MIGRAPHX_GUARD_OPERATORS_IDENTITY_HPP #include -#include #include #include #include @@ -20,10 +19,8 @@ struct identity { std::string name() const { return "identity"; } shape compute_shape(std::vector inputs) const { return inputs.at(0); } - argument compute(shape output_shape, std::vector args) const - { - return {std::move(output_shape), std::move(args.at(0).data)}; - } + argument compute(shape, std::vector args) const { return args[0]; } + std::ptrdiff_t output_alias(const std::vector&) const { return 0; } }; diff --git a/src/include/migraphx/op/if_op.hpp b/src/include/migraphx/op/if_op.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fa9dff5b2d896402c6e33d2a7099a75a4d63e545 --- /dev/null +++ b/src/include/migraphx/op/if_op.hpp @@ -0,0 +1,74 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_IF_OP_HPP +#define MIGRAPHX_GUARD_OPERATORS_IF_OP_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct if_op +{ + std::string name() const { return "if"; } + + shape compute_shape(const std::vector& inputs, std::vector mods) const + { + check_shapes{inputs, *this}.standard(); + if(mods.size() != 2) + { + MIGRAPHX_THROW("IF: operator should have two submodules."); + } + + auto out_shapes0 = mods[0]->get_output_shapes(); + auto out_shapes1 = mods[1]->get_output_shapes(); + if(not std::equal( + out_shapes1.begin(), out_shapes1.end(), out_shapes0.begin(), out_shapes0.end())) + { + MIGRAPHX_THROW("IF: output shapes of submodules must be the same."); + } + + return {out_shapes0}; + } + + argument compute(const shape&, + const std::vector& args, + const std::vector& mods, + const std::function( + module_ref&, const std::unordered_map&)>& run) const + { + auto cond = args.front().at(); + module_ref mod = cond ? mods[0] : mods[1]; + std::unordered_map params; + + std::set pnames; + for(const auto& smod : mods) + { + auto names = smod->get_parameter_names(); + pnames.insert(names.begin(), names.end()); + } + + assert(pnames.size() < args.size()); + std::transform(pnames.begin(), + pnames.end(), + args.begin() + 1, + std::inserter(params, params.end()), + [](auto&& name, auto&& arg) { return std::make_pair(name, arg); }); + + auto results = run(mod, params); + return argument{results}; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/im2col.hpp b/src/include/migraphx/op/im2col.hpp index a96d0142d2346d92b0813d3d9b5eb44bd18e9084..610b28b83a0ee143f9c15803a9de8be40ff3a42b 100644 --- a/src/include/migraphx/op/im2col.hpp +++ b/src/include/migraphx/op/im2col.hpp @@ -2,12 +2,8 @@ #define MIGRAPHX_GUARD_OPERATORS_IM2COL_HPP #include -#include #include -#include -#include -#include -#include +#include #include #include #include @@ -18,9 +14,9 @@ namespace op { struct im2col { - std::array padding = {{0, 0}}; - std::array stride = {{1, 1}}; - std::array dilation = {{1, 1}}; + std::vector padding{0, 0}; + std::vector stride{1, 1}; + std::vector dilation{1, 1}; padding_mode_t padding_mode = default_; @@ -35,7 +31,9 @@ struct im2col std::string name() const { return "im2col"; } - shape compute_shape(std::vector inputs) const + value attributes() const { return {{"normalize_padding", "padding"}}; } + + shape normalize_compute_shape(std::vector inputs) const { auto input = inputs[0]; auto weights = inputs[1]; @@ -46,17 +44,24 @@ struct im2col check_shapes{inputs, *this}.has(2); if(batch_size != 1) MIGRAPHX_THROW("im2col only support batch_size 1"); + + auto padding_h = 2 * padding[0]; + auto padding_w = 2 * padding[1]; + if(padding.size() == 2 * stride.size()) + { + padding_h = padding[0] + padding[2]; + padding_w = padding[1] + padding[3]; + } auto output_height = std::size_t(std::max( 1, - (input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) / - stride[0] + + (input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + padding_h) / stride[0] + 1)); auto output_width = std::size_t(std::max( 1, - (input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + 2 * padding[1]) / - stride[1] + + (input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + padding_w) / stride[1] + 1)); - auto channels_col = kernel_height * kernel_width * input_channels; + + auto channels_col = kernel_height * kernel_width * input_channels; return {input.type(), {output_height * output_width, channels_col}}; } }; diff --git a/src/include/migraphx/op/isnan.hpp b/src/include/migraphx/op/isnan.hpp new file mode 100644 index 0000000000000000000000000000000000000000..28e1a5908755cdecb584be0f842b4e9d9c832a70 --- /dev/null +++ b/src/include/migraphx/op/isnan.hpp @@ -0,0 +1,30 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_ISNAN_HPP +#define MIGRAPHX_GUARD_OPERATORS_ISNAN_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct isnan : unary +{ + auto apply() const + { + return [](auto x) { return std::isnan(x); }; + } + + std::string name() const { return "isnan"; } + + shape compute_shape(std::vector inputs) const + { + return unary::compute_shape(std::move(inputs)).with_type(shape::bool_type); + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/leaky_relu.hpp b/src/include/migraphx/op/leaky_relu.hpp index ef7b70359e425d42ce8dc14f2ee1fe53da523bef..e58ace62f13561d9d51d2512a76b0b9a6efe74a7 100644 --- a/src/include/migraphx/op/leaky_relu.hpp +++ b/src/include/migraphx/op/leaky_relu.hpp @@ -2,7 +2,6 @@ #define MIGRAPHX_GUARD_OPERATORS_LEAKY_RELU_HPP #include -#include #include #include #include @@ -18,7 +17,7 @@ namespace op { struct leaky_relu { - float alpha; + float alpha = 0.01; template static auto reflect(Self& self, F f) diff --git a/src/include/migraphx/op/less.hpp b/src/include/migraphx/op/less.hpp new file mode 100755 index 0000000000000000000000000000000000000000..59bf41d7445afa444fb6e79bf730ff8f34598375 --- /dev/null +++ b/src/include/migraphx/op/less.hpp @@ -0,0 +1,26 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_LESS_HPP +#define MIGRAPHX_GUARD_OPERATORS_LESS_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct less : binary +{ + std::string point_function() const { return "<"; } + auto apply() const + { + return [](auto x, auto y) { return x < y; }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/load.hpp b/src/include/migraphx/op/load.hpp old mode 100644 new mode 100755 index 7569bb30ba65d9f16f54b41690a8cbfdd292eaac..5c2d61692b61cb9d42671bba7c4244b4af0fd55a --- a/src/include/migraphx/op/load.hpp +++ b/src/include/migraphx/op/load.hpp @@ -2,13 +2,11 @@ #define MIGRAPHX_GUARD_OPERATORS_LOAD_HPP #include -#include #include -#include -#include -#include -#include +#include +#include #include +#include #include #include @@ -30,15 +28,16 @@ struct load std::string name() const { return "load"; } shape compute_shape(const std::vector& inputs) const { - check_shapes{inputs}.has(1); + check_shapes{inputs, *this}.has(1); return s; } argument compute(const shape&, const std::vector& args) const { if((offset + s.bytes()) > args[0].get_shape().bytes()) MIGRAPHX_THROW("Load access is out of bounds"); - return {s, args[0].data() + offset}; + return argument{s, args[0].data() + offset}; } + lifetime get_lifetime() const { return lifetime::borrow; } std::ptrdiff_t output_alias(const std::vector&) const { return 0; } friend std::ostream& operator<<(std::ostream& os, const load& op) diff --git a/src/include/migraphx/op/log.hpp b/src/include/migraphx/op/log.hpp index 7a46ee46ebcd50d853ede1787543e7aa8681e088..fdd3b767fd4877860d2b92f8e08f98da812e87d8 100644 --- a/src/include/migraphx/op/log.hpp +++ b/src/include/migraphx/op/log.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include diff --git a/src/include/migraphx/op/logical_and.hpp b/src/include/migraphx/op/logical_and.hpp new file mode 100755 index 0000000000000000000000000000000000000000..a8fe6fe0c90734d1dfe6e6b3b37de4a087d66d41 --- /dev/null +++ b/src/include/migraphx/op/logical_and.hpp @@ -0,0 +1,26 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_LOGICAL_AND_HPP +#define MIGRAPHX_GUARD_OPERATORS_LOGICAL_AND_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct logical_and : binary +{ + std::string point_function() const { return "&&"; } + auto apply() const + { + return [](auto x, auto y) { return static_cast(x) and static_cast(y); }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/logical_or.hpp b/src/include/migraphx/op/logical_or.hpp new file mode 100755 index 0000000000000000000000000000000000000000..86c66aa9fdecd2c8511ed70501fd75444141728c --- /dev/null +++ b/src/include/migraphx/op/logical_or.hpp @@ -0,0 +1,26 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_LOGICAL_OR_HPP +#define MIGRAPHX_GUARD_OPERATORS_LOGICAL_OR_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct logical_or : binary +{ + std::string point_function() const { return "||"; } + auto apply() const + { + return [](auto x, auto y) { return static_cast(x) or static_cast(y); }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/logical_xor.hpp b/src/include/migraphx/op/logical_xor.hpp new file mode 100755 index 0000000000000000000000000000000000000000..3b02c375d9832899bff62a7affdba7b009775a6d --- /dev/null +++ b/src/include/migraphx/op/logical_xor.hpp @@ -0,0 +1,26 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_LOGICAL_XOR_HPP +#define MIGRAPHX_GUARD_OPERATORS_LOGICAL_XOR_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct logical_xor : binary +{ + std::string point_function() const { return "^"; } + auto apply() const + { + return [](auto x, auto y) { return static_cast(x) xor static_cast(y); }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/logsoftmax.hpp b/src/include/migraphx/op/logsoftmax.hpp index 0af3b5814af0a00e76c4f741a4709b61cdd16689..75bcdd4c0653cc22174b6a36fabb75f54319d1dc 100644 --- a/src/include/migraphx/op/logsoftmax.hpp +++ b/src/include/migraphx/op/logsoftmax.hpp @@ -1,8 +1,9 @@ #ifndef MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP #define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP -#include #include +#include +#include #include namespace migraphx { @@ -11,7 +12,7 @@ namespace op { struct logsoftmax { - int axis = 1; + int64_t axis = 1; template static auto reflect(Self& self, F f) @@ -19,16 +20,25 @@ struct logsoftmax return pack(f(self.axis, "axis")); } + value attributes() const + { + value normalize; + normalize["axis"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + std::string name() const { return "logsoftmax"; } - shape compute_shape(std::vector inputs) const + shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs}.has(1).standard(); - if(axis < 0 || axis >= inputs[0].lens().size()) + if(inputs.at(0).packed()) + { + return inputs.at(0); + } + else { - MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) + - " is out of range"); + auto lens = inputs.at(0).lens(); + return {inputs.at(0).type(), lens}; } - return inputs.at(0); } auto output() const diff --git a/src/include/migraphx/op/loop.hpp b/src/include/migraphx/op/loop.hpp new file mode 100644 index 0000000000000000000000000000000000000000..93b02befe5c617fb281ede5be4230661abd27007 --- /dev/null +++ b/src/include/migraphx/op/loop.hpp @@ -0,0 +1,142 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_LOOP_HPP +#define MIGRAPHX_GUARD_OPERATORS_LOOP_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct loop +{ + int64_t max_iterations = 10; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.max_iterations, "max_iterations")); + } + + std::string name() const { return "loop"; } + + shape compute_shape(const std::vector& inputs, std::vector mods) const + { + check_shapes{inputs, *this}.standard(); + if(mods.size() != 1) + { + MIGRAPHX_THROW("LOOP: operator should have one submodule."); + } + + const auto& mod = mods.front(); + auto mod_out_shapes = mod->get_output_shapes(); + auto dep_param_num = inputs.size() - 2; + + // first item of the mod output shapes is condition used in loop, + // which is not needed to compute output shape + mod_out_shapes.erase(mod_out_shapes.begin()); + std::vector ins_out_shapes(mod_out_shapes.begin(), + mod_out_shapes.begin() + dep_param_num); + mod_out_shapes.erase(mod_out_shapes.begin(), mod_out_shapes.begin() + dep_param_num); + for(const auto& out_s : mod_out_shapes) + { + auto lens = out_s.lens(); + lens.insert(lens.begin(), max_iterations); + ins_out_shapes.push_back({out_s.type(), lens}); + } + + return {ins_out_shapes}; + } + + struct ref_loop + { + int64_t max_iterations = 0; + + template + void copy(context&, const argument& src, T& dst) const + { + dst = *src.cast(); + } + + template + void copy(context&, T src, const argument& dst) const + { + *dst.cast() = src; + } + + void append(const std::vector& iter_state, + const std::vector& concatenated_outputs, + int iter) const + { + assert(iter_state.size() == concatenated_outputs.size()); + for(auto i : range(iter_state.size())) + { + const auto& iter_stat = iter_state.at(i); + const auto& scan_out = concatenated_outputs.at(i); + + auto* in_data = iter_stat.data(); + auto* out_data = scan_out.data(); + std::size_t out_size = iter_stat.get_shape().bytes(); + assert((iter + 1) * out_size <= scan_out.get_shape().bytes()); + std::copy(in_data, in_data + out_size, out_data + iter * out_size); + } + } + + void set_zero(context&, const std::vector& concatenated_outputs, int iter) const + { + if(iter >= max_iterations) + return; + + for(const auto& out : concatenated_outputs) + { + auto s = out.get_shape(); + auto size = s.bytes() / max_iterations; + std::fill(out.data() + iter * size, out.data() + max_iterations * size, 0); + } + } + + std::unordered_map get_output_params(const module&) const { return {}; } + }; + + argument compute(context& ctx, + const shape& out_shape, + const std::vector& args, + const std::vector& mods, + const std::function( + module_ref&, const std::unordered_map&)>& run) const + { + // wrap up the arguments vector, so ref and gpu impl are the same + auto cpy_args = args; + bool in_cond = args.at(1).at(); + bool cond = in_cond; + int64_t iter = 0; + // insert iter and cond used in the loop + auto s_cond = args.at(1).get_shape(); + auto s_iter = args.at(0).get_shape(); + cpy_args.push_back({s_iter, &iter}); + cpy_args.push_back({s_cond, &cond}); + cpy_args.insert(cpy_args.end(), args.begin() + 2, args.end()); + + // add cond and mod outputs to the argument list + cpy_args.push_back(argument(s_cond)); + cpy_args.push_back(argument(out_shape)); + + // run loop + return run_loop(ref_loop{max_iterations}, ctx, cpy_args, mods, run); + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/lrn.hpp b/src/include/migraphx/op/lrn.hpp index 793fbe5a7d9d52418d52d52b278773b743cf586e..f3482e27d7c8944418a345b6a477b28e91b8b3b6 100644 --- a/src/include/migraphx/op/lrn.hpp +++ b/src/include/migraphx/op/lrn.hpp @@ -2,7 +2,6 @@ #define MIGRAPHX_GUARD_OPERATORS_LRN_HPP #include -#include #include #include #include diff --git a/src/include/migraphx/op/lstm.hpp b/src/include/migraphx/op/lstm.hpp old mode 100644 new mode 100755 index f954cfaf431924c9c1f75f3f076be49b2857b391..b87ee4cfeadabb936e78f9fc48ac8a35ee607302 --- a/src/include/migraphx/op/lstm.hpp +++ b/src/include/migraphx/op/lstm.hpp @@ -4,6 +4,8 @@ #include #include #include +#include +#include #include #include #include @@ -31,6 +33,7 @@ struct lstm return pack(f(self.hidden_size, "hidden_size"), f(self.actv_funcs, "actv_func"), f(self.direction, "direction"), + f(self.clip, "clip"), f(self.input_forget, "input_forget")); } diff --git a/src/include/migraphx/op/max.hpp b/src/include/migraphx/op/max.hpp index 63fc9da34f6f358acdcccef0f16bce1c8697bea9..9d1d6f32d363392ce71f0cd4ce49aa236eab62ff 100644 --- a/src/include/migraphx/op/max.hpp +++ b/src/include/migraphx/op/max.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -19,6 +18,12 @@ namespace op { struct max : binary { + value attributes() const + { + auto a = base_attributes(); + a["commutative"] = true; + return a; + } auto apply() const { return [](auto x, auto y) { return std::max(x, y); }; diff --git a/src/include/migraphx/op/min.hpp b/src/include/migraphx/op/min.hpp index 90d63dfcef26a979096106221637a5ed74f1e16d..ba344580b30bcf8ce60ce0041d53cc2d190048ea 100644 --- a/src/include/migraphx/op/min.hpp +++ b/src/include/migraphx/op/min.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -19,6 +18,12 @@ namespace op { struct min : binary { + value attributes() const + { + auto a = base_attributes(); + a["commutative"] = true; + return a; + } auto apply() const { return [](auto x, auto y) { return std::min(x, y); }; diff --git a/src/include/migraphx/op/mul.hpp b/src/include/migraphx/op/mul.hpp old mode 100644 new mode 100755 index b281e46b3238cbfca58171876d47af7546fef017..5f5b5152f6b8d785a198e42b0c1c4041abcc3bd7 --- a/src/include/migraphx/op/mul.hpp +++ b/src/include/migraphx/op/mul.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -19,6 +18,13 @@ namespace op { struct mul : binary { + value attributes() const + { + auto a = base_attributes(); + a["commutative"] = true; + return a; + } + std::string point_function() const { return "*"; } auto apply() const { return [](auto x, auto y) { return x * y; }; diff --git a/src/include/migraphx/op/multibroadcast.hpp b/src/include/migraphx/op/multibroadcast.hpp old mode 100644 new mode 100755 index 50a5d4b5300d4c286eb50a4232b5ad965dc2cdf3..0037d39218e725d5741fb3175f85f6a1bb858b89 --- a/src/include/migraphx/op/multibroadcast.hpp +++ b/src/include/migraphx/op/multibroadcast.hpp @@ -2,13 +2,13 @@ #define MIGRAPHX_GUARD_OPERATORS_MULTIBROADCAST_HPP #include -#include #include #include #include #include #include #include +#include #include #include @@ -23,7 +23,7 @@ struct multibroadcast template static auto reflect(Self& self, F f) { - return pack(f(self.output_lens, "output_lens")); + return pack(f(self.output_lens, "out_lens")); } std::string name() const { return "multibroadcast"; } @@ -67,7 +67,7 @@ struct multibroadcast } argument compute(shape output_shape, std::vector args) const { - return {std::move(output_shape), std::move(args.at(0).data)}; + return args[0].reshape(output_shape); } std::ptrdiff_t output_alias(const std::vector&) const { return 0; } }; diff --git a/src/include/migraphx/op/multinomial.hpp b/src/include/migraphx/op/multinomial.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d481a198a567379528bd40a51b8dc9b84e5904a6 --- /dev/null +++ b/src/include/migraphx/op/multinomial.hpp @@ -0,0 +1,64 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_MULTINOMIAL_HPP +#define MIGRAPHX_GUARD_OPERATORS_MULTINOMIAL_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct multinomial +{ + shape::type_t dtype = shape::type_t::int32_type; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.dtype, "dtype")); + } + + std::string name() const { return "multinomial"; } + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(2).only_dims(2); + size_t sample_size = inputs.back().lens().back(); + + if(not contains({shape::int32_type, shape::int64_type}, dtype)) + MIGRAPHX_THROW( + "Multinomial: Invalid output type. Valid types are int32_type and int64_type."); + + return {dtype, {inputs.front().lens().front(), sample_size}}; + } + + argument compute(const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + size_t batch_size = output_shape.lens().front(); + size_t class_size = args[0].get_shape().lens().back(); + size_t sample_size = output_shape.lens().back(); + + visit_all(args[0], args[1])([&](auto cdf, auto dist) { + result.visit([&](auto output) { + par_for(batch_size * sample_size, [&](auto i) { + auto idx = args[1].get_shape().multi(i); + auto cdf_begin = cdf.begin() + (idx[0] * class_size); + auto cdf_end = cdf_begin + class_size; + auto sample_iter = + std::upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end))); + output[i] = std::distance(cdf_begin, sample_iter); + }); + }); + }); + + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/name.hpp b/src/include/migraphx/op/name.hpp index 214dbff00fa313fcff28f7f0eeedc689e4b22c2b..0630dfcf94388619146b692f9d5fbc4af5012c60 100644 --- a/src/include/migraphx/op/name.hpp +++ b/src/include/migraphx/op/name.hpp @@ -1,17 +1,8 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_NAME_HPP #define MIGRAPHX_GUARD_RTGLIB_NAME_HPP -#include -#include -#include -#include -#include -#include -#include -#include #include -#include -#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { diff --git a/src/include/migraphx/op/neg.hpp b/src/include/migraphx/op/neg.hpp old mode 100644 new mode 100755 index d2fd97566887cf406db4824959efe3194cef60ef..3407225a87c1c7427fc104f51faab2cd3fdda9da --- a/src/include/migraphx/op/neg.hpp +++ b/src/include/migraphx/op/neg.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -19,6 +18,7 @@ namespace op { struct neg : unary { + std::string point_function() const { return "-"; } auto apply() const { return [](auto x) { return -x; }; diff --git a/src/include/migraphx/op/nonmaxsuppression.hpp b/src/include/migraphx/op/nonmaxsuppression.hpp new file mode 100644 index 0000000000000000000000000000000000000000..225a8d08ce56c177b75fbaeaf7084ccddf78634f --- /dev/null +++ b/src/include/migraphx/op/nonmaxsuppression.hpp @@ -0,0 +1,235 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_NONMAXSUPPRESSION_HPP +#define MIGRAPHX_GUARD_OPERATORS_NONMAXSUPPRESSION_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct nonmaxsuppression +{ + bool center_point_box = false; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.center_point_box, "center_point_box")); + } + + std::string name() const { return "nonmaxsuppression"; } + + shape compute_shape(std::vector inputs) const + { + // requires at least 2 inputs + check_shapes{inputs, *this}.standard(); + check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3); + auto lens = inputs.front().lens(); + + // check input shape + if(lens[1] != inputs.at(1).lens()[2]) + { + MIGRAPHX_THROW("NonMaxSuppression: dimension mismatch between first and second input!"); + } + + std::vector out_lens(2); + out_lens.at(0) = lens.at(1); + out_lens.at(1) = 3; + return {shape::int64_type, out_lens}; + } + + struct box + { + std::array x; + std::array y; + + void sort() + { + std::sort(x.begin(), x.end()); + std::sort(y.begin(), y.end()); + } + + std::array& operator[](std::size_t i) { return i == 0 ? x : y; } + + float area() const + { + assert(std::is_sorted(x.begin(), x.end())); + assert(std::is_sorted(y.begin(), y.end())); + return (x[1] - x[0]) * (y[1] - y[0]); + } + }; + + template + box batch_box(const T* boxes, std::size_t bidx) const + { + box result{}; + const T* start = boxes + 4 * bidx; + if(center_point_box) + { + float half_width = start[2] / 2.0f; + float half_height = start[3] / 2.0f; + float x_center = start[0]; + float y_center = start[1]; + result.x = {x_center - half_width, x_center + half_width}; + result.y = {y_center - half_height, y_center + half_height}; + } + else + { + result.x = {start[1], start[3]}; + result.y = {start[0], start[2]}; + } + + return result; + } + + inline bool suppress_by_iou(box b1, box b2, float iou_threshold) const + { + b1.sort(); + b2.sort(); + + box intersection{}; + for(auto i : range(2)) + { + intersection[i][0] = std::max(b1[i][0], b2[i][0]); + intersection[i][1] = std::min(b1[i][1], b2[i][1]); + } + + std::vector> bbox = {intersection.x, intersection.y}; + if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) { + return not std::is_sorted(bx.begin(), bx.end()); + })) + { + return false; + } + + const float area1 = b1.area(); + const float area2 = b2.area(); + const float intersection_area = intersection.area(); + const float union_area = area1 + area2 - intersection_area; + + if(area1 <= .0f or area2 <= .0f or union_area <= .0f) + { + return false; + } + + const float intersection_over_union = intersection_area / union_area; + + return intersection_over_union > iou_threshold; + } + + argument compute(const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + + result.visit([&](auto out) { std::fill(out.begin(), out.end(), 0); }); + + std::size_t max_output_boxes_per_class = 0; + float iou_threshold = 0.0f; + float score_threshold = 0.0f; + + if(args.size() > 2) + { + max_output_boxes_per_class = args.at(2).at(); + } + // max_output_boxes_per_class is 0, no output + if(max_output_boxes_per_class == 0) + { + return result; + } + + if(args.size() > 3) + { + iou_threshold = args.at(3).at(); + } + + if(args.size() > 4) + { + score_threshold = args.at(4).at(); + } + + const auto& lens = args.at(1).get_shape().lens(); + auto batch_num = lens[0]; + auto class_num = lens[1]; + auto box_num = args.at(0).get_shape().lens()[1]; + + std::vector> selected_boxes_inside_class; + std::vector selected_indices; + selected_boxes_inside_class.reserve(output_shape.elements()); + + auto scores = make_view(args.at(1).get_shape(), args.at(1).cast()); + const float* boxes = args.at(0).cast(); + shape comp_s{shape::float_type, {batch_num, class_num}}; + shape_for_each(comp_s, [&](auto idx) { + auto bidx = idx[0]; + auto cidx = idx[1]; + + std::size_t score_offset = (bidx * class_num + cidx) * box_num; + const float* batch_boxes = boxes + bidx * box_num * 4; + std::priority_queue> sorted_boxes; + auto insert_to_sorted_boxes = + make_function_output_iterator([&](const auto& x) { sorted_boxes.push(x); }); + + int64_t box_idx = 0; + transform_if( + scores.begin() + score_offset, + scores.begin() + score_offset + box_num, + insert_to_sorted_boxes, + [&](auto sc) { + box_idx++; + return sc >= score_threshold; + }, + [&](auto sc) { return std::make_pair(sc, box_idx - 1); }); + + selected_boxes_inside_class.clear(); + // Get the next box with top score, filter by iou_threshold + while(!sorted_boxes.empty() && + selected_boxes_inside_class.size() < max_output_boxes_per_class) + { + const std::pair& next_top_score = sorted_boxes.top(); + + // Check with existing selected boxes for this class, suppress if exceed the IOU + // (Intersection Over Union) threshold + bool not_selected = std::any_of( + selected_boxes_inside_class.begin(), + selected_boxes_inside_class.end(), + [&](auto selected_index) { + return this->suppress_by_iou(batch_box(batch_boxes, next_top_score.second), + batch_box(batch_boxes, selected_index.second), + iou_threshold); + }); + + if(not not_selected) + { + selected_boxes_inside_class.push_back(next_top_score); + selected_indices.push_back(bidx); + selected_indices.push_back(cidx); + selected_indices.push_back(next_top_score.second); + } + sorted_boxes.pop(); + } + }); + + result.visit([&](auto out) { + std::copy(selected_indices.begin(), selected_indices.end(), out.begin()); + }); + + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/nonzero.hpp b/src/include/migraphx/op/nonzero.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a0d0d12128cecc68519a5d604fc8a5ce981a86dd --- /dev/null +++ b/src/include/migraphx/op/nonzero.hpp @@ -0,0 +1,62 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_NONZERO_HPP +#define MIGRAPHX_GUARD_OPERATORS_NONZERO_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct nonzero +{ + std::string name() const { return "nonzero"; } + + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(1).standard(); + auto elem_num = inputs[0].elements(); + auto dim_num = inputs[0].lens().size(); + std::vector out_lens = {dim_num, elem_num}; + + return {shape::int64_type, out_lens}; + } + + argument compute(const shape& output_shape, std::vector args) const + { + std::vector> vec_idx; + auto s = args.front().get_shape(); + args.front().visit([&](auto v) { + shape_for_each(s, [&](auto idx) { + if(not float_equal(v[s.index(idx)], 0)) + { + vec_idx.push_back(idx); + } + }); + }); + + argument result{output_shape}; + result.visit([&](auto output) { + std::fill(output.begin(), output.end(), 0); + par_for(vec_idx.size(), [&](auto i) { + for(std::size_t j = 0; j < vec_idx.front().size(); ++j) + { + output[output_shape.index({j, i})] = vec_idx[i][j]; + } + }); + }); + + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/normalize_attribute.hpp b/src/include/migraphx/op/normalize_attribute.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ee2f678a54519ba7b49c9163a4f127ff034f0865 --- /dev/null +++ b/src/include/migraphx/op/normalize_attribute.hpp @@ -0,0 +1,34 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_OP_NORMALIZE_ATTRIBUTE_HPP +#define MIGRAPHX_GUARD_OPERATORS_OP_NORMALIZE_ATTRIBUTE_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +// different attributes +// 1) use_input(default)/use_output +// 2) use_rank(default)/use_len +// 3) clip_min(default)/not_clip_min +// 3.1) include_min(default)/exclude_min +// 4) clip_max(default)/not_clip_max +// 4.1) exclude_max(default)/include_max +// 5) normalize padding +enum class normalize_attribute +{ + use_len, + use_output, + clip_max, + clip_min, + include_max, + include_min, + normalize_padding +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/outline.hpp b/src/include/migraphx/op/outline.hpp index ccc4dabc2bd856fc628452841eb7b3625fb8d067..cc775248a9203fafa32c7babadc22b4d4008fc65 100644 --- a/src/include/migraphx/op/outline.hpp +++ b/src/include/migraphx/op/outline.hpp @@ -2,7 +2,6 @@ #define MIGRAPHX_GUARD_OPERATORS_OUTLINE_HPP #include -#include #include #include #include diff --git a/src/include/migraphx/op/pad.hpp b/src/include/migraphx/op/pad.hpp index 1fbc46c289acbd528674bd46e2b323d928271da5..c8bad3c05993ab00758e83f0a20504bb75731c41 100644 --- a/src/include/migraphx/op/pad.hpp +++ b/src/include/migraphx/op/pad.hpp @@ -2,7 +2,6 @@ #define MIGRAPHX_GUARD_OPERATORS_PAD_HPP #include -#include #include #include #include @@ -51,6 +50,12 @@ struct pad return s; } + std::size_t pad_ndims() const + { + assert(pads.size() % 2 == 0); + return pads.size() / 2; + } + bool symmetric() const { std::size_t num_dims = pads.size() / 2; diff --git a/src/include/migraphx/op/pointwise.hpp b/src/include/migraphx/op/pointwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8c5e64d5d3b430b6e88fa48d7bf4951bf09a2ef2 --- /dev/null +++ b/src/include/migraphx/op/pointwise.hpp @@ -0,0 +1,75 @@ +#ifndef MIGRAPHX_GUARD_OP_POINTWISE_HPP +#define MIGRAPHX_GUARD_OP_POINTWISE_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct pointwise +{ + std::string name() const { return "pointwise"; } + + shape compute_shape(const std::vector& inputs, std::vector mods) const + { + if(mods.size() != 1) + { + MIGRAPHX_THROW("should have one submodule."); + } + auto* pm = mods.front(); + auto pnames = pm->get_parameter_names(); + std::sort(pnames.begin(), pnames.end()); + check_shapes{inputs, *this}.has(pnames.size()).same_dims(); + + if(pm->get_output_shapes().size() != 1) + MIGRAPHX_THROW("submodule should have only one output."); + + auto type = pm->get_output_shapes().front().type(); + + // Scalar output if all inputs are scalar + if(inputs.front().elements() == 1 and + all_of(inputs, [](const auto& s) { return s.scalar(); })) + return shape{type}; + + return shape::from_permutation(type, inputs.front().lens(), find_permutation(inputs)); + } + + argument compute(const shape& output_shape, + const std::vector& args, + const std::vector& mods, + const std::function( + module_ref&, const std::unordered_map&)>& run) const + { + argument output{output_shape}; + auto* pm = mods.front(); + auto pnames = pm->get_parameter_names(); + std::sort(pnames.begin(), pnames.end()); + + par_for(output_shape.elements(), [&](auto i) { + std::unordered_map params; + + std::transform( + pnames.begin(), + pnames.end(), + args.begin(), + std::inserter(params, params.end()), + [&](auto&& name, auto&& arg) { return std::make_pair(name, arg.element(i)); }); + + auto results = run(pm, params); + assert(results.size() == 1); + visit_all(output, results.front())([&](auto out, auto x) { out[i] = x.front(); }); + }); + return output; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_OP_POINTWISE_HPP diff --git a/src/include/migraphx/op/pooling.hpp b/src/include/migraphx/op/pooling.hpp index 0f9df423b52f66ed96a158594117c04f1fac0845..9d06f4a432410d6e38ad03f5a4c3259be38fc55d 100644 --- a/src/include/migraphx/op/pooling.hpp +++ b/src/include/migraphx/op/pooling.hpp @@ -3,11 +3,13 @@ #include #include -#include #include #include #include +#include #include +#include +#include #include #include #include @@ -15,54 +17,188 @@ #include namespace migraphx { + inline namespace MIGRAPHX_INLINE_NS { namespace op { struct pooling { - std::string mode = "average"; - std::array padding = {{0, 0}}; - std::array stride = {{1, 1}}; - std::array lengths = {{1, 1}}; - padding_mode_t padding_mode = default_; + pooling_mode mode = {pooling_mode::average}; + std::vector padding = {0, 0}; + std::vector stride = {1, 1}; + std::vector lengths = {1, 1}; + bool ceil_mode = false; + int lp_order = 2; template static auto reflect(Self& self, F f) { return pack(f(self.mode, "mode"), f(self.padding, "padding"), - f(self.padding_mode, "padding_mode"), f(self.stride, "stride"), - f(self.lengths, "lengths")); + f(self.lengths, "lengths"), + f(self.ceil_mode, "ceil_mode"), + f(self.lp_order, "lp_order")); } std::string name() const { return "pooling"; } - shape compute_shape(std::vector inputs) const + void check_attribute_size() const + { + if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and + stride.size() == lengths.size())) + { + MIGRAPHX_THROW("POOLING: inconsistent attribute sizes"); + } + } + + value attributes() const { return {{"normalize_padding", "padding"}}; } + + shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(1).only_dims(4); + check_shapes{inputs, *this}.has(1); const shape& input = inputs.at(0); - auto t = input.type(); - assert(lengths[0] <= (input.lens()[2] + 2 * padding[0])); - assert(lengths[1] <= (input.lens()[3] + 2 * padding[1])); + auto input_lens = input.lens(); + size_t kdims = input_lens.size() - 2; + auto input_size = inputs[0].lens().size(); + auto padding_size = padding.size(); + if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2)) + { + MIGRAPHX_THROW("POOLING: input and attribute size mismatch!"); + } + + std::vector output_lens(input_lens.begin(), input_lens.begin() + 2); + + for(size_t i = 0; i < kdims; i++) + { + std::ptrdiff_t dim_size; + auto padding_factor = 2 * padding[i]; + if(padding_size == 2 * kdims) + padding_factor = padding[i] + padding[i + kdims]; + dim_size = input_lens[i + 2] + padding_factor - lengths[i]; + assert(dim_size >= 0); + std::size_t len = (ceil_mode) ? ceil_divide(dim_size, stride[i]) + : floor_divide(dim_size, stride[i]); + + output_lens.push_back(std::size_t(std::max(1, len + 1))); + } + return inputs[0].with_lens(output_lens); + } + + size_t kdims() const + { + check_attribute_size(); + return stride.size(); + } + + struct lpnorm_pool + { + int p = 0; + + lpnorm_pool() = delete; + + explicit lpnorm_pool(int x) : p{x} {}; + + template + double init() const + { + return 0.0; + } - return {t, + double operator()(double x, double y) const { return x + std::pow(std::abs(y), p); } + + double final(double x, std::size_t) const { return std::pow(x, 1. / p); } + }; + + struct avg_pool + { + template + double init() const + { + return 0.0; + } + + double operator()(double x, double y) const { return x + y; } + + double final(double x, std::size_t y) const { return (y == 0) ? 0.0 : (x / y); } + }; + + struct max_pool + { + template + T init() const + { + return std::numeric_limits::lowest(); + } + + double operator()(double x, double y) const { return std::max(x, y); } + + double final(double x, std::size_t) const { return (x); } + }; + + template + void calc_pooling(const shape& output_shape, Out& output, const In& input, Op op) const + { + auto in_s = input.get_shape(); + auto in_lens = in_s.lens(); + par_for(output_shape.elements(), [&](auto i) { + auto idx_o = output_shape.multi(i); + auto n_dim = idx_o.size(); + std::vector win_start; + std::vector win_size; + for(std::size_t dim = 2; dim < n_dim; ++dim) + { + auto d_2 = dim - 2; + int start = + static_cast(idx_o[dim] * stride[d_2]) - static_cast(padding[d_2]); + int end = std::min(start + lengths[d_2], in_lens[dim]); + start = std::max(start, 0); + win_start.push_back(start); + win_size.push_back(end - start); + } + + shape win_shape{output_shape.type(), win_size}; + auto pool_size = win_shape.elements(); + double output_val = op.template init(); + shape_for_each(win_shape, [&](auto idx_w) { + auto idx = idx_o; + std::transform(idx_w.begin(), + idx_w.end(), + win_start.begin(), + idx.begin() + 2, + [](auto ii, auto jj) { return ii + jj; }); + if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and + idx < in_lens) { - input.lens()[0], - input.lens()[1], - std::size_t(std::max( - 1, - floor_divide(input.lens()[2] + 2 * padding[0] - lengths[0], - stride[0]) + - 1)), - std::size_t(std::max( - 1, - floor_divide(input.lens()[3] + 2 * padding[1] - lengths[1], - stride[1]) + - 1)), - }}; + output_val = op(output_val, input[in_s.index(idx)]); + } + }); + output[i] = Type(op.final(output_val, pool_size)); + }); + } + + argument compute(const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + visit_all(result, args[0])([&](auto output, auto input) { + using type = typename decltype(output)::value_type; + switch(mode) + { + case migraphx::op::pooling_mode::average: + calc_pooling(output_shape, output, input, avg_pool{}); + break; + case migraphx::op::pooling_mode::max: + calc_pooling(output_shape, output, input, max_pool{}); + break; + case migraphx::op::pooling_mode::lpnorm: + calc_pooling(output_shape, output, input, lpnorm_pool{lp_order}); + break; + } + }); + + return result; } }; diff --git a/src/include/migraphx/op/prefix_scan_op.hpp b/src/include/migraphx/op/prefix_scan_op.hpp new file mode 100644 index 0000000000000000000000000000000000000000..24a28293478d1ae9c85c11ac7b906e58375426dc --- /dev/null +++ b/src/include/migraphx/op/prefix_scan_op.hpp @@ -0,0 +1,116 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_OP_HPP +#define MIGRAPHX_GUARD_OPERATORS_SCAN_OP_HPP + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +template +struct prefix_scan_op : op_name +{ + int64_t axis; + bool exclusive = false; + bool reverse = false; + + template + static auto reflect(Self& self, F f) + { + return pack( + f(self.axis, "axis"), f(self.exclusive, "exclusive"), f(self.reverse, "reverse")); + } + + value attributes() const + { + value normalize; + normalize["axis"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + + shape normalize_compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(1); + auto s = inputs.front(); + if(s.broadcasted()) + { + return {s.type(), s.lens()}; + } + else + { + return s.with_lens(s.lens()); + } + } + + argument compute(const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + auto s = args[0].get_shape(); + if(s == output_shape) + { + result = args[0].copy(); + } + else + { + visit_all(result, args[0])([&](auto output, auto input) { + par_for(output_shape.elements(), + [&](auto i) { output[output_shape.index(i)] = input[s.index(i)]; }); + }); + s = output_shape; + } + auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}}; + auto lens = s.lens(); + lens[axis] = 1; + auto batch = shape{s.type(), lens, s.strides()}; + auto& self = static_cast(*this); + result.visit([&](auto output) { + using type = decltype(output); + par_for(batch.elements(), [&](auto i) { + auto* start = output.data() + batch.index(i); + type x{slice, start}; + if(reverse) + { + if(exclusive) + { + std::copy(++x.begin(), x.end(), x.begin()); + x.back() = 0; + } + std::partial_sum(std::make_reverse_iterator(x.end()), + std::make_reverse_iterator(x.begin()), + std::make_reverse_iterator(x.end()), + self.op()); + } + else + { + if(exclusive) + { + std::copy_backward(x.begin(), --x.end(), x.end()); + x.front() = 0; + } + std::partial_sum(x.begin(), x.end(), x.begin(), self.op()); + } + }); + }); + + return result; + } + + auto init() const {} + prefix_scan_op() : axis(0) {} + prefix_scan_op(int64_t ax) : axis(ax) {} + prefix_scan_op(int64_t ax, bool excl) : axis(ax), exclusive(excl) {} + prefix_scan_op(int64_t ax, bool excl, bool rev) : axis(ax), exclusive(excl), reverse(rev) {} +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/prefix_scan_sum.hpp b/src/include/migraphx/op/prefix_scan_sum.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3f922a8d0ce43221303bd5944b8e258d28972e44 --- /dev/null +++ b/src/include/migraphx/op/prefix_scan_sum.hpp @@ -0,0 +1,32 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_INCLUSIVE_SUM_HPP +#define MIGRAPHX_GUARD_OPERATORS_SCAN_INCLUSIVE_SUM_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct prefix_scan_sum : prefix_scan_op +{ + prefix_scan_sum() {} + prefix_scan_sum(int64_t ax) : prefix_scan_op(ax) {} + prefix_scan_sum(int64_t ax, bool excl) : prefix_scan_op(ax, excl) {} + prefix_scan_sum(int64_t ax, bool excl, bool rev) : prefix_scan_op(ax, excl, rev) {} + + auto op() const + { + return [](auto x, auto y) { return x + y; }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/prelu.hpp b/src/include/migraphx/op/prelu.hpp new file mode 100644 index 0000000000000000000000000000000000000000..173ee97f1cfe8b9daae1cd261b89dd8268782c02 --- /dev/null +++ b/src/include/migraphx/op/prelu.hpp @@ -0,0 +1,23 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_PRELU_HPP +#define MIGRAPHX_GUARD_OPERATORS_PRELU_HPP + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct prelu : binary +{ + std::string point_op() const { return "(${0} < 0) ? (${0} * ${1}) : ${0}"; } + auto apply() const + { + return [](auto x, auto slope) { return ((x < 0) ? (x * slope) : x); }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/quant_convolution.hpp b/src/include/migraphx/op/quant_convolution.hpp index 3857a1ac27c4328ee53159854c4dd013a7972d3d..744ce27e8fdbd679ccdc066975fc4b2613064f2c 100644 --- a/src/include/migraphx/op/quant_convolution.hpp +++ b/src/include/migraphx/op/quant_convolution.hpp @@ -3,12 +3,12 @@ #include #include -#include #include #include #include #include #include +#include #include #include #include @@ -19,9 +19,9 @@ namespace op { struct quant_convolution { - std::array padding = {{0, 0}}; - std::array stride = {{1, 1}}; - std::array dilation = {{1, 1}}; + std::vector padding = {0, 0}; + std::vector stride = {1, 1}; + std::vector dilation = {1, 1}; padding_mode_t padding_mode = default_; int group = 1; @@ -36,14 +36,35 @@ struct quant_convolution f(self.group, "group")); } + value attributes() const + { + return {{"general_data_type", "convolution"}, {"normalize_padding", "padding"}}; + } + std::string name() const { return "quant_convolution"; } - shape compute_shape(std::vector inputs) const + + void check_attribute_size() const + { + if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and + stride.size() == dilation.size())) + { + MIGRAPHX_THROW("QUANT_CONVOLUTION: inconsistent attribute sizes"); + } + } + + shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4); + check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3); + check_attribute_size(); const shape& input = inputs.at(0); const shape& weights = inputs.at(1); auto t = input.type(); + size_t kdims = input.lens().size() - 2; + if(kdims != this->kdims()) + { + MIGRAPHX_THROW("quant_convolution: input k-dims does not match attribute size"); + } // all input type must be int8_type and output is float_type if(t != shape::int8_type) @@ -52,23 +73,28 @@ struct quant_convolution } t = shape::int32_type; - return {t, - { - input.lens()[0], - weights.lens()[0], - std::size_t(std::max( - 1, - (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) + - 2 * padding[0]) / - stride[0] + - 1)), - std::size_t(std::max( - 1, - (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) + - 2 * padding[1]) / - stride[1] + - 1)), - }}; + std::vector output_lens{input.lens()[0], weights.lens()[0]}; + auto padding_size = padding.size(); + for(size_t i = 0; i < kdims; i++) + { + auto padding_factor = 2 * padding[i]; + if(padding_size == 2 * kdims) + padding_factor = padding[i] + padding[i + kdims]; + output_lens.push_back(std::size_t(std::max( + 1, + (input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + + padding_factor) / + stride[i] + + 1))); + } + + return inputs[0].with_lens(t, output_lens); + } + + size_t kdims() const + { + check_attribute_size(); + return stride.size(); } }; diff --git a/src/include/migraphx/op/quant_dot.hpp b/src/include/migraphx/op/quant_dot.hpp old mode 100644 new mode 100755 index 552f71e0383d4af11286120491f585b006beb23f..d7a956ab34397ede216954389a93924c4d511d6c --- a/src/include/migraphx/op/quant_dot.hpp +++ b/src/include/migraphx/op/quant_dot.hpp @@ -2,13 +2,13 @@ #define MIGRAPHX_GUARD_OPERATORS_QUANT_DOT_HPP #include -#include #include #include #include #include #include #include +#include #include #include @@ -18,19 +18,12 @@ namespace op { struct quant_dot { - int32_t alpha = 1; - int32_t beta = 1; - - template - static auto reflect(Self& self, F f) - { - return pack(f(as_number(self.alpha), "alpha"), f(as_number(self.beta), "beta")); - } + value attributes() const { return {{"general_data_type", "dot"}}; } std::string name() const { return "quant_dot"; } shape compute_shape(std::vector inputs) const { - check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type(); + check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type().has(2); const shape& a = inputs.at(0); const shape& b = inputs.at(1); auto t = a.type(); @@ -60,27 +53,8 @@ struct quant_dot to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) + "}"); } - // k be multiple of 4 - if((a.lens()[dim_1] % 4) != 0) - { - MIGRAPHX_THROW("QUANT_DOT: size of A {" + to_string_range(a.lens()) + "} and B {" + - to_string_range(b.lens()) + "} must be multiple of 4 for int8 type"); - } - auto out_lens = a.lens(); out_lens[dim_1] = b.lens()[dim_1]; - if(inputs.size() == 3 && out_lens != inputs.at(2).lens()) - { - MIGRAPHX_THROW("QUANT_DOT: dimension mismatch, operand C: {" + - to_string_range(inputs.at(2).lens()) + - "}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}"); - } - - if(inputs.size() == 3 && inputs.at(2).type() != shape::int32_type) - { - MIGRAPHX_THROW("QUANT_DOT: operand C type must be int32"); - } - return {shape::int32_type, out_lens}; } }; diff --git a/src/include/migraphx/op/quantizelinear.hpp b/src/include/migraphx/op/quantizelinear.hpp new file mode 100644 index 0000000000000000000000000000000000000000..df1dc7aabe51afcf7e87254d5cac29a49f6be783 --- /dev/null +++ b/src/include/migraphx/op/quantizelinear.hpp @@ -0,0 +1,72 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP +#define MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct quantizelinear +{ + std::string name() const { return "quantizelinear"; } + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.same_dims(); + if(inputs.size() == 3) + { + return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()}; + } + return {shape::uint8_type, inputs[0].lens(), inputs[0].strides()}; + } + + argument compute(const shape& output_shape, std::vector args) const + { + auto x = args.at(0); + auto y_scale = args.at(1); + std::vector zeros(output_shape.bytes(), 0); + argument y_zero_point{output_shape, zeros.data()}; + if(args.size() == 3) + { + y_zero_point = args.at(2); + } + + argument result{output_shape}; + visit_all(result, y_zero_point)([&](auto output, auto zero_pts) { + x.visit([&](auto input) { + y_scale.visit([&](auto scales) { + using quant_type = typename decltype(output)::value_type; + auto min_value = std::numeric_limits::min(); + auto max_value = std::numeric_limits::max(); + par_for(output_shape.elements(), [&](auto i) { + int64_t quantized = static_cast(std::round(input[i] / scales[i])) + + static_cast(zero_pts[i]); + output[i] = std::max(static_cast(min_value), + std::min(static_cast(max_value), quantized)); + }); + }); + }); + }); + + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/recip.hpp b/src/include/migraphx/op/recip.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f8c2c15b12b6c03ed277e20ebbf4b2e9085d2c7b --- /dev/null +++ b/src/include/migraphx/op/recip.hpp @@ -0,0 +1,23 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_RECIP_HPP +#define MIGRAPHX_GUARD_OPERATORS_RECIP_HPP + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct recip : unary +{ + std::string point_op() const { return "1 / ${0}"; } + auto apply() const + { + return [](auto x) { return 1 / x; }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/reduce_mean.hpp b/src/include/migraphx/op/reduce_mean.hpp old mode 100644 new mode 100755 index 3aff37fd80a8138c16d11883e60527e33f0b3b5e..2b6de7b77ecb4c27c0f1e5d60f2a23df0c9eebfe --- a/src/include/migraphx/op/reduce_mean.hpp +++ b/src/include/migraphx/op/reduce_mean.hpp @@ -14,12 +14,12 @@ struct reduce_mean : reduce_op auto op() const { - return [=](auto x, auto y) { return x + y; }; + return [](auto x, auto y) { return x + y; }; } auto output(const shape& s) const { - return [&](auto val) { return val / s.elements(); }; + return [&](auto val) { return val / static_cast(s.elements()); }; } }; diff --git a/src/include/migraphx/op/reduce_op.hpp b/src/include/migraphx/op/reduce_op.hpp old mode 100644 new mode 100755 index 7e5a4675f38fc02518a36f86e795fc3802f0865b..dad6b0f8bf62b5f0311044d959e87fb0557fe352 --- a/src/include/migraphx/op/reduce_op.hpp +++ b/src/include/migraphx/op/reduce_op.hpp @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include namespace migraphx { @@ -40,6 +42,15 @@ struct zero } }; +struct one +{ + template + operator T() const + { + return T{1}; + } +}; + template struct reduce_op : op_name { @@ -51,6 +62,13 @@ struct reduce_op : op_name return pack(f(self.axes, "axes")); } + value attributes() const + { + value normalize; + normalize["axes"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + std::vector tune_axes(std::size_t n_dim) const { auto tuned_axes = axes; @@ -59,26 +77,11 @@ struct reduce_op : op_name tuned_axes.resize(n_dim); std::iota(tuned_axes.begin(), tuned_axes.end(), 0); } - else - { - for(auto& axis : tuned_axes) - { - int64_t s_dim = static_cast(n_dim); - if(axis >= s_dim or axis < -s_dim) - { - MIGRAPHX_THROW("REDUCE_OP: axis out of range"); - } - if(axis < 0) - { - axis += n_dim; - } - } - } return tuned_axes; } - shape compute_shape(std::vector inputs) const + shape normalize_compute_shape(std::vector inputs) const { check_shapes{inputs, *this}.has(1); auto s = inputs.at(0); @@ -89,7 +92,7 @@ struct reduce_op : op_name lens[axis] = 1; } - return {s.type(), lens}; + return inputs[0].with_lens(lens); } template @@ -110,13 +113,14 @@ struct reduce_op : op_name std::vector& out_idx, tensor_view& output) const { - auto data_idx = out_idx; - T val = static_cast(*this).init(); + using accumulator = accumulator_type; + auto& self = static_cast(*this); + auto data_idx = out_idx; + accumulator val = self.init(); shape_for_each(batch_shape, [&](auto b_idx) { this->tune_dims(tuned_axes, b_idx, data_idx); - val = static_cast(*this).op()( - static_cast(*this).input()(input(data_idx.begin(), data_idx.end())), - val); + accumulator x = input(data_idx.begin(), data_idx.end()); + val = self.op()(accumulator{self.input()(x)}, val); }); output(out_idx.begin(), out_idx.end()) = @@ -145,12 +149,12 @@ struct reduce_op : op_name auto input() const { - return [&](auto val) { return val; }; + return [](auto val) { return val; }; } auto output(const shape&) const { - return [&](auto val) { return val; }; + return [](auto val) { return val; }; } reduce_op() {} diff --git a/src/include/migraphx/op/reduce_prod.hpp b/src/include/migraphx/op/reduce_prod.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f9db82d6b6baaee4d2508274375db82f90dd52fb --- /dev/null +++ b/src/include/migraphx/op/reduce_prod.hpp @@ -0,0 +1,27 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_REDUCE_PROD_HPP +#define MIGRAPHX_GUARD_OPERATORS_REDUCE_PROD_HPP + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct reduce_prod : reduce_op +{ + reduce_prod() {} + reduce_prod(std::vector ax) : reduce_op(std::move(ax)) {} + + auto op() const + { + return [=](auto x, auto y) { return x * y; }; + } + + auto init() const { return one(); } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/relu.hpp b/src/include/migraphx/op/relu.hpp old mode 100644 new mode 100755 index 0ea391fca4c99576f7b7386d583d16a019ee5f44..d32bf2f36f5996a6dabf1210c62d45b6df61721b --- a/src/include/migraphx/op/relu.hpp +++ b/src/include/migraphx/op/relu.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -19,6 +18,7 @@ namespace op { struct relu : unary { + std::string point_op() const { return "${function:max}(decltype(${0}){0}, ${0})"; } auto apply() const { return [](auto x) { return std::max(decltype(x){0}, x); }; diff --git a/src/include/migraphx/op/reshape.hpp b/src/include/migraphx/op/reshape.hpp index 50edc331ee089af4ac172e827394a3ad33f66bec..0ccc83321606f577ad2774f60e778bf3e9cc91e6 100644 --- a/src/include/migraphx/op/reshape.hpp +++ b/src/include/migraphx/op/reshape.hpp @@ -2,13 +2,14 @@ #define MIGRAPHX_GUARD_OPERATORS_RESHAPE_HPP #include -#include #include #include #include #include #include #include +#include +#include #include #include @@ -26,6 +27,8 @@ struct reshape return pack(f(self.dims, "dims")); } + value attributes() const { return {{"require_std_shape", true}}; } + std::string name() const { return "reshape"; } shape compute_shape(std::vector inputs) const { @@ -34,7 +37,8 @@ struct reshape std::vector rdims(dims.begin(), dims.end()); auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); if(n_neg_dims > 1) - MIGRAPHX_THROW("Dimensions for reshape can only have one -1 dim"); + MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim"); + for(std::size_t i = 0; i < dims.size(); i++) { if(dims[i] == 0) @@ -45,6 +49,7 @@ struct reshape if(dims[i] == -1) rdims[i] = 1; } + if(n_neg_dims > 0) { size_t missing_dim = @@ -59,15 +64,17 @@ struct reshape shape s{inputs.front().type(), rdims}; if(s.elements() != inputs.front().elements()) - MIGRAPHX_THROW("Wrong number of elements for reshape: reshape has " + + MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " + std::to_string(s.elements()) + " elements whereas the input has " + std::to_string(inputs.front().elements())); return s; } + argument compute(shape output_shape, std::vector args) const { - return {std::move(output_shape), std::move(args.front().data)}; + return args[0].reshape(output_shape); } + std::ptrdiff_t output_alias(const std::vector&) const { return 0; } }; diff --git a/src/include/migraphx/op/reverse.hpp b/src/include/migraphx/op/reverse.hpp new file mode 100755 index 0000000000000000000000000000000000000000..ac79d05501b89d187af48d52a2ffc1c868f32c08 --- /dev/null +++ b/src/include/migraphx/op/reverse.hpp @@ -0,0 +1,68 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_REVERSE_HPP +#define MIGRAPHX_GUARD_OPERATORS_REVERSE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct reverse +{ + + std::vector axes; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.axes, "axes")); + } + + std::string name() const { return "reverse"; } + + value attributes() const + { + value normalize; + normalize["axes"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + + shape normalize_compute_shape(std::vector inputs) const + { + return inputs[0].with_lens(inputs[0].lens()); + } + + argument compute(const shape& s, std::vector args) const + { + argument result{s}; + auto lens = s.lens(); + visit_all(result, args.front())([&](auto output, auto input) { + shape_for_each(s, [&](const auto& out_idx) { + auto in_idx = out_idx; + for(const auto& axis : axes) + { + in_idx[axis] = lens[axis] - 1 - out_idx[axis]; + } + output[s.index(out_idx)] = input[s.index(in_idx)]; + }); + }); + + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/rnn.hpp b/src/include/migraphx/op/rnn.hpp index bc0a51fc67f3b1198af39cc40097354facc91c64..f3d96b2c2965f857201e92becf0c146688b0652a 100644 --- a/src/include/migraphx/op/rnn.hpp +++ b/src/include/migraphx/op/rnn.hpp @@ -3,8 +3,8 @@ #include #include -#include #include +#include #include #include #include diff --git a/src/include/migraphx/op/rnn_last_cell_output.hpp b/src/include/migraphx/op/rnn_last_cell_output.hpp index aceba4907297d43013c6c21902ef0165b848d1e6..6f608b7b64b8ebaa0e3e895dbdd43eca7b1c7fa6 100644 --- a/src/include/migraphx/op/rnn_last_cell_output.hpp +++ b/src/include/migraphx/op/rnn_last_cell_output.hpp @@ -1,27 +1,21 @@ #ifndef MIGRAPHX_GUARD_OPERATORS_RNN_LAST_CELL_OUTPUT_HPP #define MIGRAPHX_GUARD_OPERATORS_RNN_LAST_CELL_OUTPUT_HPP -#include -#include #include #include #include -#include -#include #include -#include -#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { -struct lstm_last_cell_output +struct rnn_last_cell_output { - std::string name() const { return "lstm_last_cell_output"; } + std::string name() const { return "rnn_last_cell_output"; } + shape compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(1); auto dims = inputs[0].lens(); // remove the first dimension, remaing are output shape diff --git a/src/include/migraphx/op/rnn_last_output.hpp b/src/include/migraphx/op/rnn_last_hs_output.hpp similarity index 58% rename from src/include/migraphx/op/rnn_last_output.hpp rename to src/include/migraphx/op/rnn_last_hs_output.hpp index 2e37999342cbfdf7af063cba796cdd7a7f626b2b..c52eabd73070290938eb657e222969b9cce178e5 100644 --- a/src/include/migraphx/op/rnn_last_output.hpp +++ b/src/include/migraphx/op/rnn_last_hs_output.hpp @@ -1,27 +1,21 @@ -#ifndef MIGRAPHX_GUARD_OPERATORS_RNN_LAST_OUTPUT_HPP -#define MIGRAPHX_GUARD_OPERATORS_RNN_LAST_OUTPUT_HPP +#ifndef MIGRAPHX_GUARD_OPERATORS_RNN_LAST_HS_OUTPUT_HPP +#define MIGRAPHX_GUARD_OPERATORS_RNN_LAST_HS_OUTPUT_HPP -#include -#include #include #include #include -#include -#include #include -#include -#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { -struct rnn_last_output +struct rnn_last_hs_output { - std::string name() const { return "rnn_last_output"; } + std::string name() const { return "rnn_last_hs_output"; } + shape compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(1); auto dims = inputs[0].lens(); // remove the first dimension, remaing are output shape diff --git a/src/include/migraphx/op/rnn_var_sl_last_output.hpp b/src/include/migraphx/op/rnn_var_sl_last_output.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2b0964176e9cd8e446f1f281fb8a2ac77f387aa5 --- /dev/null +++ b/src/include/migraphx/op/rnn_var_sl_last_output.hpp @@ -0,0 +1,41 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_RNN_VAR_SL_LAST_OUTPUT_HPP +#define MIGRAPHX_GUARD_OPERATORS_RNN_VAR_SL_LAST_OUTPUT_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct rnn_var_sl_last_output +{ + rnn_direction direction = rnn_direction::forward; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.direction, "direction")); + } + + std::string name() const { return "rnn_var_sl_last_output"; } + + shape compute_shape(std::vector inputs) const + { + auto dims = inputs[0].lens(); + + // remove the first dimension, remaing are output shape + dims.erase(dims.begin()); + return {inputs[0].type(), dims}; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/rnn_variable_seq_lens.hpp b/src/include/migraphx/op/rnn_variable_seq_lens.hpp new file mode 100644 index 0000000000000000000000000000000000000000..87a4bb802159e2ba846b33b566244de5515d0cfd --- /dev/null +++ b/src/include/migraphx/op/rnn_variable_seq_lens.hpp @@ -0,0 +1,108 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_RNN_VARIABLE_SEQ_LENS_HPP +#define MIGRAPHX_GUARD_OPERATORS_RNN_VARIABLE_SEQ_LENS_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct rnn_var_sl_shift_output +{ + std::string output_name = "hidden_states"; + rnn_direction direction = rnn_direction::forward; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.output_name, "output_name"), f(self.direction, "direction")); + } + + std::string name() const { return "rnn_var_sl_shift_output"; } + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(2); + return inputs[0]; + } + + argument compute(const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + int64_t max_len = output_shape.lens()[0]; + visit_all(result, args[0])([&](auto output, auto input) { + using value_type = typename decltype(output)::value_type; + args[1].visit([&](auto seq_lens) { + par_for(output_shape.elements(), [&](auto i) { + auto idx = output_shape.multi(i); + auto batch_id = idx[2]; + auto d = idx[1]; + auto t = idx[0]; + auto sl = seq_lens[batch_id]; + value_type val = value_type{0}; + if(t < sl) + { + auto in_idx = idx; + int offset = (direction == rnn_direction::reverse or d == 1) ? 1 : 0; + in_idx[0] += offset * (max_len - sl); + val = input(in_idx.begin(), in_idx.end()); + } + output(idx.begin(), idx.end()) = val; + }); + }); + }); + + return result; + } +}; + +struct rnn_var_sl_shift_sequence +{ + std::string name() const { return "rnn_var_sl_shift_sequence"; } + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(2); + return inputs[0]; + } + + argument compute(const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + int64_t max_len = output_shape.lens()[0]; + visit_all(result, args[0])([&](auto output, auto input) { + using value_type = typename decltype(output)::value_type; + args[1].visit([&](auto seq_lens) { + par_for(output_shape.elements(), [&](auto i) { + auto idx = output_shape.multi(i); + auto b = idx[1]; + auto t = idx[0]; + auto sl = seq_lens[b]; + value_type val = value_type{0}; + if(t >= max_len - sl) + { + auto in_idx = idx; + in_idx[0] -= (max_len - sl); + val = input(in_idx.begin(), in_idx.end()); + } + output(idx.begin(), idx.end()) = val; + }); + }); + }); + + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/roialign.hpp b/src/include/migraphx/op/roialign.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f8fc2cabf4680363a4e700e31b85c2892066abde --- /dev/null +++ b/src/include/migraphx/op/roialign.hpp @@ -0,0 +1,269 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_ROIALIGN_HPP +#define MIGRAPHX_GUARD_OPERATORS_ROIALIGN_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct roialign +{ + std::string coord_trans_mode = "half_pixel"; + pooling_mode mode = {pooling_mode::average}; + int64_t output_height = 1; + int64_t output_width = 1; + int64_t sampling_ratio = 0; + float spatial_scale = 1.0f; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.coord_trans_mode, "coordinate_transformation_mode"), + f(self.mode, "mode"), + f(self.output_height, "output_height"), + f(self.output_width, "output_width"), + f(self.sampling_ratio, "sampling_ratio"), + f(self.spatial_scale, "spatial_scale")); + } + + std::string name() const { return "roialign"; } + + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(3); + auto x_lens = inputs.at(0).lens(); + auto roi_lens = inputs.at(1).lens(); + auto bi_lens = inputs.at(2).lens(); + auto type = inputs.at(0).type(); + + // check input correct + if(bi_lens.size() != 1) + { + MIGRAPHX_THROW("ROIALIGN: batch indices should be 1 dimension!"); + } + + if(roi_lens.size() != 2 or roi_lens.at(1) != 4) + { + MIGRAPHX_THROW( + "ROIALIGN: rois should be 2 dimensions, and the second dim should be 4!"); + } + + if(roi_lens.front() != bi_lens.front()) + { + MIGRAPHX_THROW("ROIALIGN: rois and batch indices inputs should have the same number!"); + } + + std::vector out_lens = x_lens; + out_lens[0] = roi_lens[0]; + out_lens[2] = output_height; + out_lens[3] = output_width; + + return {type, out_lens}; + } + + struct pos_weight + { + // neighbor indices for the bilinear interpolation + std::array pos = {0, 0, 0, 0}; + // neighbor weights for the bilinear interpolation + std::array w = {0.0f, 0.0f, 0.0f, 0.0f}; + }; + + auto calc_pos_weight(const std::array& dims, + const shape& comp_s, + const std::array& roi_start, + const std::array& bin_size, + const std::array& bin_grid_size) const + { + std::vector results(bin_grid_size[0] * bin_grid_size[1] * output_height * + output_width); + shape_for_each(comp_s, [&](auto idx) { + std::array p = {idx[0], idx[1]}; + std::array i = {idx[2], idx[3]}; + auto index = comp_s.index(idx); + + std::array xy{}; + std::array low{}; + std::array high{}; + for(auto ii : range(p.size())) + { + xy[ii] = roi_start[ii] + p[ii] * bin_size[ii] + + (i[ii] + .5f) * bin_size[ii] / bin_grid_size[ii]; + xy[ii] = (coord_trans_mode == "output_half_pixel") ? (xy[ii] - 0.5f) : xy[ii]; + if(xy[ii] < -1.0 or xy[ii] > dims[ii]) + { + results[index] = pos_weight{}; + return; + } + + xy[ii] = std::max(xy[ii], 0.0f); + low[ii] = xy[ii]; + high[ii] = low[ii] + 1; + if(low[ii] >= dims[ii] - 1) + { + xy[ii] = high[ii] = low[ii] = dims[ii] - 1; + } + } + + results[index].pos = {low[0] * dims[1] + low[1], + low[0] * dims[1] + high[1], + high[0] * dims[1] + low[1], + high[0] * dims[1] + high[1]}; + + float ly = xy[0] - low[0]; + float lx = xy[1] - low[1]; + float hy = 1.0f - ly; + float hx = 1.0f - lx; + + // save weights and indeces + results[index].w = {hy * hx, hy * lx, ly * hx, ly * lx}; + }); + + return results; + } + + struct max_pool + { + double init() { return std::numeric_limits::lowest(); } + + double operator()(double x, double y) { return std::max(x, y); } + + double final(double x, std::size_t) { return (x); } + }; + + struct avg_pool + { + double init() { return 0.0; } + + double operator()(double x, double y) { return x + y; } + + double final(double x, std::size_t y) { return (y == 0) ? 0.0 : (x / y); } + }; + + template + std::tuple calc_pooling(const T& data, + const std::array& bin_grid_size, + const std::vector& pos_weights, + int64_t index, + Op op) const + { + double output_val = op.init(); + const int64_t count = bin_grid_size[0] * bin_grid_size[1]; + dfor(bin_grid_size[0], bin_grid_size[1])([&](auto, auto) { + const auto& pc = pos_weights[index]; + std::array wv; + std::transform( + pc.w.begin(), pc.w.end(), pc.pos.begin(), wv.begin(), [&](auto w, auto pos) { + return *(data + pos) * w; + }); + output_val = std::accumulate(wv.begin(), wv.end(), output_val, op); + index += 1; + }); + + output_val = op.final(output_val, count); + + return {output_val, index}; + } + + argument compute(const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + const auto& out_lens = output_shape.lens(); + int64_t n_rois = out_lens[0]; + std::size_t channels = out_lens[1]; + // output dims of height and width, in all 2-dim arrays, the first dim + // is for height and second dim is for width + std::array out_dims = {out_lens[2], out_lens[3]}; + const auto& x_lens = args.at(0).get_shape().lens(); + // input dims of height and width + std::array in_dims = {x_lens[2], x_lens[3]}; + auto roi_s = args.at(1).get_shape(); + + visit_all(result, args.at(0), args.at(1))([&](auto output, auto x, auto roi) { + const auto* batch_indices = args.at(2).cast(); + par_for(n_rois, [&](auto n) { + const auto bottom_data = x.begin(); + const auto roi_batch_ind = batch_indices[n]; + // Do not using rounding; this implementation detail is critical + std::array roi_starts = { + static_cast(roi[roi_s.index({n, 1})] * spatial_scale), + static_cast(roi[roi_s.index({n, 0})] * spatial_scale)}; + std::array roi_ends = { + static_cast(roi[roi_s.index({n, 3})] * spatial_scale), + static_cast(roi[roi_s.index({n, 2})] * spatial_scale)}; + + // Force malformed ROIs to be 1x1 + std::array roi_size{}; + std::array bin_size{}; + std::array bin_grid_size{}; + + for(auto ii : range(roi_size.size())) + { + roi_size[ii] = roi_ends[ii] - roi_starts[ii]; + roi_size[ii] = std::max(roi_size[ii], 1.0f); + + bin_size[ii] = roi_size[ii] / out_dims[ii]; + bin_grid_size[ii] = (sampling_ratio > 0) + ? sampling_ratio + : std::ceil(roi_size[ii] / out_dims[ii]); + } + + // we want to precalculate indices and weights shared by all channels, + // this is the key point of optimization + std::vector comp_lens = { + out_dims[0], out_dims[1], bin_grid_size[0], bin_grid_size[1]}; + shape comp_s{shape::float_type, comp_lens}; + auto pre_calc = + this->calc_pos_weight(in_dims, comp_s, roi_starts, bin_size, bin_grid_size); + + std::vector comp_lens1 = {channels, out_dims[0], out_dims[1]}; + shape comp_s1{migraphx::shape::float_type, comp_lens1}; + std::vector vec_index(channels, 0); + shape_for_each(comp_s1, [&](auto idx) { + auto c = idx[0]; + auto ph = idx[1]; + auto pw = idx[2]; + + const auto offset_bottom_data = + bottom_data + static_cast((roi_batch_ind * channels + c) * + in_dims[0] * in_dims[1]); + double output_val; + std::tie(output_val, vec_index[c]) = + (mode == migraphx::op::pooling_mode::average) + ? this->calc_pooling(offset_bottom_data, + bin_grid_size, + pre_calc, + vec_index[c], + avg_pool{}) + : this->calc_pooling(offset_bottom_data, + bin_grid_size, + pre_calc, + vec_index[c], + max_pool{}); + output(n, c, ph, pw) = output_val; + }); + }); + }); + + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/scalar.hpp b/src/include/migraphx/op/scalar.hpp old mode 100644 new mode 100755 index 3baed0fc5421397d07ee5a104b399f36d315627e..d01b82899fe52259f562b21fd220fe8eda0600b0 --- a/src/include/migraphx/op/scalar.hpp +++ b/src/include/migraphx/op/scalar.hpp @@ -2,13 +2,13 @@ #define MIGRAPHX_GUARD_OPERATORS_SCALAR_HPP #include -#include #include #include #include #include #include #include +#include #include #include @@ -30,7 +30,7 @@ struct scalar shape compute_shape(std::vector inputs) const { - assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1); + check_shapes{inputs, *this}.has(1).only_dims(1).nelements(1); auto t = inputs.at(0).type(); std::vector strides(scalar_bcast_lens.size(), 0); return {t, scalar_bcast_lens, strides}; @@ -38,7 +38,7 @@ struct scalar argument compute(shape output_shape, std::vector args) const { - return {std::move(output_shape), std::move(args.at(0).data)}; + return args[0].reshape(output_shape); } std::ptrdiff_t output_alias(const std::vector&) const { return 0; } }; diff --git a/src/include/migraphx/op/scatter.hpp b/src/include/migraphx/op/scatter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..03a7253711716526da0e50ff58796ec3b0891415 --- /dev/null +++ b/src/include/migraphx/op/scatter.hpp @@ -0,0 +1,97 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_HPP +#define MIGRAPHX_GUARD_OPERATORS_SCATTER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +// The scatter operator fetches a subset of data given by an index array and then performs a +// reduction operation (add, multiply, or just set the data) on each element returned. We implement +// it as a separate derived struct for each of the three reduction methods. The related operator +// scatterND is a generalization that works on a set of 3 tensors of different ranks. The +// complementary operations are gather/gatherND. +// +// This is a template for deriving child structs from. Each child needs to define +// only a reduction() method. Names are automatically handled by the op_name template. + +template +struct scatter : op_name +{ + int64_t axis = 0; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.axis, "axis")); + } + + value attributes() const + { + value normalize; + normalize["axis"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + + shape normalize_compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(3).standard(); + // If non-packed, this converts to a packed output while preserving permutation of tensor + return inputs.front().with_lens(inputs.front().lens()); + } + + argument compute(const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + auto& self = static_cast(*this); + + // max dimension in each axis + auto axis_dim_size = output_shape.lens()[axis]; + // cast all arguments as correct type + visit_all(result, args[0], args[2])([&](auto output, auto data, auto update) { + // copy all of data to output + std::copy(data.begin(), data.end(), output.begin()); + args[1].visit([&](auto indices) { + auto ind_s = indices.get_shape(); + // iterate through items in shape + shape_for_each(ind_s, [&](const auto& idx) { + auto out_idx = idx; + + // Overloaded tensor_view::() invokes indexing logic of + // std::size_t shape::index(std::size_t i) const + // which handles nonstandard shapes correctly + auto index = indices(idx.begin(), idx.end()); + + // normalize negative indexes (may be redundant after using + // normalize_compute_shape()) + index = (index < 0) ? index + axis_dim_size : index; + out_idx[axis] = index; + + // look up the appropriate locations in output, using idx and out_idx. + // call reduction() method of derived struct to copy and reduce that element + self.reduction()(output(out_idx.begin(), out_idx.end()), + update(idx.begin(), idx.end())); + }); + }); + }); + + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/scatter_add.hpp b/src/include/migraphx/op/scatter_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..91d287045a502d5e8b823e06e1bb236bdaad800b --- /dev/null +++ b/src/include/migraphx/op/scatter_add.hpp @@ -0,0 +1,38 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP +#define MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Scatter op. with "add" function as reduction. +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct scatter_add : scatter +{ + // reduction (pointwise operation) is called by the parent struct's compute() method. + // It works much like a virtual function overload. + // For the scatter methods, there are three different reduction functions. + auto reduction() const + { + return [](auto& x, const auto& y) { x += y; }; + } + + // name of this struct is automatically assigned by the op_name<> +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/scatter_mul.hpp b/src/include/migraphx/op/scatter_mul.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a24b0685dfd54f759b815f438826bce04e1efa2f --- /dev/null +++ b/src/include/migraphx/op/scatter_mul.hpp @@ -0,0 +1,36 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP +#define MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Scatter op. with "multiply" as the reduction function. +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct scatter_mul : scatter +{ + // reduction (pointwise operation) is called by the parent struct's compute() method. + // It works much like a virtual function overload. + // For the scatter operators, there are three different reduction functions. + auto reduction() const + { + return [](auto& x, const auto& y) { x *= y; }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/scatter_none.hpp b/src/include/migraphx/op/scatter_none.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8c948a245aa845a37c2d4c64197ee8bc7dfac251 --- /dev/null +++ b/src/include/migraphx/op/scatter_none.hpp @@ -0,0 +1,37 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP +#define MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Scatter op. with "none" as the reduction function (just copies the value). This is identical to +// the previously existing Scatter op. +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct scatter_none : scatter +{ + // reduction (pointwise operation) is called by the parent struct's compute() method. + // It works much like a virtual function overload. + // For the scatter operators, there are three different reduction functions. + auto reduction() const + { + return [](auto& x, const auto& y) { x = y; }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/scatternd_add.hpp b/src/include/migraphx/op/scatternd_add.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ca6ccb5841a784a9e774892ed5f038eab148239b --- /dev/null +++ b/src/include/migraphx/op/scatternd_add.hpp @@ -0,0 +1,24 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_ADD_HPP +#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_ADD_HPP + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct scatternd_add : scatternd_op +{ + scatternd_add() {} + + auto reduction() const + { + return [](auto& x, const auto& y) { x += y; }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/scatternd_mul.hpp b/src/include/migraphx/op/scatternd_mul.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5fbf6477e454f89dbe844299ab23abf81d5cdde7 --- /dev/null +++ b/src/include/migraphx/op/scatternd_mul.hpp @@ -0,0 +1,24 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MUL_HPP +#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MUL_HPP + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct scatternd_mul : scatternd_op +{ + scatternd_mul() {} + + auto reduction() const + { + return [](auto& x, const auto& y) { x *= y; }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/scatternd_none.hpp b/src/include/migraphx/op/scatternd_none.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cd89dc3db9a715a997ab20e633a3624a525c2547 --- /dev/null +++ b/src/include/migraphx/op/scatternd_none.hpp @@ -0,0 +1,24 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_NONE_HPP +#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_NONE_HPP + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct scatternd_none : scatternd_op +{ + scatternd_none() {} + + auto reduction() const + { + return [](auto& x, const auto& y) { x = y; }; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/scatternd_op.hpp b/src/include/migraphx/op/scatternd_op.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0207f21e7c20584e93d0fab38932e7cf1afe4077 --- /dev/null +++ b/src/include/migraphx/op/scatternd_op.hpp @@ -0,0 +1,84 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_OP_HPP +#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_OP_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +template +struct scatternd_op : op_name +{ + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(3); + auto r = inputs.front().lens().size(); + auto q = inputs.at(1).lens().size(); + auto k = inputs.at(1).lens().back(); + auto ind_lens = inputs.at(1).lens(); + auto upd_lens = inputs.back().lens(); + auto data_lens = inputs.front().lens(); + if(k > r) + MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) + + " is too large for tensor of rank " + std::to_string(r)); + if(not(std::equal(ind_lens.begin(), ind_lens.begin() + q - 1, upd_lens.begin()) and + std::equal(data_lens.begin() + k, data_lens.end(), upd_lens.begin() + q - 1))) + MIGRAPHX_THROW("ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] " + "++ data.lens[k:r-1]"); + auto s = inputs.front(); + if(s.broadcasted()) + { + return {s.type(), s.lens()}; + } + else + { + return s.with_lens(s.lens()); + } + } + + argument compute(const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + auto& self = static_cast(*this); + visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) { + std::copy(data.begin(), data.end(), output.begin()); + args[1].visit([&](auto indices) { + auto updates_shape = updates.get_shape(); + auto updates_std = shape{updates_shape.type(), updates_shape.lens()}; + auto indices_shape = indices.get_shape(); + auto k = indices_shape.lens().back(); + auto q = indices_shape.lens().size(); + auto r = output_shape.lens().size(); + par_for(updates_shape.elements(), [&](const auto i) { + auto updates_idx = updates_std.multi(i); + std::vector indices_idx(q, 0); + std::copy( + updates_idx.begin(), updates_idx.begin() + q - 1, indices_idx.begin()); + auto index_start = indices.begin() + + indices_shape.index(indices_idx.begin(), indices_idx.end()); + auto index_end = index_start + k; + + std::vector out_idx(r, 0); + std::copy(index_start, index_end, out_idx.begin()); + std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k); + + self.reduction()(output[output_shape.index(out_idx)], updates[i]); + }); + }); + }); + + return result; + } + + auto init() const {} + scatternd_op() {} +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/sigmoid.hpp b/src/include/migraphx/op/sigmoid.hpp index 29ab43a5aec83d7f38691e60245a73fa0d9de725..e9bad491deb614f75428d930357a99eeecf0a897 100644 --- a/src/include/migraphx/op/sigmoid.hpp +++ b/src/include/migraphx/op/sigmoid.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -19,6 +18,7 @@ namespace op { struct sigmoid : unary { + std::string point_op() const { return "1.f / (1.f + ${function:exp}(-${0}))"; } auto apply() const { return [](auto x) { return 1.f / (1.f + std::exp(-x)); }; diff --git a/src/include/migraphx/op/sign.hpp b/src/include/migraphx/op/sign.hpp index b72c2e664efd53be620e4cc41726772b0786de8a..9b959a086653eaf4d41e1669c69b962314b3aef1 100644 --- a/src/include/migraphx/op/sign.hpp +++ b/src/include/migraphx/op/sign.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -19,6 +18,7 @@ namespace op { struct sign : unary { + std::string point_op() const { return "(${0} > 0 ? 1 : ((${0} < 0) ? -1 : 0))"; } auto apply() const { return [](auto x) { return (x > 0 ? 1 : ((x < 0) ? -1 : 0)); }; diff --git a/src/include/migraphx/op/sin.hpp b/src/include/migraphx/op/sin.hpp index f9640d95492fc2393abcfd78996fca590a541838..d9309355869699616506a88e86620d942fd272d0 100644 --- a/src/include/migraphx/op/sin.hpp +++ b/src/include/migraphx/op/sin.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include diff --git a/src/include/migraphx/op/sinh.hpp b/src/include/migraphx/op/sinh.hpp index d6f340307c5303a2a527687f983aa71cba48612c..67ed31e224b7c0cd56277a58f179c34debcade4b 100644 --- a/src/include/migraphx/op/sinh.hpp +++ b/src/include/migraphx/op/sinh.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include diff --git a/src/include/migraphx/op/slice.hpp b/src/include/migraphx/op/slice.hpp index 9ad5c338db0923e08301b3e36a8a48ad0e8fa12d..0cf44b246e54002387bd910c6a9123c72ade0c28 100644 --- a/src/include/migraphx/op/slice.hpp +++ b/src/include/migraphx/op/slice.hpp @@ -1,16 +1,15 @@ #ifndef MIGRAPHX_GUARD_OPERATORS_SLICE_HPP #define MIGRAPHX_GUARD_OPERATORS_SLICE_HPP -#include -#include #include #include #include -#include -#include #include +#include +#include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -28,6 +27,23 @@ struct slice return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends")); } + value attributes() const + { + value normalize = value::object{}; + normalize["axes"] = value::array{normalize_attribute::include_min}; + normalize["starts"] = value::array{normalize_attribute::clip_max, + normalize_attribute::clip_min, + normalize_attribute::include_max, + normalize_attribute::use_len, + normalize_attribute::include_min}; + normalize["ends"] = value::array{normalize_attribute::clip_max, + normalize_attribute::clip_min, + normalize_attribute::include_max, + normalize_attribute::use_len, + normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + std::string name() const { return "slice"; } auto fix_index(const std::vector& lens, std::size_t axis, int64_t index) const @@ -61,16 +77,24 @@ struct slice return offset; } - shape compute_shape(std::vector inputs) const + shape normalize_compute_shape(std::vector inputs) const { auto input_shape = inputs[0]; auto t = input_shape.type(); const auto& old_lens = input_shape.lens(); const auto& old_strides = input_shape.strides(); + + if(std::any_of( + axes.begin(), axes.end(), [&](auto i) { return (i >= old_lens.size() and i < 0); })) + { + MIGRAPHX_THROW("SLICE: input axis " + to_string_range(axes) + " out of range"); + } + if(starts.size() != axes.size() || axes.size() != ends.size()) { - MIGRAPHX_THROW("inconsistent sizes"); + MIGRAPHX_THROW("SLICE: inconsistent sizes"); } + std::vector new_lens = old_lens; for(std::size_t i = 0; i < axes.size(); i++) { @@ -80,6 +104,7 @@ struct slice } return shape{t, new_lens, old_strides}; } + argument compute(shape output_shape, std::vector args) const { auto input = args[0]; diff --git a/src/include/migraphx/op/softmax.hpp b/src/include/migraphx/op/softmax.hpp index 83875ffa57e1b28dafce398c295e7575a653bb37..91e904a2387d07f478fa96250008b3bc69bf43f3 100644 --- a/src/include/migraphx/op/softmax.hpp +++ b/src/include/migraphx/op/softmax.hpp @@ -1,8 +1,9 @@ #ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP #define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP -#include #include +#include +#include #include namespace migraphx { @@ -11,7 +12,7 @@ namespace op { struct softmax { - int axis = 1; + int64_t axis = 1; template static auto reflect(Self& self, F f) @@ -19,16 +20,26 @@ struct softmax return pack(f(self.axis, "axis")); } + value attributes() const + { + value normalize; + normalize["axis"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + std::string name() const { return "softmax"; } - shape compute_shape(std::vector inputs) const + shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs}.has(1).standard(); - if(axis < 0 || axis >= inputs[0].lens().size()) + check_shapes{inputs, *this}.has(1); + if(inputs.at(0).packed()) + { + return inputs.at(0); + } + else { - MIGRAPHX_THROW("SoftMax: input axis value " + std::to_string(axis) + - " is out of range"); + auto lens = inputs.at(0).lens(); + return {inputs.at(0).type(), lens}; } - return inputs.at(0); } auto output() const diff --git a/src/include/migraphx/op/sqdiff.hpp b/src/include/migraphx/op/sqdiff.hpp old mode 100644 new mode 100755 index 1dd74fe39342f18bafa2337fdc037b4cf6ce4e74..0821d6bbccd935d4a996eb8208f4b4503699abf2 --- a/src/include/migraphx/op/sqdiff.hpp +++ b/src/include/migraphx/op/sqdiff.hpp @@ -9,6 +9,7 @@ namespace op { struct sqdiff : binary { + std::string point_op() const { return "(${0} - ${1}) * (${0} - ${1})"; } auto apply() const { return [](auto x, auto y) { return (x - y) * (x - y); }; diff --git a/src/include/migraphx/op/squeeze.hpp b/src/include/migraphx/op/squeeze.hpp index 2808a5ceaf5848f31a88ef7647aeb4e85018c26d..57ebf8e550e4d656083e676db2eb52aac0a8f9ee 100644 --- a/src/include/migraphx/op/squeeze.hpp +++ b/src/include/migraphx/op/squeeze.hpp @@ -2,13 +2,14 @@ #define MIGRAPHX_GUARD_OPERATORS_SQUEEZE_HPP #include -#include #include #include #include -#include #include #include +#include +#include +#include #include #include @@ -26,49 +27,62 @@ struct squeeze return pack(f(self.axes, "axes")); } + value attributes() const + { + value normalize; + normalize["axes"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + std::string name() const { return "squeeze"; } - shape compute_shape(std::vector inputs) const + shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(1).standard(); + check_shapes{inputs, *this}.has(1); auto input_shape = inputs[0]; auto type = input_shape.type(); auto old_lens = input_shape.lens(); - if(std::any_of( - axes.begin(), axes.end(), [&](auto axis) { return input_shape.lens()[axis] != 1; })) + auto old_strides = input_shape.strides(); + if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; })) { MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); } std::vector new_lens; + std::vector new_strides; if(axes.empty()) { - std::copy_if(old_lens.begin(), - old_lens.end(), - std::back_inserter(new_lens), - [](auto len) { return len != 1; }); + for(auto i : range(old_lens.size())) + { + if(old_lens[i] != 1) + { + new_lens.push_back(old_lens[i]); + new_strides.push_back(old_strides[i]); + } + } } else { - for(std::size_t i = 0; i < old_lens.size(); i++) + for(auto i : range(old_lens.size())) { if(std::find(axes.begin(), axes.end(), i) == axes.end()) { new_lens.push_back(old_lens[i]); + new_strides.push_back(old_strides[i]); } } } - if(new_lens.empty()) { return shape{type}; } else { - return shape{type, new_lens}; + return shape{type, new_lens, new_strides}; } } + argument compute(shape output_shape, std::vector args) const { - return {std::move(output_shape), std::move(args.front().data)}; + return args[0].reshape(output_shape); } std::ptrdiff_t output_alias(const std::vector&) const { return 0; } }; diff --git a/src/include/migraphx/op/step.hpp b/src/include/migraphx/op/step.hpp new file mode 100755 index 0000000000000000000000000000000000000000..9c20ce6ca58fbf0e74d80767dc421c73cd17870c --- /dev/null +++ b/src/include/migraphx/op/step.hpp @@ -0,0 +1,82 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_STEP_HPP +#define MIGRAPHX_GUARD_OPERATORS_STEP_HPP + +#include "migraphx/stringutils.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct step +{ + std::vector axes; + std::vector steps; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.axes, "axes"), f(self.steps, "steps")); + } + + value attributes() const + { + value normalize; + normalize["axes"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + + std::string name() const { return "step"; } + shape normalize_compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(1); + auto input = inputs.at(0); + auto in_lens = input.lens(); + auto t = input.type(); + + if(axes.size() != steps.size()) + { + MIGRAPHX_THROW("STEP: attribute axes {" + to_string_range(axes) + + "} has different dimensions from step {" + to_string_range(steps) + + "}."); + } + + if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return axis >= in_lens.size(); })) + { + MIGRAPHX_THROW("STEP: axis value is out of range!"); + } + + auto lens = in_lens; + auto strides = input.strides(); + for(auto i : range(axes.size())) + { + auto axis = axes[i]; + auto step = steps[i]; + lens[axis] = (in_lens[axis] + step - 1) / step; + strides[axis] *= step; + } + + return {t, lens, strides}; + } + + argument compute(shape output_shape, std::vector args) const + { + return args[0].reshape(output_shape); + } + + std::ptrdiff_t output_alias(const std::vector&) const { return 0; } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/sub.hpp b/src/include/migraphx/op/sub.hpp old mode 100644 new mode 100755 index 0b3557ffef9d0bb894e1888e47ea6c8fbceab367..9c1742678a894d5f195e034e9d6ee72902e076c0 --- a/src/include/migraphx/op/sub.hpp +++ b/src/include/migraphx/op/sub.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -19,6 +18,7 @@ namespace op { struct sub : binary { + std::string point_function() const { return "-"; } auto apply() const { return [](auto x, auto y) { return x - y; }; diff --git a/src/include/migraphx/op/tan.hpp b/src/include/migraphx/op/tan.hpp index 5f78cc7184800632db6f11607af388a45aba0c79..94db2f41e219b7bd05f3a286313849fec0e6613f 100644 --- a/src/include/migraphx/op/tan.hpp +++ b/src/include/migraphx/op/tan.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include diff --git a/src/include/migraphx/op/tanh.hpp b/src/include/migraphx/op/tanh.hpp index a9753e3cb388cb3413f09c4c41dca4f9fab4931f..fa3f69020072ac0421e2ef25154b2b7d4ccb9aa6 100644 --- a/src/include/migraphx/op/tanh.hpp +++ b/src/include/migraphx/op/tanh.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include diff --git a/src/include/migraphx/op/topk.hpp b/src/include/migraphx/op/topk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..af9b4569e23f6fc8f12a4fc097d8d922180abb77 --- /dev/null +++ b/src/include/migraphx/op/topk.hpp @@ -0,0 +1,142 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_GATHER_HPP +#define MIGRAPHX_GUARD_OPERATORS_GATHER_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct topk +{ + int64_t k = 1; + int64_t axis = 0; + bool largest = true; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.k, "k"), f(self.axis, "axis"), f(self.largest, "largest")); + } + + value attributes() const + { + value normalize; + normalize["axis"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + + std::string name() const { return "topk"; } + + shape normalize_compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(1).standard(); + auto lens = inputs.at(0).lens(); + auto type = inputs.at(0).type(); + + lens[axis] = k; + + shape s_val{type, lens}; + shape s_ind{shape::int64_type, lens}; + + return {{s_val, s_ind}}; + } + + template + struct heap_vector + { + std::vector data; + Compare compare; + + heap_vector(const std::vector& val, Compare comp) : data(val), compare(std::move(comp)) + { + std::make_heap(data.begin(), data.end(), compare); + } + + void try_push(T val) + { + if(not compare(val, data.front())) + return; + + std::pop_heap(data.begin(), data.end(), compare); + data.back() = val; + std::push_heap(data.begin(), data.end(), compare); + } + + std::vector sort() + { + auto sorted_data = data; + std::sort_heap(sorted_data.begin(), sorted_data.end(), compare); + return sorted_data; + } + }; + + template + heap_vector make_heap(std::vector val, Compare compare) const + { + return {std::move(val), std::move(compare)}; + } + + argument compute(const shape& output_shape, std::vector args) const + { + auto vec_ss = output_shape.sub_shapes(); + argument res_val{vec_ss.front()}; + argument res_ind{vec_ss.back()}; + auto in_s = args.front().get_shape(); + auto out_s = vec_ss.front(); + auto comp_lens = in_s.lens(); + auto axis_dim = comp_lens[axis]; + + // compute shape + comp_lens[axis] = 1; + shape comp_s{in_s.type(), comp_lens}; + visit_all(res_val, args.front())([&](auto out_val, auto input) { + auto* out_ind = res_ind.cast(); + par_for(comp_s.elements(), [&](auto i) { + auto idx = comp_s.multi(i); + std::vector indices(k); + std::iota(indices.begin(), indices.end(), 0); + + auto comp = [&](auto i1, auto i2) { + auto idx1 = idx; + auto idx2 = idx; + idx1[axis] = i1; + idx2[axis] = i2; + return this->largest + ? std::greater<>{}(input[in_s.index(idx1)], input[in_s.index(idx2)]) + : std::less<>{}(input[in_s.index(idx1)], input[in_s.index(idx2)]); + }; + + auto hp = this->make_heap(indices, comp); + for(std::size_t ii = indices.size(); ii < axis_dim; ++ii) + { + hp.try_push(ii); + } + auto sorted_indices = hp.sort(); + auto out_idx = idx; + auto in_idx = idx; + for(auto j : range(sorted_indices.size())) + { + out_idx[axis] = j; + in_idx[axis] = sorted_indices[j]; + out_val[out_s.index(out_idx)] = input[in_s.index(in_idx)]; + out_ind[out_s.index(out_idx)] = sorted_indices[j]; + } + }); + }); + + return {{res_val, res_ind}}; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/transpose.hpp b/src/include/migraphx/op/transpose.hpp old mode 100644 new mode 100755 index c811eeb080240e0b3ba188cf8329dfa7f21cb3c3..aed48c4d7fb5a12f5471bc6371f0490ff43eca9e --- a/src/include/migraphx/op/transpose.hpp +++ b/src/include/migraphx/op/transpose.hpp @@ -2,13 +2,11 @@ #define MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP #include -#include #include -#include -#include -#include -#include +#include +#include #include +#include #include #include @@ -23,7 +21,7 @@ struct transpose template static auto reflect(Self& self, F f) { - return pack(f(self.dims, "dims")); + return pack(f(self.dims, "permutation")); } std::string name() const { return "transpose"; } @@ -34,6 +32,7 @@ struct transpose auto input_lens = input.lens(); auto input_strides = input.strides(); auto t = input.type(); + if(dims.size() != input_lens.size()) { MIGRAPHX_THROW("Permutation has wrong number of axes"); @@ -42,7 +41,7 @@ struct transpose std::iota(axes.begin(), axes.end(), 0); if(!std::is_permutation(axes.begin(), axes.end(), dims.begin())) { - MIGRAPHX_THROW("Invalid permutation"); + MIGRAPHX_THROW("TRANSPOSE: Invalid permutation"); } std::vector output_lens(input_lens.size()); std::vector output_strides(input_lens.size()); @@ -55,7 +54,7 @@ struct transpose } argument compute(shape output_shape, std::vector args) const { - return {std::move(output_shape), std::move(args.front().data)}; + return args[0].reshape(output_shape); } std::ptrdiff_t output_alias(const std::vector&) const { return 0; } }; diff --git a/src/include/migraphx/op/unary.hpp b/src/include/migraphx/op/unary.hpp index 361e945eece4e2f8bd4cda40ead40415dbb3f48c..27f727946189020ee407c5c48aa8b6fa6c0a2673 100644 --- a/src/include/migraphx/op/unary.hpp +++ b/src/include/migraphx/op/unary.hpp @@ -2,6 +2,11 @@ #define MIGRAPHX_GUARD_OPERATORS_UNARY_HPP #include +#include +#include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -10,52 +15,57 @@ namespace op { template struct unary : op_name { + std::string point_function() const { return this->name(); } + std::string point_op() const + { + const auto& self = static_cast(*this); + auto pf = self.point_function(); + if(pf.empty()) + return {}; + if(with_char(::ispunct)(pf.front())) + { + return pf + "${0}"; + } + else + { + return "${function:" + pf + "}(${0})"; + } + } + value base_attributes() const + { + const auto& self = static_cast(*this); + return {{"pointwise", true}, {"point_op", self.point_op()}}; + } + value attributes() const { return base_attributes(); } shape compute_shape(std::vector inputs) const { - check_shapes{inputs}.has(1); + check_shapes{inputs, static_cast(*this)}.has(1); auto s = inputs.at(0); - if(s.packed()) + if(s.scalar()) { return s; } - else + else if(s.broadcasted()) { return {s.type(), s.lens()}; } + else + { + return s.with_lens(s.lens()); + } } argument compute(const shape& output_shape, std::vector args) const { argument result{output_shape}; - auto in_shape = args[0].get_shape(); - if(in_shape.packed()) - { - shape std_in_shape{in_shape.type(), in_shape.lens()}; - shape std_out_shape{output_shape.type(), output_shape.lens()}; - argument arg_in{std_in_shape, args[0].data()}; - argument arg_out{std_out_shape, result.data()}; - arg_out.visit([&](auto output) { - arg_in.visit([&](auto input) { - std::transform(input.begin(), - input.end(), - output.begin(), - static_cast(*this).apply()); - - }); - }); - } - else - { - result.visit([&](auto output) { - args[0].visit([&](auto input) { - shape_for_each(output.get_shape(), [&](const auto& idx) { - output(idx.begin(), idx.end()) = static_cast(*this).apply()( - input(idx.begin(), idx.end())); - }); - }); + result.visit([&](auto output) { + args[0].visit([&](auto input) { + std::transform(input.begin(), + input.end(), + output.begin(), + static_cast(*this).apply()); }); - } - + }); return result; } }; diff --git a/src/include/migraphx/op/unary_not.hpp b/src/include/migraphx/op/unary_not.hpp new file mode 100755 index 0000000000000000000000000000000000000000..5ea3b6e5303e0d873b06094a99fa66b9ba4f74a3 --- /dev/null +++ b/src/include/migraphx/op/unary_not.hpp @@ -0,0 +1,28 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_NOT_HPP +#define MIGRAPHX_GUARD_OPERATORS_UNARY_NOT_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct unary_not : unary +{ + std::string point_function() const { return "!"; } + auto apply() const + { + return [](auto x) { return not x; }; + } + + std::string name() const { return "not"; } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/undefined.hpp b/src/include/migraphx/op/undefined.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0d3fc084ef3db880611b0bb44ea4117956e2a143 --- /dev/null +++ b/src/include/migraphx/op/undefined.hpp @@ -0,0 +1,28 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_UNDEFINED_HPP +#define MIGRAPHX_GUARD_RTGLIB_UNDEFINED_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct undefined +{ + std::string name() const { return "undefined"; } + shape compute_shape(const std::vector& inputs) const + { + check_shapes{inputs, *this}.has(0); + return {}; + } + + argument compute(const shape&, const std::vector&) const { return {{}, nullptr}; } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/unknown.hpp b/src/include/migraphx/op/unknown.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6791e8dce7a237757a0900c04181f7d6d6dde1e4 --- /dev/null +++ b/src/include/migraphx/op/unknown.hpp @@ -0,0 +1,40 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_UNKNOWN_HPP +#define MIGRAPHX_GUARD_RTGLIB_UNKNOWN_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct unknown +{ + std::string op; + template + static auto reflect(Self& self, F f) + { + return pack(f(self.op, "op")); + } + std::string name() const { return "unknown:" + op; } + shape compute_shape(std::vector input) const + { + if(input.empty()) + return {}; + else + return input.front(); + } + + friend std::ostream& operator<<(std::ostream& os, const unknown& x) + { + os << x.name(); + return os; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/unsqueeze.hpp b/src/include/migraphx/op/unsqueeze.hpp index b7218a05d37771aacdc071dfe2651a4c4c011655..c881f19df0687296d8f08f938f63b0e03252462b 100644 --- a/src/include/migraphx/op/unsqueeze.hpp +++ b/src/include/migraphx/op/unsqueeze.hpp @@ -2,13 +2,13 @@ #define MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP #include -#include #include #include #include -#include #include #include +#include +#include #include #include @@ -26,36 +26,60 @@ struct unsqueeze return pack(f(self.axes, "axes")); } + value attributes() const + { + value normalize; + normalize["axes"] = + value::array{normalize_attribute::include_min, normalize_attribute::use_output}; + return {{"normalize_axes", normalize}}; + } + std::string name() const { return "unsqueeze"; } - shape compute_shape(std::vector inputs) const + shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(1).standard_or_scalar(); + check_shapes{inputs, *this}.has(1); auto input_shape = inputs[0]; auto type = input_shape.type(); auto old_lens = input_shape.lens(); - + auto old_strides = input_shape.strides(); if(input_shape.scalar()) - return shape{type, old_lens}; + { + if(old_lens.size() == 1 and old_lens.front() == 1) + return shape{type, old_lens}; + else + MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar"); + } std::size_t new_size = old_lens.size() + axes.size(); + std::vector new_lens(new_size); + std::vector new_strides(new_size); std::size_t p = 0; - for(std::size_t i = 0; i < new_size; i++) + for(auto i : range(new_size)) { if(std::find(axes.begin(), axes.end(), i) != axes.end()) { new_lens[i] = 1; + if(p == 0) // unsqueeze on the first axes + { + new_strides[i] = old_lens[0] * old_strides[0]; + } + else // unsqueeze on middle or last axes + { + new_strides[i] = (p < old_strides.size()) ? old_strides[p - 1] : 1; + } } else { - new_lens[i] = old_lens[p++]; + new_lens[i] = old_lens[p]; + new_strides[i] = old_strides[p++]; } } - return shape{type, new_lens}; + return shape{type, new_lens, new_strides}; } argument compute(shape output_shape, std::vector args) const { - return {std::move(output_shape), std::move(args.front().data)}; + return args[0].reshape(output_shape); } std::ptrdiff_t output_alias(const std::vector&) const { return 0; } }; diff --git a/src/include/migraphx/op/where.hpp b/src/include/migraphx/op/where.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ff072fd5a9404af4f39df913f6ad63730dbad555 --- /dev/null +++ b/src/include/migraphx/op/where.hpp @@ -0,0 +1,68 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_WHERE_HPP +#define MIGRAPHX_GUARD_OPERATORS_WHERE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct where +{ + std::string name() const { return "where"; } + + value attributes() const { return {{"pointwise", true}, {"point_op", "${0} ? ${1} : ${2}"}}; } + + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.has(3).same_dims(); + auto s1 = inputs.at(1); + auto s2 = inputs.at(2); + if(s1 == s2 and s1.packed()) + { + return s1; + } + else if(s1.packed() != s2.packed()) + { + return s1.packed() ? s1 : s2; + } + else if(s1.broadcasted() != s2.broadcasted()) + { + return s1.broadcasted() ? s2.with_lens(s1.lens()) : s1.with_lens(s1.lens()); + } + else + { + return {s1.type(), s1.lens()}; + } + } + + argument compute(const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + visit_all(result, args[1], args[2])([&](auto output, const auto x, const auto y) { + args[0].visit([&](const auto condition) { + par_for(output_shape.elements(), + [&](auto i) { output[i] = condition[i] ? x[i] : y[i]; }); + }); + }); + + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/operation.hpp b/src/include/migraphx/operation.hpp index e0fef369e949e1467e0a1b07edb84a8ac1ee7648..24d8bbe7e9d9eb28337d1e1f7ea09efb20498043 100644 --- a/src/include/migraphx/operation.hpp +++ b/src/include/migraphx/operation.hpp @@ -7,10 +7,15 @@ #include #include #include +#include #include #include +#include #include +#include +#include #include +#include #include namespace migraphx { @@ -57,6 +62,8 @@ struct operation /// Returns true if operation does not require a context to run compute bool is_context_free(const operation& x); +/// Returns true if operation needs normalization before running compute +bool need_normalization(const operation& x); /// Returns true if the operation has a finalize method bool has_finalize(const operation& x); @@ -96,7 +103,73 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) } // namespace operation_operators template -auto compute_op(rank<2>, +auto compute_shape_op(rank<3>, const T& x, const std::vector& inputs) + -> decltype(x.compute_shape(inputs)) +{ + return x.compute_shape(inputs); +} + +template +auto compute_shape_op(rank<2>, const T& x, const std::vector& inputs) + -> decltype(x.normalize_compute_shape(inputs)) +{ + dependent_type y = x; + normalize_attributes(y, inputs[0].lens()); + return any_cast(y).normalize_compute_shape(inputs); +} + +template +auto compute_shape_op(rank<1>, const T& x, const std::vector& inputs) + -> decltype(x.compute_shape(inputs, {})) +{ + return x.compute_shape(inputs, {}); +} + +template +shape compute_shape_op(rank<0>, const T& x, const std::vector&) +{ + std::string name = x.name(); + MIGRAPHX_THROW("Shape not computable: " + name); +} + +template +shape compute_shape_op(const T& x, const std::vector& inputs) +{ + return compute_shape_op(rank<3>{}, x, inputs); +} + +template +auto mod_compute_shape_op(rank<1>, + const T& x, + const std::vector& inputs, + const std::vector& mod_args) + -> decltype(x.compute_shape(inputs, mod_args)) +{ + return x.compute_shape(inputs, mod_args); +} + +template +shape mod_compute_shape_op(rank<0>, + const T& x, + const std::vector& inputs, + const std::vector& mod_args) +{ + if(mod_args.empty()) + return compute_shape_op(x, inputs); + std::string name = x.name(); + MIGRAPHX_THROW("Shape not computable: " + name); +} + +template +shape mod_compute_shape_op(const T& x, + const std::vector& inputs, + const std::vector& mod_args) +{ + return mod_compute_shape_op(rank<1>{}, x, inputs, mod_args); +} + +template +auto compute_op(rank<1>, const T& x, context& ctx, const shape& output_shape, @@ -106,14 +179,6 @@ auto compute_op(rank<2>, return x.compute(auto_any_cast(ctx), output_shape, input); } -template -auto compute_op( - rank<1>, const T& x, context&, const shape& output_shape, const std::vector& input) - -> decltype(x.compute(output_shape, input)) -{ - return x.compute(output_shape, input); -} - template argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector&) { @@ -125,35 +190,132 @@ template argument compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector& input) { - return compute_op(rank<2>{}, x, ctx, output_shape, input); + return compute_op(rank<1>{}, x, ctx, output_shape, input); } template -auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector& input) +auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector& input) -> decltype(x.compute(output_shape, input)) { return x.compute(output_shape, input); } template -auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector& input) - -> decltype(x.compute(auto_any_cast(std::declval()), output_shape, input)) +argument compute_op(rank<0>, const T& x, const shape&, const std::vector&) { std::string name = x.name(); - MIGRAPHX_THROW("Not computable without a context: " + name); + MIGRAPHX_THROW("Not computable: " + name); } template -argument compute_op(rank<0>, const T& x, const shape&, const std::vector&) +argument compute_op(const T& x, const shape& output_shape, const std::vector& input) +{ + return compute_op(rank<1>{}, x, output_shape, input); +} + +template +auto compute_op(rank<1>, + const T& x, + const shape& output, + const std::vector& inputs, + const std::vector& module_args, + F f) -> decltype(x.compute(output, inputs, module_args, f)) +{ + return x.compute(output, inputs, module_args, f); +} + +template +argument compute_op(rank<0>, + const T& x, + const shape&, + const std::vector&, + const std::vector&, + F) { std::string name = x.name(); MIGRAPHX_THROW("Not computable: " + name); } -template -argument compute_op(const T& x, const shape& output_shape, const std::vector& input) +template +argument compute_op(const T& x, + const shape& output, + const std::vector& inputs, + const std::vector& module_args, + F f) +{ + return compute_op(rank<1>{}, x, output, inputs, module_args, f); +} + +template +auto compute_op(rank<4>, + const T& x, + context& ctx, + const shape& output, + const std::vector& inputs, + const std::vector& module_args, + F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f)) +{ + return x.compute(auto_any_cast(ctx), output, inputs, module_args, f); +} + +template +auto compute_op(rank<3>, + const T& x, + context&, + const shape& output, + const std::vector& inputs, + const std::vector& module_args, + F f) -> decltype(x.compute(output, inputs, module_args, f)) { - return compute_op(rank<2>{}, x, output_shape, input); + return x.compute(output, inputs, module_args, f); +} + +template +auto compute_op(rank<2>, + const T& x, + context&, + const shape& output, + const std::vector& inputs, + const std::vector&, + F) -> decltype(x.compute(output, inputs)) +{ + return x.compute(output, inputs); +} + +template +auto compute_op(rank<1>, + const T& x, + context& ctx, + const shape& output, + const std::vector& inputs, + const std::vector&, + F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs)) +{ + return x.compute(auto_any_cast(ctx), output, inputs); +} + +template +argument compute_op(rank<0>, + const T& x, + context&, + const shape&, + const std::vector&, + const std::vector&, + F) +{ + std::string name = x.name(); + MIGRAPHX_THROW("Not computable: " + name); +} + +template +argument compute_op(const T& x, + context& ctx, + const shape& output, + const std::vector& inputs, + const std::vector& module_args, + F f) +{ + return compute_op(rank<4>{}, x, ctx, output, inputs, module_args, f); } template @@ -174,6 +336,20 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op( return {}; } +template +auto need_normalization_op(rank<1>, const T& x, const std::vector& inputs) + -> decltype(x.normalize_compute_shape(inputs), std::true_type{}); + +template +auto need_normalization_op(rank<0>, const T&, const std::vector&) -> std::false_type; + +template +auto need_normalization_op(const T& x) + -> decltype(need_normalization_op(rank<1>{}, x, std::declval>())) +{ + return {}; +} + template std::ptrdiff_t output_alias_op(const T&, const std::vector&) { @@ -218,26 +394,113 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, return {}; } +template +auto compile_op( + rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector& input) + -> decltype(x.compile(auto_any_cast(ctx), output_shape, input)) +{ + return x.compile(auto_any_cast(ctx), output_shape, input); +} + +template +value compile_op(rank<0>, T&, context&, const shape&, const std::vector&) +{ + return value::object{}; +} + +template +value compile_op(const T& x, + context& ctx, + const shape& output_shape, + const std::vector& input) +{ + return compile_op(rank<1>{}, x, ctx, output_shape, input); +} + +template +value attributes_op(const T&) +{ + return value::object{}; +} + +template +value to_value_op(const T& x) +{ + return migraphx::to_value(x); +} + +template +void from_value_op(T& x, const value& v) +{ + if(not(v.is_object() or (v.empty() and v.is_array()))) + MIGRAPHX_THROW("Value is not an object"); + return migraphx::from_value(v, x); +} + +template +lifetime get_lifetime_op(const T&) +{ + return lifetime::local; +} + } // namespace detail -/* - * Type-erased interface for: - * - * struct operation - * { - * std::string name() const; - * bool is_context_free() const; - * bool has_finalize() const; - * std::ptrdiff_t output_alias(const std::vector& input) const; - * void finalize(context& ctx,const shape& output,const std::vector& input) ; - * shape compute_shape(const std::vector& input) const; - * argument compute(context& ctx,const shape& output,const std::vector& input) const; - * argument compute(const shape& output,const std::vector& input) const; - * friend std::ostream & operator<<(std::ostream & os,const operation & op) ; - * friend bool operator==(const operation & x,const operation & y) ; - * }; - * - */ +#ifdef TYPE_ERASED_DECLARATION + +// Type-erased interface for: +struct operation +{ + // + std::string name() const; + // (optional) + bool is_context_free() const; + // (optional) + bool need_normalization() const; + // (optional) + bool has_finalize() const; + // (optional) + lifetime get_lifetime() const; + // (optional) + std::ptrdiff_t output_alias(const std::vector& input) const; + // (optional) + value compile(context& ctx, const shape& output, const std::vector& input); + // (optional) + void finalize(context& ctx, const shape& output, const std::vector& input); + // (optional) + shape compute_shape(const std::vector& input) const; + // (optional) + shape compute_shape(const std::vector& inputs, + const std::vector& mod_args) const; + // (optional) + argument compute(context& ctx, const shape& output, const std::vector& input) const; + // (optional) + argument compute(const shape& output, const std::vector& input) const; + // (optional) + argument compute(const shape& output, + const std::vector& input, + const std::vector& module_args, + std::function( + module_ref&, const std::unordered_map&)> run) const; + // (optional) + argument compute(context& ctx, + const shape& output, + const std::vector& input, + const std::vector& module_args, + std::function( + module_ref&, const std::unordered_map&)> run) const; + // (optional) + value to_value() const; + // (optional) + void from_value(const value& v); + // (optional) + value attributes() const; + // + friend std::ostream& operator<<(std::ostream& os, const operation& op); + // + friend bool operator==(const operation& x, const operation& y); +}; + +#else struct operation { @@ -257,11 +520,17 @@ struct operation template operation& operator=(PrivateDetailTypeErasedT value) { - if(private_detail_te_handle_mem_var.unique()) - *private_detail_te_handle_mem_var = std::forward(value); - else if(!private_detail_te_handle_mem_var) - private_detail_te_handle_mem_var = std::make_shared( - std::forward(value)); + using std::swap; + auto* derived = this->any_cast(); + if(derived and private_detail_te_handle_mem_var.unique()) + { + *derived = std::forward(value); + } + else + { + operation rhs(value); + swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var); + } return *this; } @@ -269,7 +538,7 @@ struct operation template PrivateDetailTypeErasedT* any_cast() { - return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) + return this->type_id() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type>&>( private_detail_te_get_handle()) @@ -280,7 +549,7 @@ struct operation template const typename std::remove_cv::type* any_cast() const { - return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) + return this->type_id() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type>&>( private_detail_te_get_handle()) @@ -308,18 +577,36 @@ struct operation return (*this).private_detail_te_get_handle().is_context_free(); } + bool need_normalization() const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().need_normalization(); + } + bool has_finalize() const { assert((*this).private_detail_te_handle_mem_var); return (*this).private_detail_te_get_handle().has_finalize(); } + lifetime get_lifetime() const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().get_lifetime(); + } + std::ptrdiff_t output_alias(const std::vector& input) const { assert((*this).private_detail_te_handle_mem_var); return (*this).private_detail_te_get_handle().output_alias(input); } + value compile(context& ctx, const shape& output, const std::vector& input) + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().compile(ctx, output, input); + } + void finalize(context& ctx, const shape& output, const std::vector& input) { assert((*this).private_detail_te_handle_mem_var); @@ -332,6 +619,13 @@ struct operation return (*this).private_detail_te_get_handle().compute_shape(input); } + shape compute_shape(const std::vector& inputs, + const std::vector& mod_args) const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().compute_shape(inputs, mod_args); + } + argument compute(context& ctx, const shape& output, const std::vector& input) const { assert((*this).private_detail_te_handle_mem_var); @@ -344,6 +638,47 @@ struct operation return (*this).private_detail_te_get_handle().compute(output, input); } + argument compute(const shape& output, + const std::vector& input, + const std::vector& module_args, + std::function( + module_ref&, const std::unordered_map&)> run) const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().compute( + output, input, module_args, std::move(run)); + } + + argument compute(context& ctx, + const shape& output, + const std::vector& input, + const std::vector& module_args, + std::function( + module_ref&, const std::unordered_map&)> run) const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().compute( + ctx, output, input, module_args, std::move(run)); + } + + value to_value() const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().to_value(); + } + + void from_value(const value& v) + { + assert((*this).private_detail_te_handle_mem_var); + (*this).private_detail_te_get_handle().from_value(v); + } + + value attributes() const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().attributes(); + } + friend std::ostream& operator<<(std::ostream& os, const operation& op) { assert(op.private_detail_te_handle_mem_var); @@ -371,16 +706,38 @@ struct operation virtual std::string name() const = 0; virtual bool is_context_free() const = 0; + virtual bool need_normalization() const = 0; virtual bool has_finalize() const = 0; + virtual lifetime get_lifetime() const = 0; virtual std::ptrdiff_t output_alias(const std::vector& input) const = 0; + virtual value + compile(context& ctx, const shape& output, const std::vector& input) = 0; virtual void finalize(context& ctx, const shape& output, const std::vector& input) = 0; virtual shape compute_shape(const std::vector& input) const = 0; + virtual shape compute_shape(const std::vector& inputs, + const std::vector& mod_args) const = 0; virtual argument compute(context& ctx, const shape& output, const std::vector& input) const = 0; virtual argument compute(const shape& output, const std::vector& input) const = 0; - virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; - virtual bool operator==(const operation& y) const = 0; + virtual argument + compute(const shape& output, + const std::vector& input, + const std::vector& module_args, + std::function( + module_ref&, const std::unordered_map&)> run) const = 0; + virtual argument + compute(context& ctx, + const shape& output, + const std::vector& input, + const std::vector& module_args, + std::function( + module_ref&, const std::unordered_map&)> run) const = 0; + virtual value to_value() const = 0; + virtual void from_value(const value& v) = 0; + virtual value attributes() const = 0; + virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; + virtual bool operator==(const operation& y) const = 0; }; template @@ -396,6 +753,19 @@ struct operation return detail::is_context_free_op(private_detail_te_self); } + template + static auto private_detail_te_default_need_normalization(char, T&& private_detail_te_self) + -> decltype(private_detail_te_self.need_normalization()) + { + return private_detail_te_self.need_normalization(); + } + + template + static bool private_detail_te_default_need_normalization(float, T&& private_detail_te_self) + { + return detail::need_normalization_op(private_detail_te_self); + } + template static auto private_detail_te_default_has_finalize(char, T&& private_detail_te_self) -> decltype(private_detail_te_self.has_finalize()) @@ -409,6 +779,19 @@ struct operation return detail::has_finalize_op(private_detail_te_self); } + template + static auto private_detail_te_default_get_lifetime(char, T&& private_detail_te_self) + -> decltype(private_detail_te_self.get_lifetime()) + { + return private_detail_te_self.get_lifetime(); + } + + template + static lifetime private_detail_te_default_get_lifetime(float, T&& private_detail_te_self) + { + return detail::get_lifetime_op(private_detail_te_self); + } + template static auto private_detail_te_default_output_alias(char, T&& private_detail_te_self, @@ -426,6 +809,27 @@ struct operation return detail::output_alias_op(private_detail_te_self, input); } + template + static auto private_detail_te_default_compile(char, + T&& private_detail_te_self, + context& ctx, + const shape& output, + const std::vector& input) + -> decltype(private_detail_te_self.compile(ctx, output, input)) + { + return private_detail_te_self.compile(ctx, output, input); + } + + template + static value private_detail_te_default_compile(float, + T&& private_detail_te_self, + context& ctx, + const shape& output, + const std::vector& input) + { + return detail::compile_op(private_detail_te_self, ctx, output, input); + } + template static auto private_detail_te_default_finalize(char, T&& private_detail_te_self, @@ -447,6 +851,42 @@ struct operation detail::finalize_op(private_detail_te_self, ctx, output, input); } + template + static auto private_detail_te_default_compute_shape(char, + T&& private_detail_te_self, + const std::vector& input) + -> decltype(private_detail_te_self.compute_shape(input)) + { + return private_detail_te_self.compute_shape(input); + } + + template + static shape private_detail_te_default_compute_shape(float, + T&& private_detail_te_self, + const std::vector& input) + { + return detail::compute_shape_op(private_detail_te_self, input); + } + + template + static auto private_detail_te_default_compute_shape(char, + T&& private_detail_te_self, + const std::vector& inputs, + const std::vector& mod_args) + -> decltype(private_detail_te_self.compute_shape(inputs, mod_args)) + { + return private_detail_te_self.compute_shape(inputs, mod_args); + } + + template + static shape private_detail_te_default_compute_shape(float, + T&& private_detail_te_self, + const std::vector& inputs, + const std::vector& mod_args) + { + return detail::mod_compute_shape_op(private_detail_te_self, inputs, mod_args); + } + template static auto private_detail_te_default_compute(char, T&& private_detail_te_self, @@ -487,6 +927,105 @@ struct operation return detail::compute_op(private_detail_te_self, output, input); } + template + static auto private_detail_te_default_compute( + char, + T&& private_detail_te_self, + const shape& output, + const std::vector& input, + const std::vector& module_args, + std::function(module_ref&, + const std::unordered_map&)> run) + -> decltype(private_detail_te_self.compute(output, input, module_args, std::move(run))) + { + return private_detail_te_self.compute(output, input, module_args, std::move(run)); + } + + template + static argument private_detail_te_default_compute( + float, + T&& private_detail_te_self, + const shape& output, + const std::vector& input, + const std::vector& module_args, + std::function(module_ref&, + const std::unordered_map&)> run) + { + return detail::compute_op( + private_detail_te_self, output, input, module_args, std::move(run)); + } + + template + static auto private_detail_te_default_compute( + char, + T&& private_detail_te_self, + context& ctx, + const shape& output, + const std::vector& input, + const std::vector& module_args, + std::function(module_ref&, + const std::unordered_map&)> run) + -> decltype(private_detail_te_self.compute(ctx, output, input, module_args, std::move(run))) + { + return private_detail_te_self.compute(ctx, output, input, module_args, std::move(run)); + } + + template + static argument private_detail_te_default_compute( + float, + T&& private_detail_te_self, + context& ctx, + const shape& output, + const std::vector& input, + const std::vector& module_args, + std::function(module_ref&, + const std::unordered_map&)> run) + { + return detail::compute_op( + private_detail_te_self, ctx, output, input, module_args, std::move(run)); + } + + template + static auto private_detail_te_default_to_value(char, T&& private_detail_te_self) + -> decltype(private_detail_te_self.to_value()) + { + return private_detail_te_self.to_value(); + } + + template + static value private_detail_te_default_to_value(float, T&& private_detail_te_self) + { + return detail::to_value_op(private_detail_te_self); + } + + template + static auto + private_detail_te_default_from_value(char, T&& private_detail_te_self, const value& v) + -> decltype(private_detail_te_self.from_value(v)) + { + private_detail_te_self.from_value(v); + } + + template + static void + private_detail_te_default_from_value(float, T&& private_detail_te_self, const value& v) + { + detail::from_value_op(private_detail_te_self, v); + } + + template + static auto private_detail_te_default_attributes(char, T&& private_detail_te_self) + -> decltype(private_detail_te_self.attributes()) + { + return private_detail_te_self.attributes(); + } + + template + static value private_detail_te_default_attributes(float, T&& private_detail_te_self) + { + return detail::attributes_op(private_detail_te_self); + } + template struct private_detail_te_handle_type : private_detail_te_handle_base_type { @@ -523,18 +1062,37 @@ struct operation return private_detail_te_default_is_context_free(char(0), private_detail_te_value); } + bool need_normalization() const override + { + + return private_detail_te_default_need_normalization(char(0), private_detail_te_value); + } + bool has_finalize() const override { return private_detail_te_default_has_finalize(char(0), private_detail_te_value); } + lifetime get_lifetime() const override + { + + return private_detail_te_default_get_lifetime(char(0), private_detail_te_value); + } + std::ptrdiff_t output_alias(const std::vector& input) const override { return private_detail_te_default_output_alias(char(0), private_detail_te_value, input); } + value compile(context& ctx, const shape& output, const std::vector& input) override + { + + return private_detail_te_default_compile( + char(0), private_detail_te_value, ctx, output, input); + } + void finalize(context& ctx, const shape& output, const std::vector& input) override { @@ -545,7 +1103,15 @@ struct operation shape compute_shape(const std::vector& input) const override { - return private_detail_te_value.compute_shape(input); + return private_detail_te_default_compute_shape(char(0), private_detail_te_value, input); + } + + shape compute_shape(const std::vector& inputs, + const std::vector& mod_args) const override + { + + return private_detail_te_default_compute_shape( + char(0), private_detail_te_value, inputs, mod_args); } argument compute(context& ctx, @@ -564,6 +1130,49 @@ struct operation char(0), private_detail_te_value, output, input); } + argument compute( + const shape& output, + const std::vector& input, + const std::vector& module_args, + std::function( + module_ref&, const std::unordered_map&)> run) const override + { + + return private_detail_te_default_compute( + char(0), private_detail_te_value, output, input, module_args, std::move(run)); + } + + argument compute( + context& ctx, + const shape& output, + const std::vector& input, + const std::vector& module_args, + std::function( + module_ref&, const std::unordered_map&)> run) const override + { + + return private_detail_te_default_compute( + char(0), private_detail_te_value, ctx, output, input, module_args, std::move(run)); + } + + value to_value() const override + { + + return private_detail_te_default_to_value(char(0), private_detail_te_value); + } + + void from_value(const value& v) override + { + + private_detail_te_default_from_value(char(0), private_detail_te_value, v); + } + + value attributes() const override + { + + return private_detail_te_default_attributes(char(0), private_detail_te_value); + } + std::ostream& operator_shift_left(std::ostream& os) const override { using migraphx::detail::operation_operators::operator<<; @@ -640,9 +1249,72 @@ inline const ValueType& any_cast(const operation& x) throw std::bad_cast(); return *y; } +#endif inline bool operator!=(const operation& x, const operation& y) { return !(x == y); } +inline value +compile(operation& op, context& ctx, const shape& output_shape, const std::vector& input) +{ + return op.compile(ctx, output_shape, input); +} +template +inline value +compile(operation& op, Context& ctx, const shape& output_shape, const std::vector& input) +{ + dependent_type ctx2 = std::ref(ctx); + return compile(op, ctx2, output_shape, input); +} +template +inline auto compile(T& op, Context& ctx, const shape& output_shape, const std::vector& input) + -> decltype(op.compile(ctx, ctx, output_shape, input)) +{ + return op.compile(ctx, ctx, output_shape, input); +} +inline shape compute_shape(const operation& op, const std::vector& inputs) +{ + return op.compute_shape(inputs); +} + +template +inline auto compute_shape(const T& op, const std::vector& inputs) + -> decltype(op.compute_shape(inputs)) +{ + return op.compute_shape(inputs); +} + +template +inline auto compute_shape(const T& op, const std::vector& inputs) + -> decltype(op.normalize_compute_shape(inputs)) +{ + return detail::compute_shape_op(op, inputs); +} + +inline shape compute_shape(const operation& op, + const std::vector& inputs, + const std::vector& mod_args) +{ + return op.compute_shape(inputs, mod_args); +} + +template +inline auto compute_shape(const T& op, + const std::vector& inputs, + const std::vector& mod_args) + -> decltype(op.compute_shape(inputs, mod_args)) +{ + return op.compute_shape(inputs, mod_args); +} + +template +inline auto compute_shape(const T& op, + const std::vector& inputs, + const std::vector& mod_args) + -> decltype(op.normalize_compute_shape(inputs, mod_args)) +{ + return detail::compute_shape_op(op, inputs, mod_args); +} + inline bool is_context_free(const operation& op) { return op.is_context_free(); } template @@ -651,6 +1323,14 @@ bool is_context_free(const T& x) return detail::is_context_free_op(x); } +inline bool need_normalization(const operation& op) { return op.need_normalization(); } + +template +bool need_normalization(const T& x) +{ + return detail::need_normalization_op(x); +} + inline bool has_finalize(const operation& op) { return op.has_finalize(); } template @@ -659,6 +1339,9 @@ bool has_finalize(const T& x) return detail::has_finalize_op(x); } +void migraphx_to_value(value& v, const operation& op); +void migraphx_from_value(const value& v, operation& op); + #endif } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/operators.hpp b/src/include/migraphx/operators.hpp old mode 100644 new mode 100755 index 9954e52d867592404786209b3e72ffd9e4d7aa58..04f7a63fca796ddb779676f63e8023953f286b36 --- a/src/include/migraphx/operators.hpp +++ b/src/include/migraphx/operators.hpp @@ -1,16 +1,18 @@ #ifndef MIGRAPHX_GUARD_OPERATORS_HPP #define MIGRAPHX_GUARD_OPERATORS_HPP -#include #include #include +#include #include #include #include #include +#include #include #include -#include +#include +#include #include #include #include @@ -23,21 +25,33 @@ #include #include #include +#include #include #include #include +#include #include #include #include #include #include +#include +#include +#include #include #include +#include #include +#include #include +#include #include #include +#include +#include +#include #include +#include #include #include #include @@ -45,24 +59,40 @@ #include #include #include +#include +#include #include #include #include +#include +#include +#include #include #include -#include -#include +#include +#include #include #include -#include +#include +#include #include #include +#include #include #include -#include +#include +#include +#include +#include #include #include #include +#include +#include +#include +#include +#include +#include #include #include #include @@ -72,11 +102,17 @@ #include #include #include +#include #include #include #include +#include #include #include +#include +#include +#include #include +#include #endif diff --git a/src/include/migraphx/optional.hpp b/src/include/migraphx/optional.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e9ed8b1528a27da5add388c4fba166275a316b56 --- /dev/null +++ b/src/include/migraphx/optional.hpp @@ -0,0 +1,57 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP + +#include + +#if defined(CPPCHECK) +#define MIGRAPHX_HAS_OPTIONAL 1 +#define MIGRAPHX_HAS_OPTIONAL_TS 1 +#elif defined(__has_include) +#if __has_include() && __cplusplus >= 201703L +#define MIGRAPHX_HAS_OPTIONAL 1 +#else +#define MIGRAPHX_HAS_OPTIONAL 0 +#endif +#if __has_include() && __cplusplus >= 201103L +#define MIGRAPHX_HAS_OPTIONAL_TS 1 +#else +#define MIGRAPHX_HAS_OPTIONAL_TS 0 +#endif +#else +#define MIGRAPHX_HAS_OPTIONAL 0 +#define MIGRAPHX_HAS_OPTIONAL_TS 0 +#endif + +#if MIGRAPHX_HAS_OPTIONAL +#include +#elif MIGRAPHX_HAS_OPTIONAL_TS +#include +#else +#error "No optional include available" +#endif + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +#if MIGRAPHX_HAS_OPTIONAL +template +using optional = std::optional; +using nullopt_t = std::nullopt_t; +constexpr auto nullopt = std::nullopt; +#elif MIGRAPHX_HAS_OPTIONAL_TS +template +using optional = std::experimental::optional; +using nullopt_t = std::experimental::nullopt_t; +constexpr auto nullopt = std::experimental::nullopt; +#endif + +template +bool has_value(const optional& x) +{ + return x != nullopt; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif // MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP diff --git a/src/include/migraphx/output_iterator.hpp b/src/include/migraphx/output_iterator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e5e0b69695bc4bab50a1630553ab0dc3a56aeb0f --- /dev/null +++ b/src/include/migraphx/output_iterator.hpp @@ -0,0 +1,46 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +struct function_output_iterator +{ + F f; + + using self = function_output_iterator; + using difference_type = void; + using reference = void; + using value_type = void; + using pointer = void; + using iterator_category = std::output_iterator_tag; + + struct output_proxy + { + template + output_proxy& operator=(const T& value) + { + assert(f); + (*f)(value); + return *this; + } + F* f; + }; + output_proxy operator*() { return output_proxy{&f}; } + self& operator++() { return *this; } + self& operator++(int) { return *this; } // NOLINT +}; + +template +function_output_iterator make_function_output_iterator(F f) +{ + return {std::move(f)}; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP diff --git a/src/include/migraphx/pad_calc.hpp b/src/include/migraphx/pad_calc.hpp index 41cb904838d00903de49dd4e6d532c16be197c34..2751740dfc15415ec080c53ec1a689e7b92ccd94 100644 --- a/src/include/migraphx/pad_calc.hpp +++ b/src/include/migraphx/pad_calc.hpp @@ -13,14 +13,25 @@ inline void calculate_padding(int64_t idx, int64_t input_dim, int64_t stride, int64_t dilation, - int64_t weight_dim) + int64_t weight_dim, + bool is_same_upper = true) { int64_t output_dim = (input_dim + stride - 1) / stride; // round up result int64_t new_weight_dim = weight_dim + (weight_dim - 1) * (dilation - 1); int64_t pad = std::max(static_cast(0), (output_dim - 1) * stride + new_weight_dim - input_dim); - pads[idx] = pad / 2; - pads[idx + 2] = pad - pad / 2; + auto pad_ndims = pads.size() / 2; + + if(is_same_upper) + { + pads[idx] = pad / 2; + pads[idx + pad_ndims] = pad - pad / 2; + } + else + { + pads[idx + pad_ndims] = pad / 2; + pads[idx] = pad - pad / 2; + } } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/par_dfor.hpp b/src/include/migraphx/par_dfor.hpp index 41ad2ca7c00b3e8d2c314ff8906b6972599b80aa..6c2dbb520a5ae4b881154fca326c1ca89f8ec973 100644 --- a/src/include/migraphx/par_dfor.hpp +++ b/src/include/migraphx/par_dfor.hpp @@ -41,7 +41,6 @@ auto par_dfor(Ts... xs) { dfor(xs...)(f); } - }; } diff --git a/src/include/migraphx/par_for.hpp b/src/include/migraphx/par_for.hpp index 1d5654cf3db76009168f074d2edbc22e3578cba5..3f25e4548b1af1226373662cdd496a60eec7d7a5 100644 --- a/src/include/migraphx/par_for.hpp +++ b/src/include/migraphx/par_for.hpp @@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f) template void par_for(std::size_t n, std::size_t min_grain, F f) { - const auto threadsize = - std::min(std::thread::hardware_concurrency(), n / min_grain); + const auto threadsize = std::min(std::thread::hardware_concurrency(), + n / std::max(1, min_grain)); par_for_impl(n, threadsize, f); } diff --git a/src/include/migraphx/pass.hpp b/src/include/migraphx/pass.hpp index da520c7cc85dd50463e70dc915980aeef31b39e3..ce076d2554f065afcf202815e0c0595b94364359 100644 --- a/src/include/migraphx/pass.hpp +++ b/src/include/migraphx/pass.hpp @@ -3,16 +3,19 @@ #include #include -#include #include #include #include +#include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { struct program; +struct module; +struct module_pass_manager; #ifdef DOXYGEN @@ -22,22 +25,53 @@ struct pass { /// A unique name used to identify the pass std::string name() const; + /// Run the pass on the module + void apply(module_pass_manager& mpm) const; + void apply(module& m) const; /// Run the pass on the program void apply(program& p) const; }; #else -/* - * Type-erased interface for: - * - * struct pass - * { - * std::string name() const; - * void apply(program & p) const; - * }; - * - */ +module& get_module(module_pass_manager& mpm); + +namespace detail { + +template +auto module_pass_manager_apply(rank<1>, const T& x, module_pass_manager& mpm) + -> decltype(x.apply(get_module(mpm))) +{ + return x.apply(get_module(mpm)); +} + +template +void module_pass_manager_apply(rank<0>, const T&, module_pass_manager&) +{ +} + +template +void module_pass_manager_apply(const T& x, module_pass_manager& mpm) +{ + module_pass_manager_apply(rank<1>{}, x, mpm); +} + +} // namespace detail + +#ifdef TYPE_ERASED_DECLARATION + +// Type-erased interface for: +struct pass +{ + // + std::string name() const; + // (optional) + void apply(module_pass_manager& mpm) const; + // (optional) + void apply(program& p) const; +}; + +#else struct pass { @@ -57,11 +91,17 @@ struct pass template pass& operator=(PrivateDetailTypeErasedT value) { - if(private_detail_te_handle_mem_var.unique()) - *private_detail_te_handle_mem_var = std::forward(value); - else if(!private_detail_te_handle_mem_var) - private_detail_te_handle_mem_var = std::make_shared( - std::forward(value)); + using std::swap; + auto* derived = this->any_cast(); + if(derived and private_detail_te_handle_mem_var.unique()) + { + *derived = std::forward(value); + } + else + { + pass rhs(value); + swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var); + } return *this; } @@ -69,7 +109,7 @@ struct pass template PrivateDetailTypeErasedT* any_cast() { - return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) + return this->type_id() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type>&>( private_detail_te_get_handle()) @@ -80,7 +120,7 @@ struct pass template const typename std::remove_cv::type* any_cast() const { - return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) + return this->type_id() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type>&>( private_detail_te_get_handle()) @@ -102,6 +142,12 @@ struct pass return (*this).private_detail_te_get_handle().name(); } + void apply(module_pass_manager& mpm) const + { + assert((*this).private_detail_te_handle_mem_var); + (*this).private_detail_te_get_handle().apply(mpm); + } + void apply(program& p) const { assert((*this).private_detail_te_handle_mem_var); @@ -121,10 +167,39 @@ struct pass virtual std::shared_ptr clone() const = 0; virtual const std::type_info& type() const = 0; - virtual std::string name() const = 0; - virtual void apply(program& p) const = 0; + virtual std::string name() const = 0; + virtual void apply(module_pass_manager& mpm) const = 0; + virtual void apply(program& p) const = 0; }; + template + static auto + private_detail_te_default_apply(char, T&& private_detail_te_self, module_pass_manager& mpm) + -> decltype(private_detail_te_self.apply(mpm)) + { + private_detail_te_self.apply(mpm); + } + + template + static void + private_detail_te_default_apply(float, T&& private_detail_te_self, module_pass_manager& mpm) + { + migraphx::detail::module_pass_manager_apply(private_detail_te_self, mpm); + } + + template + static auto private_detail_te_default_apply(char, T&& private_detail_te_self, program& p) + -> decltype(private_detail_te_self.apply(p)) + { + private_detail_te_self.apply(p); + } + + template + static void private_detail_te_default_apply(float, T&& private_detail_te_self, program& p) + { + migraphx::nop(private_detail_te_self, p); + } + template struct private_detail_te_handle_type : private_detail_te_handle_base_type { @@ -155,7 +230,17 @@ struct pass std::string name() const override { return private_detail_te_value.name(); } - void apply(program& p) const override { private_detail_te_value.apply(p); } + void apply(module_pass_manager& mpm) const override + { + + private_detail_te_default_apply(char(0), private_detail_te_value, mpm); + } + + void apply(program& p) const override + { + + private_detail_te_default_apply(char(0), private_detail_te_value, p); + } PrivateDetailTypeErasedT private_detail_te_value; }; @@ -221,6 +306,7 @@ inline const ValueType& any_cast(const pass& x) throw std::bad_cast(); return *y; } +#endif #endif diff --git a/src/include/migraphx/pass_manager.hpp b/src/include/migraphx/pass_manager.hpp old mode 100644 new mode 100755 index e117d85b4eafbee958bf8acf21868050468fde71..1faca7a2b59afbc1b99de93e9e3df4fc3218a91d --- a/src/include/migraphx/pass_manager.hpp +++ b/src/include/migraphx/pass_manager.hpp @@ -1,22 +1,27 @@ #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_PASS_MANAGER_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_PASS_MANAGER_HPP -#include -#include -#include -#include -#include -#include -#include -#include -#include #include -#include -#include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +struct module_pass_manager +{ + module_pass_manager() = default; + module_pass_manager(const module_pass_manager&) = delete; + virtual module& get_module() = 0; + virtual module* create_module(const std::string& name) = 0; + virtual void run_pass(const pass& p) = 0; + + protected: + virtual ~module_pass_manager() {} +}; + +void run_passes(module& mod, const std::vector& passes, tracer trace = tracer{}); void run_passes(program& prog, const std::vector& passes, tracer trace = tracer{}); } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/permutation.hpp b/src/include/migraphx/permutation.hpp old mode 100644 new mode 100755 index e53376662c3b9c3e3c7045d60143b75cc17f6f9e..8a4e8d4eb6cbd869a6ba05eed566173c2771b3b3 --- a/src/include/migraphx/permutation.hpp +++ b/src/include/migraphx/permutation.hpp @@ -20,29 +20,22 @@ inline Vector reorder_dims(const Vector& dims, const std::vector& permu return result; } -inline shape reorder_shape(const shape& s, const std::vector& permutation) -{ - return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)}; -} +shape reorder_shape(const shape& s, const std::vector& permutation); template inline std::vector sort_permutation(const Vector& data, Op op) { std::vector result(data.size()); std::iota(result.begin(), result.end(), 0); - std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); }); + std::stable_sort( + result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); }); return result; } -inline std::vector invert_permutation(const std::vector& permutation) -{ - return sort_permutation(permutation, std::less<>{}); -} +std::vector invert_permutation(const std::vector& permutation); -inline std::vector find_permutation(const shape& s) -{ - return sort_permutation(s.strides(), std::greater<>{}); -} +std::vector find_permutation(const shape& s); +std::vector find_permutation(const std::vector& shapes); } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/preallocate_param.hpp b/src/include/migraphx/preallocate_param.hpp new file mode 100755 index 0000000000000000000000000000000000000000..4d5082744a0770b5c9d89c27a5731582602aaa1a --- /dev/null +++ b/src/include/migraphx/preallocate_param.hpp @@ -0,0 +1,22 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_PREALLOCATE_PARAM_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_PREALLOCATE_PARAM_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +struct preallocate_param +{ + std::string param; + allocation_model model; + std::string name() const { return "preallocate_param"; } + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_PREALLOCATE_PARAM_HPP diff --git a/src/include/migraphx/process.hpp b/src/include/migraphx/process.hpp new file mode 100755 index 0000000000000000000000000000000000000000..ace6d77812424b43a58c9cdf63f920c247bef9cd --- /dev/null +++ b/src/include/migraphx/process.hpp @@ -0,0 +1,36 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_PROCESS_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_PROCESS_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct process_impl; + +struct process +{ + process(const std::string& cmd); + + // move constructor + process(process&&) noexcept; + + // copy assignment operator + process& operator=(process rhs); + + ~process() noexcept; + + process& cwd(const fs::path& p); + + void exec(); + + private: + std::unique_ptr impl; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_PROCESS_HPP diff --git a/src/include/migraphx/program.hpp b/src/include/migraphx/program.hpp index 2b95316e9d13f6ea3184b983ddb1a6ef1017e21a..573467a0f9b77d1908c870930ea8d38668e66868 100644 --- a/src/include/migraphx/program.hpp +++ b/src/include/migraphx/program.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -22,7 +23,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL) struct program_impl; -const operation& get_operation(instruction_ref ins); +struct marker; /** * @brief Stores the instruction stream @@ -42,50 +43,7 @@ struct program ~program() noexcept; - using parameter_map = std::unordered_map; - - template - instruction_ref add_instruction(operation op, Ts... args) - { - return add_instruction(op, {args...}); - } - instruction_ref add_instruction(const operation& op, std::vector args); - - template - instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args) - { - return insert_instruction(ins, op, {args...}); - } - instruction_ref - insert_instruction(instruction_ref ins, const operation& op, std::vector args); - - template - instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args) - { - return replace_instruction(ins, op, {args...}); - } - instruction_ref replace_instruction(instruction_ref ins, - const operation& op, - std::vector args); - - instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep); - - instruction_ref remove_instruction(instruction_ref ins); - instruction_ref remove_instructions(instruction_ref first, instruction_ref last); - - instruction_ref move_instruction(instruction_ref src, instruction_ref dst); - - template - instruction_ref add_literal(Ts&&... xs) - { - return add_literal(literal{std::forward(xs)...}); - } - - instruction_ref add_literal(literal l); - - instruction_ref add_outline(const shape& s); - - instruction_ref add_parameter(std::string name, shape s); + std::vector get_parameter_names() const; shape get_parameter_shape(std::string name) const; @@ -93,15 +51,11 @@ struct program std::unordered_map get_parameter_shapes() const; - argument eval(parameter_map params) const; - - bool has_instruction(instruction_ref ins) const; + std::vector eval(parameter_map params) const; std::size_t size() const; - instruction_ref begin() const; - instruction_ref end() const; - shape get_shape() const; + std::vector get_output_shapes() const; context& get_context() const; @@ -109,27 +63,57 @@ struct program void compile(const target& t, compile_options options = compile_options{}); + bool is_compiled() const; + void finalize(); - void perf_report(std::ostream& os, std::size_t n, parameter_map params) const; + void + perf_report(std::ostream& os, std::size_t n, parameter_map params, std::size_t batch = 1) const; + + void mark(const parameter_map& params, marker&& m); + + value to_value() const; + void from_value(const value& v); void debug_print() const; void debug_print(instruction_ref ins) const; - void debug_print(const std::vector& inss) const; + void print(std::unordered_map& names, + const std::function)>& + print_func) const; + void print(const std::function)>& + print_func) const; + void print_graph(std::ostream& os, bool brief = false) const; + void print_cpp(std::ostream& os) const; void dry_run(parameter_map params) const; - void annotate(std::ostream& os, std::function a) const; + void annotate(std::ostream& os, const std::function& a) const; + + program& sort(); friend std::ostream& operator<<(std::ostream& os, const program& p); friend bool operator==(const program& x, const program& y); friend bool operator!=(const program& x, const program& y) { return !(x == y); } - private: - void assign(const program& p); + // module related api + module* create_module(const std::string& name); + module* get_module(const std::string& name); + const module* get_module(const std::string& name) const; + + module* get_main_module(); + const module* get_main_module() const; + + std::vector get_modules() const; + std::vector get_modules(); + + void remove_module(const std::string& name); + void remove_unused_modules(); private: + void assign(const program& p); std::unique_ptr impl; }; diff --git a/src/include/migraphx/propagate_constant.hpp b/src/include/migraphx/propagate_constant.hpp index 6c3dea924185c0f95cefab966bc9f4487de9cef7..d88f28fe917b19dd9bae0be723883552cd1483ed 100644 --- a/src/include/migraphx/propagate_constant.hpp +++ b/src/include/migraphx/propagate_constant.hpp @@ -7,7 +7,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; /** * Replace instructions which take all literals with a literal of the computation. @@ -15,7 +15,7 @@ struct program; struct propagate_constant { std::string name() const { return "propagate_constant"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/quantization.hpp b/src/include/migraphx/quantization.hpp index b7a30b917b1acd485044dd464e997b673ed2ef7a..cbc1054e069dbb8becc1e25d454b77ef35945d05 100644 --- a/src/include/migraphx/quantization.hpp +++ b/src/include/migraphx/quantization.hpp @@ -17,32 +17,10 @@ struct program; void quantize_fp16(program& prog, const std::vector& ins_names = {"all"}); -// insert the capture operator for the inputs of each operator to be quantized -// to int8 -std::size_t capture_arguments(program& prog, - const std::vector& ins_names, - const std::function)>& func); - -std::shared_ptr>> -capture_arguments_impl(program& prog, const target& t, const std::vector& ins_names); - -template -std::shared_ptr>> -capture_arguments(program& prog, T&& t, const std::vector& ins_names) -{ - static_assert(std::is_same>, target>{} && - std::is_lvalue_reference{}, - "Dangling reference to target!"); - return capture_arguments_impl(prog, t, ins_names); -} - void quantize_int8(program& prog, const target& t, - const std::vector& calibration, + const std::vector& calibration, const std::vector& ins_names = {"dot", "convolution"}); -void quantize_int8_impl(program& prog, - const std::vector>& quant_params, - const std::vector& ins_names); } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/quantize_fp16.hpp b/src/include/migraphx/quantize_fp16.hpp new file mode 100644 index 0000000000000000000000000000000000000000..050de984957e3c02ef3a9c1274fc09f169c0176e --- /dev/null +++ b/src/include/migraphx/quantize_fp16.hpp @@ -0,0 +1,27 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP +#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct program; +struct module; + +/** + * quantize a program to fp16 + */ +struct quantize_fp16_pass +{ + std::vector ins_names = {"all"}; + std::string name() const { return "quantize_fp16"; } + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/quantize_int8.hpp b/src/include/migraphx/quantize_int8.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cd25ba46fb16a46b86a5ef1894b7fe3b989ca529 --- /dev/null +++ b/src/include/migraphx/quantize_int8.hpp @@ -0,0 +1,42 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP +#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct program; +struct module; + +/** + * capture inputs of operators to be quantized to int8 + */ +struct capture_arguments_pass +{ + std::vector ins_names = {"dot", "convolution"}; + std::function)> f{}; + std::size_t* param_index = nullptr; + std::string name() const { return "capture_arguments"; } + void apply(module& m) const; +}; + +/** + * quantize a program to int8 + */ +struct quantize_int8_pass +{ + std::vector ins_names = {"dot", "convolution"}; + std::vector> quant_params; + std::string name() const { return "quantize_int8"; } + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/ranges.hpp b/src/include/migraphx/ranges.hpp old mode 100644 new mode 100755 index 85ee3b4a038fd0202969e3449f108997d49774af..4ab5ee6fab615157bd5a89a44f127bac5c016b79 --- a/src/include/migraphx/ranges.hpp +++ b/src/include/migraphx/ranges.hpp @@ -2,8 +2,13 @@ #define MIGRAPHX_GUARD_MIGRAPHLIB_RANGES_HPP #include +#include #include #include +#include +#include +#include +#include #include namespace migraphx { @@ -33,6 +38,33 @@ auto generic_find_impl(rank<0>, C&& c, const T& x) return std::find(c.begin(), c.end(), x); } +template +auto generic_find_at_impl(rank<1>, C&& c, const T& x) -> decltype(c.find(x)) +{ + return c.find(x); +} + +template +auto generic_find_at_impl(rank<0>, C&& c, const T& x) +{ + auto n = std::distance(c.begin(), c.end()); + if(x >= n) + return c.end(); + return std::next(c.begin(), x); +} + +template +decltype(auto) generic_at_impl(rank<1>, const C&, T&& it) +{ + return it->second; +} + +template +decltype(auto) generic_at_impl(rank<0>, const C&, T&& it) +{ + return *it; +} + struct empty { }; @@ -45,6 +77,20 @@ auto generic_find(C&& c, const T& x) return detail::generic_find_impl(rank<2>{}, c, x); } +template +decltype(auto) at(C&& c, const T& x, const std::string& msg = "") +{ + auto it = detail::generic_find_at_impl(rank<2>{}, c, x); + if(it == c.end()) + { + if(msg.empty()) + MIGRAPHX_THROW("At operator out of range for " + get_type_name(c)); + else + MIGRAPHX_THROW(msg); + } + return detail::generic_at_impl(rank<2>{}, c, it); +} + template bool contains(const C& c, const T& x) { @@ -123,12 +169,41 @@ void copy(Range&& r, Iterator it) std::copy(r.begin(), r.end(), it); } +template +void transform(Range&& r, Iterator it, F f) +{ + std::transform(r.begin(), r.end(), it, f); +} + +template +auto reverse(Range& r) +{ + return range(std::make_reverse_iterator(r.end()), std::make_reverse_iterator(r.begin())); +} + template void replace(Range&& r, const T& old, const T& new_x) { std::replace(r.begin(), r.end(), old, new_x); } +template +bool equal(R1&& r1, R2&& r2) +{ + return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end()); +} + +template +using range_value = std::decay_t().begin())>; + +template +std::vector> find_all(Range&& r, Predicate p) +{ + std::vector> result; + std::copy_if(r.begin(), r.end(), std::back_inserter(result), p); + return result; +} + template struct iterator_range { @@ -140,12 +215,18 @@ struct iterator_range Iterator end() const { return last; } }; -template +template {})> iterator_range range(Iterator start, Iterator last) { return {start, last}; } +inline iterator_range range(std::ptrdiff_t start, std::ptrdiff_t last) +{ + return {{start, {}}, {last, {}}}; +} +inline iterator_range range(std::ptrdiff_t last) { return range(0, last); } + template iterator_range range(std::pair p) { diff --git a/src/include/migraphx/raw_data.hpp b/src/include/migraphx/raw_data.hpp index c8e255944fc182492a2fa047899214ca8022e1b5..4374b638ac3f1677d63db933a32d5f5dd4a45062 100644 --- a/src/include/migraphx/raw_data.hpp +++ b/src/include/migraphx/raw_data.hpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -28,7 +29,15 @@ struct raw_data : raw_data_base friend Stream& operator<<(Stream& os, const Derived& d) { if(not d.empty()) - d.visit([&](auto x) { os << x; }); + d.visit([&](auto x) { os << x; }, + [&](auto&& xs) { + for(auto&& x : xs) + { + os << "{ "; + os << x; + os << " }, "; + } + }); return os; } @@ -44,9 +53,19 @@ struct raw_data : raw_data_base auto&& derived = static_cast(*this); if(derived.empty()) MIGRAPHX_THROW("Visiting empty data!"); - auto&& s = derived.get_shape(); - auto&& buffer = derived.data(); - s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); }); + auto&& s = derived.get_shape(); + s.visit_type([&](auto as) { v(*(as.from(derived.data()) + s.index(n))); }); + } + + template + void visit(Visitor v, TupleVisitor tv) const + { + auto&& derived = static_cast(*this); + if(derived.empty()) + MIGRAPHX_THROW("Visiting empty data!"); + auto&& s = derived.get_shape(); + s.visit_type([&](auto as) { v(make_view(s, as.from(derived.data()))); }, + [&] { tv(derived.get_sub_objects()); }); } /** @@ -59,12 +78,7 @@ struct raw_data : raw_data_base template void visit(Visitor v) const { - auto&& derived = static_cast(*this); - if(derived.empty()) - MIGRAPHX_THROW("Visiting empty data!"); - auto&& s = derived.get_shape(); - auto&& buffer = derived.data(); - s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); }); + visit(v, [&](const auto&) { MIGRAPHX_THROW("Invalid tuple type"); }); } /// Returns true if the raw data is only one element @@ -141,50 +155,41 @@ struct raw_data : raw_data_base template T* cast() const { - auto&& s = static_cast(*this).get_shape(); auto&& buffer = static_cast(*this).data(); - assert(s.type() == migraphx::shape::get_type{}); + assert(static_cast(*this).get_shape().type() == + migraphx::shape::get_type{}); return reinterpret_cast(buffer); } -}; -template {} && - std::is_base_of{})> -bool operator==(const T& x, const U& y) -{ - auto&& xshape = x.get_shape(); - auto&& yshape = y.get_shape(); - bool result = x.empty() && y.empty(); - if(not result && xshape == yshape) + std::string to_string() const { - auto&& xbuffer = x.data(); - auto&& ybuffer = y.data(); - // TODO: Dont use tensor view for single values - xshape.visit_type([&](auto as) { - auto xview = make_view(xshape, as.from(xbuffer)); - auto yview = make_view(yshape, as.from(ybuffer)); - result = xview == yview; - }); + std::stringstream ss; + ss << static_cast(*this); + return ss.str(); } - return result; +}; + +namespace detail { +template +void visit_all_flatten(const shape& s, V1&& v1, V2&& v2, Ts&&... xs) +{ + s.visit_type([&](auto as) { v1(make_view(xs.get_shape(), as.from(xs.data()))...); }, + [&] { v2(xs.get_sub_objects()...); }); } -template {} && - std::is_base_of{})> -bool operator!=(const T& x, const U& y) +template +auto visit_all_pack(const shape& s, V1&& v1, V2&& v2) { - return !(x == y); + return [&](auto&&... xs) { + // Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100 + visit_all_flatten(s, v1, v2, xs...); + }; } -namespace detail { -template -void visit_all_impl(const shape& s, V&& v, Ts&&... xs) +template +auto visit_all_pack(const shape& s, V1&& v1) { - s.visit_type([&](auto as) { v(make_view(xs.get_shape(), as.from(xs.data()))...); }); + return visit_all_pack(s, v1, [](auto&&...) { MIGRAPHX_THROW("Invalid tuple type"); }); } } // namespace detail @@ -206,10 +211,7 @@ auto visit_all(T&& x, Ts&&... xs) std::initializer_list types = {xs.get_shape().type()...}; if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); })) MIGRAPHX_THROW("Types must be the same"); - return [&](auto v) { - // Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100 - detail::visit_all_impl(s, v, x, xs...); - }; + return [&](auto... vs) { detail::visit_all_pack(s, vs...)(x, xs...); }; } template @@ -231,6 +233,34 @@ auto visit_all(const std::vector& x) }; } +template {} && + std::is_base_of{})> +bool operator==(const T& x, const U& y) +{ + auto&& xshape = x.get_shape(); + auto&& yshape = y.get_shape(); + bool result = x.empty() and y.empty(); + if(not result and xshape == yshape) + { + visit_all(x, y)([&](auto xview, auto yview) { result = xview == yview; }, + [&](auto&& xs, auto&& ys) { + result = std::equal(xs.begin(), xs.end(), ys.begin(), ys.end()); + }); + } + return result; +} + +template {} && + std::is_base_of{})> +bool operator!=(const T& x, const U& y) +{ + return !(x == y); +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/reduce_dims.hpp b/src/include/migraphx/reduce_dims.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5d2a5e3bcc591671092f3dc6f26fde13dd18b072 --- /dev/null +++ b/src/include/migraphx/reduce_dims.hpp @@ -0,0 +1,16 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_REDUCE_DIMS_HPP +#define MIGRAPHX_GUARD_RTGLIB_REDUCE_DIMS_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::vector reduce_dims(const std::vector& shapes); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/reflect.hpp b/src/include/migraphx/reflect.hpp old mode 100644 new mode 100755 index e271b6cbf34a5164d44db559a5caa7ab052609fe..c87706c1c8dd1d453f432d994e591a1e733ee010 --- a/src/include/migraphx/reflect.hpp +++ b/src/include/migraphx/reflect.hpp @@ -102,6 +102,29 @@ void reflect_each(T& x, F f) }); } +template +struct reflect_equality +{ + friend bool operator==(const T& x, const T& y) { return reflect_tie(x) == reflect_tie(y); } + friend bool operator!=(const T& x, const T& y) { return !(x == y); } +}; + +template +struct reflect_stream +{ + template + friend Stream& operator<<(Stream& os, const T& x) + { + char d = '{'; + reflect_each(x, [&](const auto& y, const auto& name) { + os << d << name << "=" << y; + d = ','; + }); + os << "}"; + return os; + } +}; + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/register_op.hpp b/src/include/migraphx/register_op.hpp new file mode 100755 index 0000000000000000000000000000000000000000..0e211e8f1942f2681685a8d8e789e7fdd0e6ba9f --- /dev/null +++ b/src/include/migraphx/register_op.hpp @@ -0,0 +1,41 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_REGISTER_OP_HPP +#define MIGRAPHX_GUARD_RTGLIB_REGISTER_OP_HPP + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +void register_op(const operation& op); +operation load_op(const std::string& name); +bool has_op(const std::string& name); +std::vector get_operators(); + +template +void register_op() +{ + register_op(T{}); +} + +struct register_op_action +{ + template + static void apply() + { + register_op(); + } +}; + +template +using auto_register_op = auto_register; + +#define MIGRAPHX_REGISTER_OP(...) MIGRAPHX_AUTO_REGISTER(register_op_action, __VA_ARGS__) + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/register_target.hpp b/src/include/migraphx/register_target.hpp new file mode 100644 index 0000000000000000000000000000000000000000..011d1113dfb04268ca9d94126198aad2b6bd31a6 --- /dev/null +++ b/src/include/migraphx/register_target.hpp @@ -0,0 +1,40 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_REGISTER_TARGET_HPP +#define MIGRAPHX_GUARD_RTGLIB_REGISTER_TARGET_HPP + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +void register_target(const target& t); +target make_target(const std::string& name); +std::vector get_targets(); + +template +void register_target() +{ + register_target(T{}); +} + +struct register_target_action +{ + template + static void apply() + { + register_target(); + } +}; + +template +using auto_register_target = auto_register; + +#define MIGRAPHX_REGISTER_TARGET(...) MIGRAPHX_AUTO_REGISTER(register_target_action, __VA_ARGS__) + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/replace_allocate.hpp b/src/include/migraphx/replace_allocate.hpp new file mode 100644 index 0000000000000000000000000000000000000000..02096dc48cd21f41bbb9cae1e0c6254e6fe4ef6b --- /dev/null +++ b/src/include/migraphx/replace_allocate.hpp @@ -0,0 +1,23 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_REPLACE_ALLOCATE_HPP +#define MIGRAPHX_GUARD_RTGLIB_REPLACE_ALLOCATE_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +struct replace_allocate +{ + allocation_model model; + bool offload_copy = false; + std::string name() const { return "replace_allocate"; } + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/requires.hpp b/src/include/migraphx/requires.hpp index 006bba65910ed4d006d42c5caf31715ae5a2474f..1a7c07a2ab33471772306f177c6303f53132e64f 100644 --- a/src/include/migraphx/requires.hpp +++ b/src/include/migraphx/requires.hpp @@ -22,12 +22,14 @@ using bool_c = std::integral_constant; #ifdef CPPCHECK #define MIGRAPHX_REQUIRES(...) class = void +#define MIGRAPHX_CLASS_REQUIRES(...) void #else #define MIGRAPHX_REQUIRES(...) \ long MIGRAPHX_REQUIRES_VAR() = __LINE__, \ typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() == __LINE__ && \ (migraphx::and_<__VA_ARGS__>{})), \ int>::type = 0 +#define MIGRAPHX_CLASS_REQUIRES(...) typename std::enable_if<(migraphx::and_<__VA_ARGS__>{})>::type #endif } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/rewrite_batchnorm.hpp b/src/include/migraphx/rewrite_batchnorm.hpp index 96d4ba374b88ce10bd0d98f29a5036abaac68db5..eb3f22cb46875d9c868195aabd8ae33de8f81dee 100644 --- a/src/include/migraphx/rewrite_batchnorm.hpp +++ b/src/include/migraphx/rewrite_batchnorm.hpp @@ -8,7 +8,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; /** * Rewrite batchnorm to a multiply and add. @@ -16,7 +16,7 @@ struct program; struct rewrite_batchnorm { std::string name() const { return "rewrite_batchnorm"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/rewrite_pooling.hpp b/src/include/migraphx/rewrite_pooling.hpp index de44ca64c3edbd3a9af72da1ee0ba297019fb3bb..58e4332a1f74639fd43491c19a9a9cbc8521910c 100644 --- a/src/include/migraphx/rewrite_pooling.hpp +++ b/src/include/migraphx/rewrite_pooling.hpp @@ -7,7 +7,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; /** * Rewrite pooling to reduce_mean @@ -15,7 +15,7 @@ struct program; struct rewrite_pooling { std::string name() const { return "rewrite_pooling"; } - void apply(program& prog) const; + void apply(module& m) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/rewrite_quantization.hpp b/src/include/migraphx/rewrite_quantization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f4f402da1a26ef7d3e332ef137e2c86279b26b2e --- /dev/null +++ b/src/include/migraphx/rewrite_quantization.hpp @@ -0,0 +1,24 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_QUANTIZATION_HPP +#define MIGRAPHX_GUARD_RTGLIB_REWRITE_QUANTIZATION_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +/** + * Rewrite quantization ops to equivalent operators + */ +struct rewrite_quantization +{ + std::string name() const { return "rewrite_quantization"; } + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/rewrite_rnn.hpp b/src/include/migraphx/rewrite_rnn.hpp index 49b897d1c9225314815d0f544364c9803c1ddec1..23ffe8138b29403c8d9f2b1d63db1f112482e2cf 100644 --- a/src/include/migraphx/rewrite_rnn.hpp +++ b/src/include/migraphx/rewrite_rnn.hpp @@ -6,11 +6,12 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; /** * Rewrite rnn to gemm and add. @@ -18,26 +19,22 @@ struct program; struct rewrite_rnn { std::string name() const { return "rewrite_rnn"; } - void apply(program& prog) const; + void apply(module& m) const; private: // for vanilla rnn operators - void apply_vanilla_rnn(program& prog, instruction_ref ins) const; + void apply_vanilla_rnn(module& m, instruction_ref ins) const; std::vector vanilla_rnn_cell(bool is_forward, - program& prog, + module& m, instruction_ref ins, - instruction_ref input, - instruction_ref w, - instruction_ref r, - instruction_ref bias, - instruction_ref ih, + std::vector inputs, operation& actv_func) const; std::vector vanilla_rnn_actv_funcs(instruction_ref ins) const; // for gru operators - void apply_gru(program& prog, instruction_ref ins) const; + void apply_gru(module& m, instruction_ref ins) const; std::vector gru_cell(bool is_forward, - program& prog, + module& m, instruction_ref ins, std::vector inputs, int linear_before_reset, @@ -47,9 +44,9 @@ struct rewrite_rnn std::vector gru_actv_funcs(instruction_ref ins) const; // for lstm operators - void apply_lstm(program& prog, instruction_ref ins) const; + void apply_lstm(module& m, instruction_ref ins) const; std::vector lstm_cell(bool is_forward, - program& prog, + module& m, instruction_ref ins, std::vector inputs, const operation& actv_func1, @@ -57,6 +54,27 @@ struct rewrite_rnn const operation& actv_func3) const; std::vector lstm_actv_funcs(instruction_ref ins) const; + + bool is_variable_seq_lens(const module& m, instruction_ref seq_lens) const; + instruction_ref replace_last_hs_output(module& m, + instruction_ref ins, + instruction_ref seq_lens, + instruction_ref last_hs_output, + op::rnn_direction dirct) const; + + void replace_last_cell_output(module& m, + instruction_ref ins, + instruction_ref seq_lens, + instruction_ref cell_outputs, + instruction_ref last_cell_output, + op::rnn_direction dirct) const; + + std::size_t get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const; + + instruction_ref pad_hidden_states(module& m, + instruction_ref seq, + instruction_ref seq_lens, + instruction_ref hs) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/run_loop.hpp b/src/include/migraphx/run_loop.hpp new file mode 100644 index 0000000000000000000000000000000000000000..33429a11b10079e58f36f0cff45ba3211047ba15 --- /dev/null +++ b/src/include/migraphx/run_loop.hpp @@ -0,0 +1,115 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP +#define MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +argument run_loop(const LoopModel& model, + T& ctx, + std::vector args, + const std::vector& mods, + const std::function( + module_ref&, const std::unordered_map&)>& run) +{ + std::vector> results; + // process argu lists + auto iter_num = args.at(0).at(); + auto cond = args.at(1).at(); + + auto input_num = (args.size() - 2) / 2; + auto dep_num = input_num - 2; + + module_ref mod = mods.at(0); + auto param_name_shapes = mod->get_parameter_shapes(); + auto param_names = mod->get_parameter_names(); + + std::vector dep0(args.begin() + input_num + 1, args.begin() + 2 * input_num); + std::vector dep1(args.begin() + 2 * input_num, args.begin() + 2 * input_num + 1); + auto ins_outputs = args.back().get_sub_objects(); + dep1.insert(dep1.end(), ins_outputs.begin(), ins_outputs.begin() + dep_num); + std::array, 2> loop_carry_deps = {dep0, dep1}; + + // loop iter argument + std::vector in_args = {args.at(input_num), dep1.at(0)}; + in_args.insert(in_args.end(), args.begin() + 2, args.begin() + input_num); + + std::vector out_args = dep0; + out_args.insert(out_args.end(), ins_outputs.begin() + dep_num, ins_outputs.end()); + std::vector scan_outputs(ins_outputs.begin() + dep_num, ins_outputs.end()); + + auto out_param_indices = model.get_output_params(*mod); + + int64_t iter = 0; + for(iter = 0; iter < iter_num and cond; ++iter) + { + // copy iter num and cond to device memory + model.copy(ctx, iter, in_args.at(0)); + model.copy(ctx, cond, in_args.at(1)); + + // wrap up the inputs and outputs + std::unordered_map params; + int input_index = 0; + for(const auto& name : param_names) + { + auto ps = mod->get_parameter_shape(name); + if(ps == shape{}) + { + continue; + } + + // it is an input parameter + if(not contains(out_param_indices, name)) + { + params[name] = in_args.at(input_index++); + } + else + { + auto output_index = out_param_indices[name]; + if(output_index > dep_num) + { + const auto& arg = out_args.at(output_index); + assert((iter + 1) * ps.bytes() <= arg.get_shape().bytes()); + params[name] = argument(ps, arg.data() + iter * ps.bytes()); + } + else + { + params[name] = out_args.at(output_index); + } + } + } + + auto mod_args = run(mod, params); + + // copy back cond to be used next iteration + model.copy(ctx, mod_args.at(0), cond); + + // mod outputs are used as next loop input + std::copy(mod_args.begin(), mod_args.begin() + dep_num + 1, in_args.begin() + 1); + const auto& dep_out = loop_carry_deps[(iter + 1) % 2]; + std::copy(dep_out.begin(), dep_out.end(), out_args.begin()); + + std::vector mod_scan_outs(mod_args.begin() + 1 + dep_num, mod_args.end()); + model.append(mod_scan_outs, scan_outputs, iter); + } + + out_args.erase(out_args.begin()); + std::copy(in_args.begin() + 2, in_args.end(), out_args.begin()); + model.set_zero(ctx, scan_outputs, iter); + + return {out_args}; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/schedule.hpp b/src/include/migraphx/schedule.hpp index a35c237ae048f9c95820efecdf0b7a9bd00b0fb1..352ba0e51d416575d1614b27e987c65f812d4c58 100644 --- a/src/include/migraphx/schedule.hpp +++ b/src/include/migraphx/schedule.hpp @@ -9,7 +9,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; /** * Schedule instructions for concurrent execution @@ -19,7 +19,7 @@ struct schedule schedule_model model{}; bool enable = true; std::string name() const { return "schedule"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/schedule_model.hpp b/src/include/migraphx/schedule_model.hpp index 78788d69521d327f259b9ce19b64bbf91391e64a..14742707aa46244551e20a5bdb2a787083311c7e 100644 --- a/src/include/migraphx/schedule_model.hpp +++ b/src/include/migraphx/schedule_model.hpp @@ -15,7 +15,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; struct operation; #ifdef DOXYGEN @@ -26,30 +26,35 @@ struct schedule_model /// Get the number of concurrent instruction allowed std::size_t concurrency() const; /// Schedule a concurrent instruction - void sched(program& p, instruction_ref ins, std::size_t n) const; + void sched(module& m, instruction_ref ins, std::size_t n) const; // Insert necessary waits before an instruction - void wait(program& p, instruction_ref ins, std::size_t wait_id) const; + void wait(module& m, instruction_ref ins, std::size_t wait_id) const; // Insert necessary records after an instruction - void record(program& p, instruction_ref ins, std::size_t wait_id) const; + void record(module& m, instruction_ref ins, std::size_t wait_id) const; /// Compute weights for an operation std::size_t weight(const operation& op) const; }; #else -/* - * Type-erased interface for: - * - * struct schedule_model - * { - * std::size_t concurrency() const; - * void sched(program& p,instruction_ref ins,std::size_t n) const; - * void wait(program& p,instruction_ref ins,std::size_t wait_id) const; - * void record(program& p,instruction_ref ins,std::size_t wait_id) const; - * std::size_t weight(const operation& op) const; - * }; - * - */ +#ifdef TYPE_ERASED_DECLARATION + +// Type-erased interface for: +struct schedule_model +{ + // + std::size_t concurrency() const; + // + void sched(module& m, instruction_ref ins, std::size_t n) const; + // + void wait(module& m, instruction_ref ins, std::size_t wait_id) const; + // + void record(module& m, instruction_ref ins, std::size_t wait_id) const; + // + std::size_t weight(const operation& op) const; +}; + +#else struct schedule_model { @@ -69,11 +74,17 @@ struct schedule_model template schedule_model& operator=(PrivateDetailTypeErasedT value) { - if(private_detail_te_handle_mem_var.unique()) - *private_detail_te_handle_mem_var = std::forward(value); - else if(!private_detail_te_handle_mem_var) - private_detail_te_handle_mem_var = std::make_shared( - std::forward(value)); + using std::swap; + auto* derived = this->any_cast(); + if(derived and private_detail_te_handle_mem_var.unique()) + { + *derived = std::forward(value); + } + else + { + schedule_model rhs(value); + swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var); + } return *this; } @@ -81,7 +92,7 @@ struct schedule_model template PrivateDetailTypeErasedT* any_cast() { - return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) + return this->type_id() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type>&>( private_detail_te_get_handle()) @@ -92,7 +103,7 @@ struct schedule_model template const typename std::remove_cv::type* any_cast() const { - return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) + return this->type_id() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type>&>( private_detail_te_get_handle()) @@ -114,22 +125,22 @@ struct schedule_model return (*this).private_detail_te_get_handle().concurrency(); } - void sched(program& p, instruction_ref ins, std::size_t n) const + void sched(module& m, instruction_ref ins, std::size_t n) const { assert((*this).private_detail_te_handle_mem_var); - (*this).private_detail_te_get_handle().sched(p, ins, n); + (*this).private_detail_te_get_handle().sched(m, ins, n); } - void wait(program& p, instruction_ref ins, std::size_t wait_id) const + void wait(module& m, instruction_ref ins, std::size_t wait_id) const { assert((*this).private_detail_te_handle_mem_var); - (*this).private_detail_te_get_handle().wait(p, ins, wait_id); + (*this).private_detail_te_get_handle().wait(m, ins, wait_id); } - void record(program& p, instruction_ref ins, std::size_t wait_id) const + void record(module& m, instruction_ref ins, std::size_t wait_id) const { assert((*this).private_detail_te_handle_mem_var); - (*this).private_detail_te_get_handle().record(p, ins, wait_id); + (*this).private_detail_te_get_handle().record(m, ins, wait_id); } std::size_t weight(const operation& op) const @@ -152,11 +163,11 @@ struct schedule_model virtual std::shared_ptr clone() const = 0; virtual const std::type_info& type() const = 0; - virtual std::size_t concurrency() const = 0; - virtual void sched(program& p, instruction_ref ins, std::size_t n) const = 0; - virtual void wait(program& p, instruction_ref ins, std::size_t wait_id) const = 0; - virtual void record(program& p, instruction_ref ins, std::size_t wait_id) const = 0; - virtual std::size_t weight(const operation& op) const = 0; + virtual std::size_t concurrency() const = 0; + virtual void sched(module& m, instruction_ref ins, std::size_t n) const = 0; + virtual void wait(module& m, instruction_ref ins, std::size_t wait_id) const = 0; + virtual void record(module& m, instruction_ref ins, std::size_t wait_id) const = 0; + virtual std::size_t weight(const operation& op) const = 0; }; template @@ -189,22 +200,22 @@ struct schedule_model std::size_t concurrency() const override { return private_detail_te_value.concurrency(); } - void sched(program& p, instruction_ref ins, std::size_t n) const override + void sched(module& m, instruction_ref ins, std::size_t n) const override { - private_detail_te_value.sched(p, ins, n); + private_detail_te_value.sched(m, ins, n); } - void wait(program& p, instruction_ref ins, std::size_t wait_id) const override + void wait(module& m, instruction_ref ins, std::size_t wait_id) const override { - private_detail_te_value.wait(p, ins, wait_id); + private_detail_te_value.wait(m, ins, wait_id); } - void record(program& p, instruction_ref ins, std::size_t wait_id) const override + void record(module& m, instruction_ref ins, std::size_t wait_id) const override { - private_detail_te_value.record(p, ins, wait_id); + private_detail_te_value.record(m, ins, wait_id); } std::size_t weight(const operation& op) const override @@ -277,6 +288,7 @@ inline const ValueType& any_cast(const schedule_model& x) throw std::bad_cast(); return *y; } +#endif #endif diff --git a/src/include/migraphx/serialize.hpp b/src/include/migraphx/serialize.hpp new file mode 100755 index 0000000000000000000000000000000000000000..66d1da2a1ce5e489e7dd5a08e0954b18209dda5b --- /dev/null +++ b/src/include/migraphx/serialize.hpp @@ -0,0 +1,210 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_SERIALIZE_HPP +#define MIGRAPHX_GUARD_RTGLIB_SERIALIZE_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +// Avoid implicit conversion with ADL lookup +template +void migraphx_to_value(value&, const T&) = delete; + +template +value to_value(const T& x); + +template +void from_value(const value& v, T& x); + +template +T from_value(const value& v) +{ + T x{}; + from_value(v, x); + return x; +} + +namespace detail { + +template {})> +value to_value_impl(rank<0>, const T&) +{ + return value::object{}; +} + +template +value to_value_impl(rank<1>, const std::pair& x) +{ + + return {x.first, x.second}; +} + +template +auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{}) +{ + value result = value::array{}; + for(auto&& y : x) + { + result.insert(to_value(y)); + } + return result; +} + +template {})> +value to_value_impl(rank<3>, const T& x) +{ + value result = value::object{}; + reflect_each(x, [&](auto&& y, std::string name) { result.emplace(name, to_value(y)); }); + return result; +} + +template {})> +value to_value_impl(rank<4>, const T& x) +{ + return std::int64_t{x}; +} + +template {})> +value to_value_impl(rank<5>, const T& x) +{ + return std::uint64_t{x}; +} + +template {})> +value to_value_impl(rank<6>, const T& x) +{ + return double{x}; +} + +template {})> +value to_value_impl(rank<7>, const T& x) +{ + return x; +} + +inline value to_value_impl(rank<8>, const std::string& x) { return x; } + +template +auto to_value_impl(rank<9>, const T& x) -> decltype(migraphx_to_value(x)) +{ + return migraphx_to_value(x); +} + +template +auto to_value_impl(rank<10>, const T& x) -> decltype(x.to_value()) +{ + return x.to_value(); +} + +template +auto to_value_impl(rank<11>, const T& x) + -> decltype(migraphx_to_value(std::declval(), x), value{}) +{ + value v; + migraphx_to_value(v, x); + return v; +} + +template {})> +void from_value_impl(rank<0>, const value& v, T& x) +{ + if(not v.is_object()) + MIGRAPHX_THROW("Expected an object"); + if(not v.get_object().empty()) + MIGRAPHX_THROW("Expected an empty object"); + x = T{}; +} + +template +auto from_value_impl(rank<1>, const value& v, T& x) + -> decltype(x.insert(x.end(), *x.begin()), void()) +{ + x.clear(); + for(auto&& e : v) + x.insert(x.end(), from_value(e)); +} + +template {})> +auto from_value_impl(rank<2>, const value& v, T& x) + -> decltype(x.insert(x.end(), *x.begin()), void()) +{ + x.clear(); + if(v.is_binary()) + { + for(auto&& e : v.get_binary()) + x.insert(x.end(), e); + } + else + { + for(auto&& e : v) + x.insert(x.end(), from_value(e)); + } +} + +template +auto from_value_impl(rank<3>, const value& v, T& x) -> decltype(x.insert(*x.begin()), void()) +{ + x.clear(); + for(auto&& e : v) + x.emplace(e.get_key(), from_value(e)); +} + +template {})> +void from_value_impl(rank<4>, const value& v, T& x) +{ + reflect_each(x, [&](auto& y, const std::string& name) { + using type = std::decay_t; + if(v.contains(name)) + y = from_value(v.at(name).without_key()); + }); +} + +template {})> +void from_value_impl(rank<5>, const value& v, T& x) +{ + x = v.to(); +} + +template {})> +void from_value_impl(rank<6>, const value& v, T& x) +{ + x = v.to(); +} + +inline void from_value_impl(rank<7>, const value& v, std::string& x) { x = v.to(); } + +template +auto from_value_impl(rank<8>, const value& v, T& x) -> decltype(x.from_value(v), void()) +{ + x.from_value(v); +} + +template +auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void()) +{ + migraphx_from_value(v, x); +} + +} // namespace detail + +template +value to_value(const T& x) +{ + return detail::to_value_impl(rank<11>{}, x); +} + +template +void from_value(const value& v, T& x) +{ + detail::from_value_impl(rank<9>{}, v, x); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 5785925255476553b54f0adac866f4c37d223d12..372a2efd820f745f3f5727a4b8cdf62d07ab2f3d 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -14,6 +14,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +struct value; struct shape_impl; struct shape @@ -22,6 +23,7 @@ struct shape // Add new types here // clang-format off #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \ + m(bool_type, bool) \ m(half_type, half) \ m(float_type, float) \ m(double_type, double) \ @@ -33,12 +35,12 @@ struct shape m(int64_type, int64_t) \ m(uint32_type, uint32_t) \ m(uint64_type, uint64_t) -// clang-format on + // clang-format on #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, enum type_t { - MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) + MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) tuple_type }; #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES @@ -57,6 +59,11 @@ struct shape { }; + static const std::vector& types(); + + static std::string name(type_t t); + static std::string cpp_type(type_t t); + shape(); shape(type_t t); shape(type_t t, std::vector l); @@ -75,6 +82,10 @@ struct shape { } + shape(const std::vector& subs); + + static shape + from_permutation(type_t t, const std::vector& l, const std::vector& perm); type_t type() const; const std::vector& lens() const; const std::vector& strides() const; @@ -93,13 +104,14 @@ struct shape { assert(std::distance(start, last) <= this->lens().size()); assert(this->lens().size() == this->strides().size()); - return std::inner_product(start, last, this->strides().begin(), std::size_t{0}); + return std::inner_product(start, last, this->strides().begin(), std::size_t{0}); // NOLINT } /// Map element index to space index std::size_t index(std::size_t i) const; std::vector multi(std::size_t i) const; + void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const; /// Returns true if the shape is packed with no padding bool packed() const; @@ -114,6 +126,13 @@ struct shape /// Returns true if all strides are equal to 0 (scalar tensor) bool scalar() const; + shape normalize_standard() const; + + shape with_lens(type_t t, const std::vector& l) const; + shape with_lens(const std::vector& l) const; + + shape with_type(type_t t) const; + friend bool operator==(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y); friend std::ostream& operator<<(std::ostream& os, const shape& x); @@ -121,50 +140,58 @@ struct shape template struct as { - using type = T; + using type = std::conditional_t{}, int8_t, T>; + + type max() const { return std::numeric_limits::max(); } + + type min() const { return std::numeric_limits::lowest(); } template - T operator()(U u) const + type operator()(U u) const { - return T(u); + return type(u); } template - T* operator()(U* u) const + type* operator()(U* u) const { - return static_cast(u); + return static_cast(u); } template - const T* operator()(const U* u) const + const type* operator()(const U* u) const { - return static_cast(u); + return static_cast(u); } - T operator()() const { return {}; } + type operator()() const { return {}; } - std::size_t size(std::size_t n = 1) const { return sizeof(T) * n; } + std::size_t size(std::size_t n = 1) const { return sizeof(type) * n; } template - T* from(U* buffer, std::size_t n = 0) const + type* from(U* buffer, std::size_t n = 0) const { - return reinterpret_cast(buffer) + n; + return reinterpret_cast(buffer) + n; } template - const T* from(const U* buffer, std::size_t n = 0) const + const type* from(const U* buffer, std::size_t n = 0) const { - return reinterpret_cast(buffer) + n; + return reinterpret_cast(buffer) + n; } - type_t type_enum() const { return get_type{}; } + type_t type_enum() const { return get_type{}; } }; - template - void visit_type(Visitor v) const + template + static void visit(type_t t, Visitor v, TupleVisitor tv) { - switch(this->type()) + switch(t) { + case tuple_type: { + tv(); + return; + } #define MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE(x, t) \ case x: v(as()); return; MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE) @@ -173,6 +200,18 @@ struct shape MIGRAPHX_THROW("Unknown type"); } + template + static void visit(type_t t, Visitor v) + { + return visit(t, v, [] { MIGRAPHX_THROW("Tuple cannot be visited."); }); + } + + template + void visit_type(Visitors... vs) const + { + visit(this->type(), vs...); + } + template static void visit_types(Visitor v) { @@ -181,13 +220,21 @@ struct shape #undef MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL } - private: - std::shared_ptr impl; + std::string type_string() const; + static type_t parse_type(const std::string& s); + + const std::vector& sub_shapes() const; std::size_t element_space() const; - std::string type_string() const; + + private: + shape(std::shared_ptr pimpl); + std::shared_ptr impl; }; +void migraphx_to_value(value& v, const shape& s); +void migraphx_from_value(const value& v, shape& s); + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/simplify_algebra.hpp b/src/include/migraphx/simplify_algebra.hpp index cbd15804653e0a4fed71117a7458a90fc821e534..569e8439871ceac6e2e20dfc8b3c4a3152933bc0 100644 --- a/src/include/migraphx/simplify_algebra.hpp +++ b/src/include/migraphx/simplify_algebra.hpp @@ -7,7 +7,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; /** * Simplify many algebraic instructions to more efficient versions. @@ -15,7 +15,7 @@ struct program; struct simplify_algebra { std::string name() const { return "simplify_algebra"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/simplify_qdq.hpp b/src/include/migraphx/simplify_qdq.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9fdb4fa5ac12d6e18d1e78638386938f8510ff54 --- /dev/null +++ b/src/include/migraphx/simplify_qdq.hpp @@ -0,0 +1,25 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_QDQ_HPP +#define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_QDQ_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +/** + * Inserts quantized operators in place of dq->quantizable_op->q + * then removes remaining fake quantization (q->dq pairs) + */ +struct simplify_qdq +{ + std::string name() const { return "simplify_qdq"; } + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/simplify_reshapes.hpp b/src/include/migraphx/simplify_reshapes.hpp index c5c42d52a7761f7c3a8a5c0c46716b80e5584a23..369cdbccb0f59243b5ca9db5d42d797966ddf5f6 100644 --- a/src/include/migraphx/simplify_reshapes.hpp +++ b/src/include/migraphx/simplify_reshapes.hpp @@ -8,7 +8,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; /** * Eliminate redundant reshapes. @@ -16,7 +16,7 @@ struct program; struct simplify_reshapes { std::string name() const { return "simplify_reshapes"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/stream_model.hpp b/src/include/migraphx/stream_model.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f0358a9543c1a19d307077e49fdbf2dedbb49a51 --- /dev/null +++ b/src/include/migraphx/stream_model.hpp @@ -0,0 +1,312 @@ +#ifndef MIGRAPHX_GUARD_STREAM_MODEL_HPP +#define MIGRAPHX_GUARD_STREAM_MODEL_HPP + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +#ifdef DOXYGEN + +/// An interface for target-dependent model for the scheduler +struct stream_model +{ + /// Get the number of streams used in the program + std::size_t get_nstream() const; + /// Get stream for instruction + std::size_t get_stream(instruction_ref ins) const; + /// Get unique event id for instruction + std::size_t get_event_id(instruction_ref ins) const; + /// Returns true if instruction has a stream assignment + bool has_stream(instruction_ref ins) const; + /// Returns true if the instruction records the event + bool is_record(instruction_ref ins) const; + /// Returns true if the instruction wait on the event + bool is_wait(instruction_ref ins) const; +}; + +#else + +#ifdef TYPE_ERASED_DECLARATION + +// Type-erased interface for: +struct stream_model +{ + // + std::size_t get_nstream() const; + // + std::size_t get_stream(instruction_ref ins) const; + // + std::size_t get_event_id(instruction_ref ins) const; + // + bool has_stream(instruction_ref ins) const; + // + bool is_record(instruction_ref ins) const; + // + bool is_wait(instruction_ref ins) const; +}; + +#else + +struct stream_model +{ + // Constructors + stream_model() = default; + + template + stream_model(PrivateDetailTypeErasedT value) + : private_detail_te_handle_mem_var( + std::make_shared::type>>( + std::forward(value))) + { + } + + // Assignment + template + stream_model& operator=(PrivateDetailTypeErasedT value) + { + using std::swap; + auto* derived = this->any_cast(); + if(derived and private_detail_te_handle_mem_var.unique()) + { + *derived = std::forward(value); + } + else + { + stream_model rhs(value); + swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var); + } + return *this; + } + + // Cast + template + PrivateDetailTypeErasedT* any_cast() + { + return this->type_id() == typeid(PrivateDetailTypeErasedT) + ? std::addressof(static_cast::type>&>( + private_detail_te_get_handle()) + .private_detail_te_value) + : nullptr; + } + + template + const typename std::remove_cv::type* any_cast() const + { + return this->type_id() == typeid(PrivateDetailTypeErasedT) + ? std::addressof(static_cast::type>&>( + private_detail_te_get_handle()) + .private_detail_te_value) + : nullptr; + } + + const std::type_info& type_id() const + { + if(private_detail_te_handle_empty()) + return typeid(std::nullptr_t); + else + return private_detail_te_get_handle().type(); + } + + std::size_t get_nstream() const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().get_nstream(); + } + + std::size_t get_stream(instruction_ref ins) const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().get_stream(ins); + } + + std::size_t get_event_id(instruction_ref ins) const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().get_event_id(ins); + } + + bool has_stream(instruction_ref ins) const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().has_stream(ins); + } + + bool is_record(instruction_ref ins) const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().is_record(ins); + } + + bool is_wait(instruction_ref ins) const + { + assert((*this).private_detail_te_handle_mem_var); + return (*this).private_detail_te_get_handle().is_wait(ins); + } + + friend bool is_shared(const stream_model& private_detail_x, + const stream_model& private_detail_y) + { + return private_detail_x.private_detail_te_handle_mem_var == + private_detail_y.private_detail_te_handle_mem_var; + } + + private: + struct private_detail_te_handle_base_type + { + virtual ~private_detail_te_handle_base_type() {} + virtual std::shared_ptr clone() const = 0; + virtual const std::type_info& type() const = 0; + + virtual std::size_t get_nstream() const = 0; + virtual std::size_t get_stream(instruction_ref ins) const = 0; + virtual std::size_t get_event_id(instruction_ref ins) const = 0; + virtual bool has_stream(instruction_ref ins) const = 0; + virtual bool is_record(instruction_ref ins) const = 0; + virtual bool is_wait(instruction_ref ins) const = 0; + }; + + template + struct private_detail_te_handle_type : private_detail_te_handle_base_type + { + template + private_detail_te_handle_type( + PrivateDetailTypeErasedT value, + typename std::enable_if::value>::type* = + nullptr) + : private_detail_te_value(value) + { + } + + template + private_detail_te_handle_type( + PrivateDetailTypeErasedT value, + typename std::enable_if::value, + int>::type* = nullptr) noexcept + : private_detail_te_value(std::move(value)) + { + } + + std::shared_ptr clone() const override + { + return std::make_shared(private_detail_te_value); + } + + const std::type_info& type() const override { return typeid(private_detail_te_value); } + + std::size_t get_nstream() const override { return private_detail_te_value.get_nstream(); } + + std::size_t get_stream(instruction_ref ins) const override + { + + return private_detail_te_value.get_stream(ins); + } + + std::size_t get_event_id(instruction_ref ins) const override + { + + return private_detail_te_value.get_event_id(ins); + } + + bool has_stream(instruction_ref ins) const override + { + + return private_detail_te_value.has_stream(ins); + } + + bool is_record(instruction_ref ins) const override + { + + return private_detail_te_value.is_record(ins); + } + + bool is_wait(instruction_ref ins) const override + { + + return private_detail_te_value.is_wait(ins); + } + + PrivateDetailTypeErasedT private_detail_te_value; + }; + + template + struct private_detail_te_handle_type> + : private_detail_te_handle_type + { + private_detail_te_handle_type(std::reference_wrapper ref) + : private_detail_te_handle_type(ref.get()) + { + } + }; + + bool private_detail_te_handle_empty() const + { + return private_detail_te_handle_mem_var == nullptr; + } + + const private_detail_te_handle_base_type& private_detail_te_get_handle() const + { + assert(private_detail_te_handle_mem_var != nullptr); + return *private_detail_te_handle_mem_var; + } + + private_detail_te_handle_base_type& private_detail_te_get_handle() + { + assert(private_detail_te_handle_mem_var != nullptr); + if(!private_detail_te_handle_mem_var.unique()) + private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); + return *private_detail_te_handle_mem_var; + } + + std::shared_ptr private_detail_te_handle_mem_var; +}; + +template +inline const ValueType* any_cast(const stream_model* x) +{ + return x->any_cast(); +} + +template +inline ValueType* any_cast(stream_model* x) +{ + return x->any_cast(); +} + +template +inline ValueType& any_cast(stream_model& x) +{ + auto* y = x.any_cast::type>(); + if(y == nullptr) + throw std::bad_cast(); + return *y; +} + +template +inline const ValueType& any_cast(const stream_model& x) +{ + const auto* y = x.any_cast::type>(); + if(y == nullptr) + throw std::bad_cast(); + return *y; +} +#endif + +#endif + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/stringutils.hpp b/src/include/migraphx/stringutils.hpp index 05ed70c07007769bfc41675470d380f4966955a8..2f0ab08f0927698a4a3923459a4718420309085c 100644 --- a/src/include/migraphx/stringutils.hpp +++ b/src/include/migraphx/stringutils.hpp @@ -5,11 +5,22 @@ #include #include #include +#include +#include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +#define MIGRAPHX_STRINGIZE_1(...) #__VA_ARGS__ +#define MIGRAPHX_STRINGIZE(...) MIGRAPHX_STRINGIZE_1(__VA_ARGS__) + +template +auto with_char(F f) +{ + return [=](unsigned char c) -> bool { return f(c); }; +} + inline std::string replace_string(std::string subject, const std::string& search, const std::string& replace) { @@ -43,17 +54,29 @@ inline std::string join_strings(Strings strings, const std::string& delim) }); } +inline std::vector split_string(const std::string& s, char delim) +{ + std::vector elems; + std::stringstream ss(s + ' '); + std::string item; + while(std::getline(ss, item, delim)) + { + elems.push_back(item); + } + return elems; +} + template std::string trim(const std::string& s, F f) { auto start = std::find_if_not(s.begin(), s.end(), f); auto last = std::find_if_not(s.rbegin(), std::string::const_reverse_iterator(start), f).base(); - return std::string(start, last); + return {start, last}; } inline std::string trim(const std::string& s) { - return trim(s, [](int c) { return std::isspace(c); }); + return trim(s, [](unsigned char c) { return std::isspace(c); }); } template @@ -83,6 +106,44 @@ inline std::string remove_prefix(std::string s, const std::string& prefix) return s; } +template +inline std::string +interpolate_string(const std::string& input, F f, std::string start = "${", std::string end = "}") +{ + std::string result = ""; + result.reserve(input.size()); + auto it = input.begin(); + while(it != input.end()) + { + auto next_start = std::search(it, input.end(), start.begin(), start.end()); + auto next_end = std::search(next_start, input.end(), end.begin(), end.end()); + result.append(it, next_start); + if(next_start == input.end()) + break; + auto r = f(next_start + start.size(), next_end); + result.append(r.begin(), r.end()); + it = next_end + end.size(); + } + return result; +} +inline std::string interpolate_string(const std::string& input, + const std::unordered_map& vars, + std::string start = "${", + std::string end = "}") +{ + return interpolate_string( + input, + [&](auto start_it, auto last_it) { + auto key = trim({start_it, last_it}); + auto it = vars.find(key); + if(it == vars.end()) + throw std::runtime_error("Unknown key: " + key); + return it->second; + }, + std::move(start), + std::move(end)); +} + template inline std::string to_string_range(Iterator start, Iterator last) { @@ -108,7 +169,8 @@ inline std::string to_string_range(const std::initializer_list& r) } template -inline std::string to_string(const T& x) +inline auto to_string(const T& x) + -> decltype((std::declval() << x), std::string{}) { std::stringstream ss; ss << x; diff --git a/src/include/migraphx/target.hpp b/src/include/migraphx/target.hpp index 8a0a7b65a776ea20dbaed63d8f60293a7629da75..7bd854192a5f333b5598ec599bf9d86d14fcd40c 100644 --- a/src/include/migraphx/target.hpp +++ b/src/include/migraphx/target.hpp @@ -82,20 +82,26 @@ argument copy_from_target(T&, const argument& arg) return arg; } -/* - * Type-erased interface for: - * - * struct target - * { - * std::string name() const; - * std::vector get_passes(context& ctx,const compile_options& options) const; - * context get_context() const; - * argument copy_to(const argument& input) const; - * argument copy_from(const argument& input) const; - * argument allocate(const shape& s) const; - * }; - * - */ +#ifdef TYPE_ERASED_DECLARATION + +// Type-erased interface for: +struct target +{ + // + std::string name() const; + // + std::vector get_passes(context& ctx, const compile_options& options) const; + // + context get_context() const; + // (optional) + argument copy_to(const argument& input) const; + // (optional) + argument copy_from(const argument& input) const; + // (optional) + argument allocate(const shape& s) const; +}; + +#else struct target { @@ -115,11 +121,17 @@ struct target template target& operator=(PrivateDetailTypeErasedT value) { - if(private_detail_te_handle_mem_var.unique()) - *private_detail_te_handle_mem_var = std::forward(value); - else if(!private_detail_te_handle_mem_var) - private_detail_te_handle_mem_var = std::make_shared( - std::forward(value)); + using std::swap; + auto* derived = this->any_cast(); + if(derived and private_detail_te_handle_mem_var.unique()) + { + *derived = std::forward(value); + } + else + { + target rhs(value); + swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var); + } return *this; } @@ -127,7 +139,7 @@ struct target template PrivateDetailTypeErasedT* any_cast() { - return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) + return this->type_id() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type>&>( private_detail_te_get_handle()) @@ -138,7 +150,7 @@ struct target template const typename std::remove_cv::type* any_cast() const { - return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) + return this->type_id() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type>&>( private_detail_te_get_handle()) @@ -376,6 +388,7 @@ inline const ValueType& any_cast(const target& x) throw std::bad_cast(); return *y; } +#endif #endif diff --git a/src/include/migraphx/tensor_view.hpp b/src/include/migraphx/tensor_view.hpp index 27b8be7f9c34e3526a13f42d56bc877c8925f753..c7d63f7d4fdf0bc571c9fe5547abb7f4ecc16072 100644 --- a/src/include/migraphx/tensor_view.hpp +++ b/src/include/migraphx/tensor_view.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -20,10 +21,24 @@ T as_number(T x) inline int32_t as_number(int8_t x) { return static_cast(x); } inline uint32_t as_number(uint8_t x) { return static_cast(x); } +template +struct tensor_view_iterator_read +{ + T* view; + auto& operator()(std::size_t n) const + { + assert(view != nullptr); + return (*view)[n]; + } +}; + template struct tensor_view { using value_type = T; + using iterator = basic_iota_iterator>, std::size_t>; + using const_iterator = + basic_iota_iterator>, std::size_t>; tensor_view() : m_data(nullptr) {} tensor_view(shape s, T* d) : m_data(d), m_shape(std::move(s)) {} @@ -56,12 +71,16 @@ struct tensor_view template {})> const T& operator()(Iterator start, Iterator last) const { + assert(std::distance(start, last) > 0); + assert(std::all_of(start, last, [](auto x) { return x >= 0; })); return m_data[m_shape.index(start, last)]; } template {})> T& operator()(Iterator start, Iterator last) { + assert(std::distance(start, last) > 0); + assert(std::all_of(start, last, [](auto x) { return x >= 0; })); return m_data[m_shape.index(start, last)]; } @@ -101,36 +120,13 @@ struct tensor_view return m_data[m_shape.index(this->size() - 1)]; } - // TODO: Add iterators so it can handle nonstandard tensors - T* begin() - { - assert(this->m_shape.standard() or this->empty()); - return m_data; - } + iterator begin() { return {0, {this}}; } - T* end() - { - assert(this->m_shape.standard() or this->empty()); - if(this->empty()) - return m_data; - else - return m_data + this->size(); - } + iterator end() { return {this->size(), {this}}; } - const T* begin() const - { - assert(this->m_shape.standard() or this->empty()); - return m_data; - } + const_iterator begin() const { return {0, {this}}; } - const T* end() const - { - assert(this->m_shape.standard() or this->empty()); - if(this->empty()) - return m_data; - else - return m_data + this->size(); - } + const_iterator end() const { return {this->size(), {this}}; } template std::vector to_vector() const diff --git a/src/include/migraphx/tf.hpp b/src/include/migraphx/tf.hpp index b74057e0751b355d58467b98146b7fb88e90fccb..0fb07176a9002e96bc7c473df3ceb7cfd5599dda 100644 --- a/src/include/migraphx/tf.hpp +++ b/src/include/migraphx/tf.hpp @@ -7,8 +7,18 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +/// struct to pass in tf options to parser +struct tf_options +{ + bool is_nhwc = false; + unsigned int batch_size = 1; + /// Explicitly specify the dims of an input + std::unordered_map> map_input_dims = {}; + std::vector output_node_names = {}; +}; + /// Create a program from a tf pb file (default is nhwc format) -program parse_tf(const std::string& name, bool is_nhwc); +program parse_tf(const std::string& name, const tf_options& options = tf_options{}); } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/time.hpp b/src/include/migraphx/time.hpp index 637253c2a14ed73d249e96151aa84d896e6fb4ef..4769e3dac92d306980461b3534ecb870f1dff916 100644 --- a/src/include/migraphx/time.hpp +++ b/src/include/migraphx/time.hpp @@ -7,13 +7,23 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +struct timer +{ + std::chrono::time_point start = std::chrono::steady_clock::now(); + template + auto record() const + { + auto finish = std::chrono::steady_clock::now(); + return std::chrono::duration_cast(finish - start).count(); + } +}; + template auto time(F f) { - auto start = std::chrono::steady_clock::now(); + timer t{}; f(); - auto finish = std::chrono::steady_clock::now(); - return std::chrono::duration_cast(finish - start).count(); + return t.record(); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/tmp_dir.hpp b/src/include/migraphx/tmp_dir.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4e43b5f6b80a9217c27d071fed855d3f76ff32de --- /dev/null +++ b/src/include/migraphx/tmp_dir.hpp @@ -0,0 +1,26 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_TMP_DIR_HPP +#define MIGRAPHX_GUARD_RTGLIB_TMP_DIR_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct tmp_dir +{ + fs::path path; + tmp_dir(const std::string& prefix = ""); + + void execute(const std::string& exe, const std::string& args) const; + + tmp_dir(tmp_dir const&) = delete; + tmp_dir& operator=(tmp_dir const&) = delete; + + ~tmp_dir(); +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/tune_axis.hpp b/src/include/migraphx/tune_axis.hpp new file mode 100644 index 0000000000000000000000000000000000000000..21e59cf71a776e6aeb131945c9c4a825f656fcf7 --- /dev/null +++ b/src/include/migraphx/tune_axis.hpp @@ -0,0 +1,24 @@ +#ifndef MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP +#define MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +inline int tune_axis(const int n_dim, const int axis, const std::string& op_name = "OPERATOR") +{ + if(axis >= n_dim || std::abs(axis) > n_dim) + { + MIGRAPHX_THROW(to_upper(op_name) + ": axis is out of range."); + } + return (axis < 0) ? axis + n_dim : axis; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/type_name.hpp b/src/include/migraphx/type_name.hpp index e170deb05e36b53a352017323ab68d38fa527585..5e31437c1a106a87fa992d093918c1935692a0a8 100644 --- a/src/include/migraphx/type_name.hpp +++ b/src/include/migraphx/type_name.hpp @@ -8,30 +8,32 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { template -const std::string& get_type_name() +std::string compute_type_name() { - static std::string name; - - if(name.empty()) - { + std::string name; #ifdef _MSC_VER - name = typeid(PrivateMigraphTypeNameProbe).name(); - name = name.substr(7); + name = typeid(PrivateMigraphTypeNameProbe).name(); + name = name.substr(7); #else - const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT + const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT - name = __PRETTY_FUNCTION__; + name = __PRETTY_FUNCTION__; - auto begin = name.find(parameter_name) + sizeof(parameter_name); + auto begin = name.find(parameter_name) + sizeof(parameter_name); #if(defined(__GNUC__) && !defined(__clang__) && __GNUC__ == 4 && __GNUC_MINOR__ < 7) - auto length = name.find_last_of(",") - begin; + auto length = name.find_last_of(",") - begin; #else - auto length = name.find_first_of("];", begin) - begin; + auto length = name.find_first_of("];", begin) - begin; #endif - name = name.substr(begin, length); + name = name.substr(begin, length); #endif - } + return name; +} +template +const std::string& get_type_name() +{ + static const std::string name = compute_type_name(); return name; } diff --git a/src/include/migraphx/type_traits.hpp b/src/include/migraphx/type_traits.hpp index c5b789c6eaf731096c6118503aa8e9331f8160b6..618ccd446345ca495d4a46f602ac220ad4814d36 100644 --- a/src/include/migraphx/type_traits.hpp +++ b/src/include/migraphx/type_traits.hpp @@ -30,6 +30,12 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half) +template +using accumulator_type = + std::conditional_t{}, + double, + std::conditional_t{}, std::int64_t, std::uint64_t>>; + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/include/migraphx/value.hpp b/src/include/migraphx/value.hpp new file mode 100644 index 0000000000000000000000000000000000000000..27051d1b49aa28c86ff92cbe83861da7678561da --- /dev/null +++ b/src/include/migraphx/value.hpp @@ -0,0 +1,455 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_VALUE_HPP +#define MIGRAPHX_GUARD_RTGLIB_VALUE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct value_base_impl; + +template +struct value_converter +{ + template + static auto apply(const std::string& x) + -> decltype((std::declval() >> std::declval()), To{}) + { + To result; + std::stringstream ss; + ss.str(x); + ss >> result; + if(ss.fail()) + throw std::runtime_error("Failed to parse: " + x); + return result; + } + + template {})> + static To apply(const From& x) + { + return To(x); + } +}; + +template +struct value_converter{})> +{ + template + static auto apply(const From& x) + -> decltype(static_cast(value_converter>::apply(x))) + { + return static_cast(value_converter>::apply(x)); + } +}; + +template <> +struct value_converter +{ + static const std::string& apply(const std::string& x) { return x; } + + template + static auto apply(const From& x) + -> decltype(std::declval() << x, std::string()) + { + std::stringstream ss; + ss << x; + if(ss.fail()) + throw std::runtime_error("Failed to parse"); + return ss.str(); + } +}; + +template +struct value_converter> +{ + template + static auto apply(const std::pair& x) + -> decltype(std::pair(x.first, value_converter::apply(x.second))) + { + return std::pair(x.first, value_converter::apply(x.second)); + } +}; + +template +To try_convert_value(const From& x); + +namespace detail { +template +To try_convert_value_impl(rank<1>, const std::pair& x) +{ + return try_convert_value(x.second); +} + +template +auto try_convert_value_impl(rank<2>, const From& x) -> decltype(value_converter::apply(x)) +{ + return value_converter::apply(x); +} + +template {})> +To try_convert_value_impl(rank<3>, std::nullptr_t) +{ + MIGRAPHX_THROW("Incompatible values: null -> " + get_type_name()); +} + +template +To try_convert_value_impl(rank<0>, const From& x) +{ + MIGRAPHX_THROW("Incompatible values: " + get_type_name(x) + " -> " + get_type_name()); +} +} // namespace detail + +template +To try_convert_value(const From& x) +{ + return detail::try_convert_value_impl(rank<3>{}, x); +} + +struct value +{ +// clang-format off +#define MIGRAPHX_VISIT_VALUE_TYPES(m) \ + m(int64, std::int64_t) \ + m(uint64, std::uint64_t) \ + m(float, double) \ + m(string, std::string) \ + m(bool, bool) \ + m(binary, value::binary) + // clang-format on + enum type_t + { +#define MIGRAPHX_VALUE_GENERATE_ENUM_TYPE(vt, cpp_type) vt##_type, + MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_ENUM_TYPE) object_type, + array_type, + null_type +#undef MIGRAPHX_VALUE_GENERATE_ENUM_TYPE + }; + using iterator = value*; + using const_iterator = const value*; + using value_type = value; + using key_type = std::string; + using mapped_type = value; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + using array = std::vector; + using object = std::unordered_map; + struct binary : std::vector + { + using base = std::vector; + binary() {} + template ().begin()) == 1)> + explicit binary(const Container& c) : base(c.begin(), c.end()) + { + } + template + binary(T* data, std::size_t s) : base(data, data + s) + { + } + explicit binary(std::size_t s) : base(s) {} + }; + + value() = default; + + value(const value& rhs); + value& operator=(value rhs); + value(const std::string& pkey, const value& rhs); + + value(const std::initializer_list& i); + value(const std::vector& v, bool array_on_empty = true); + value(const std::unordered_map& m); + value(const std::string& pkey, const std::vector& v, bool array_on_empty = true); + value(const std::string& pkey, const std::unordered_map& m); + value(const std::string& pkey, std::nullptr_t); + value(std::nullptr_t); + + value(const char* i); + value(const std::string& pkey, const char* i); + +#define MIGRAPHX_VALUE_GENERATE_DECL_METHODS(vt, cpp_type) \ + value(cpp_type i); \ + value(const std::string& pkey, cpp_type i); \ + value& operator=(cpp_type rhs); \ + bool is_##vt() const; \ + const cpp_type& get_##vt() const; \ + const cpp_type* if_##vt() const; + MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DECL_METHODS) + + template + using literal_to_string = std::conditional_t<(std::is_convertible{} and + std::is_convertible{}), + std::string, + T>; + + template + using pick_numeric = std::conditional_t< + std::is_floating_point{}, + double, + std::conditional_t{}, + std::int64_t, + std::conditional_t{}, std::uint64_t, T>>>; + + template + using pick = pick_numeric{}, + std::underlying_type, + std::enable_if>::type>; + + template + using is_pickable = + bool_c<((std::is_arithmetic{} or std::is_enum{}) and not std::is_pointer{})>; + + template + using range_value = std::decay_t().end(), *std::declval().begin())>; + + template + using is_generic_range = + bool_c<(std::is_convertible, value>{} and + not std::is_convertible{} and not std::is_convertible{})>; + + template {})> + value(const T& r) : value(from_values(r)) + { + } + + template {})> + value(const std::string& pkey, const T& r) : value(pkey, from_values(r)) + { + } + + template {})> + value(T i) : value(static_cast>(i)) + { + } + template {})> + value(const std::string& pkey, T i) : value(pkey, static_cast>(i)) + { + } + template + value(const std::pair& p) : value(p.first, p.second) + { + } + template {})> + value& operator=(T rhs) + { + return *this = static_cast>(rhs); // NOLINT + } + template {})> + value& operator=(T rhs) + { + return *this = from_values(rhs); // NOLINT + } + + value& operator=(const char* c); + value& operator=(std::nullptr_t); + value& operator=(const std::initializer_list& i); + + bool is_array() const; + const std::vector& get_array() const; + const std::vector* if_array() const; + + bool is_object() const; + const std::vector& get_object() const; + const std::vector* if_object() const; + + bool is_null() const; + + const std::string& get_key() const; + value* find(const std::string& pkey); + const value* find(const std::string& pkey) const; + bool contains(const std::string& pkey) const; + std::size_t size() const; + bool empty() const; + const value* data() const; + value* data(); + value* begin(); + const value* begin() const; + value* end(); + const value* end() const; + + value& front(); + const value& front() const; + value& back(); + const value& back() const; + value& at(std::size_t i); + const value& at(std::size_t i) const; + value& at(const std::string& pkey); + const value& at(const std::string& pkey) const; + value& operator[](std::size_t i); + const value& operator[](std::size_t i) const; + value& operator[](const std::string& pkey); + + void clear(); + void resize(std::size_t n); + void resize(std::size_t n, const value& v); + + std::pair insert(const value& v); + value* insert(const value* pos, const value& v); + + template + std::pair emplace(Ts&&... xs) + { + return insert(value(std::forward(xs)...)); + } + + template + value* emplace(const value* pos, Ts&&... xs) + { + return insert(pos, value(std::forward(xs)...)); + } + + void push_back(const value& v) { insert(end(), v); } + + void push_front(const value& v) { insert(begin(), v); } + + value with_key(const std::string& pkey) const; + value without_key() const; + + template + void visit(Visitor v) const + { + switch(this->get_type()) + { + case null_type: { + std::nullptr_t null{}; + if(this->key.empty()) + v(null); + else + v(std::make_pair(this->get_key(), std::ref(null))); + return; + } +#define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \ + case vt##_type: { \ + if(this->key.empty()) \ + v(this->get_##vt()); \ + else \ + v(std::make_pair(this->get_key(), std::ref(this->get_##vt()))); \ + return; \ + } + MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE) + MIGRAPHX_VALUE_GENERATE_CASE(array, ) + MIGRAPHX_VALUE_GENERATE_CASE(object, ) + } + MIGRAPHX_THROW("Unknown type"); + } + + // Visit value without key + template + void visit_value(Visitor v) const + { + switch(this->get_type()) + { + case null_type: { + std::nullptr_t null{}; + v(null); + return; + } +#define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \ + case vt##_type: { \ + v(this->get_##vt()); \ + return; \ + } + MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE) + MIGRAPHX_VALUE_GENERATE_CASE(array, ) + MIGRAPHX_VALUE_GENERATE_CASE(object, ) + } + MIGRAPHX_THROW("Unknown type"); + } + + template + To to() const + { + To result; + this->visit([&](auto y) { result = try_convert_value(y); }); + return result; + } + + template + literal_to_string value_or(const To& default_value) const + { + if(this->is_null()) + return default_value; + return to>(); + } + + template + std::vector to_vector() const + { + std::vector result; + const auto& values = is_object() ? get_object() : get_array(); + result.reserve(values.size()); + std::transform(values.begin(), values.end(), std::back_inserter(result), [&](auto v) { + return v.template to(); + }); + return result; + } + + template + literal_to_string get(const std::string& pkey, const To& default_value) const + { + const auto* v = find(pkey); + if(v == this->end()) + return default_value; + return v->to>(); + } + + template + std::vector get(const std::string& pkey, const std::vector& default_value) const + { + const auto* v = find(pkey); + if(v == this->end()) + return default_value; + return v->to_vector(); + } + + template + std::vector> get(const std::string& pkey, + const std::initializer_list& default_value) const + { + return get(pkey, + std::vector>{default_value.begin(), default_value.end()}); + } + + friend bool operator==(const value& x, const value& y); + friend bool operator!=(const value& x, const value& y); + friend bool operator<(const value& x, const value& y); + friend bool operator<=(const value& x, const value& y); + friend bool operator>(const value& x, const value& y); + friend bool operator>=(const value& x, const value& y); + + friend std::ostream& operator<<(std::ostream& os, const value& d); + + void debug_print(bool show_type = false) const; + + type_t get_type() const; + + private: + template + std::vector from_values(const T& r) + { + std::vector v; + std::transform( + r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); }); + return v; + } + std::shared_ptr x; + std::string key; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/verify.hpp b/src/include/migraphx/verify.hpp old mode 100644 new mode 100755 index 9266be165b208d298fc6def1cce5fc18b65a7018..5d508f822a674990a3234c3af91cb11a2bc1120c --- a/src/include/migraphx/verify.hpp +++ b/src/include/migraphx/verify.hpp @@ -147,7 +147,7 @@ std::size_t mismatch_diff(R1&& r1, R2&& r2, T diff) } template -double rms_range(R1&& r1, R2&& r2) +double rms_range(const R1& r1, const R2& r2) { std::size_t n = range_distance(r1); if(n == range_distance(r2)) @@ -164,11 +164,10 @@ double rms_range(R1&& r1, R2&& r2) } template -bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = nullptr) +bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out_error = nullptr) { double threshold = std::numeric_limits>::epsilon() * tolerance; auto error = rms_range(r1, r2); - // cppcheck-suppress uninitvar if(out_error != nullptr) *out_error = error; return error <= threshold; diff --git a/src/include/migraphx/verify_args.hpp b/src/include/migraphx/verify_args.hpp index 3187e4e3b605cc25c38719eb984fd0ee02e726e3..4c3a96fb5e70435d5a81c90608b93a3fc112a8c7 100644 --- a/src/include/migraphx/verify_args.hpp +++ b/src/include/migraphx/verify_args.hpp @@ -8,81 +8,10 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -inline bool verify_args(const std::string& name, - const argument& cpu_arg, - const argument& gpu_arg, - double tolerance = 80) -{ - bool passed = true; - visit_all(cpu_arg, gpu_arg)([&](auto cpu, auto gpu) { - double error; - passed = verify_range(cpu, gpu, tolerance, &error); - if(not passed) - { - // TODO: Check for nans - std::cout << "FAILED: " << name << std::endl; - std::cout << "error: " << error << std::endl; - if(cpu.size() < 32) - std::cout << "cpu:" << cpu << std::endl; - if(gpu.size() < 32) - std::cout << "gpu:" << gpu << std::endl; - if(range_zero(cpu)) - std::cout << "Cpu data is all zeros" << std::endl; - if(range_zero(gpu)) - std::cout << "Gpu data is all zeros" << std::endl; - - auto mxdiff = max_diff(cpu, gpu); - std::cout << "Max diff: " << mxdiff << std::endl; - - auto idx = mismatch_idx(cpu, gpu, float_equal); - if(idx < range_distance(cpu)) - { - std::cout << "Mismatch at " << idx << ": " << cpu[idx] << " != " << gpu[idx] - << std::endl; - } - - auto cpu_nan_idx = find_idx(cpu, not_finite); - if(cpu_nan_idx >= 0) - std::cout << "Non finite number found in cpu at " << cpu_nan_idx << ": " - << cpu[cpu_nan_idx] << std::endl; - - auto gpu_nan_idx = find_idx(gpu, not_finite); - if(gpu_nan_idx >= 0) - std::cout << "Non finite number found in gpu at " << gpu_nan_idx << ": " - << gpu[gpu_nan_idx] << std::endl; - std::cout << std::endl; - } - else - { - if(range_zero(cpu)) - std::cout << "Cpu data is all zeros" << std::endl; - if(range_zero(gpu)) - std::cout << "Gpu data is all zeros" << std::endl; - - // auto mxdiff = max_diff(cpu, gpu); - // std::cout << "Max diff: " << mxdiff << std::endl; - - // auto idx = mismatch_idx(cpu, gpu, float_equal); - // if(idx < range_distance(cpu)) - // { - // std::cout << "Mismatch at " << idx << ": " << cpu[idx] << " != " << gpu[idx] - // << std::endl; - // } - - auto cpu_nan_idx = find_idx(cpu, not_finite); - if(cpu_nan_idx >= 0) - std::cout << "Non finite number found in cpu at " << cpu_nan_idx << ": " - << cpu[cpu_nan_idx] << std::endl; - - auto gpu_nan_idx = find_idx(gpu, not_finite); - if(gpu_nan_idx >= 0) - std::cout << "Non finite number found in gpu at " << gpu_nan_idx << ": " - << gpu[gpu_nan_idx] << std::endl; - // std::cout << std::endl; - } - }); - return passed; -} +bool verify_args(const std::string& name, + const argument& ref_arg, + const argument& target_arg, + double tolerance = 80); } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/inline_module.cpp b/src/inline_module.cpp new file mode 100755 index 0000000000000000000000000000000000000000..01c88f6447b11353cd3fe2dd4d4bc802346150aa --- /dev/null +++ b/src/inline_module.cpp @@ -0,0 +1,45 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +static void inline_submodule(module& m, instruction_ref ins, bool cond) +{ + const auto& mod_inputs = ins->module_inputs(); + module_ref smod = cond ? mod_inputs.at(0) : mod_inputs.at(1); + auto mod_outputs = m.insert_module_instructions(ins, smod); + + auto ins_outputs = ins->outputs(); + assert(mod_outputs.size() >= ins_outputs.size()); + for(const auto& out : ins_outputs) + { + auto val = out->get_operator().to_value(); + assert(val.contains("index")); + auto index = val.at("index").to(); + m.replace_instruction(out, mod_outputs.at(index)); + } +} + +void inline_module::apply(module& m) const +{ + for(auto ins : iterator_for(m)) + { + if(ins->name() != "if") + continue; + + auto arg_cond = ins->inputs().front()->eval(); + if(not arg_cond.empty()) + { + bool cond = arg_cond.at(); + inline_submodule(m, ins, cond); + } + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/insert_pad.cpp b/src/insert_pad.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e37a5e4b403659d312b1c3ee4a3075fd87402179 --- /dev/null +++ b/src/insert_pad.cpp @@ -0,0 +1,91 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +static void update_op(const instruction_ref& input, const instruction_ref& ins, module& m) +{ + auto op = ins->get_operator(); + auto val = op.to_value(); + auto op_padding = val.at("padding").to_vector(); + + auto kdims = input->get_shape().lens().size() - 2; + if(std::equal(op_padding.begin(), + op_padding.begin() + kdims, + op_padding.begin() + kdims, + op_padding.end())) + return; + + std::vector padding(input->get_shape().lens().size() * 2, 0); + std::vector pads_l(op_padding.begin(), op_padding.begin() + kdims); + std::vector pads_r(op_padding.begin() + kdims, op_padding.end()); + op_padding = std::vector(kdims * 2, 0); + op.from_value({{"padding", op_padding}}); + + std::copy(pads_l.begin(), pads_l.end(), padding.begin() + 2); + std::copy(pads_r.begin(), pads_r.end(), padding.begin() + kdims + 2 + 2); + + auto pad_op = m.insert_instruction(ins, op::pad{padding}, input); + + auto new_inputs = ins->inputs(); + new_inputs.front() = pad_op; + + m.replace_instruction(ins, op, new_inputs); +} + +static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m) +{ + auto op = any_cast(ins->get_operator()); + if(op.mode == op::pooling_mode::average) + { + return; + } + auto kdims = input->get_shape().lens().size() - 2; + if(std::equal(op.padding.begin(), + op.padding.begin() + kdims, + op.padding.begin() + kdims, + op.padding.end())) + return; + + std::vector padding(input->get_shape().lens().size() * 2, 0); + std::vector pads_l(op.padding.begin(), op.padding.begin() + kdims); + std::vector pads_r(op.padding.begin() + kdims, op.padding.end()); + op.padding = std::vector(kdims * 2, 0); + std::copy(pads_l.begin(), pads_l.end(), padding.begin() + 2); + std::copy(pads_r.begin(), pads_r.end(), padding.begin() + kdims + 2 + 2); + + // maxpool uses lowest value for padding + float pad_val = std::numeric_limits::lowest(); + auto pad_op = m.insert_instruction(ins, op::pad{padding, pad_val}, input); + + auto new_inputs = ins->inputs(); + new_inputs.front() = pad_op; + + m.replace_instruction(ins, op, new_inputs); +} + +void insert_pad::apply(module& m) const +{ + for(auto ins : iterator_for(m)) + { + const std::string& op_name = ins->name(); + if(op_name != "convolution" and op_name != "im2col" and op_name != "pooling") + continue; + auto input = ins->inputs().front(); + if(op_name == "convolution" or op_name == "im2col") + update_op(input, ins, m); + else if(op_name == "pooling") + update_pooling(input, ins, m); + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/instruction.cpp b/src/instruction.cpp index d515957317eca8a59fa86fe367550088564f1e88..eda5a4794c1470635752cfba0a3847a7ec9ddd7c 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -1,15 +1,34 @@ #include #include #include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +template +auto equal_to(const T& x) +{ + return [&](const T& y) { return std::equal_to{}(x, y); }; +} + instruction::instruction(operation o, shape r, std::vector args) : op(std::move(o)), result(std::move(r)), arguments(std::move(args)) { } +instruction::instruction(operation o, + shape r, + std::vector args, + std::vector modules) + : op(std::move(o)), + result(std::move(r)), + arguments(std::move(args)), + module_args(std::move(modules)) +{ +} + instruction::instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) { @@ -22,6 +41,9 @@ void instruction::replace(const shape& r) result = r; for(auto&& ins : output) { + if(ins->name() == "@return") + continue; + assert(ins->name().front() != '@'); ins->recompute_shape(); } @@ -30,11 +52,12 @@ void instruction::replace(const shape& r) void instruction::replace(operation o) { - op = std::move(o); + normalized = false; + op = std::move(o); recompute_shape(); } -void instruction::recompute_shape() { replace(compute_shape(op, arguments)); } +void instruction::recompute_shape() { replace(compute_shape(op, arguments, module_args)); } void instruction::clear_arguments() { @@ -43,6 +66,7 @@ void instruction::clear_arguments() arg->remove_output(*this); } arguments.clear(); + module_args.clear(); } bool operator==(const instruction& i, instruction_ref ref) @@ -50,12 +74,17 @@ bool operator==(const instruction& i, instruction_ref ref) return std::addressof(i) == std::addressof(*ref); } -bool instruction::valid(instruction_ref start) const +bool instruction::valid(instruction_ref start, bool check_order) const { return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) { auto self = std::find(i->outputs().begin(), i->outputs().end(), *this); - return self != i->outputs().end() && - std::distance(start, i) < std::distance(start, *self); + bool ret = self != i->outputs().end(); + if(check_order) + { + // check arguments for this instruction before this instruction + ret = ret and (std::distance(start, i) < std::distance(start, *self)); + } + return ret; }); } @@ -70,18 +99,24 @@ bool instruction::valid() const { computed = result; } + else if(op.name() == "@return") + { + computed = {}; + } else { try { - computed = compute_shape(op, arguments); + computed = compute_shape(op, arguments, module_args); } catch(migraphx::exception&) { return false; } } - return result == computed && std::all_of(output.begin(), output.end(), [&](instruction_ref i) { + + return (result == computed) && + std::all_of(output.begin(), output.end(), [&](instruction_ref i) { return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end(); }); } @@ -99,11 +134,19 @@ std::string instruction::name() const { return op.name(); } const std::vector& instruction::inputs() const { return arguments; } +const std::vector& instruction::module_inputs() const { return module_args; } + const std::vector& instruction::outputs() const { return output; } bool operator==(const instruction& x, const instruction& y) { - if(std::tie(x.result, x.op, x.arguments) != std::tie(y.result, y.op, y.arguments)) + if(not std::equal(x.arguments.begin(), + x.arguments.end(), + y.arguments.begin(), + y.arguments.end(), + std::equal_to{})) + return false; + if(std::tie(x.result, x.op, x.module_args) != std::tie(y.result, y.op, y.module_args)) return false; if(x.name() == "@literal") return x.lit == y.lit; @@ -120,7 +163,7 @@ bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); void instruction::add_output(instruction_ref ins) { - if(std::find(output.begin(), output.end(), ins) == output.end()) + if(std::find_if(output.begin(), output.end(), equal_to(ins)) == output.end()) output.push_back(ins); } @@ -139,6 +182,13 @@ void instruction::replace_argument(instruction_ref ins, ins->recompute_shape(); } +void instruction::replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod) +{ + ins->replace_mod_argument(old, new_mod); + backreference(ins); + ins->recompute_shape(); +} + void instruction::replace(instruction_ref ins, operation o, const shape& r, @@ -148,26 +198,87 @@ void instruction::replace(instruction_ref ins, backreference(ins); } +void instruction::replace(instruction_ref ins, + operation o, + const shape& r, + std::vector args, + std::vector module_args) +{ + ins->replace(std::move(o), r, std::move(args), std::move(module_args)); + backreference(ins); +} + void instruction::replace(operation o, const shape& r, std::vector args) { - op = std::move(o); + normalized = false; + op = std::move(o); replace(r); replace(std::move(args)); } +void instruction::replace(operation o, + const shape& r, + std::vector args, + std::vector mdl_args) +{ + op = std::move(o); + replace(r); + replace(std::move(args), std::move(mdl_args)); +} + +void instruction::replace_refs( + instruction_ref ins, + const std::unordered_map& map_insts, + const std::unordered_map& map_mods) +{ + const auto& args = ins->inputs(); + for(const auto& arg : args) + { + if(contains(map_insts, arg)) + { + instruction::replace_argument(ins, arg, map_insts.at(arg)); + } + } + + const auto& module_args = ins->module_inputs(); + if(module_args.empty()) + return; + + for(const auto& mod : module_args) + { + if(contains(map_mods, mod)) + { + instruction::replace_mod_argument(ins, mod, map_mods.at(mod)); + } + } +} + void instruction::replace(std::vector args) { clear_arguments(); arguments = std::move(args); } +void instruction::replace(std::vector args, std::vector mdl_args) +{ + clear_arguments(); + arguments = std::move(args); + module_args = std::move(mdl_args); +} + void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) { - assert(std::any_of(arguments.begin(), arguments.end(), [&](auto i) { return i == old; })); - std::replace(arguments.begin(), arguments.end(), old, new_ins); + assert(std::any_of(arguments.begin(), arguments.end(), equal_to(old))); + std::replace_if(arguments.begin(), arguments.end(), equal_to(old), new_ins); old->remove_output(*this); } +void instruction::replace_mod_argument(module_ref old, module_ref new_mod) +{ + assert(std::any_of(module_args.begin(), module_args.end(), [&](auto i) { return i == old; })); + std::replace(module_args.begin(), module_args.end(), old, new_mod); +} + bool instruction::can_eval() const { if(op.name() == "@literal") @@ -200,7 +311,7 @@ argument instruction::eval(bool check_eval) const this->inputs().end(), std::back_inserter(args), [](auto arg) { return arg->eval(false); }); - return op.compute(result, args); + return normalized_operator().compute(result, args); } return {}; } @@ -211,6 +322,82 @@ void instruction::finalize(context& ctx) this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs())); } +void instruction::print(std::ostream& os, + instruction_ref ins, + const std::unordered_map& names) +{ + os << names.at(ins) << " = "; + + os << ins->get_operator(); + + if(ins->name() == "@literal") + { + if(ins->get_literal().get_shape().elements() > 10) + os << "{ ... }"; + else + os << "{" << ins->get_literal() << "}"; + } + + if(!ins->inputs().empty()) + { + char delim = '('; + for(auto&& arg : ins->inputs()) + { + std::string arg_name = contains(names, arg) ? names.at(arg) : "?"; + os << delim << arg_name; + delim = ','; + } + os << ")"; + } + + // print module inputs + if(!ins->module_inputs().empty()) + { + std::string delim = ", ["; + for(auto&& mod_arg : ins->module_inputs()) + { + os << delim << mod_arg->name(); + delim = ", "; + } + os << "]"; + } + + // skip return instruction shape + if(ins->name() != "@return") + os << " -> " << ins->get_shape(); +} + +static void debug_name(std::ostream& os, const instruction& ins) +{ + if(ins.name() == "@literal") + { + os << "@literal"; + if(ins.get_literal().get_shape().elements() > 10) + os << "{ ... }"; + else + os << "{" << ins.get_literal() << "}"; + } + else + { + os << ins.get_operator(); + } +} + +void instruction::debug_print() const +{ + debug_name(std::cout, *this); + std::string delim = "("; + for(auto arg : this->inputs()) + { + std::cout << delim; + debug_name(std::cout, *arg); + delim = ", "; + } + if(not this->inputs().empty()) + std::cout << ")"; + std::cout << " -> " << this->get_shape() << std::endl; +} + instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow) { auto i = ins->get_operator().output_alias(to_shapes(ins->inputs())); @@ -221,6 +408,27 @@ instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow) return get_output_alias(ins->inputs().at(i)); } +void instruction::set_normalized(bool value) { normalized = value; } + +bool instruction::is_normalized() const { return normalized; } + +bool instruction::need_normalization() const +{ + return this->get_operator().need_normalization() and not normalized; +} + +operation instruction::normalized_operator() const +{ + operation o = this->get_operator(); + if(this->need_normalization()) + { + auto lens = this->inputs().front()->get_shape().lens(); + if(!normalize_attributes(o, lens)) + return this->get_operator(); + } + return o; +} + std::vector to_shapes(const std::vector& args) { std::vector shapes(args.size()); @@ -234,5 +442,38 @@ shape compute_shape(const operation& op, const std::vector& arg return op.compute_shape(to_shapes(args)); } +shape compute_shape(const operation& op, + const std::vector& args, + const std::vector& mods) +{ + if(mods.empty()) + { + return op.compute_shape(to_shapes(args)); + } + else + { + return op.compute_shape(to_shapes(args), mods); + } +} + +std::vector try_compute_shape(const operation& op, const std::vector& inputs) +{ + shape new_shape; + try + { + new_shape = op.compute_shape(inputs); + } + catch(...) + { + return {}; + } + return {new_shape}; +} + +migraphx::instruction* as_address(const instruction_ref& ins) noexcept +{ + return std::addressof(*ins); +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/json.cpp b/src/json.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c2f2bf3fd22135c632bc1dc3beac0417de70c949 --- /dev/null +++ b/src/json.cpp @@ -0,0 +1,154 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +using json = nlohmann::json; + +void value_to_json(const value& val, json& j); +migraphx::value value_from_json(const json& j); + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +namespace nlohmann { +template <> +struct adl_serializer +{ + static void to_json(json& j, const migraphx::value& val) { migraphx::value_to_json(val, j); } + + static void from_json(const json& j, migraphx::value& val) + { + val = migraphx::value_from_json(j); + } +}; +} // namespace nlohmann + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +using json = nlohmann::json; + +template +void value_to_json(const T& x, json& j) +{ + j = x; +} + +void value_to_json(const value::binary& x, json& j) +{ + j = json::object(); + j["bytes"] = std::vector(x.begin(), x.end()); +} + +void value_to_json(const std::vector& x, json& j) +{ + for(const auto& v : x) + { + if(v.get_key().empty()) + { + j.push_back(v); + } + else + { + j[v.get_key()] = v.without_key(); + } + } +} + +void value_to_json(std::nullptr_t&, json& j) { j = {}; } + +void value_to_json(const value& val, json& j) +{ + if(val.is_array()) + { + j = json::array(); + } + + if(val.is_object()) + { + j = json::object(); + } + + val.visit([&](auto v) { value_to_json(v, j); }); +} + +migraphx::value value_from_json(const json& j) +{ + migraphx::value val; + json::value_t type = j.type(); + switch(type) + { + case json::value_t::null: val = nullptr; break; + + case json::value_t::boolean: val = j.get(); break; + + case json::value_t::number_float: val = j.get(); break; + + case json::value_t::number_integer: val = j.get(); break; + + case json::value_t::number_unsigned: val = j.get(); break; + + case json::value_t::string: val = j.get(); break; + + case json::value_t::array: + val = migraphx::value::array{}; + std::transform(j.begin(), j.end(), std::back_inserter(val), [&](const json& jj) { + return jj.get(); + }); + break; + + case json::value_t::object: + if(j.contains("bytes") and j.size() == 1) + { + val = migraphx::value::binary{j["bytes"].get>()}; + } + else + { + val = migraphx::value::object{}; + for(const auto& item : j.items()) + { + const auto& key = item.key(); + const json& jv = item.value(); + val[key] = jv.get(); + } + } + break; + + case json::value_t::binary: MIGRAPHX_THROW("Convert JSON to Value: binary type not supported!"); + case json::value_t::discarded: + MIGRAPHX_THROW("Convert JSON to Value: discarded type not supported!"); + } + + return val; +} + +std::string to_json_string(const value& val) +{ + json j = val; + return j.dump(); +} + +std::string to_pretty_json_string(const value& val, std::size_t indent) +{ + json j = val; + return j.dump(indent); +} + +migraphx::value from_json_string(const char* str, std::size_t size) +{ + json j = json::parse(str, str + size); + return j.get(); +} +migraphx::value from_json_string(const std::string& str) +{ + json j = json::parse(str); + return j.get(); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/load_save.cpp b/src/load_save.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e8dc6117317d80fbfaf76b9731e87f97966365d4 --- /dev/null +++ b/src/load_save.cpp @@ -0,0 +1,62 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +program load(const std::string& filename, const file_options& options) +{ + return load_buffer(read_buffer(filename), options); +} +program load_buffer(const std::vector& buffer, const file_options& options) +{ + return load_buffer(buffer.data(), buffer.size(), options); +} +program load_buffer(const char* buffer, std::size_t size, const file_options& options) +{ + program p; + if(options.format == "msgpack") + { + p.from_value(from_msgpack(buffer, size)); + } + else if(options.format == "json") + { + p.from_value(from_json_string(buffer, size)); + } + else + { + MIGRAPHX_THROW("Unknown format: " + options.format); + } + return p; +} + +void save(const program& p, const std::string& filename, const file_options& options) +{ + write_buffer(filename, save_buffer(p, options)); +} +std::vector save_buffer(const program& p, const file_options& options) +{ + value v = p.to_value(); + std::vector buffer; + if(options.format == "msgpack") + { + buffer = to_msgpack(v); + } + else if(options.format == "json") + { + std::string s = to_json_string(v); + buffer = std::vector(s.begin(), s.end()); + } + else + { + MIGRAPHX_THROW("Unknown format: " + options.format); + } + return buffer; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/make_op.cpp b/src/make_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a5f0bc75a7b8aaf68341ba9d64025d1a6d4dfb41 --- /dev/null +++ b/src/make_op.cpp @@ -0,0 +1,45 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +operation make_op(const std::string& name) { return load_op(name); } + +template +operation make_op_generic(const std::string& name, F for_each) +{ + auto op = load_op(name); + // Merge values + value w = op.to_value(); + for_each([&](const auto& key, const auto& x) { + if(not w.contains(key)) + // NOLINTNEXTLINE(performance-inefficient-string-concatenation) + MIGRAPHX_THROW("No key '" + key + "' in " + name); + w.at(key) = x; + }); + op.from_value(w); + return op; +} + +operation make_op(const std::string& name, + const std::initializer_list>& v) +{ + return make_op_generic(name, [&](auto f) { + for(auto&& [key, x] : v) + f(key, x); + }); +} + +operation make_op_from_value(const std::string& name, const value& v) +{ + if(not(v.is_object() or (v.empty() and v.is_array()))) + MIGRAPHX_THROW("Value is not an object for make_op: " + name); + return make_op_generic(name, [&](auto f) { + for(auto&& x : v) + f(x.get_key(), x.without_key()); + }); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/module.cpp b/src/module.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d12d4ef6d528d47dc1269dd4c860579011c7b399 --- /dev/null +++ b/src/module.cpp @@ -0,0 +1,895 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_FINALIZE) + +struct module_impl +{ + // A list is used to keep references to an instruction stable + std::list instructions; + std::unordered_set instruction_set; + std::string name; + uint32_t nparams = 0; + bool bypass = false; + + bool contains(instruction_ref ins) const + { + if(is_end(ins, instructions.end())) + return false; + return instruction_set.count(std::addressof(*ins)) > 0; + } + + template + instruction_ref emplace(instruction_ref pos, Ts&&... xs) + { + // cppcheck-suppress redundantInitialization + auto r = instructions.emplace(pos, std::forward(xs)...); + instruction_set.insert(std::addressof(*r)); + return r; + } + instruction_ref insert(instruction_ref pos, const instruction& ins) + { + return emplace(pos, ins); + } + + void clear() + { + instructions.clear(); + instruction_set.clear(); + nparams = 0; + } + + void push_front(const instruction& ins) { insert(instructions.begin(), ins); } + + void push_back(const instruction& ins) { insert(instructions.end(), ins); } + + template + void emplace_front(Ts&&... xs) + { + emplace(instructions.begin(), std::forward(xs)...); + } + + template + void emplace_back(Ts&&... xs) + { + emplace(instructions.end(), std::forward(xs)...); + } + + instruction_ref erase(instruction_ref pos) + { + instruction_set.erase(std::addressof(*pos)); + return instructions.erase(pos); + } + + instruction_ref erase(instruction_ref start, instruction_ref last) + { + std::for_each(start, last, [&](auto& ins) { instruction_set.erase(std::addressof(ins)); }); + return instructions.erase(start, last); + } +}; + +const operation& get_operation(instruction_ref ins) { return ins->get_operator(); } + +module::module(const std::string& name) : impl(std::make_unique()) +{ + impl->name = name; +} + +module::module(module&&) noexcept = default; +module::~module() noexcept = default; + +// copy constructor +module::module(const module& m) { assign(m); } + +// copy assignment operator +module& module::operator=(module m) +{ + std::swap(m.impl, this->impl); + return *this; +} + +std::string module::name() const { return impl->name; } + +bool module::bypass() const { return impl->bypass; } +void module::set_bypass(bool b) { impl->bypass = b; } + +void module::assign(const module& m) +{ + // copy the impl + if(!impl) + impl = std::make_unique(); + *impl = *m.impl; + + // clear instructions + if(!impl->instructions.empty()) + { + impl->clear(); + } + + std::unordered_map ins_map; + for(auto ins : iterator_for(m)) + { + instruction_ref copy_ins{}; + if(ins->name() == "@literal") + { + auto l = ins->get_literal(); + copy_ins = impl->insert(impl->instructions.end(), instruction{l}); + } + else if(ins->name() == "@param") + { + auto&& name = any_cast(ins->get_operator()).parameter; + auto order = any_cast(ins->get_operator()).order; + auto s = ins->get_shape(); + copy_ins = impl->insert(impl->instructions.end(), + {builtin::param{name, order}, std::move(s), {}}); + } + else if(ins->name() == "@outline") + { + auto s = ins->get_shape(); + copy_ins = impl->insert(impl->instructions.end(), {builtin::outline{s}, s, {}}); + } + else + { + // if there are sub_module inputs, need to make a copy of the submodule + auto module_args = ins->module_inputs(); + // retrieve its mapped input + auto inputs = ins->inputs(); + std::vector copy_inputs(inputs.size()); + std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) { + return contains(ins_map, i) ? ins_map[i] : i; + }); + if(ins->name() == "@return") + { + copy_ins = add_return(copy_inputs); + } + else + { + copy_ins = add_instruction(ins->get_operator(), copy_inputs, module_args); + } + } + + ins_map[ins] = copy_ins; + } +} + +instruction_ref module::add_instruction(const operation& op, std::vector args) +{ + return insert_instruction(impl->instructions.end(), op, std::move(args)); +} +instruction_ref module::insert_instruction(instruction_ref ins, + const operation& op, + std::vector args) +{ + assert(has_instruction(ins) or is_end(ins, this->end())); + assert(not starts_with(op.name(), "@")); + shape r = compute_shape(op, args); + auto result = impl->insert(ins, {op, r, std::move(args)}); + instruction::backreference(result); + assert(result->valid(begin())); + return result; +} + +instruction_ref module::add_instruction(const operation& op, + std::vector args, + std::vector module_args) +{ + return insert_instruction( + impl->instructions.end(), op, std::move(args), std::move(module_args)); +} + +instruction_ref module::insert_instruction(instruction_ref ins, + const operation& op, + std::vector args, + std::vector module_args) +{ + assert(has_instruction(ins) or is_end(ins, this->end())); + assert(not starts_with(op.name(), "@")); + auto out_shape = compute_shape(op, args, module_args); + auto result = impl->insert(ins, {op, out_shape, std::move(args), std::move(module_args)}); + instruction::backreference(result); + assert(result->valid(begin())); + return result; +} + +instruction_ref module::replace_instruction(instruction_ref ins, + const operation& op, + std::vector args) MIGRAPHX_TIDY_CONST +{ + assert(has_instruction(ins)); + assert(not starts_with(op.name(), "@")); + + shape r = compute_shape(op, args); + instruction::replace(ins, op, r, std::move(args)); + assert(ins->valid(begin())); + return ins; +} + +instruction_ref module::replace_instruction(instruction_ref ins, + const operation& op, + std::vector args, + std::vector module_args) MIGRAPHX_TIDY_CONST +{ + assert(has_instruction(ins)); + assert(not starts_with(op.name(), "@")); + auto out_shape = compute_shape(op, args, module_args); + instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); + assert(ins->valid(begin())); + return ins; +} + +instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep) +{ + assert(has_instruction(ins)); + assert(has_instruction(rep)); + assert(ins != rep); + + if(ins == std::prev(this->end())) + { + return replace_instruction(ins, make_op("identity"), rep); + } + + // TODO: Should it be an error if the output is empty? + if(ins->outputs().empty()) + { + return rep; + } + // Make a copy of outputs which can be changed when calling replace_argument + auto outputs = ins->outputs(); + for(auto out : outputs) + { + // TODO: Check for possible cycles + if(out != rep) + { + instruction::replace_argument(out, ins, rep); + } + assert(out->valid(begin())); + } + // Replacement should not be dead code unless its the last instruction + assert(!rep->outputs().empty() or rep == std::prev(end())); + // Output of the original instruction should only be the replacement or empty + assert(ins->outputs().empty() or std::all_of(ins->outputs().begin(), + ins->outputs().end(), + [&](auto i) { return i == rep; })); + assert(ins->valid(begin())); + assert(rep->valid(begin())); + return rep; +} + +instruction_ref module::remove_instruction(instruction_ref ins) +{ + assert(has_instruction(ins)); + assert(ins->outputs().empty()); + ins->clear_arguments(); + return impl->erase(ins); +} + +instruction_ref module::remove_instructions(instruction_ref first, instruction_ref last) +{ + if(first == last) + return first; + // TODO: Check every element + assert(has_instruction(first)); + std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); }); + assert(std::all_of(first, last, [&](const instruction& ins) { return ins.outputs().empty(); })); + return impl->erase(first, last); +} + +instruction_ref module::move_instruction(instruction_ref src, instruction_ref dst) +{ + assert(has_instruction(src)); + assert(has_instruction(dst) or is_end(dst, this->end())); + impl->instructions.splice(dst, impl->instructions, src); + return src; +} + +instruction_ref module::move_instructions(instruction_ref src, instruction_ref dst) +{ + this->move_instruction(src, dst); + for(auto ins : src->inputs()) + this->move_instruction(ins, src); + return src; +} + +std::vector module::insert_module_instructions( + instruction_ref ins, module_ref m, std::unordered_map map_ins) +{ + std::vector mod_outputs; + for(auto sins : iterator_for(*m)) + { + if(contains(map_ins, sins)) + continue; + instruction_ref copy_ins; + if(sins->name() == "@literal") + { + auto l = sins->get_literal(); + copy_ins = this->add_literal(l); + } + else if(sins->name() == "@param") + { + auto&& name = any_cast(sins->get_operator()).parameter; + auto s = sins->get_shape(); + copy_ins = this->add_parameter(name, s); + } + else if(sins->name() == "@outline") + { + auto s = sins->get_shape(); + copy_ins = this->add_outline(s); + } + else + { + auto mod_args = sins->module_inputs(); + auto inputs = sins->inputs(); + std::vector copy_inputs(inputs.size()); + std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) { + return contains(map_ins, i) ? map_ins[i] : i; + }); + + if(sins->name() == "@return") + { + mod_outputs = copy_inputs; + break; + } + + copy_ins = this->insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args); + } + map_ins[sins] = copy_ins; + } + if(mod_outputs.empty()) + mod_outputs = {map_ins.at(std::prev(m->end()))}; + return mod_outputs; +} + +instruction_ref module::add_literal(literal l) +{ + impl->emplace_front(std::move(l)); + return impl->instructions.begin(); +} + +instruction_ref module::add_outline(const shape& s) +{ + impl->push_front({builtin::outline{s}, s, {}}); + return impl->instructions.begin(); +} + +instruction_ref module::add_parameter(std::string name, shape s) +{ + assert(get_parameter_shape(name) == shape{}); + impl->push_front({builtin::param{std::move(name), impl->nparams}, std::move(s), {}}); + impl->nparams++; + return impl->instructions.begin(); +} + +instruction_ref module::add_return(std::vector args) +{ + impl->push_back({builtin::returns{}, {}, std::move(args)}); + auto result = std::prev(impl->instructions.end()); + instruction::backreference(result); + assert(result->valid(begin())); + + return result; +} + +instruction_ref module::replace_return(std::vector args) +{ + auto last = std::prev(this->end()); + // If there is no return then add a return + if(last->name() != "@return") + return this->add_return(args); + + shape r = compute_shape(last->get_operator(), args); + instruction::replace(last, last->get_operator(), r, std::move(args)); + assert(last->valid(begin())); + + return last; +} + +shape module::get_parameter_shape(std::string name) const +{ + auto ins = std::find_if( + impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) { + if(x.name() == "@param") + { + return any_cast(x.get_operator()).parameter == name; + } + else + { + return false; + } + }); + if(ins != this->end()) + return ins->get_shape(); + else + return {}; +} + +std::vector module::get_parameter_names() const +{ + std::vector result; + std::vector params; + for(auto&& ins : impl->instructions) + { + if(ins.name() == "@param") + { + auto&& param = any_cast(ins.get_operator()); + params.push_back(param); + } + } + std::stable_sort( + params.begin(), params.end(), by(std::less<>{}, [](auto&& p) { return p.order; })); + std::transform(params.begin(), params.end(), std::back_inserter(result), [&](auto&& p) { + return p.parameter; + }); + return result; +} + +instruction_ref module::get_parameter(std::string name) const +{ + auto ins = std::find_if( + impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) { + if(x.name() == "@param") + { + return any_cast(x.get_operator()).parameter == name; + } + else + { + return false; + } + }); + if(ins != this->end()) + return ins; + else + return this->end(); +} + +std::unordered_map module::get_parameter_shapes() const +{ + std::unordered_map result; + for(auto&& ins : impl->instructions) + { + if(ins.name() == "@param") + { + auto&& name = any_cast(ins.get_operator()).parameter; + result[name] = ins.get_shape(); + } + } + return result; +} + +bool module::has_instruction(instruction_ref ins) const { return impl->contains(ins); } + +std::size_t module::size() const { return impl->instructions.size(); } +instruction_ref module::begin() const { return impl->instructions.begin(); } +instruction_ref module::end() const { return impl->instructions.end(); } + +std::vector module::get_output_shapes() const +{ + if(impl->instructions.empty()) + return {}; + auto last_ins = impl->instructions.back(); + if(last_ins.name() == "@return") + { + const auto& output_ins = last_ins.inputs(); + std::vector output_shapes; + std::transform(output_ins.begin(), + output_ins.end(), + std::back_inserter(output_shapes), + [](auto& ins) { return ins->get_shape(); }); + + return output_shapes; + } + // The else branch is to provide backward compatibility + else + { + return {last_ins.get_shape()}; + } +} + +instruction_ref module::validate() const +{ + return std::find_if( + impl->instructions.begin(), impl->instructions.end(), [&](const instruction& i) { + auto inputs = i.inputs(); + bool check_order = std::all_of( + inputs.begin(), inputs.end(), [&](auto in) { return has_instruction(in); }); + return !i.valid(impl->instructions.begin(), check_order); + }); +} + +bool is_borrowed(instruction_ref ins) +{ + auto alias = instruction::get_output_alias(ins, true); + if(alias == ins) + return false; + lifetime l = alias->get_operator().get_lifetime(); + if(l == lifetime::borrow) + return true; + return is_borrowed(alias); +} + +bool is_global(instruction_ref ins) +{ + const auto& op = instruction::get_output_alias(ins)->get_operator(); + return op.name() == "@param" or op.get_lifetime() == lifetime::global; +} + +bool is_dangling(instruction_ref ins) { return not is_global(ins) and is_borrowed(ins); } + +instruction_ref module::find_dangling_reference() const +{ + auto last = std::prev(end()); + if(last->name() == "@return") + { + auto dangling = std::find_if( + last->inputs().begin(), last->inputs().end(), [](auto x) { return is_dangling(x); }); + if(dangling != last->inputs().end()) + return *dangling; + } + else if(is_dangling(last)) + { + return last; + } + return end(); +} + +void module::finalize(context& ctx) +{ + const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{}); + for(auto ins : iterator_for(*this)) + { + if(trace) + { + std::cout << "Finalize: "; + this->debug_print(ins); + } + ins->finalize(ctx); + for(const auto& smod : ins->module_inputs()) + { + smod->finalize(ctx); + } + } + + // Warn when an instruction is not normalized + auto ins = std::find_if(begin(), end(), [](auto& i) { return i.need_normalization(); }); + if(ins != end()) + std::cerr << "WARNING: Instruction needs normalization, performance may be affected." + << std::endl; +} + +void module::debug_print() const { std::cout << *this << std::endl; } + +void module::debug_print(instruction_ref ins, + std::unordered_map& names) const +{ + if(is_end(ins, this->end())) + { + std::cout << "End instruction" << std::endl; + return; + } + if(not has_instruction(ins)) + { + std::cout << "Instruction not part of module" << std::endl; + return; + } + std::stringstream ss; + names = this->print( + [&](auto x, auto ins_names) { + if(x == ins) + { + instruction::print(std::cout, x, ins_names); + std::cout << std::endl; + } + }, + names); +} + +void module::debug_print(instruction_ref ins) const +{ + std::unordered_map names; + this->debug_print(ins, names); +} + +void module::debug_print(const std::vector& inss) const +{ + for(auto ins : inss) + this->debug_print(ins); + std::cout << std::endl; +} + +std::unordered_map module::print( + const std::function&)>& print_func, + std::unordered_map names) const +{ + int count = 0; + for(auto ins : iterator_for(*this)) + { + std::string var_name; + if(ins->name() == "@param") + { + var_name = any_cast(ins->get_operator()).parameter; + } + else + { + var_name = this->name(); + var_name.append((this->name().empty() ? "@" : ":@")); + var_name.append(std::to_string(count)); + } + // count every instruction so index matches loc in the printout program + count++; + names.emplace(ins, var_name); + + print_func(ins, names); + } + return names; +} + +void module::print(const std::function< + void(instruction_ref, const std::unordered_map&)>& + print_func) const +{ + this->print(print_func, {}); +} + +static std::string enclose_name(const std::string& name) +{ + return '"' + replace_string(name, "\"", "\\\"") + '"'; +} + +void module::print_graph(std::ostream& os, bool brief) const +{ + os << "digraph {" << std::endl; + os << "\trankdir=LR;" << std::endl; + this->print([&](auto ins, auto ins_names) { + std::string label; + if(brief) + label = ins->name(); + else + label = to_string(ins->get_operator()); + os << "\t" << enclose_name(ins_names.at(ins)) << "[label=" << enclose_name(label) << "]"; + os << ";" << std::endl; + if(!ins->inputs().empty()) + { + for(auto&& arg : ins->inputs()) + { + os << "\t" << enclose_name(ins_names.at(arg)) << " -> " + << enclose_name(ins_names.at(ins)); + if(not brief) + os << "[label=" << enclose_name(to_string(ins->get_shape())) << "]"; + os << ";" << std::endl; + } + } + }); + os << "}" << std::endl; +} + +static std::string cpp_var_name(const std::string& name) +{ + return "m" + replace_string(name, "@", "x"); +} + +static std::string cpp_op_var(const std::string& name, instruction_ref ins) +{ + return replace_string(name, "@", ins->name()); +} + +static void print_op_attributes(std::ostream& os, const std::string& name, const operation& op) +{ + std::string x = to_string(op); + if(contains(x, "[")) + { + auto start = x.find('['); + auto end = x.find(']'); + std::string attribute_text = x.substr(start + 1, end - start - 1); + std::vector attributes; + for(auto&& attribute : split_string(attribute_text, ',')) + { + if(contains(attribute, '=')) + attributes.push_back(attribute); + else + attributes.back() += "," + attribute; + } + for(auto&& attribute : attributes) + { + auto p = split_string(attribute, '='); + auto key = p.front(); + auto value = p.back(); + if(contains({"bn_mode", "padding_mode"}, key)) + continue; + if(key == "mode") + value = enclose_name(trim(value)); + os << name << "." << key << " = " << value << ";" << std::endl; + } + } +} + +static void print_cpp_shape(std::ostream& os, const migraphx::shape& s) +{ + os << "migraphx::shape{migraphx::shape::" << s.type_string(); + os << ", {" << to_string_range(s.lens()) << "}"; + if(not s.standard()) + os << ", {" << to_string_range(s.strides()) << "}"; + os << "}"; +} + +std::unordered_map +module::print_cpp(std::ostream& os, std::unordered_map names) const +{ + os << "migraphx::module p;" << std::endl; + unsigned long seed = 0; + names = this->print( + [&](auto ins, auto ins_names) { + auto op = cpp_op_var(ins_names.at(ins), ins); + if(ins->name().front() != '@') + { + os << "migraphx::op::" << ins->name() << " " << op << ";" << std::endl; + print_op_attributes(os, op, ins->get_operator()); + } + os << "auto " << cpp_var_name(ins_names.at(ins)) << " = "; + if(ins->name() == "@literal") + { + os << "p.add_literal("; + bool use_abs = false; + ins->get_literal().visit([&](auto v) { + use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; }); + }); + if(use_abs) + os << "migraphx::abs("; + os << "migraphx::generate_literal("; + print_cpp_shape(os, ins->get_shape()); + os << ", " << seed << ")"; + if(use_abs) + os << ")"; + os << ");" << std::endl; + seed++; + } + else if(ins->name() == "@param") + { + std::string name = any_cast(ins->get_operator()).parameter; + os << "p.add_parameter(" << enclose_name(name) << ","; + print_cpp_shape(os, ins->get_shape()); + os << ");" << std::endl; + } + else + { + os << "p.add_instruction(" << op; + for(auto input : ins->inputs()) + { + os << ", " << cpp_var_name(ins_names.at(input)); + } + os << ");" << std::endl; + } + }, + names); + + return names; +} + +void module::print_cpp(std::ostream& os) const { this->print_cpp(os, {}); } + +void module::annotate(std::ostream& os, std::function a) const +{ + this->print([&](auto ins, auto ins_names) { + instruction::print(os, ins, ins_names); + a(ins); + os << std::endl; + }); +} + +std::vector module::get_sub_modules() const +{ + std::vector vec_modules; + for(auto ins : iterator_for(*this)) + { + const auto& mod_args = ins->module_inputs(); + vec_modules.insert(vec_modules.end(), mod_args.begin(), mod_args.end()); + for(const auto& smod : mod_args) + { + auto sub_mods = smod->get_sub_modules(); + vec_modules.insert(vec_modules.end(), sub_mods.begin(), sub_mods.end()); + } + } + + return vec_modules; +} + +module& module::sort() +{ + fix([&](auto self, auto ins) { + this->move_instruction(ins, this->begin()); + for(auto child : ins->inputs()) + { + if(!contains(this->impl->instructions, child)) + { + continue; + } + self(child); + } + })(std::prev(this->end())); + assert(this->validate() == this->end()); + return *this; +} + +void module::calc_implicit_deps(const module& smod, + const module& pmod, + instruction_ref ins, + ins_dep_map& deps) const +{ + const auto& ins_inputs = ins->inputs(); + for(auto ii : iterator_for(smod)) + { + const auto& ii_inputs = ii->inputs(); + for(auto iii : ii_inputs) + { + if(pmod.has_instruction(iii)) + { + if(not contains(ins_inputs, iii)) + deps[ins].insert(iii); + } + } + + const auto& mod_args = ii->module_inputs(); + if(not mod_args.empty()) + { + for(const auto* ssmod : mod_args) + { + calc_implicit_deps(*ssmod, pmod, ins, deps); + } + } + } +} + +ins_dep_map module::calc_implicit_deps() const +{ + ins_dep_map mod_implicit_deps; + for(auto ins : iterator_for(*this)) + { + const auto& mod_args = ins->module_inputs(); + if(mod_args.empty()) + { + continue; + } + + for(const auto* mod : mod_args) + { + calc_implicit_deps(*mod, *this, ins, mod_implicit_deps); + } + } + + return mod_implicit_deps; +} + +bool operator==(const module& x, const module& y) { return to_string(x) == to_string(y); } + +std::ostream& operator<<(std::ostream& os, const module& m) +{ + m.print([&](auto ins, auto ins_names) { + instruction::print(os, ins, ins_names); + os << std::endl; + }); + + return os; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/msgpack.cpp b/src/msgpack.cpp new file mode 100644 index 0000000000000000000000000000000000000000..79fb819bf9ca9a5aa2487c84951d4c744334df54 --- /dev/null +++ b/src/msgpack.cpp @@ -0,0 +1,169 @@ +#include +#include +#include + +namespace msgpack { +MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) +{ + namespace adaptor { + + template <> + struct convert + { + const msgpack::object& operator()(const msgpack::object& o, migraphx::value& v) const + { + switch(o.type) + { + case msgpack::type::NIL: { + v = nullptr; + break; + } + case msgpack::type::BOOLEAN: { + v = o.as(); + break; + } + case msgpack::type::POSITIVE_INTEGER: { + v = o.as(); + break; + } + case msgpack::type::NEGATIVE_INTEGER: { + v = o.as(); + break; + } + case msgpack::type::FLOAT32: + case msgpack::type::FLOAT64: { + v = o.as(); + break; + } + case msgpack::type::STR: { + v = o.as(); + break; + } + case msgpack::type::BIN: { + v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size}; + break; + } + case msgpack::type::ARRAY: { + migraphx::value r = migraphx::value::array{}; + std::for_each( + o.via.array.ptr, + o.via.array.ptr + o.via.array.size, + [&](const msgpack::object& so) { r.push_back(so.as()); }); + v = r; + break; + } + case msgpack::type::MAP: { + migraphx::value r = migraphx::value::object{}; + std::for_each(o.via.map.ptr, + o.via.map.ptr + o.via.map.size, + [&](const msgpack::object_kv& p) { + r[p.key.as()] = p.val.as(); + }); + v = r; + break; + } + case msgpack::type::EXT: { + MIGRAPHX_THROW("msgpack EXT type not supported."); + } + } + return o; + } + }; + + template <> + struct pack + { + template + packer& operator()(msgpack::packer& o, + const migraphx::value::binary& x) const + { + const auto* data = reinterpret_cast(x.data()); + auto size = x.size(); + o.pack_bin(size); + o.pack_bin_body(data, size); + return o; + } + }; + + template <> + struct pack + { + template + void write(msgpack::packer& o, const std::nullptr_t&) const + { + o.pack_nil(); + } + template + void write(msgpack::packer& o, const T& x) const + { + o.pack(x); + } + template + void write(msgpack::packer& o, const std::vector& v) const + { + if(v.empty()) + { + o.pack_array(0); + return; + } + if(not v.front().get_key().empty()) + { + o.pack_map(v.size()); + for(auto&& x : v) + { + o.pack(x.get_key()); + o.pack(x.without_key()); + } + } + else + { + o.pack_array(v.size()); + for(auto&& x : v) + { + o.pack(x); + } + } + } + template + packer& operator()(msgpack::packer& o, const migraphx::value& v) const + { + v.visit_value([&](auto&& x) { this->write(o, x); }); + return o; + } + }; + + } // namespace adaptor +} // MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) +} // namespace msgpack + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct vector_stream +{ + std::vector buffer{}; + vector_stream& write(const char* b, std::size_t n) + { + buffer.insert(buffer.end(), b, b + n); + return *this; + } +}; + +std::vector to_msgpack(const value& v) +{ + vector_stream vs; + msgpack::pack(vs, v); + return vs.buffer; +} +value from_msgpack(const char* buffer, std::size_t size) +{ + msgpack::object_handle oh = msgpack::unpack(buffer, size); + return oh.get().as(); +} +value from_msgpack(const std::vector& buffer) +{ + return from_msgpack(buffer.data(), buffer.size()); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/normalize_attributes.cpp b/src/normalize_attributes.cpp new file mode 100644 index 0000000000000000000000000000000000000000..98658424bf4b6ae40e9ac6c43b7246bac5dc92fb --- /dev/null +++ b/src/normalize_attributes.cpp @@ -0,0 +1,201 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +// different attributes +// 1) use_input(default)/use_output +// 2) use_rank(default)/use_len +// 3) clip_min(default)/not_clip_min +// 3.1) include_min(default)/exclude_min +// 4) clip_max(default)/not_clip_max +// 4.1) exclude_max(default)/include_max +auto tune_attribute(const std::vector& vec, + const std::vector& axes, + const value& val, + const std::vector& lens) +{ + std::vector result(vec); + int64_t n_rank = lens.size(); + std::vector vec_attrs = val.to_vector(); + if(contains(vec_attrs, op::normalize_attribute::use_output)) + { + n_rank = n_rank + vec.size(); + } + + std::vector max_vals(vec.size(), n_rank); + if(contains(vec_attrs, op::normalize_attribute::use_len)) + { + std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) { return lens[i]; }); + } + + if(contains(vec_attrs, op::normalize_attribute::clip_max)) + { + if(contains(vec_attrs, op::normalize_attribute::include_max)) + { + std::transform(result.begin(), + result.end(), + max_vals.begin(), + result.begin(), + [](auto v, auto mv) { return v > mv ? mv : v; }); + } + else + { + std::transform(result.begin(), + result.end(), + max_vals.begin(), + result.begin(), + [](auto v, auto mv) { return v >= mv ? mv - 1 : v; }); + } + } + else + { + if(contains(vec_attrs, op::normalize_attribute::include_max)) + { + if(!std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{})) + { + MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); + } + } + else + { + if(!std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{})) + { + MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); + } + } + } + + std::vector min_vals = max_vals; + std::transform(min_vals.begin(), min_vals.end(), min_vals.begin(), [](auto v) { return -v; }); + if(contains(vec_attrs, op::normalize_attribute::clip_min)) + { + if(contains(vec_attrs, op::normalize_attribute::include_min)) + { + std::transform(result.begin(), + result.end(), + min_vals.begin(), + result.begin(), + [](auto v, auto mv) { return v < mv ? mv : v; }); + } + else + { + std::transform(result.begin(), + result.end(), + min_vals.begin(), + result.begin(), + [](auto v, auto mv) { return v < mv + 1 ? mv + 1 : v; }); + } + } + else + { + if(contains(vec_attrs, op::normalize_attribute::include_min)) + { + if(!std::equal(min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{})) + { + MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); + } + } + else + { + if(!std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{})) + { + MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); + } + } + } + + std::transform( + result.begin(), result.end(), max_vals.begin(), result.begin(), [](auto v, auto mv) { + return v < 0 ? v + mv : v; + }); + + return result; +} + +auto tune_pad_attribute(const value& val) +{ + + std::vector vec_attrs = val.to_vector(); + std::vector result(vec_attrs.begin(), vec_attrs.end()); + std::copy(vec_attrs.begin(), vec_attrs.end(), std::back_inserter(result)); + + return result; +} + +bool normalize_attributes(operation& op, const std::vector& lens) +{ + bool tuned = false; + auto attrs = op.attributes(); + auto val = op.to_value(); + if(attrs.contains("normalize_padding")) + { + auto padding = val.at(attrs.at("normalize_padding").to()); + auto padding_size = padding.size(); + // for now, assume the dimensions to pad start at dim 2 + auto padding_start = 2; + + if(padding_size == 2 * (lens.size() - padding_start)) + tuned = true; + else if(padding_size != (lens.size() - padding_start)) + MIGRAPHX_THROW("inconsistent padding size"); + else + { + auto result = tune_pad_attribute(padding); + val["padding"] = result; + op.from_value(val); + tuned = true; + } + } + if(!attrs.contains("normalize_axes")) + { + return tuned; + } + + auto attr_v = attrs.at("normalize_axes").without_key(); + for(const auto& rv : attr_v) + { + const auto& key = rv.get_key(); + if(val.contains(key)) + { + auto vv = val.at(key).without_key(); + if(vv.is_array()) + { + std::vector axes; + if(val.contains("axes")) + { + axes = val.at("axes").without_key().to_vector(); + } + auto vec = vv.to_vector(); + auto result = tune_attribute(vec, axes, rv.without_key(), lens); + val[key] = result; + op.from_value(val); + val = op.to_value(); + tuned = true; + } + else + { + auto num = vv.to(); + auto result = tune_attribute({num}, {num}, rv.without_key(), lens); + val[key] = result.front(); + op.from_value(val); + val = op.to_value(); + tuned = true; + } + } + else + { + MIGRAPHX_THROW("NORMALIZE_ATTR : op " + op.name() + " attribute \"" + key + + "\" not exist!"); + } + } + + return tuned; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/normalize_ops.cpp b/src/normalize_ops.cpp new file mode 100644 index 0000000000000000000000000000000000000000..047992db9b9e6dc240bd3e42bbfb567c31a23840 --- /dev/null +++ b/src/normalize_ops.cpp @@ -0,0 +1,34 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +void normalize_ops::apply(module& m) const +{ + for(auto ins : iterator_for(m)) + { + auto inputs = ins->inputs(); + if(inputs.empty()) + continue; + + auto lens = inputs[0]->get_shape().lens(); + migraphx::operation tuned_op = ins->get_operator(); + if(normalize_attributes(tuned_op, lens)) + { + m.replace_instruction(ins, tuned_op, inputs); + ins->set_normalized(); + } + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/CMakeLists.txt b/src/onnx/CMakeLists.txt old mode 100644 new mode 100755 index 03890e657cf2debbe8f338db0dc008ca4cd8cfc2..e73f0eeac9b7c61cc5b7cb18b6b0f7239644eee3 --- a/src/onnx/CMakeLists.txt +++ b/src/onnx/CMakeLists.txt @@ -7,23 +7,15 @@ target_compile_options(onnx-proto PRIVATE -w) target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY}) set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On) -add_library(migraphx_onnx onnx.cpp) +file(GLOB ONNX_SRCS ${CONFIGURE_DEPENDS} *.cpp) +add_library(migraphx_onnx ${ONNX_SRCS}) +target_include_directories(migraphx_onnx PRIVATE include) set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx) rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION}) rocm_clang_tidy_check(migraphx_onnx) -target_link_libraries(migraphx_onnx PRIVATE onnx-proto) +target_link_libraries(migraphx_onnx PRIVATE onnx-proto "-Wl,--exclude-libs,ALL") target_link_libraries(migraphx_onnx PUBLIC migraphx) rocm_install_targets( TARGETS migraphx_onnx ) - -if(MIGRAPHX_ENABLE_GPU) -add_executable(mnist mnist.cpp) -rocm_clang_tidy_check(mnist) -target_link_libraries(mnist migraphx_cpu migraphx_gpu migraphx_onnx) - -add_executable(cifar10 cifar10.cpp) -rocm_clang_tidy_check(cifar10) -target_link_libraries(cifar10 migraphx_cpu migraphx_gpu migraphx_onnx) -endif() \ No newline at end of file diff --git a/src/onnx/checks.cpp b/src/onnx/checks.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b56d4eb4b9cb0508bc3f862d1f78fc9c24f59e2c --- /dev/null +++ b/src/onnx/checks.cpp @@ -0,0 +1,27 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +void check_arg_empty(const argument& arg, const std::string& msg) +{ + if(arg.empty()) + { + MIGRAPHX_THROW(msg); + } +} + +void check_attr_sizes(size_t kdims, size_t attr_size, const std::string& error_msg) +{ + if(kdims != attr_size) + { + MIGRAPHX_THROW(error_msg + " k-dims: " + std::to_string(kdims) + + " attribute size: " + std::to_string(attr_size)); + } +} + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/cifar10.cpp b/src/onnx/cifar10.cpp deleted file mode 100644 index eee88f4f45a4d510732156b9041fa0f7d3627c0c..0000000000000000000000000000000000000000 --- a/src/onnx/cifar10.cpp +++ /dev/null @@ -1,107 +0,0 @@ -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include - -#include "softmax.hpp" - -auto read_cifar10_images(const std::string& full_path) -{ - std::ifstream file(full_path, std::ios::binary); - - const size_t nimages = 10; - const size_t nbytes_per_image = 3072; - std::vector raw_data(nimages * (nbytes_per_image + 1)); - std::vector labels(nimages); - std::vector data(nimages * nbytes_per_image); - if(file.is_open()) - { - file.read(reinterpret_cast(raw_data.data()), - (nbytes_per_image + 1) * nimages * sizeof(uint8_t)); - uint8_t* pimage = raw_data.data(); - for(size_t i = 0; i < nimages; i++, pimage += nbytes_per_image) - { - labels[i] = *pimage++; - for(size_t j = 0; j < nbytes_per_image; j++) - { - float v = float(*(pimage + j)) / 255.0f; - data[i * nbytes_per_image + j] = v; - } - } - return std::make_pair(labels, data); - } - else - { - throw std::runtime_error("Cannot open file `" + full_path + "`!"); - } -} - -int main(int argc, char const* argv[]) -{ - if(argc < 4) - { - throw std::runtime_error("Usage: cifar10 [gpu | cpu] "); - } - std::string gpu_cpu = argv[1]; - std::string file = argv[2]; - std::string datafile = argv[3]; - auto prog = migraphx::parse_onnx(file); - std::cout << prog << std::endl; - auto imageset = read_cifar10_images(datafile); - - if(gpu_cpu == "gpu") - { - // GPU target - prog.compile(migraphx::gpu::target{}); - migraphx::program::parameter_map m; - auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}}; - for(auto&& x : prog.get_parameter_shapes()) - { - m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second)); - } - auto labels = imageset.first; - auto input = imageset.second; - auto ptr = input.data(); - for(int i = 0; i < 10; i++) - { - std::cout << "label: " << static_cast(labels[i]) << " ----> "; - m["0"] = migraphx::gpu::to_gpu(migraphx::argument{s, &ptr[3072 * i]}); - auto result = migraphx::gpu::from_gpu(prog.eval(m)); - std::vector logits; - result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); - std::vector probs = softmax(logits); - for(auto x : probs) - std::cout << x << " "; - std::cout << std::endl << std::endl; - } - } - else - { - // CPU target - prog.compile(migraphx::cpu::target{}); - auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}}; - auto labels = imageset.first; - auto input = imageset.second; - auto ptr = input.data(); - for(int i = 0; i < 10; i++) - { - std::cout << "label: " << static_cast(labels[i]) << " ----> "; - auto input3 = migraphx::argument{s, &ptr[3072 * i]}; - auto result = prog.eval({{"0", input3}}); - std::vector logits; - result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); - std::vector probs = softmax(logits); - for(auto x : probs) - std::cout << x << " "; - std::cout << std::endl; - } - } -} diff --git a/src/onnx/conv.cpp b/src/onnx/conv.cpp new file mode 100755 index 0000000000000000000000000000000000000000..d4afaaca5d191f6f49d4e1bb7769b1f261a727e1 --- /dev/null +++ b/src/onnx/conv.cpp @@ -0,0 +1,29 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +void recalc_conv_attributes(value& v, size_t kdims) +{ + if(not(v["padding"].size() == kdims or v["padding"].size() == kdims * 2)) + { + v["padding"].resize(kdims); + std::fill_n(v["padding"].begin(), kdims, 0); + } + if(v["stride"].size() != kdims) + { + v["stride"].resize(kdims); + std::fill_n(v["stride"].begin(), kdims, 1); + } + if(v["dilation"].size() != kdims) + { + v["dilation"].resize(kdims); + std::fill_n(v["dilation"].begin(), kdims, 1); + } +} + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/include/migraphx/onnx/checks.hpp b/src/onnx/include/migraphx/onnx/checks.hpp new file mode 100755 index 0000000000000000000000000000000000000000..598a449c83030ad6a6692fe4a508e2a98147cbc7 --- /dev/null +++ b/src/onnx/include/migraphx/onnx/checks.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_CHECKS_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_CHECKS_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +void check_arg_empty(const argument& arg, const std::string& msg); +void check_attr_sizes(size_t kdims, size_t attr_size, const std::string& error_msg); + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/onnx/include/migraphx/onnx/conv.hpp b/src/onnx/include/migraphx/onnx/conv.hpp new file mode 100755 index 0000000000000000000000000000000000000000..ea16fd367e0849ac88dd6a8b9b2c8f2fa6550f5c --- /dev/null +++ b/src/onnx/include/migraphx/onnx/conv.hpp @@ -0,0 +1,17 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_CONV_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_CONV_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +void recalc_conv_attributes(value& v, size_t kdims); + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/onnx/include/migraphx/onnx/map_activation_functions.hpp b/src/onnx/include/migraphx/onnx/map_activation_functions.hpp new file mode 100755 index 0000000000000000000000000000000000000000..5e913b0c27c227226e4508985d5192705131dc0c --- /dev/null +++ b/src/onnx/include/migraphx/onnx/map_activation_functions.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_MAP_ACTIVATION_FUNCTIONS_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_MAP_ACTIVATION_FUNCTIONS_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +const std::unordered_map& map_activation_functions(); + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/onnx/include/migraphx/onnx/onnx_parser.hpp b/src/onnx/include/migraphx/onnx/onnx_parser.hpp new file mode 100644 index 0000000000000000000000000000000000000000..648ccc7b48f3a56118d76eb801b4b07d451119cd --- /dev/null +++ b/src/onnx/include/migraphx/onnx/onnx_parser.hpp @@ -0,0 +1,103 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_PARSER_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_PARSER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +namespace onnx = onnx_for_migraphx; + +struct onnx_parser +{ + std::string filename; + std::string path = "."; + using attribute_map = std::unordered_map; + struct node_info + { + attribute_map attributes{}; + std::size_t num_outputs = 1; + std::string name = ""; + module* mod = nullptr; + instruction_ref make_contiguous(instruction_ref ins) const; + instruction_ref add_bias(const std::vector& args, + instruction_ref curr_ins, + uint64_t axis) const; + + instruction_ref add_broadcastable_binary_op(const std::string& op_name, + instruction_ref arg0, + instruction_ref arg1) const; + + instruction_ref add_common_op(const std::string& op_name, + std::vector inputs) const; + + template + instruction_ref add_common_op(const std::string& op_name, Ts... xs) const + { + return add_common_op(op_name, {xs...}); + } + + instruction_ref add_instruction(const operation& op, + const std::vector& args) const; + + instruction_ref add_instruction(const operation& op, + const std::vector& args, + const std::vector& mods) const; + + template + instruction_ref add_instruction(const operation& op, Ts... xs) const + { + return add_instruction(op, {xs...}); + } + instruction_ref add_literal(literal l) const; + template + instruction_ref add_literal(Ts&&... xs) const + { + return add_literal(literal{std::forward(xs)...}); + } + }; + using node_map = std::unordered_map; + using op_func = std::function( + onnx_parser&, const node_info&, std::vector)>; + node_map nodes; + std::unordered_map instructions; + program prog = program(); + std::size_t default_dim_value = 1; + std::unordered_map> map_input_dims; + bool skip_unknown_operators = false; + int64_t max_loop_iterations = 10; + int64_t opset_version = 13; + + std::unordered_map ops; + + onnx_parser(); + operation load(const std::string& name, const node_info& info) const; + + void parse_undefined(module* mod, const std::string& name); + + static int64_t get_opset_version(const onnx::ModelProto& model); + + void parse_from(std::istream& is, std::string name = ""); + void parse_from(const void* data, std::size_t size); + void parse_graph(module* mod, const onnx::GraphProto& graph); + literal parse_value(const onnx::AttributeProto& attr) const; + literal parse_tensor(const onnx::TensorProto& t) const; + shape parse_type(const onnx::TypeProto& t, const std::vector& input_dims) const; +}; + +shape::type_t get_type(int dtype); + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/onnx/include/migraphx/onnx/op_parser.hpp b/src/onnx/include/migraphx/onnx/op_parser.hpp new file mode 100755 index 0000000000000000000000000000000000000000..c0c7cd4695a962ba9ad144a97904812ad2717a04 --- /dev/null +++ b/src/onnx/include/migraphx/onnx/op_parser.hpp @@ -0,0 +1,57 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_REGISTER_OP_PARSER_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_REGISTER_OP_PARSER_HPP + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct op_desc +{ + std::string onnx_name = ""; + std::string op_name = ""; +}; + +void register_op_parser(const std::string& name, onnx_parser::op_func f); +onnx_parser::op_func get_op_parser(const std::string& name); +std::vector get_op_parsers(); + +inline std::vector implicit_multi_op(std::vector inss) +{ + return inss; +} + +inline std::vector implicit_multi_op(instruction_ref ins) { return {ins}; } + +template +void register_op_parser() +{ + T parser; + for(auto&& opd : parser.operators()) + register_op_parser(opd.onnx_name, [opd, parser](auto&&... xs) { + return implicit_multi_op(parser.parse(opd, xs...)); + }); +} + +struct register_op_parser_action +{ + template + static void apply() + { + register_op_parser(); + } +}; + +template +using op_parser = auto_register; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/onnx/include/migraphx/onnx/padding.hpp b/src/onnx/include/migraphx/onnx/padding.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f22ea058512a5fa9f695965e26bc47dcc057bfc0 --- /dev/null +++ b/src/onnx/include/migraphx/onnx/padding.hpp @@ -0,0 +1,38 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_PADDING_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_PADDING_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +bool is_asym_padding(const std::vector& padding); + +void cal_auto_padding_size(onnx_parser::node_info info, + value& v, + const std::vector& k_lens, + const std::vector& dilation, + const std::vector& in_lens, + std::vector& paddings); + +void check_padding_mode(const onnx_parser::node_info& info, const std::string& op_name); + +void tune_padding_size(const value& v, + std::vector& padding, + int count_include_pad, + std::vector& s_start); + +void check_asym_padding(const onnx_parser::node_info& info, + instruction_ref& ins, + const std::vector& padding, + value& v, + int count_include_pad = 0, + float pad_val = 0); + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/onnx/map_activation_functions.cpp b/src/onnx/map_activation_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a296c55a52918afb34a3426c45d24f492c701cc --- /dev/null +++ b/src/onnx/map_activation_functions.cpp @@ -0,0 +1,21 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +const std::unordered_map& map_activation_functions() +{ + static const std::unordered_map m = { + {"tanh", make_op("tanh")}, + {"relu", make_op("relu")}, + {"sigmoid", make_op("sigmoid")}, + {"leakyrelu", make_op("leaky_relu")}, + {"elu", make_op("elu")}}; + return m; +} + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/mnist.cpp b/src/onnx/mnist.cpp deleted file mode 100644 index 1f4b8b086bebbc6d2e4c6d497da609efcb3797df..0000000000000000000000000000000000000000 --- a/src/onnx/mnist.cpp +++ /dev/null @@ -1,144 +0,0 @@ -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -#include "softmax.hpp" - -auto reverse_int(unsigned int i) -{ - unsigned char c1; - unsigned char c2; - unsigned char c3; - unsigned char c4; - c1 = i & 255u; - c2 = (i >> 8u) & 255u; - c3 = (i >> 16u) & 255u; - c4 = (i >> 24u) & 255u; - return (static_cast(c1) << 24u) + (static_cast(c2) << 16u) + - (static_cast(c3) << 8u) + c4; -}; - -std::vector -read_mnist_images(const std::string& full_path, int& number_of_images, int& image_size) -{ - using uchar = unsigned char; - - std::ifstream file(full_path, std::ios::binary); - - if(file.is_open()) - { - int magic_number = 0; - int n_rows = 0; - int n_cols = 0; - - file.read(reinterpret_cast(&magic_number), sizeof(magic_number)); - magic_number = reverse_int(magic_number); - - if(magic_number != 2051) - throw std::runtime_error("Invalid MNIST image file!"); - - file.read(reinterpret_cast(&number_of_images), sizeof(number_of_images)); - number_of_images = reverse_int(number_of_images); - file.read(reinterpret_cast(&n_rows), sizeof(n_rows)); - n_rows = reverse_int(n_rows); - file.read(reinterpret_cast(&n_cols), sizeof(n_cols)); - n_cols = reverse_int(n_cols); - - image_size = n_rows * n_cols; - - std::vector result(number_of_images * image_size); - for(int i = 0; i < number_of_images; i++) - { - for(int j = 0; j < image_size; j++) - { - uchar tmp; - file.read(reinterpret_cast(&tmp), 1); - result[i * image_size + j] = tmp / 255.0; - } - } - return result; - } - else - { - throw std::runtime_error("Cannot open file `" + full_path + "`!"); - } -} - -std::vector read_mnist_labels(const std::string& full_path, int& number_of_labels) -{ - using uchar = unsigned char; - - std::ifstream file(full_path, std::ios::binary); - - if(file.is_open()) - { - int magic_number = 0; - file.read(reinterpret_cast(&magic_number), sizeof(magic_number)); - magic_number = reverse_int(magic_number); - - if(magic_number != 2049) - throw std::runtime_error("Invalid MNIST label file!"); - - file.read(reinterpret_cast(&number_of_labels), sizeof(number_of_labels)); - number_of_labels = reverse_int(number_of_labels); - - std::vector result(number_of_labels); - for(int i = 0; i < number_of_labels; i++) - { - uchar tmp; - file.read(reinterpret_cast(&tmp), 1); - result[i] = tmp; - } - return result; - } - else - { - throw std::runtime_error("Unable to open file `" + full_path + "`!"); - } -} - -int main(int argc, char const* argv[]) -{ - if(argc > 3) - { - std::string datafile = argv[2]; - std::string labelfile = argv[3]; - int nimages = -1; - int image_size = -1; - int nlabels = -1; - std::vector input = read_mnist_images(datafile, nimages, image_size); - std::vector labels = read_mnist_labels(labelfile, nlabels); - - std::string file = argv[1]; - auto prog = migraphx::parse_onnx(file); - std::cout << prog << std::endl << std::endl; - prog.compile(migraphx::gpu::target{}); - auto s = migraphx::shape{migraphx::shape::float_type, {1, 1, 28, 28}}; - std::cout << s << std::endl; - auto ptr = input.data(); - migraphx::program::parameter_map m; - m["output"] = - migraphx::gpu::to_gpu(migraphx::generate_argument(prog.get_parameter_shape("output"))); - for(int i = 0; i < 20; i++) - { - std::cout << "label: " << labels[i] << " ----> "; - m["0"] = migraphx::gpu::to_gpu(migraphx::argument{s, &ptr[784 * i]}); - auto result = migraphx::gpu::from_gpu(prog.eval(m)); - std::vector logits; - result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); - std::vector probs = softmax(logits); - for(auto x : probs) - std::cout << x << " "; - std::cout << std::endl; - } - std::cout << std::endl; - } -} diff --git a/src/onnx/onnx.cpp b/src/onnx/onnx.cpp index 837e6366d43e281cf38da35aa45b384b8a862678..b029d3373c9da6982357d4210bec47195cabef29 100644 --- a/src/onnx/onnx.cpp +++ b/src/onnx/onnx.cpp @@ -1,6 +1,5 @@ -#include -#include -#include +#include +#include #include #include #include @@ -9,1745 +8,58 @@ #include #include -#include #include -#include -#include -#include -#include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct onnx_parser +template +program parse_onnx_from(const onnx_options& options, Ts&&... xs) { - using attribute_map = std::unordered_map; - using node_map = std::unordered_map; - using op_func = - std::function(attribute_map, std::vector)>; - node_map nodes; - std::unordered_map instructions; - program prog = program(); - bool is_pytorch = false; + onnx::onnx_parser parser; + parser.map_input_dims = options.map_input_dims; + parser.default_dim_value = options.default_dim_value; + parser.skip_unknown_operators = options.skip_unknown_operators; + parser.max_loop_iterations = options.max_loop_iterations; - std::unordered_map ops; - std::unordered_map map_actv_funcs; - - onnx_parser() - { - add_generic_op("Relu", op::relu{}); - add_generic_op("Sigmoid", op::sigmoid{}); - add_generic_op("Abs", op::abs{}); - add_generic_op("Exp", op::exp{}); - add_generic_op("Erf", op::erf{}); - add_generic_op("Log", op::log{}); - // disable dropout for inference - add_generic_op("Dropout", op::identity{}); - add_generic_op("Identity", op::identity{}); - add_generic_op("Sin", op::sin{}); - add_generic_op("Cos", op::cos{}); - add_generic_op("Tan", op::tan{}); - add_generic_op("Sinh", op::sinh{}); - add_generic_op("Cosh", op::cosh{}); - add_generic_op("Tanh", op::tanh{}); - add_generic_op("Asin", op::asin{}); - add_generic_op("Acos", op::acos{}); - add_generic_op("Atan", op::atan{}); - add_generic_op("Sqrt", op::sqrt{}); - add_generic_op("Round", op::round{}); - add_generic_op("Sign", op::sign{}); - add_generic_op("Ceil", op::ceil{}); - add_generic_op("Floor", op::floor{}); - - add_binary_op("Add", op::add{}); - add_binary_op("Div", op::div{}); - add_binary_op("Mul", op::mul{}); - add_binary_op("Sub", op::sub{}); - add_binary_op("Pow", op::pow{}); - - add_variadic_op("Sum", op::add{}); - add_variadic_op("Max", op::max{}); - add_variadic_op("Min", op::min{}); - - add_mem_op("ArgMax", &onnx_parser::parse_arg_op); - add_mem_op("ArgMin", &onnx_parser::parse_arg_op); - add_mem_op("Cast", &onnx_parser::parse_cast); - add_mem_op("Clip", &onnx_parser::parse_clip); - add_mem_op("LRN", &onnx_parser::parse_lrn); - add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); - add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); - add_mem_op("Elu", &onnx_parser::parse_elu); - add_mem_op("Expand", &onnx_parser::parse_expand); - add_mem_op("Constant", &onnx_parser::parse_constant); - add_mem_op("Conv", &onnx_parser::parse_conv); - add_mem_op("MaxPool", &onnx_parser::parse_pooling); - add_mem_op("AveragePool", &onnx_parser::parse_pooling); - add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling); - add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling); - add_mem_op("Reshape", &onnx_parser::parse_reshape); - add_mem_op("Flatten", &onnx_parser::parse_flatten); - add_mem_op("Gemm", &onnx_parser::parse_gemm); - add_mem_op("MatMul", &onnx_parser::parse_matmul); - add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm); - add_mem_op("Softmax", &onnx_parser::parse_softmax); - add_mem_op("LogSoftmax", &onnx_parser::parse_softmax); - add_mem_op("Squeeze", &onnx_parser::parse_squeeze); - add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze); - add_mem_op("Slice", &onnx_parser::parse_slice); - add_mem_op("Concat", &onnx_parser::parse_concat); - add_mem_op("Gather", &onnx_parser::parse_gather); - add_mem_op("Shape", &onnx_parser::parse_shape); - add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill); - add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape); - add_mem_op("Transpose", &onnx_parser::parse_transpose); - add_mem_op("RNN", &onnx_parser::parse_rnn); - add_mem_op("GRU", &onnx_parser::parse_gru); - add_mem_op("LSTM", &onnx_parser::parse_lstm); - add_mem_op("Pad", &onnx_parser::parse_pad); - add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper); - add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper); - add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper); - add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper); - - // init the activation function map - init_actv_func(); - } - - void init_actv_func() - { - // Support name format of all lower case or the first letter capital - map_actv_funcs.insert(std::make_pair("tanh", op::tanh{})); - map_actv_funcs.insert(std::make_pair("relu", op::relu{})); - map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{})); - map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{})); - map_actv_funcs.insert(std::make_pair("elu", op::elu{})); - } - - template - void add_op(std::string name, F f) - { - ops.emplace(name, [=](auto&&... xs) { - return std::vector{f(std::forward(xs)...)}; - }); - } - - // Multi output op - template - void add_multi_op(std::string name, F f) - { - ops.emplace(name, f); - } - - template - void add_mem_op(std::string name, F f) - { - add_op(name, [=](auto&&... xs) { - return std::mem_fn(f)(*this, name, std::forward(xs)...); - }); - } - - template - void add_binary_op(std::string name, T x) - { - add_op(name, [this, x](attribute_map attributes, std::vector args) { - if(args.size() != 2) - MIGRAPHX_THROW("binary operators should have 2 operands"); - if(contains(attributes, "broadcast") and contains(attributes, "axis")) - { - uint64_t broadcasted = parse_value(attributes.at("broadcast")).at(); - if(broadcasted != 0) - { - uint64_t axis = parse_value(attributes.at("axis")).at(); - auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, - args[1]); - return prog.add_instruction(x, args[0], l); - } - return prog.add_instruction(x, args); - } - else - { - return add_broadcastable_binary_op(args[0], args[1], x); - } - }); - } - - std::vector compute_broadcasted_lens(std::vector s0, - std::vector s1) - { - // Example: - // s0 = (3,2,4,5) and s1 = (2,1,1) - // - // In this case we need to broadcast (:,1,1) portion of - // s1 plus broadcast the 1st dimension of s1 - // giving output_lens = (3,2,4,5) - // - // Another example: - // s0 = (3,2,1,5) and s1 = (2,7,5) - // In this case we need to broadcast the (:,:,1:,:) axis - // of s0 plus the 1st dimension of s1 giving - // output_lens = (3,2,7,5) - if(s0.size() > s1.size()) - { - s0.swap(s1); - } - - std::vector out_lens(s1); - auto offset = s1.size() - s0.size(); - std::transform(s0.begin(), - s0.end(), - s1.begin() + offset, - out_lens.begin() + offset, - [&](auto a, auto b) { - if(a != b and a != 1 and b != 1) - { - MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + - to_string_range(s0) + "} and {" + - to_string_range(s1) + "} mismatch!"); - } - return std::max(a, b); - }); - - return out_lens; - } - - instruction_ref make_contiguous(instruction_ref ins) - { - if(ins->get_shape().standard()) - { - return ins; - } - - return prog.add_instruction(op::contiguous{}, ins); - } - - template - instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x) - { - if(arg0->get_shape().lens() != arg1->get_shape().lens()) - { - // Get lengths for both arguments - auto s0 = arg0->get_shape().lens(); - auto s1 = arg1->get_shape().lens(); - auto out_lens = compute_broadcasted_lens(s0, s1); - auto l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0); - auto l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1); - return prog.add_instruction(x, l0, l1); - } - else - { - return prog.add_instruction(x, {arg0, arg1}); - } - } - - template - void add_generic_op(std::string name, T x) - { - add_op(name, [this, x](const attribute_map&, std::vector args) { - return prog.add_instruction(x, args); - }); - } - - template - void add_variadic_op(std::string name, T x) - { - add_op(name, [this, x](const attribute_map&, std::vector args) { - return std::accumulate(std::next(args.begin()), - args.end(), - args.front(), - [this, x](instruction_ref a, instruction_ref b) { - return add_broadcastable_binary_op(a, b, x); - }); - }); - } - - instruction_ref parse_clip(const std::string&, - const attribute_map& attributes, - std::vector args) - { - op::clip op; - if(contains(attributes, "max")) - { - op.max_val = parse_value(attributes.at("max")).at(); - } - if(contains(attributes, "min")) - { - op.min_val = parse_value(attributes.at("min")).at(); - } - return prog.add_instruction(op, std::move(args)); - } - - template - instruction_ref parse_softmax(const std::string&, - const attribute_map& attributes, - std::vector args) - { - int axis = 1; - if(contains(attributes, "axis")) - { - axis = parse_value(attributes.at("axis")).at(); - } - - return prog.add_instruction(Op{axis}, std::move(args)); - } - - template - instruction_ref parse_arg_op(const std::string&, - const attribute_map& attributes, - std::vector args) - { - int64_t axis = 0; - if(contains(attributes, "axis")) - { - axis = static_cast(parse_value(attributes.at("axis")).at()); - } - - int keep_dims = 1; - if(contains(attributes, "keepdims")) - { - keep_dims = parse_value(attributes.at("keepdims")).at(); - } - - if(keep_dims == 0) - { - auto ins = prog.add_instruction(Op{axis}, std::move(args)); - return prog.add_instruction(op::squeeze{{axis}}, ins); - } - else - { - return prog.add_instruction(Op{axis}, std::move(args)); - } - } - - instruction_ref - parse_conv(const std::string&, attribute_map attributes, std::vector args) - { - op::convolution op; - auto l0 = args[0]; - if(contains(attributes, "pads")) - { - if(contains(attributes, "auto_pad")) - { - auto s = attributes["auto_pad"].s(); - if(contains(attributes, "pads") and to_upper(s) != "NOTSET") - { - MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously"); - } - } - std::vector padding; - copy(attributes["pads"].ints(), std::back_inserter(padding)); - if(padding.size() != 4) - { - MIGRAPHX_THROW("padding should have 4 values"); - } - if(padding[0] != padding[2] || padding[1] != padding[3]) - { - // insert zeros for pad op (args[0] has 4 dims) - padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}; - l0 = prog.add_instruction(op::pad{padding}, l0); - } - else - { - op.padding[0] = padding[0]; - op.padding[1] = padding[1]; - } - } - if(contains(attributes, "strides")) - { - copy(attributes["strides"].ints(), op.stride.begin()); - } - if(contains(attributes, "dilations")) - { - copy(attributes["dilations"].ints(), op.dilation.begin()); - } - if(contains(attributes, "auto_pad")) - { - auto s = attributes["auto_pad"].s(); - if(contains(attributes, "pads") and to_upper(s) != "NOTSET") - { - MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously"); - } - - if(s.find("SAME") != std::string::npos) - { - op.padding_mode = op::padding_mode_t::same; - } - } - if(contains(attributes, "group")) - { - op.group = parse_value(attributes.at("group")).at(); - } - if(args.size() == 3) - { - uint64_t axis = 1; - auto l1 = prog.add_instruction(op, l0, args[1]); - auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape().lens()}, args[2]); - return prog.add_instruction(op::add{}, l1, l2); - } - return prog.add_instruction(op, l0, args[1]); - } - - instruction_ref parse_pooling(const std::string& name, - attribute_map attributes, - std::vector args) - { - op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"}; - auto l0 = args[0]; - if(starts_with(name, "Global")) - { - auto lens = args.front()->get_shape().lens(); - op.lengths = {lens[2], lens[3]}; - } - if(contains(attributes, "pads")) - { - std::vector padding; - copy(attributes["pads"].ints(), std::back_inserter(padding)); - if(padding.size() != 4) - { - MIGRAPHX_THROW("padding should have 4 values"); - } - if(padding[0] != padding[2] || padding[1] != padding[3]) - { - // insert zeros for pad op (args[0] has 4 dims) - padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}; - l0 = prog.add_instruction(op::pad{padding, std::numeric_limits::lowest()}, - l0); - } - else - { - op.padding[0] = padding[0]; - op.padding[1] = padding[1]; - } - } - if(contains(attributes, "strides")) - { - copy(attributes["strides"].ints(), op.stride.begin()); - } - if(contains(attributes, "kernel_shape")) - { - copy(attributes["kernel_shape"].ints(), op.lengths.begin()); - } - if(contains(attributes, "auto_pad")) - { - auto s = attributes["auto_pad"].s(); - if(s.find("SAME_UPPER") == std::string::npos) - { - MIGRAPHX_THROW("auto_pad only supports SAME_UPPER for pooling"); - } - op.padding_mode = op::padding_mode_t::same; - } - - return prog.add_instruction(op, l0); - } - - instruction_ref - parse_reshape(const std::string&, attribute_map attributes, std::vector args) - { - op::reshape op; - if(args.size() == 1) - { - literal s = parse_value(attributes.at("shape")); - s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); - } - if(args.size() == 2) - { - auto s = args[1]->eval(); - check_arg_empty(s, "Reshape: dynamic shape is not supported"); - s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); - } - - return prog.add_instruction(op, make_contiguous(args[0])); - } - - instruction_ref - parse_flatten(const std::string&, attribute_map attributes, std::vector args) - { - uint64_t axis = 1; - if(contains(attributes, "axis")) - { - axis = parse_value(attributes.at("axis")).at(); - } - return prog.add_instruction(op::flatten{axis}, args[0]); - } - - instruction_ref - parse_squeeze(const std::string&, attribute_map attributes, std::vector args) - { - op::squeeze op; - literal s = parse_value(attributes.at("axes")); - s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); }); - return prog.add_instruction(op, args[0]); - } - - instruction_ref - parse_unsqueeze(const std::string&, attribute_map attributes, std::vector args) - { - op::unsqueeze op; - literal s = parse_value(attributes.at("axes")); - s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); }); - return prog.add_instruction(op, args[0]); - } - - instruction_ref - parse_concat(const std::string&, attribute_map attributes, std::vector args) - { - // change to hande axis to be negative values - if(!contains(attributes, "axis")) - { - MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!"); - } - - int axis = parse_value(attributes.at("axis")).at(); - op::concat op{axis}; - return prog.add_instruction(op, std::move(args)); - } - - instruction_ref - parse_gather(const std::string&, attribute_map attributes, std::vector args) - { - int axis = 0; - if(contains(attributes, "axis")) - { - axis = parse_value(attributes.at("axis")).at(); - } - - op::gather op{axis}; - return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1])); - } - - instruction_ref - parse_slice(const std::string&, attribute_map attributes, std::vector args) - { - op::slice op; - std::vector dims = args[0]->get_shape().lens(); - size_t num_dims = dims.size(); - if(contains(attributes, "axes")) - { - literal s = parse_value(attributes.at("axes")); - s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); }); - } - else - { - op.axes = std::vector(num_dims); - std::iota(op.axes.begin(), op.axes.end(), 0); - } - - if(contains(attributes, "ends")) - { - op.ends = get_indices(attributes.at("ends")); - } - if(contains(attributes, "starts")) - { - literal s = parse_value(attributes.at("starts")); - s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); }); - } - return prog.add_instruction(op, args[0]); - } - - instruction_ref parse_constant(const std::string&, - attribute_map attributes, - const std::vector&) - { - literal v = parse_value(attributes.at("value")); - // return empty literal - if(v.get_shape().elements() == 0) - { - return prog.add_literal(literal{}); - } - - auto dim_size = attributes.at("value").t().dims_size(); - // if dim_size is 0, it is a scalar - if(dim_size == 0) - { - migraphx::shape scalar_shape{v.get_shape().type()}; - return prog.add_literal(migraphx::literal{scalar_shape, v.data()}); - } - - return prog.add_literal(v); - } - - instruction_ref - parse_gemm(const std::string&, attribute_map attributes, std::vector args) - { - float alpha = 1.0f; - float beta = 1.0f; - bool transa = false; - bool transb = false; - if(contains(attributes, "alpha")) - { - alpha = parse_value(attributes.at("alpha")).at(); - } - if(contains(attributes, "beta")) - { - beta = parse_value(attributes.at("beta")).at(); - } - if(contains(attributes, "transA")) - { - transa = parse_value(attributes.at("transA")).at(); - } - if(contains(attributes, "transB")) - { - transb = parse_value(attributes.at("transB")).at(); - } - - std::vector perm(args[0]->get_shape().lens().size()); - std::iota(perm.begin(), perm.end(), int64_t{0}); - // swap the last two elements - std::swap(*perm.rbegin(), *(perm.rbegin() + 1)); - - auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0]; - auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1]; - if(args.size() == 3) - { - if(beta != 0.f && args[2]->get_shape().elements() > 0) - { - auto out_lens = l1->get_shape().lens(); - out_lens.back() = l2->get_shape().lens().back(); - auto l3 = args[2]; - auto l3_lens = l3->get_shape().lens(); - if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end())) - { - l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]); - } - return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3); - } - } - - return prog.add_instruction(op::dot{alpha, beta}, l1, l2); - } - - instruction_ref - parse_matmul(const std::string&, const attribute_map&, std::vector args) - { - auto l0 = args[0]; - auto l1 = args[1]; - auto l0_lens = l0->get_shape().lens(); - auto l1_lens = l1->get_shape().lens(); - - // args[0] is a vector, prepend 1 to the shape - bool is_a_prepended = false; - if(l0_lens.size() == 1) - { - is_a_prepended = true; - l0_lens.insert(l0_lens.begin(), 1); - l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]); - } - - bool is_b_appended = false; - if(l1_lens.size() == 1) - { - is_b_appended = true; - l1_lens.push_back(1); - l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]); - } - - instruction_ref bl0 = l0; - instruction_ref bl1 = l1; - if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend())) - { - auto l0_it = l0_lens.begin() + l0_lens.size() - 2; - std::vector l0_broadcasted_lens(l0_lens.begin(), l0_it); - auto l1_it = l1_lens.begin() + l1_lens.size() - 2; - std::vector l1_broadcasted_lens(l1_lens.begin(), l1_it); - auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens); - l0_broadcasted_lens = output_lens; - l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end()); - l1_broadcasted_lens = output_lens; - l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end()); - if(l0_lens != l0_broadcasted_lens) - { - bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0); - } - if(l1_lens != l1_broadcasted_lens) - { - bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1); - } - } - - auto dot_res = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1); - int64_t num_axis = static_cast(dot_res->get_shape().lens().size()); - if(is_a_prepended) - { - dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res); - --num_axis; - } - if(is_b_appended) - { - dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res); - } - - return dot_res; - } - - instruction_ref - parse_batchnorm(const std::string&, attribute_map attributes, std::vector args) - { - float epsilon = 1e-5f; - float momentum = 0.9f; - op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial; - if(contains(attributes, "epsilon")) - { - epsilon = parse_value(attributes.at("epsilon")).at(); - } - if(contains(attributes, "momentum")) - { - momentum = parse_value(attributes.at("momentum")).at(); - } - if(contains(attributes, "spatial")) - { - bn_mode = (parse_value(attributes.at("spatial")).at() > 0) - ? op::batch_norm_inference::spatial - : op::batch_norm_inference::per_activation; - } - op::batch_norm_inference op{epsilon, momentum, bn_mode}; - return prog.add_instruction(op, std::move(args)); - } - - instruction_ref parse_leaky_relu(const std::string&, - attribute_map attributes, - std::vector args) - { - float alpha = 0.01; // default alpha val for leaky relu - if(contains(attributes, "alpha")) - { - alpha = parse_value(attributes.at("alpha")).at(); - } - op::leaky_relu op{alpha}; - return prog.add_instruction(op, args.front()); - } - - instruction_ref - parse_elu(const std::string&, attribute_map attributes, std::vector args) - { - float alpha = 1.0; // default alpha val for elu - if(contains(attributes, "alpha")) - { - alpha = parse_value(attributes.at("alpha")).at(); - } - op::elu op{alpha}; - return prog.add_instruction(op, args.front()); - } - - instruction_ref - parse_lrn(const std::string&, attribute_map attributes, std::vector args) - { - float alpha = 0.0001; - float beta = 0.75; - float bias = 1.0; - int size = 1; - if(contains(attributes, "alpha")) - alpha = parse_value(attributes.at("alpha")).at(); - if(contains(attributes, "beta")) - beta = parse_value(attributes.at("beta")).at(); - if(contains(attributes, "bias")) - bias = parse_value(attributes.at("bias")).at(); - if(contains(attributes, "size")) - size = parse_value(attributes.at("size")).at(); - op::lrn op{alpha, beta, bias, size}; - return prog.add_instruction(op, args.front()); - } - - instruction_ref parse_imagescaler(const std::string&, - attribute_map attributes, - std::vector args) - { - float scale = 1.0; - std::vector bias{}; - if(contains(attributes, "scale")) - { - scale = parse_value(attributes.at("scale")).at(); - } - - if(contains(attributes, "bias")) - { - auto&& bias_floats = attributes["bias"].floats(); - bias = std::vector(bias_floats.begin(), bias_floats.end()); - } - auto input_lens = args.front()->get_shape().lens(); - - auto scale_val = prog.add_literal(scale); - auto bias_vals = prog.add_literal( - migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias}); - - auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val); - auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor); - auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals); - return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); - } - - instruction_ref - parse_transpose(const std::string&, attribute_map attributes, std::vector args) - { - std::vector perm{}; - if(contains(attributes, "perm")) - { - auto&& perm_vals = attributes["perm"].ints(); - perm = std::vector(perm_vals.begin(), perm_vals.end()); - } - return prog.add_instruction(migraphx::op::transpose{perm}, args.front()); - } - - instruction_ref - parse_pad(const std::string&, attribute_map attributes, std::vector args) - { - std::vector pads{}; - float value = 0.0f; - if(contains(attributes, "pads")) - { - auto&& pad_vals = attributes["pads"].ints(); - pads = std::vector(pad_vals.begin(), pad_vals.end()); - } - // check if padding is actually being done (at least one value is nonzero) - if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; })) - { - return prog.add_instruction(migraphx::op::identity{}, args.front()); - } - if(contains(attributes, "value")) - { - value = parse_value(attributes.at("value")).at(); - } - if(contains(attributes, "mode")) - { - auto mode = attributes.at("mode").s(); - if(mode != "constant") - MIGRAPHX_THROW("migraphx currently only supports constant padding"); - } - return prog.add_instruction(migraphx::op::pad{pads, value}, args.front()); - } - // Use a literal instruction to replace the shape since, output of - // shape operator are literals in migraphx - instruction_ref - parse_shape(const std::string&, const attribute_map&, std::vector args) - { - if(args.size() != 1) - MIGRAPHX_THROW("Shape: operator should have 1 operand"); - std::vector arg_shape = args[0]->get_shape().lens(); - std::vector vec_shape(arg_shape.size()); - migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()}); - std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { - return int64_t(i); - }); - return prog.add_literal(migraphx::literal{s, vec_shape}); - } - - // Use a literal instruction to replace the constantFill operator. In RNN, input shape - // and value are fixed, so no need to do the actual computation for the constantFill - // operator - instruction_ref parse_constant_fill(const std::string&, - attribute_map attributes, - std::vector args) - { - int input_as_shape = 0; - int dtype = 1; - float value = 0.0f; - - if(contains(attributes, "dtype")) - { - dtype = parse_value(attributes.at("dtype")).at(); - } - shape::type_t type = get_type(dtype); - - if(contains(attributes, "input_as_shape")) - { - input_as_shape = parse_value(attributes.at("input_as_shape")).at(); - } - - if(contains(attributes, "value")) - { - value = parse_value(attributes.at("value")).at(); - } - - if(contains(attributes, "extra_shape")) - { - MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute"); - } - - if(input_as_shape == 1) - { - if(args.size() != 1) - { - MIGRAPHX_THROW("ConstantFill: need an input argument as output shape"); - } - - if(contains(attributes, "shape")) - { - MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input " - "at the same time"); - } - - migraphx::argument in = args[0]->eval(); - check_arg_empty(in, "ConstantFill: dynamic shape is not supported"); - - std::vector dims; - in.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); - migraphx::shape s(type, dims); - std::vector values(s.elements(), value); - return prog.add_literal(migraphx::literal(s, values)); - } - else if(input_as_shape == 0) - { - if(!contains(attributes, "shape")) - { - MIGRAPHX_THROW("ConstantFill: attribute output shape is needed"); - } - - literal ls = parse_value(attributes.at("shape")); - std::vector dims; - ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); }); - migraphx::shape s{type, dims}; - std::vector values(s.elements(), value); - return prog.add_literal(migraphx::literal(s, values)); - } - else - { - MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape"); - } - } - - instruction_ref parse_constant_of_shape(const std::string&, - attribute_map attributes, - std::vector args) + if(options.print_program_on_error) { - literal l_val{}; - if(contains(attributes, "value")) - { - l_val = parse_value(attributes.at("value")); - if(l_val.get_shape().elements() != 1) - { - MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!"); - } - } - else + // Log the program when it can't be parsed + try { - l_val = literal({shape::float_type, {1}, {0}}, {0.0f}); + parser.parse_from(std::forward(xs)...); } - - // input is empty, output is a scalar - auto type = l_val.get_shape().type(); - - if(args.empty()) - { - MIGRAPHX_THROW("ConstantOfShape : must have 1 input!"); - } - else + catch(...) { - migraphx::shape s; - // empty input tensor, output is a scalar - if(args[0]->get_shape().elements() == 0) - { - s = migraphx::shape{type, {1}, {0}}; - } - else - { - migraphx::argument in = args[0]->eval(); - check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported"); - - std::vector dims; - in.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); - s = migraphx::shape{type, dims}; - } - - literal l_out{}; - l_val.visit([&](auto val) { - using val_type = std::remove_cv_t; - // l_val contains only one element - std::vector out_vec(s.elements(), val.front()); - l_out = literal(s, out_vec); - }); - - return prog.add_literal(l_out); + std::cerr << parser.prog << std::endl; + throw; } } - - instruction_ref - parse_expand(const std::string&, const attribute_map&, std::vector args) - { - auto in_lens = args[0]->get_shape().lens(); - migraphx::argument arg_s = args[1]->eval(); - check_arg_empty(arg_s, "Expand: dynamic shape is not supported"); - std::vector dims; - arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); - auto out_lens = compute_broadcasted_lens(in_lens, dims); - return prog.add_instruction(op::multibroadcast{out_lens}, args[0]); - } - - std::vector - parse_rnn(const std::string&, attribute_map attributes, std::vector args) + else { - migraphx::shape input_shape = args[0]->get_shape(); - std::size_t hidden_size = args[1]->get_shape().lens()[1]; - - if(contains(attributes, "hidden_size")) - { - std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at(); - if(hidden_size != hidden_size_att) - { - MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute"); - } - } - - // Handling of direction to be added later - std::string direction{"forward"}; - if(contains(attributes, "direction")) - { - direction = attributes.at("direction").s(); - } - - op::rnn_direction dirct = op::rnn_direction::forward; - if(direction == "bidirectional") - { - dirct = op::rnn_direction::bidirectional; - } - else if(direction == "reverse") - { - dirct = op::rnn_direction::reverse; - } - - std::vector vec_names{"tanh"}; - if(contains(attributes, "activations")) - { - auto names = attributes.at("activations").strings(); - vec_names.clear(); - vec_names.resize(names.size()); - std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) { - return to_lower(name); - }); - } - - auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) { - return (map_actv_funcs.count(name) == 0); - }); - if(name_it != vec_names.end()) - { - MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported"); - } - - // bidirectional case should have two activation functions. - // one is for forward, and the other is for reverse. - // if only one actv function is provided, we use it in both - // forward and reverse direction - if(dirct == op::rnn_direction::bidirectional) - { - if(vec_names.size() == 1) - { - vec_names.push_back(vec_names.at(0)); - } - } - - std::vector vec_actv_funcs(vec_names.size()); - std::transform(vec_names.begin(), - vec_names.end(), - vec_actv_funcs.begin(), - [&](const auto& fn) { return map_actv_funcs[fn]; }); - - // To be added later - float clip = 0.0; - if(contains(attributes, "clip")) - { - clip = parse_value(attributes.at("clip")).at(); - } - - // if the number of arguments is less than 6, append - // undefined operator to have 6 arguments - if(args.size() < 6) - { - auto ins = prog.add_instruction(op::undefined{}); - args.insert(args.end(), (6 - args.size()), ins); - } - - // first output for the concatenation of hidden states - auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip}, - std::move(args)); - - // second output for the last hidden state - auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states); - - return {hidden_states, last_output}; - } - - std::vector - parse_gru(const std::string&, attribute_map attributes, std::vector args) - { - migraphx::shape input_shape = args[0]->get_shape(); - std::size_t hidden_size = args[2]->get_shape().lens()[2]; - - if(contains(attributes, "hidden_size")) - { - std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at(); - if(hidden_size != hidden_size_att) - { - MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute"); - } - } - - // Handling of direction to be added later - std::string direction{"forward"}; - if(contains(attributes, "direction")) - { - direction = attributes.at("direction").s(); - } - - op::rnn_direction dirct = op::rnn_direction::forward; - if(direction == "bidirectional") - { - dirct = op::rnn_direction::bidirectional; - } - else if(direction == "reverse") - { - dirct = op::rnn_direction::reverse; - } - - std::vector vec_names = {"sigmoid", "tanh"}; - if(contains(attributes, "activations")) - { - auto names = attributes.at("activations").strings(); - vec_names.clear(); - vec_names.resize(names.size()); - std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) { - return to_lower(name); - }); - } - - // need 4 activation functions - if(dirct == op::rnn_direction::bidirectional) - { - // 4 activation functions are used in the bidirectional - // scenario. No spec is provided in onnx::operator. we - // use the algorithm that: if 1 actv function is provided, - // repeat 1 four times. If 2 actv functins are provided, - // assume forward and reverse use the same pair of actv - // functions. For the case of 3 actv functions provided, - // assume the 3rd one is repeated once and used by the - // reverse direction. - // This may need change later - if(vec_names.size() == 1) - { - vec_names.insert(vec_names.end(), 3, vec_names.at(0)); - } - else if(vec_names.size() == 2) - { - // repeat the activation functions - vec_names.push_back(vec_names.at(0)); - vec_names.push_back(vec_names.at(1)); - } - else if(vec_names.size() == 3) - { - vec_names.push_back(vec_names.at(2)); - } - } - else - { - if(vec_names.size() == 1) - { - vec_names.push_back(vec_names.at(0)); - } - } - - auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) { - return (map_actv_funcs.count(name) == 0); - }); - if(name_it != vec_names.end()) - { - MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported"); - } - - std::vector vec_actv_funcs(vec_names.size()); - std::transform(vec_names.begin(), - vec_names.end(), - vec_actv_funcs.begin(), - [&](const auto& name) { return map_actv_funcs[name]; }); - - float clip = 0.0; - if(contains(attributes, "clip")) - { - clip = parse_value(attributes.at("clip")).at(); - } - - int linear_before_reset = 0; - if(contains(attributes, "linear_before_reset")) - { - linear_before_reset = parse_value(attributes.at("linear_before_reset")).at(); - } - - // append undefined opeator to make 6 arguments - if(args.size() < 6) - { - auto ins = prog.add_instruction(op::undefined{}); - args.insert(args.end(), 6 - args.size(), ins); - } - - // first output for concatenation of hidden states - auto hidden_states = prog.add_instruction( - op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset}, - std::move(args)); - - // second output for last gru output - auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states); - - return {hidden_states, last_output}; - } - - std::vector - parse_lstm(const std::string&, attribute_map attributes, std::vector args) - { - migraphx::shape input_shape = args[0]->get_shape(); - std::size_t hidden_size = args[2]->get_shape().lens()[2]; - - if(contains(attributes, "hidden_size")) - { - std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at(); - if(hidden_size != hidden_size_att) - { - MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute"); - } - } - - // Handling of direction to be added later - std::string direction{"forward"}; - if(contains(attributes, "direction")) - { - direction = attributes.at("direction").s(); - } - - op::rnn_direction dirct = op::rnn_direction::forward; - if(direction == "bidirectional") - { - dirct = op::rnn_direction::bidirectional; - } - else if(direction == "reverse") - { - dirct = op::rnn_direction::reverse; - } - else if(direction == "forward") - { - dirct = op::rnn_direction::forward; - } - else - { - MIGRAPHX_THROW("LSTM: incorrect direction attribute"); - } - - std::vector vec_names = {"sigmoid", "tanh", "tanh"}; - if(contains(attributes, "activations")) - { - auto names = attributes.at("activations").strings(); - vec_names.clear(); - vec_names.resize(names.size()); - std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) { - return to_lower(name); - }); - } - - // need 6 activation functions for bidirectional directions - if(dirct == op::rnn_direction::bidirectional) - { - // 6 activation functions are used in the bidirectional - // scenario. No spec is provided in onnx::operator. we - // use the algorithm that: if 1 actv function is provided, - // repeat 1st six times. If 2 actv functins are provided, - // repeat 2nd once, then repeat all three once - // if 3 actv funcs are provide, repeat all three once. - // the same algorithm is used for 4, 5, and 6 actv funcions - // provided. This may need change later - switch(vec_names.size()) - { - case 1: - vec_names = {vec_names.at(0), - vec_names.at(0), - vec_names.at(0), - vec_names.at(0), - vec_names.at(0), - vec_names.at(0)}; - break; - - case 2: - // repeat the 2nd actv func once, then repeat all three another time - vec_names = {vec_names.at(0), - vec_names.at(1), - vec_names.at(1), - vec_names.at(0), - vec_names.at(1), - vec_names.at(1)}; - break; - - case 3: - // repeat all three actv funcs once - vec_names = {vec_names.at(0), - vec_names.at(1), - vec_names.at(2), - vec_names.at(0), - vec_names.at(1), - vec_names.at(2)}; - break; - - case 4: - vec_names = {vec_names.at(0), - vec_names.at(1), - vec_names.at(2), - vec_names.at(3), - vec_names.at(3), - vec_names.at(3)}; - break; - - case 5: - vec_names = {vec_names.at(0), - vec_names.at(1), - vec_names.at(2), - vec_names.at(3), - vec_names.at(4), - vec_names.at(4)}; - break; - - default: break; - } - } - else - { - switch(vec_names.size()) - { - case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break; - - case 2: - // repeat the 2nd actv func once, so we have 3 actv funcs - vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)}; - break; - - default: break; - } - } - - auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) { - return (map_actv_funcs.count(name) == 0); - }); - if(name_it != vec_names.end()) - { - MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported"); - } - - std::vector vec_actv_funcs(vec_names.size()); - std::transform(vec_names.begin(), - vec_names.end(), - vec_actv_funcs.begin(), - [&](const auto& name) { return map_actv_funcs[name]; }); - - float clip = 0.0; - if(contains(attributes, "clip")) - { - clip = parse_value(attributes.at("clip")).at(); - } - - int input_forget = 0; - if(contains(attributes, "input_forget")) - { - input_forget = parse_value(attributes.at("input_forget")).at(); - } - - // append undefined opeator to make 6 arguments - if(args.size() < 8) - { - auto ins = prog.add_instruction(op::undefined{}); - args.insert(args.end(), 8 - args.size(), ins); - } - - // first output for concatenation of hidden states - auto hidden_states = prog.add_instruction( - op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args)); - - // second output for last lstm output - auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states); - - // third output for last cell output - auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states); - - return {hidden_states, last_output, last_cell_output}; - } - - template - instruction_ref parse_reduce_oper(const std::string&, - attribute_map attributes, - std::vector args) - { - std::size_t n_dim = args.front()->get_shape().lens().size(); - - // default to reduce over all dimensions - std::vector axes(n_dim); - std::iota(axes.begin(), axes.end(), 0); - if(contains(attributes, "axes")) - { - axes.clear(); - auto&& attr_axes = attributes["axes"].ints(); - axes = std::vector(attr_axes.begin(), attr_axes.end()); - } - - int keep_dims = 1; - if(contains(attributes, "keepdims")) - { - keep_dims = parse_value(attributes.at("keepdims")).at(); - } - - if(keep_dims == 1) - { - return prog.add_instruction(T{axes}, std::move(args)); - } - else - { - auto ins = prog.add_instruction(T{axes}, std::move(args)); - return prog.add_instruction(op::squeeze{axes}, ins); - } - } - - instruction_ref - parse_cast(const std::string&, attribute_map attributes, std::vector args) - { - if(!contains(attributes, "to")) - { - MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!"); - } - - int to_type = parse_value(attributes.at("to")).at(); - shape::type_t type = get_type(to_type); - return prog.add_instruction(op::convert{type}, std::move(args)); - } - - void parse_from(std::istream& is) - { - onnx::ModelProto model; - if(model.ParseFromIstream(&is)) - { - if(model.has_graph()) - { - this->parse_graph(model.graph()); - } - } - else - { - MIGRAPHX_THROW("Failed reading onnx file."); - } - } - - void parse_graph(const onnx::GraphProto& graph) - { - nodes = get_nodes(graph); - for(auto&& f : graph.initializer()) - instructions[f.name()] = prog.add_literal(parse_tensor(f)); - - for(auto&& input : graph.input()) - { - const std::string& name = input.name(); - // input not in initializer_data, so it is a real input - if(!contains(instructions, name)) - { - // TODO: Get shape of input parameter - shape s = parse_type(input.type()); - instructions[name] = prog.add_parameter(name, s); - } - } - for(auto&& output : graph.output()) - { - this->parse_node(output.name()); - } - } - - void parse_undefined(const std::string& name) - { - auto ins = prog.add_instruction(op::undefined{}); - instructions[name] = ins; - } - - void parse_node(const std::string& name) - { - if(name.empty()) - MIGRAPHX_THROW("Onnx node must have a name"); - if(instructions.count(name) == 0) - { - auto&& node = nodes.at(name); - std::vector args; - for(auto&& input : node.input()) - { - if(nodes.count(input) > 0) - { - assert(name != input); - this->parse_node(input); - } - else if(input.empty()) - { - this->parse_undefined(input); - } - args.push_back(instructions.at(input)); - } - std::vector result; - if(ops.count(node.op_type()) == 0) - { - result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args)); - } - else - { - result = ops[node.op_type()](get_attributes(node), args); - } - // Even no output nodes produce output in migraphx - if(node.output().empty() and result.size() == 1) - { - instructions[name] = result.front(); - } - else - { - assert(node.output().size() >= result.size()); - std::transform(result.begin(), - result.end(), - node.output().begin(), - std::inserter(instructions, instructions.end()), - [](auto&& x, auto&& y) { return std::make_pair(y, x); }); - } - } - } - - static attribute_map get_attributes(const onnx::NodeProto& node) - { - std::unordered_map result; - for(auto&& attr : node.attribute()) - { - result[attr.name()] = attr; - } - return result; - } - - static node_map get_nodes(const onnx::GraphProto& graph) - { - std::unordered_map result; - std::size_t n = 0; - for(auto&& node : graph.node()) - { - if(node.output().empty()) - { - if(node.name().empty()) - { - result["migraphx_unamed_node_" + std::to_string(n)] = node; - n++; - } - else - { - result[node.name()] = node; - } - } - for(auto&& output : node.output()) - { - result[output] = node; - } - } - return result; - } - - static std::vector get_indices(const onnx::AttributeProto& attr) - { - std::vector result; - literal s = parse_value(attr); - s.visit([&](auto v) { copy(v, std::back_inserter(result)); }); - // Clamp large indices to -1 - std::replace_if( - result.begin(), - result.end(), - [](auto x) { return x > int64_t{std::numeric_limits::max()} / 2; }, - -1); - return result; - } - - template - static literal from_repeated(shape::type_t t, const T& r) - { - std::size_t size = r.size(); - return literal{{t, {size}}, r.begin(), r.end()}; - } - - static literal parse_value(const onnx::AttributeProto& attr) - { - switch(attr.type()) - { - case onnx::AttributeProto::FLOAT: return literal{attr.f()}; - case onnx::AttributeProto::INT: return literal{attr.i()}; - case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t()); - case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats()); - case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints()); - case onnx::AttributeProto::UNDEFINED: - case onnx::AttributeProto::GRAPH: - case onnx::AttributeProto::STRING: - case onnx::AttributeProto::STRINGS: - case onnx::AttributeProto::TENSORS: - case onnx::AttributeProto::GRAPHS: return {}; - } - MIGRAPHX_THROW("Invalid attribute type"); - } - - static literal parse_tensor(const onnx::TensorProto& t) - { - std::vector dims(t.dims().begin(), t.dims().end()); - if(t.has_raw_data()) - { - const std::string& s = t.raw_data(); - switch(t.data_type()) - { - case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data()); - case onnx::TensorProto::FLOAT16: - return create_literal(shape::half_type, dims, s.data()); - case onnx::TensorProto::DOUBLE: - return create_literal(shape::double_type, dims, s.data()); - case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data()); - case onnx::TensorProto::INT8: - case onnx::TensorProto::UINT16: - case onnx::TensorProto::INT16: - case onnx::TensorProto::INT32: - case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data()); - case onnx::TensorProto::UINT8: - case onnx::TensorProto::STRING: - case onnx::TensorProto::UNDEFINED: - case onnx::TensorProto::UINT32: - case onnx::TensorProto::UINT64: - case onnx::TensorProto::COMPLEX64: - case onnx::TensorProto::COMPLEX128: throw std::runtime_error(""); - } - MIGRAPHX_THROW("Invalid tensor type"); - } - switch(t.data_type()) - { - case onnx::TensorProto::INT8: - case onnx::TensorProto::UINT16: - case onnx::TensorProto::INT16: - case onnx::TensorProto::INT32: - case onnx::TensorProto::BOOL: - return create_literal(shape::int32_type, dims, t.int32_data()); - case onnx::TensorProto::INT64: - return create_literal(shape::int64_type, dims, t.int64_data()); - case onnx::TensorProto::DOUBLE: - return create_literal(shape::double_type, dims, t.double_data()); - case onnx::TensorProto::FLOAT: - return create_literal(shape::float_type, dims, t.float_data()); - case onnx::TensorProto::FLOAT16: - { - std::vector data_uint16(t.int32_data().begin(), t.int32_data().end()); - std::vector data_half; - std::transform(data_uint16.begin(), - data_uint16.end(), - std::back_inserter(data_half), - [](uint16_t raw_val) { return *reinterpret_cast(&raw_val); }); - return create_literal(shape::half_type, dims, data_half); - } - case onnx::TensorProto::UNDEFINED: - case onnx::TensorProto::UINT8: - case onnx::TensorProto::STRING: - case onnx::TensorProto::UINT32: - case onnx::TensorProto::UINT64: - case onnx::TensorProto::COMPLEX64: - case onnx::TensorProto::COMPLEX128: throw std::runtime_error(""); - } - MIGRAPHX_THROW("Invalid tensor type"); - } - - static literal - create_literal(shape::type_t shape_type, const std::vector& dims, const char* data) - { - // in case of scalar constants in onnx file, use dims=1 to fill initializer data - if(dims.empty()) - return literal{{shape_type}, data}; - return literal{{shape_type, dims}, data}; - } - - template {})> - static literal create_literal(shape::type_t shape_type, const std::vector& dims, T data) - { - if(dims.empty()) - return literal{{shape_type}, data.begin(), data.end()}; - return literal{{shape_type, dims}, data.begin(), data.end()}; - } - - static shape parse_type(const onnx::TypeProto& t) - { - shape::type_t shape_type{}; - switch(t.tensor_type().elem_type()) - { - case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break; - case onnx::TensorProto::INT8: shape_type = shape::int8_type; break; - case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break; - case onnx::TensorProto::INT16: shape_type = shape::int16_type; break; - case onnx::TensorProto::INT32: shape_type = shape::int32_type; break; - case onnx::TensorProto::INT64: shape_type = shape::int64_type; break; - case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break; - case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break; - case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break; - case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break; - case onnx::TensorProto::UINT8: - case onnx::TensorProto::STRING: - case onnx::TensorProto::BOOL: - case onnx::TensorProto::UNDEFINED: - case onnx::TensorProto::COMPLEX64: - case onnx::TensorProto::COMPLEX128: - break; // throw std::runtime_error("Unsupported type"); - } - std::vector dims; - auto&& tensor_dims = t.tensor_type().shape().dim(); - std::transform(tensor_dims.begin(), - tensor_dims.end(), - std::back_inserter(dims), - [](auto&& d) -> std::size_t { - if(not d.has_dim_value()) - { - long default_batch_size = 1; // FIXME - return default_batch_size; - } - return d.dim_value(); - }); - return {shape_type, dims}; + parser.parse_from(std::forward(xs)...); } + return std::move(parser.prog); +} - shape::type_t get_type(int dtype) - { - switch(dtype) - { - case 1: return shape::float_type; - case 2: return shape::uint8_type; - case 3: return shape::int8_type; - case 4: return shape::uint16_type; - case 5: return shape::int16_type; - case 6: return shape::int32_type; - case 7: return shape::int64_type; - case 10: return shape::half_type; - case 11: return shape::double_type; - case 12: return shape::uint32_type; - case 13: return shape::uint64_type; - default: - { - MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported"); - } - } - } +program parse_onnx(const std::string& name, const onnx_options& options) +{ + std::fstream input(name.c_str(), std::ios::in | std::ios::binary); + return parse_onnx_from(options, input, name); +} - void check_arg_empty(const argument& arg, const std::string& msg) - { - if(arg.empty()) - { - MIGRAPHX_THROW(msg); - } - } -}; +program parse_onnx_buffer(const std::string& buffer, const onnx_options& options) +{ + return parse_onnx_from(options, buffer.data(), buffer.size()); +} -program parse_onnx(const std::string& name) +program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options) { - std::fstream input(name.c_str(), std::ios::in | std::ios::binary); - onnx_parser parser; -#ifndef NDEBUG - // Log the program when it can't be parsed - try - { - parser.parse_from(input); - } - catch(...) - { - std::cerr << parser.prog << std::endl; - throw; - } -#else - parser.parse_from(input); -#endif - return std::move(parser.prog); + return parse_onnx_from(options, data, size); } +std::vector get_onnx_operators() { return onnx::get_op_parsers(); } + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/onnx/onnx.proto b/src/onnx/onnx.proto index b577ff0868395e18771ceb46d95fb3f82601e4b4..112e39191c6790af5be40c5127129e25f737cb37 100644 --- a/src/onnx/onnx.proto +++ b/src/onnx/onnx.proto @@ -3,24 +3,42 @@ // -// Copyright (c) Facebook Inc. and Microsoft Corporation. +// Copyright (c) ONNX Project Contributors. // Licensed under the MIT license. syntax = "proto2"; -package onnx; +package onnx_for_migraphx; -// Note [Release] +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// // We are still in the very early stage of defining ONNX. The current // version of ONNX is a starting point. While we are actively working // towards a complete spec, we would like to get the community involved // by sharing our working version of ONNX. - -// Note [Protobuf compatibility] -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Based on experience working with downstream vendors, we generally can't -// assume recent versions of protobufs. This means that we do not use any -// protobuf features that are only available in proto3. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. // // Here are the most notable contortions we have to carry out to work around // these limitations: @@ -29,30 +47,11 @@ package onnx; // of key-value pairs, where order does not matter and duplicates // are not allowed. -// Note [Namespaces] -// ~~~~~~~~~~~~~~~~~ -// ONNX gives explicit names to graphs, intermediate values and -// serialized tensors. To make it easier to generate names, we organize -// these into separate namespaces (so, e.g., a graph can have the same -// name as a serialized tensor.) The namespaces are as follows: -// -// - Node: These names identify specific nodes in the graph (but not, necessarily -// any particular input or output of the node. -// - Graph: These names identify graphs in the protobuf. -// - Attribute: These names identify attribute names for extra attributes that -// are passed to operators. -// - Operator: These names identify particular operators. -// - Value: These names identify intermediate values (typically tensors) flowing through -// the computation of a graph. -// - Shape: These names represent parameters for unknown shape dimensions. + +// Versioning // -// We specify the namespace of a name in ONNX as comments in the form -// of "namespace {Node,Graph,Operator,Attribute,Value,Shape}". Framework is responsible -// for supporting the namespaces. +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md // -// Naming things is hard. Every element with a name has an optional doc_string associated -// with it, providing a human-readable description in text markdown. - // To be compatible with both proto2 and proto3, we will use a version number // that is not defined by the default value but an explicit enum number. enum Version { @@ -61,26 +60,53 @@ enum Version { _START_VERSION = 0; // The version field is always serialized and we will use it to store the // version that the graph is generated from. This helps us set up version - // control. We should use version as - // xx(major) - xx(minor) - xxxx(bugfix) - // and we are starting with 0x00000001 (0.0.1), which was the - // version we published on Oct 10, 2017. - IR_VERSION_2017_10_10 = 0x00000001; + // control. + // For the IR, we are using simple numbers starting with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; - // IR_VERSION 0.0.2 published on Oct 30, 2017 + // IR_VERSION 2 published on Oct 30, 2017 // - Added type discriminator to AttributeProto to support proto3 users - IR_VERSION_2017_10_30 = 0x00000002; + IR_VERSION_2017_10_30 = 0x0000000000000002; - // IR VERSION 0.0.3 published on Nov 3, 2017 + // IR VERSION 3 published on Nov 3, 2017 // - For operator versioning: // - Added new message OperatorSetIdProto // - Added opset_import in ModelProto // - For vendor extensions, added domain in NodeProto - IR_VERSION = 0x00000003; + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION_2019_9_19 = 0x0000000000000006; + + // IR VERSION 7 published on + // - Add a list to promote inference graph's initializers to global and + // mutable variables. Global variables are visible in all graphs of the + // stored models. + // - Add message TrainingInfoProto to store initialization + // method and training algorithm. The execution of TrainingInfoProto + // can modify the values of mutable variables. + // - Make inference graph callable from TrainingInfoProto via GraphCall operator. + IR_VERSION = 0x0000000000000007; } -// A named attribute containing either singular float, integer, string -// and tensor values, or repeated float, integer, string and tensor values. +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. // An AttributeProto MUST contain the name field, and *only one* of the // following content fields, effectively enforcing a C/C++ union equivalent. message AttributeProto { @@ -94,26 +120,34 @@ message AttributeProto { STRING = 3; TENSOR = 4; GRAPH = 5; + SPARSE_TENSOR = 11; FLOATS = 6; INTS = 7; STRINGS = 8; TENSORS = 9; GRAPHS = 10; + SPARSE_TENSORS = 12; } // The name field MUST be present for this version of the IR. optional string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + optional string ref_attr_name = 21; // A human-readable documentation for this attribute. Markdown is allowed. optional string doc_string = 13; // The type field MUST be present for this version of the IR. // For 0.0.1 versions of the IR, this field was not defined, and - // implementations needed to use has_field hueristics to determine + // implementations needed to use has_field heuristics to determine // which value field was in use. For IR_VERSION 0.0.2 or later, this // field MUST be set and match the f|i|s|t|... field in use. This - // change was made to accomodate proto3 implementations. + // change was made to accommodate proto3 implementations. optional AttributeType type = 20; // discriminator that indicates which field below is in use // Exactly ONE of the following fields must be present for this version of the IR @@ -122,6 +156,7 @@ message AttributeProto { optional bytes s = 4; // UTF-8 string optional TensorProto t = 5; // tensor value optional GraphProto g = 6; // graph + optional SparseTensorProto sparse_tensor = 22; // sparse tensor value // Do not use field below, it's deprecated. // optional ValueProto v = 12; // value - subsumes everything but graph @@ -130,6 +165,7 @@ message AttributeProto { repeated bytes strings = 9; // list of UTF-8 strings repeated TensorProto tensors = 10; // list of tensors repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors } // Defines information on value, including the name, the type, and @@ -137,16 +173,20 @@ message AttributeProto { message ValueInfoProto { // This field MUST be present in this version of the IR. optional string name = 1; // namespace Value - // This field MUST be present in this version of the IR. + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. optional TypeProto type = 2; // A human-readable documentation for this value. Markdown is allowed. optional string doc_string = 3; } -// NodeProto stores a node that is similar to the notion of "layer" -// or "operator" in many deep learning frameworks. For example, it can be a -// node of type "Conv" that takes in an image, a filter tensor and a bias -// tensor, and produces the convolved output. +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. message NodeProto { repeated string input = 1; // namespace Value repeated string output = 2; // namespace Value @@ -161,18 +201,125 @@ message NodeProto { optional string domain = 7; // namespace Domain // Additional named attributes. - // NOTE: Simply using ValueProto.NameValuePairProto is the most general - // solution. I kept AttributeProto to minimize churn on CI results. repeated AttributeProto attribute = 5; // A human-readable documentation for this node. Markdown is allowed. optional string doc_string = 6; } -// ModelProto is a top-level file/container format for bundling a ML model. -// The semantics of the model are described by the GraphProto that represents -// a parameterized computation graph against a set of named operators that are -// defined independently from the graph. +// Training information +// TrainingInfoProto stores information for training a model. +// In particular, this defines two functionalities: an initialization-step +// and a training-algorithm-step. Initialization resets the model +// back to its original state as if no training has been consumed. +// Training algorithm improves the model based on input data. +// +// The semantics of the initialization-step is that the initializers +// in ModelProto.graph and in TrainingInfoProto.algorithm are first +// initialized as specified by the initializers in the graph, and then +// updated by the "initialization_binding" in every instance in +// ModelProto.training_info. +// +// The field "algorithm" defines a computation graph which represents a +// training algorithm's step. After the execution of a +// TrainingInfoProto.algorithm, the initializers specified by "update_binding" +// may be immediately updated. If the targeted training algorithm contains +// consecutive update stages (such as block coordinate descent methods), +// the user needs to create a TrainingInfoProto for each stage. +message TrainingInfoProto { + // This field describes a graph to compute the initial tensors + // upon starting the training process. Initialization graph has no input + // and can have multiple outputs. Usually, trainable tensors in neural + // networks are randomly initialized. To achieve that, for each tensor, + // the user can put a random number operator such as RandomNormal or + // RandomUniform in TrainingInfoProto.initialization.node and assign its + // random output to the specific tensor using "initialization_binding". + // This graph can also set the initializers in "algorithm" in the same + // TrainingInfoProto; a use case is resetting the number of training + // iteration to zero. + // + // By default, this field is an empty graph and its evaluation does not + // produce any output. + optional GraphProto initialization = 1; + + // This field represents a training algorithm step. Given required inputs, + // it computes outputs to update initializers in its own or inference graph's + // initializer lists. In general, this graph contains loss node, gradient node, + // optimizer node, increment of iteration count, and some calls to the inference + // graph. + // + // The field algorithm.node is the only place the user can use GraphCall + // operator. The only callable graph is the one stored in ModelProto.graph. + // + // By default, this field is an empty graph and its evaluation does not + // produce any output. + optional GraphProto algorithm = 2; + + // This field specifies the bindings from the outputs of "initialization" to + // some initializers in "ModelProto.graph.initializer" and + // the "algorithm.initializer" in the same TrainingInfoProto. + // See "update_binding" below for details. + // + // By default, this field is empty and no initializer would be changed + // by the execution of "initialization". + repeated StringStringEntryProto initialization_binding = 3; + + // Gradient-based training is usually an iterative procedure. In one gradient + // descent iteration, we apply + // + // x = x - r * g + // + // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is + // gradient of "x" with respect to a chosen loss. To avoid adding assignments + // into the training graph, we split the update equation into + // + // y = x - r * g + // x = y + // + // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To + // tell that "y" should be assigned to "x", the field "update_binding" may + // contain a key-value pair of strings, "x" (key of StringStringEntryProto) + // and "y" (value of StringStringEntryProto). + // For a neural network with multiple trainable (mutable) tensors, there can + // be multiple key-value pairs in "update_binding". + // + // The initializers appears as keys in "update_binding" are considered + // mutable and globally-visible variables. This implies some behaviors + // as described below. + // + // 1. We have only unique keys in all "update_binding"s so that two global + // variables may not have the same name. This ensures that one + // global variable is assigned up to once. + // 2. The keys must appear in names of "ModelProto.graph.initializer" or + // "TrainingInfoProto.algorithm.initializer". + // 3. The values must be output names of "algorithm". + // 4. If an optional input of a graph is omitted when using GraphCall, the + // global variable with the same name may be used. + // 5. When using GraphCall, the users always can pass values to optional + // inputs of the called graph even if the associated initializers appears + // as keys in "update_binding"s. + // 6. The graphs in TrainingInfoProto's can use global variables as + // their operator inputs. + // 7. Mutable variables are initialized to the value specified by the + // corresponding initializer, and then potentially updated by + // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s. + // + // This field usually contains names of trainable tensors + // (in ModelProto.graph), optimizer states such as momentums in advanced + // stochastic gradient methods (in TrainingInfoProto.graph), + // and number of training iterations (in TrainingInfoProto.graph). + // + // By default, this field is empty and no initializer would be changed + // by the execution of "algorithm". + repeated StringStringEntryProto update_binding = 4; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto's. message ModelProto { // The version of the IR this model targets. See Version enum above. // This field MUST be present. @@ -217,6 +364,17 @@ message ModelProto { // Named metadata values; keys should be distinct. repeated StringStringEntryProto metadata_props = 14; + + // Training-specific information. Sequentially executing all stored + // `TrainingInfoProto.algorithm`s and assigning their outputs following + // the corresponding `TrainingInfoProto.update_binding`s is one training + // iteration. Similarly, to initialize the model + // (as if training hasn't happened), the user should sequentially execute + // all stored `TrainingInfoProto.initialization`s and assigns their outputs + // using `TrainingInfoProto.initialization_binding`s. + // + // If this field is empty, the training behavior of the model is undefined. + repeated TrainingInfoProto training_info = 20; }; // StringStringEntryProto follows the pattern for cross-proto-version maps. @@ -226,25 +384,38 @@ message StringStringEntryProto { optional string value= 2; }; -// GraphProto defines a parameterized series of nodes to form a directed acyclic graph. -// This is the equivalent of the "network" and "graph" in many deep learning +message TensorAnnotation { + optional string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning // frameworks. message GraphProto { - // The nodes in the graph. + // The nodes in the graph, sorted topologically. repeated NodeProto node = 1; // The name of the graph. optional string name = 2; // namespace Graph - // A list of named tensor values (constants), used to specify default - // values for some of the inputs of the graph. + // A list of named tensor values, used to specify constant inputs of the graph. // Each TensorProto entry must have a distinct name (within the list) that - // also appears in the input list. - // In an evaluation, the default value specified here is used if and only if - // user specifies no value for the corresponding input parameter. - // May be used to pass serialized parameters for networks. + // MAY also appear in the input list. repeated TensorProto initializer = 5; + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + // A human-readable documentation for this graph. Markdown is allowed. optional string doc_string = 10; @@ -256,7 +427,13 @@ message GraphProto { // must be distinct. It is optional for a value to appear in value_info list. repeated ValueInfoProto value_info = 13; - // DO NOT USE the following fields, they were deprecated before + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + // DO NOT USE the following fields, they were deprecated from earlier versions. // repeated string input = 3; // repeated string output = 4; // optional int64 ir_version = 6; @@ -265,7 +442,9 @@ message GraphProto { // optional string domain = 9; } -// A message defined to store a tensor in its serialized format. +// Tensors +// +// A serialized tensor value. message TensorProto { enum DataType { UNDEFINED = 0; @@ -280,13 +459,21 @@ message TensorProto { STRING = 8; // string BOOL = 9; // bool - // Advanced types + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. FLOAT16 = 10; + DOUBLE = 11; UINT32 = 12; UINT64 = 13; COMPLEX64 = 14; // complex with float32 real and imaginary components COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + // Future extensions go here. } @@ -294,7 +481,8 @@ message TensorProto { repeated int64 dims = 1; // The data type of the tensor. - optional DataType data_type = 2; + // This field MUST have a valid TensorProto.DataType value + optional int32 data_type = 2; // For very large tensors, we may want to store them in chunks, in which // case the following fields will specify the segment that is stored in @@ -305,7 +493,7 @@ message TensorProto { } optional Segment segment = 3; - // Tensor content must be in the row major order. + // Tensor content must be organized in row-major order. // // Depending on the data_type field, exactly one of the fields below with // name ending in _data is used to store the elements of the tensor. @@ -313,7 +501,7 @@ message TensorProto { // For float and complex64 values // Complex64 tensors are encoded as a single array of floats, // with the real components appearing in odd numbered positions, - // and the corresponding imaginary component apparing in the + // and the corresponding imaginary component appearing in the // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] // is encoded as [1.0, 2.0 ,3.0 ,4.0] // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. @@ -323,7 +511,7 @@ message TensorProto { // float16 values must be bit-wise converted to an uint16_t prior // to writing to the buffer. // When this field is present, the data_type field MUST be - // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32 + // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 repeated int32 int32_data = 5 [packed = true]; // For strings. @@ -360,10 +548,32 @@ message TensorProto { // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED optional bytes raw_data = 9; + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + optional DataLocation data_location = 14; + // For double - // Complex64 tensors are encoded as a single array of doubles, + // Complex128 tensors are encoded as a single array of doubles, // with the real components appearing in odd numbered positions, - // and the corresponding imaginary component apparing in the + // and the corresponding imaginary component appearing in the // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] // is encoded as [1.0, 2.0 ,3.0 ,4.0] // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 @@ -375,6 +585,28 @@ message TensorProto { repeated uint64 uint64_data = 11 [packed = true]; } +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + optional TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + optional TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + // Defines a tensor shape. A dimension can be either an integer value // or a symbolic variable. A symbolic variable represents an unknown // dimension. @@ -384,28 +616,73 @@ message TensorShapeProto { int64 dim_value = 1; string dim_param = 2; // namespace Shape }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + optional string denotation = 3; }; repeated Dimension dim = 1; } -// Define the types. +// Types +// +// The standard ONNX data types. message TypeProto { message Tensor { // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value // This field MUST be present for this version of the IR. - optional TensorProto.DataType elem_type = 1; + optional int32 elem_type = 1; optional TensorShapeProto shape = 2; } + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + optional TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + optional int32 key_type = 1; + // This field MUST be present for this version of the IR. + optional TypeProto value_type = 2; + }; + oneof value { // The type of a tensor. Tensor tensor_type = 1; + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + optional string denotation = 6; } +// Operator Sets +// // OperatorSets are uniquely identified by a (domain, opset_version) pair. message OperatorSetIdProto { // The domain of the operator set being identified. @@ -418,3 +695,8 @@ message OperatorSetIdProto { // This field MUST be present in this version of the IR. optional int64 version = 2; } + + +// For using protobuf-lite +option optimize_for = LITE_RUNTIME; + diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cf6cded0d1aa3b22b6495c736aaf387080426dfb --- /dev/null +++ b/src/onnx/onnx_parser.cpp @@ -0,0 +1,469 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node) +{ + std::unordered_map result; + for(auto&& attr : node.attribute()) + { + result[attr.name()] = attr; + } + return result; +} + +static literal +create_literal(shape::type_t shape_type, const std::vector& dims, const char* data) +{ + // empty input + auto elem_num = + std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies()); + if(elem_num == 0) + { + return {}; + } + + // in case of scalar constants in onnx file, use dims=1 to fill initializer data + if(dims.empty()) + return literal{{shape_type}, data}; + return literal{{shape_type, dims}, data}; +} + +template {})> +static literal create_literal(shape::type_t shape_type, const std::vector& dims, T data) +{ + // empty input + auto elem_num = + std::accumulate(dims.begin(), dims.end(), std::size_t(1), std::multiplies()); + if(elem_num == 0) + { + return {}; + } + + // scalar input + if(dims.empty()) + return literal{{shape_type}, data.begin(), data.end()}; + return literal{{shape_type, dims}, data.begin(), data.end()}; +} + +template +static literal from_repeated(shape::type_t t, const T& r) +{ + std::size_t size = r.size(); + return literal{{t, {size}}, r.begin(), r.end()}; +} + +instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) const +{ + auto attr = ins->get_operator().to_value(); + std::string key = "require_std_shape"; + if((attr.get(key, false)) or (not ins->get_shape().standard())) + { + return add_instruction(make_op("contiguous"), ins); + } + + return ins; +} + +instruction_ref onnx_parser::node_info::add_bias(const std::vector& args, + instruction_ref curr_ins, + uint64_t axis) const +{ + if(args.size() == 3) + { + auto bias_bcast = mod->add_instruction( + make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}), + args[2]); + return mod->add_instruction(make_op("add"), curr_ins, bias_bcast); + } + return curr_ins; +} + +instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::string& op_name, + instruction_ref arg0, + instruction_ref arg1) const +{ + return this->add_common_op(op_name, arg0, arg1); +} + +instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name, + std::vector inputs) const +{ + return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs)); +} + +instruction_ref +onnx_parser::node_info::add_instruction(const operation& op, + const std::vector& args) const +{ + return mod->add_instruction(op, args); +} + +instruction_ref onnx_parser::node_info::add_instruction(const operation& op, + const std::vector& args, + const std::vector& mods) const +{ + return mod->add_instruction(op, args, mods); +} + +instruction_ref onnx_parser::node_info::add_literal(literal l) const +{ + return mod->add_literal(std::move(l)); +} + +onnx_parser::onnx_parser() +{ + // Add all registered op parsers + for(auto&& name : get_op_parsers()) + ops.emplace(name, get_op_parser(name)); +} + +operation onnx_parser::load(const std::string& name, const node_info& info) const +{ + auto op = make_op(name); + auto v = op.to_value(); + for(auto&& x : v) + { + if(info.attributes.count(x.get_key()) == 0) + continue; + literal s = parse_value(info.attributes.at(x.get_key())); + if(x.is_array()) + { + std::vector values; + s.visit([&](auto y) { + std::transform(y.begin(), y.end(), std::back_inserter(values), [](auto z) { + return value(z); + }); + }); + x = values; + } + else + { + s.visit([&](auto y) { x = y.front(); }); + } + } + op.from_value(v); + return op; +} + +void onnx_parser::parse_undefined(module* mod, const std::string& name) +{ + if(!contains(instructions, name)) + { + auto ins = mod->add_instruction(make_op("undefined")); + instructions[name] = ins; + } +} + +void onnx_parser::parse_from(std::istream& is, std::string name) +{ + auto* mm = prog.get_main_module(); + this->filename = std::move(name); + auto parent_path = fs::path(this->filename).parent_path(); + if(not parent_path.empty()) + this->path = parent_path; + + onnx::ModelProto model; + if(model.ParseFromIstream(&is)) + { + auto version = get_opset_version(model); + opset_version = (version == -1) ? opset_version : version; + + if(model.has_graph()) + { + this->parse_graph(mm, model.graph()); + } + } + else + { + MIGRAPHX_THROW("PARSE_FROM: Failed reading onnx file: " + this->filename); + } +} + +void onnx_parser::parse_from(const void* data, std::size_t size) +{ + auto* mm = prog.get_main_module(); + onnx::ModelProto model; + if(model.ParseFromArray(data, size)) + { + auto version = get_opset_version(model); + opset_version = (version == -1) ? opset_version : version; + + if(model.has_graph()) + { + this->parse_graph(mm, model.graph()); + } + } + else + { + MIGRAPHX_THROW("Failed reading onnx file."); + } +} + +int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) +{ + const auto& opset_import = model.opset_import(); + int64_t version = -1; + for(const auto& opset : opset_import) + { + if(opset.has_version()) + { + version = std::max(version, opset.version()); + } + } + + return version; +} + +void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) +{ + std::unordered_map mod_insts; + for(auto&& f : graph.initializer()) + { + // backup instructions in parent mod + mod_insts[f.name()] = mod->add_literal(parse_tensor(f)); + } + + for(auto&& input : graph.input()) + { + const std::string& name = input.name(); + // input not in initializer_data, so it is a real input + if(!contains(mod_insts, name)) + { + // ONNX specification does not specify hwo to deal with the + // scenario that a nested subgraph contains a parameter with the + // name existed in its parent graph. + // In the current implementation, MIGraphX throws an exception for that. + if(contains(instructions, name)) + { + MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name + + "\" existing in parent graph!"); + } + + std::vector dims; + if(map_input_dims.count(name) > 0) + { + dims = map_input_dims.at(name); + } + + shape s = parse_type(input.type(), dims); + mod_insts[name] = mod->add_parameter(name, s); + } + } + + std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end())); + + for(auto&& node : graph.node()) + { + std::vector args; + for(auto&& input : node.input()) + { + if(input.empty()) + { + this->parse_undefined(mod, input); + } + if(instructions.count(input) == 0) + { + MIGRAPHX_THROW("PARSE_GRAPH: invalid onnx file. Input \"" + input + + "\" is unavailable due to unordered nodes!"); + } + args.push_back(instructions.at(input)); + } + + std::vector result; + std::size_t output_num = static_cast(node.output().size()); + if(ops.count(node.op_type()) == 0) + { + if(skip_unknown_operators) + result.push_back(mod->add_instruction(op::unknown{node.op_type()}, args)); + else + MIGRAPHX_THROW("Unknown operator: " + node.op_type()); + } + else + { + std::string node_name = node.op_type() + "_" + std::to_string(mod->size()); + result = ops[node.op_type()]( + *this, {get_attributes(node), output_num, node_name, mod}, args); + } + + output_num = std::min(output_num, result.size()); + std::transform(node.output().begin(), + node.output().begin() + output_num, + result.begin(), + std::inserter(instructions, instructions.end()), + [](auto&& x, auto&& y) { return std::make_pair(x, y); }); + } + + // Find instructions corresponding to the output + auto prog_output = graph.output(); + std::vector all_output_names; + std::vector prog_output_names; + std::transform(prog_output.begin(), + prog_output.end(), + std::back_inserter(all_output_names), + [](auto& node) { return node.name(); }); + std::copy_if( + all_output_names.begin(), + all_output_names.end(), + std::back_inserter(prog_output_names), + [&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); }); + + std::vector output_ins; + std::transform(prog_output_names.begin(), + prog_output_names.end(), + std::back_inserter(output_ins), + [&](const auto& name) { return instructions[name]; }); + + // add the return instuction + mod->add_return(output_ins); + + // remove instructions added in this mod + erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); }); +} + +literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const +{ + switch(attr.type()) + { + case onnx::AttributeProto::FLOAT: return literal{attr.f()}; + case onnx::AttributeProto::INT: return literal{attr.i()}; + case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t()); + case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats()); + case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints()); + case onnx::AttributeProto::UNDEFINED: + case onnx::AttributeProto::GRAPH: + case onnx::AttributeProto::STRING: + case onnx::AttributeProto::STRINGS: + case onnx::AttributeProto::TENSORS: + case onnx::AttributeProto::SPARSE_TENSOR: + case onnx::AttributeProto::SPARSE_TENSORS: + case onnx::AttributeProto::GRAPHS: return {}; + } + MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type())); +} + +literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const +{ + std::vector dims(t.dims().begin(), t.dims().end()); + if(not t.external_data().empty()) + { + const std::string& data_file = t.external_data().at(0).value(); + auto raw_buffer = read_buffer(path + "/" + data_file); + std::string s(raw_buffer.begin(), raw_buffer.end()); + auto type = get_type(t.data_type()); + return create_literal(type, dims, s.data()); + } + if(t.has_raw_data()) + { + const std::string& s = t.raw_data(); + auto type = get_type(t.data_type()); + return create_literal(type, dims, s.data()); + } + + switch(t.data_type()) + { + case onnx::TensorProto::BOOL: return create_literal(shape::bool_type, dims, t.int32_data()); + case onnx::TensorProto::INT8: return create_literal(shape::int8_type, dims, t.int32_data()); + case onnx::TensorProto::UINT8: return create_literal(shape::uint8_type, dims, t.int32_data()); + case onnx::TensorProto::INT16: return create_literal(shape::int16_type, dims, t.int32_data()); + case onnx::TensorProto::UINT16: return create_literal(shape::uint16_type, dims, t.int32_data()); + case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, t.int32_data()); + case onnx::TensorProto::UINT32: + return create_literal(shape::uint32_type, dims, t.uint64_data()); + case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data()); + case onnx::TensorProto::UINT64: + return create_literal(shape::uint64_type, dims, t.uint64_data()); + case onnx::TensorProto::FLOAT16: { + std::vector data_uint16(t.int32_data().begin(), t.int32_data().end()); + std::vector data_half; + std::transform(data_uint16.begin(), + data_uint16.end(), + std::back_inserter(data_half), + [](uint16_t raw_val) { return *reinterpret_cast(&raw_val); }); + return create_literal(shape::half_type, dims, data_half); + } + case onnx::TensorProto::DOUBLE: + return create_literal(shape::double_type, dims, t.double_data()); + case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data()); + case onnx::TensorProto::UNDEFINED: + case onnx::TensorProto::STRING: + case onnx::TensorProto::COMPLEX64: + case onnx::TensorProto::COMPLEX128: throw std::runtime_error(""); + } + MIGRAPHX_THROW("PARSE_TENSOR: Invalid tensor type"); +} +shape onnx_parser::parse_type(const onnx::TypeProto& t, + const std::vector& input_dims) const +{ + shape::type_t shape_type = get_type(t.tensor_type().elem_type()); + if(!input_dims.empty()) + { + return {shape_type, input_dims}; + } + + std::vector dims; + auto&& tensor_dims = t.tensor_type().shape().dim(); + std::transform(tensor_dims.begin(), + tensor_dims.end(), + std::back_inserter(dims), + [&](auto&& d) -> std::size_t { + if(d.has_dim_value()) + { + if(static_cast(d.dim_value()) <= 0) + { + return default_dim_value; + } + return d.dim_value(); + } + else + { + return default_dim_value; + } + }); + + if(dims.empty()) + return {shape_type}; + + return {shape_type, dims}; +} + +shape::type_t get_type(int dtype) +{ + switch(dtype) + { + case 1: return shape::float_type; + case 2: return shape::uint8_type; + case 3: return shape::int8_type; + case 4: return shape::uint16_type; + case 5: return shape::int16_type; + case 6: return shape::int32_type; + case 7: return shape::int64_type; + case 9: return shape::bool_type; + case 10: return shape::half_type; + case 11: return shape::double_type; + case 12: return shape::uint32_type; + case 13: return shape::uint64_type; + default: { + MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported"); + } + } +} + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/op_parser.cpp b/src/onnx/op_parser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d5757dbce66fa85ad7951f9b001c75f5b52ed2e7 --- /dev/null +++ b/src/onnx/op_parser.cpp @@ -0,0 +1,31 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +std::unordered_map& op_parser_map() +{ + static std::unordered_map m; // NOLINT + return m; +} + +void register_op_parser(const std::string& name, onnx_parser::op_func f) +{ + op_parser_map()[name] = std::move(f); +} +onnx_parser::op_func get_op_parser(const std::string& name) { return op_parser_map().at(name); } +std::vector get_op_parsers() +{ + std::vector result; + std::transform(op_parser_map().begin(), + op_parser_map().end(), + std::back_inserter(result), + [&](auto&& p) { return p.first; }); + return result; +} + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/padding.cpp b/src/onnx/padding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..09dc15ae39e6a1723b5c11bac213be79dfb3de87 --- /dev/null +++ b/src/onnx/padding.cpp @@ -0,0 +1,156 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +void cal_auto_padding_size(onnx_parser::node_info info, + value& v, + const std::vector& k_lens, + const std::vector& dilation, + const std::vector& in_lens, + std::vector& paddings) +{ + size_t kdims = in_lens.size() - 2; + assert(k_lens.size() == kdims and dilation.size() == kdims); + + if(!contains(info.attributes, "auto_pad")) + { + return; + } + + auto auto_pad = info.attributes["auto_pad"].s(); + if(auto_pad.find("SAME") != std::string::npos) + { + bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos); + paddings.resize(2 * kdims); + + for(size_t i = 0; i < paddings.size() / 2; i++) + { + calculate_padding(i, + paddings, + in_lens[i + 2], + v["stride"][i].to(), + dilation[i], + k_lens[i], + is_same_upper); + } + } +} + +bool is_asym_padding(const std::vector& padding) +{ + assert(padding.size() % 2 == 0); + size_t pad_ndims = padding.size() / 2; + + for(size_t i = 0; i < pad_ndims; i++) + { + if(padding[i] != padding[i + pad_ndims]) + { + return true; + } + } + return false; +} + +void check_padding_mode(const onnx_parser::node_info& info, const std::string& op_name) +{ + // ensure pads availabe only when auto_pad is "NOT_SET" + if(contains(info.attributes, "pads") and contains(info.attributes, "auto_pad")) + { + auto s = info.attributes.at("auto_pad").s(); + if(to_upper(s) != "NOTSET") + { + MIGRAPHX_THROW("PARSE_" + op_name + + ": auto_pad and padding cannot be specified simultaneously"); + } + } +} + +static void +tune_padding_to_symmetric(int64_t& left, int64_t& right, const int stride, int64_t& s_start) +{ + s_start = 0; + if(left > right) + { + right = left; + } + else if(left < right) + { + auto diff = right - left; + s_start = (diff + stride - 1) / stride; + left = left + s_start * stride; + right = left; + } +} + +void tune_padding_size(const value& v, + std::vector& padding, + int count_include_pad, + std::vector& s_start) +{ + // maxpooling or count_include_pad is 1, no change is required. + if(v.at("mode").to() == op::pooling_mode::max or count_include_pad == 1) + { + return; + } + + // if padding is symmetric, return directly + if(!is_asym_padding(padding)) + { + return; + } + + // asymmetric padding, make it symmetric + std::size_t n_dims = padding.size() / 2; + s_start.resize(n_dims); + for(std::size_t i = 0; i < n_dims; ++i) + { + tune_padding_to_symmetric( + padding[i], padding[i + n_dims], v.at("stride")[i].to(), s_start[i]); + } +} + +void check_asym_padding(const onnx_parser::node_info& info, + instruction_ref& ins, + const std::vector& padding, + value& v, + int count_include_pad, + float pad_val) +{ + size_t pad_ndims = padding.size() / 2; + auto left_pad_it = padding.begin(); + auto right_pad_it = left_pad_it + pad_ndims; + + if(count_include_pad == 1) + { + std::vector asym_pads{0, 0, 0, 0}; // don't pad N and C + // add left pads + asym_pads.insert(asym_pads.begin() + 2, left_pad_it, right_pad_it); + // add right pads + asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end()); + ins = info.add_instruction(make_op("pad", {{"pads", asym_pads}, {"value", pad_val}}), ins); + std::vector new_padding(padding.size()); + // subtract asym padding originally found from parsing the operator + std::transform(padding.begin(), + left_pad_it, + asym_pads.begin() + 2, + new_padding.begin(), + std::minus()); + std::transform(right_pad_it, + padding.end(), + asym_pads.begin() + pad_ndims + 4, + new_padding.begin() + pad_ndims, + std::minus()); + v["padding"] = new_padding; + } +} + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_arg_op.cpp b/src/onnx/parse_arg_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..41110f498300536d6302c1806a52a088819da15a --- /dev/null +++ b/src/onnx/parse_arg_op.cpp @@ -0,0 +1,44 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_arg_op : op_parser +{ + std::vector operators() const { return {{"ArgMax", "argmax"}, {"ArgMin", "argmin"}}; } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& parser, + onnx_parser::node_info info, + const std::vector& args) const + { + int64_t axis = 0; + if(contains(info.attributes, "axis")) + { + axis = static_cast(parser.parse_value(info.attributes.at("axis")).at()); + } + + int keep_dims = 1; + if(contains(info.attributes, "keepdims")) + { + keep_dims = parser.parse_value(info.attributes.at("keepdims")).at(); + } + + if(keep_dims == 0) + { + auto ins = info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args); + return info.add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins); + } + else + { + return info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args); + } + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_aten.cpp b/src/onnx/parse_aten.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d900baf75abcd44f72ad3c360786cb677625e2ad --- /dev/null +++ b/src/onnx/parse_aten.cpp @@ -0,0 +1,67 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +enum class reduce_mode_t +{ + sum = 0, + mean = 1, + max = 2 +}; + +struct parse_aten : op_parser +{ + std::vector operators() const { return {{"ATen"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + onnx_parser::node_info info, + std::vector args) const + { + if(contains(info.attributes, "operator")) + { + auto op_name = info.attributes.at("operator").s(); + if(op_name.find("embedding_bag") != std::string::npos) + { + return parse_embedding_bag(info, std::move(args)); + } + } + MIGRAPHX_THROW("PARSE_ATEN: unsupported custom operator"); + } + + instruction_ref parse_embedding_bag(onnx_parser::node_info info, + std::vector args) const + { + if(args[2]->get_shape().elements() != 1) + MIGRAPHX_THROW("PARSE_EMBEDDING_BAG: MIGraphX only supports offsets of size 1"); + reduce_mode_t reduce_mode = reduce_mode_t::sum; + if(contains(info.attributes, "mode")) + { + reduce_mode = static_cast(info.attributes.at("mode").i()); + } + + auto l0 = info.add_instruction(make_op("gather"), args[0], args[1]); + switch(reduce_mode) + { + case reduce_mode_t::sum: + l0 = info.add_instruction(make_op("reduce_sum", {{"axes", {0}}}), l0); + break; + case reduce_mode_t::mean: + l0 = info.add_instruction(make_op("reduce_mean", {{"axes", {0}}}), l0); + break; + case reduce_mode_t::max: + l0 = info.add_instruction(make_op("reduce_max", {{"axes", {0}}}), l0); + break; + } + return l0; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_batchnorm.cpp b/src/onnx/parse_batchnorm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc3718ca6d6a6234fb5df445cba66fb0d6a84257 --- /dev/null +++ b/src/onnx/parse_batchnorm.cpp @@ -0,0 +1,43 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_batchnorm : op_parser +{ + std::vector operators() const { return {{"BatchNormalization"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + const std::vector& args) const + { + float epsilon = 1e-5f; + float momentum = 0.9f; + op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial; + if(contains(info.attributes, "epsilon")) + { + epsilon = parser.parse_value(info.attributes.at("epsilon")).at(); + } + if(contains(info.attributes, "momentum")) + { + momentum = parser.parse_value(info.attributes.at("momentum")).at(); + } + if(contains(info.attributes, "spatial")) + { + bn_mode = (parser.parse_value(info.attributes.at("spatial")).at() > 0) + ? op::batch_norm_inference::spatial + : op::batch_norm_inference::per_activation; + } + op::batch_norm_inference op{epsilon, momentum, bn_mode}; + return info.add_instruction(op, args); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_binary_op.cpp b/src/onnx/parse_binary_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed8f4ca6f7e983e1de64aad9d6d5d32e05add390 --- /dev/null +++ b/src/onnx/parse_binary_op.cpp @@ -0,0 +1,55 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_binary_op : op_parser +{ + std::vector operators() const + { + return {{"Add", "add"}, + {"Div", "div"}, + {"And", "logical_and"}, + {"Or", "logical_or"}, + {"Xor", "logical_xor"}, + {"Mul", "mul"}, + {"PRelu", "prelu"}, + {"Sub", "sub"}}; + } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + if(args.size() != 2) + MIGRAPHX_THROW("binary operators should have 2 operands"); + if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis")) + { + uint64_t broadcasted = + parser.parse_value(info.attributes.at("broadcast")).at(); + if(broadcasted != 0) + { + uint64_t axis = parser.parse_value(info.attributes.at("axis")).at(); + auto l = info.add_instruction( + make_op("broadcast", + {{"axis", axis}, {"out_lens", args[0]->get_shape().lens()}}), + args[1]); + return info.add_instruction(make_op(opd.op_name), args[0], l); + } + return info.add_instruction(make_op(opd.op_name), args); + } + else + { + return info.add_broadcastable_binary_op(opd.op_name, args[0], args[1]); + } + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_cast.cpp b/src/onnx/parse_cast.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f797d3241eb52d5978e6e74a36479ded48830272 --- /dev/null +++ b/src/onnx/parse_cast.cpp @@ -0,0 +1,31 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_cast : op_parser +{ + std::vector operators() const { return {{"Cast"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + const std::vector& args) const + { + if(!contains(info.attributes, "to")) + { + MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!"); + } + + int to_type = parser.parse_value(info.attributes.at("to")).at(); + shape::type_t type = get_type(to_type); + return info.add_instruction(make_op("convert", {{"target_type", type}}), args); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_celu.cpp b/src/onnx/parse_celu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1140d839aef883c782a10bd55bdf139ad1ff5150 --- /dev/null +++ b/src/onnx/parse_celu.cpp @@ -0,0 +1,57 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_celu : op_parser +{ + std::vector operators() const { return {{"Celu"}}; } + + instruction_ref parse(const op_desc&, + const onnx_parser&, + const onnx_parser::node_info& info, + std::vector args) const + { + float alpha = 1.0; + if(contains(info.attributes, "alpha")) + { + alpha = info.attributes.at("alpha").f(); + } + if(float_equal(alpha, 0.0f)) + { + MIGRAPHX_THROW("CELU: alpha is zero (division by zero)"); + } + + auto input_lens = args[0]->get_shape().lens(); + auto input_type = args[0]->get_shape().type(); + if(input_type != migraphx::shape::float_type) + { + MIGRAPHX_THROW("CELU: input tensor not float type"); + } + auto zero_lit = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}})); + auto one_lit = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}})); + auto alpha_lit = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + info.add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}})); + auto linear_part = info.add_instruction(migraphx::make_op("max"), zero_lit, args[0]); + auto divi = info.add_instruction(migraphx::make_op("div"), args[0], alpha_lit); + auto expo = info.add_instruction(migraphx::make_op("exp"), divi); + auto sub = info.add_instruction(migraphx::make_op("sub"), expo, one_lit); + auto mul = info.add_instruction(migraphx::make_op("mul"), alpha_lit, sub); + auto exp_part = info.add_instruction(migraphx::make_op("min"), zero_lit, mul); + return info.add_instruction(migraphx::make_op("add"), linear_part, exp_part); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_clip.cpp b/src/onnx/parse_clip.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72fd113a71b31c929f95a4fe515f03cbd8d475e6 --- /dev/null +++ b/src/onnx/parse_clip.cpp @@ -0,0 +1,68 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_clip : op_parser +{ + std::vector operators() const { return {{"Clip"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + instruction_ref min_arg; + instruction_ref max_arg; + bool min_used = false; + bool max_used = false; + + if(args.size() == 3 and args[2]->name() != "undefined") + { + max_arg = args[2]; + max_used = true; + } + + if(args.size() >= 2 and args[1]->name() != "undefined") + { + min_arg = args[1]; + min_used = true; + } + // if using previous opset for attributes + else if(contains(info.attributes, "min") and contains(info.attributes, "max")) + { + + float min_val = parser.parse_value(info.attributes.at("min")).at(); + float max_val = parser.parse_value(info.attributes.at("max")).at(); + min_arg = info.add_literal(min_val); + max_arg = info.add_literal(max_val); + min_used = true; + max_used = true; + } + + if(min_used and max_used) + { + return info.add_common_op("clip", args[0], min_arg, max_arg); + } + else if(max_used) + { + return info.add_broadcastable_binary_op("min", args[0], max_arg); + } + else if(min_used) + { + return info.add_broadcastable_binary_op("max", args[0], min_arg); + } + else + { + return info.add_instruction(make_op("identity"), args[0]); + } + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_compare_op.cpp b/src/onnx/parse_compare_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2f4944c6430dfbcee6f85e6e6ddb5ca7a010aaf7 --- /dev/null +++ b/src/onnx/parse_compare_op.cpp @@ -0,0 +1,33 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_compare_op : op_parser +{ + std::vector operators() const + { + return {{"Equal", "equal"}, {"Greater", "greater"}, {"Less", "less"}}; + } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + auto l = info.add_broadcastable_binary_op(opd.op_name, args[0], args[1]); + if(l->get_shape().type() != shape::bool_type) + { + l = info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), l); + } + return l; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_constant.cpp b/src/onnx/parse_constant.cpp new file mode 100644 index 0000000000000000000000000000000000000000..77eb3582dca48e33a442a358f195c887e1becc7a --- /dev/null +++ b/src/onnx/parse_constant.cpp @@ -0,0 +1,40 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_constant : op_parser +{ + std::vector operators() const { return {{"Constant"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + const std::vector& /*args*/) const + { + literal v = parser.parse_value(info.attributes.at("value")); + // return empty literal + if(v.get_shape().elements() == 0) + { + return info.add_literal(literal{}); + } + + auto dim_size = info.attributes.at("value").t().dims_size(); + // if dim_size is 0, it is a scalar + if(dim_size == 0) + { + migraphx::shape scalar_shape{v.get_shape().type()}; + return info.add_literal(migraphx::literal{scalar_shape, v.data()}); + } + + return info.add_literal(v); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_constant_fill.cpp b/src/onnx/parse_constant_fill.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0881c5c0507b2d716b0eb05128ac6030997d137b --- /dev/null +++ b/src/onnx/parse_constant_fill.cpp @@ -0,0 +1,94 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +// Use a literal instruction to replace the constantFill operator. In RNN, input shape +// and value are fixed, so no need to do the actual computation for the constantFill +// operator +struct parse_constant_fill : op_parser +{ + std::vector operators() const { return {{"ConstantFill"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + int input_as_shape = 0; + int dtype = 1; + float value = 0.0f; + + if(contains(info.attributes, "dtype")) + { + dtype = parser.parse_value(info.attributes.at("dtype")).at(); + } + shape::type_t type = get_type(dtype); + + if(contains(info.attributes, "input_as_shape")) + { + input_as_shape = parser.parse_value(info.attributes.at("input_as_shape")).at(); + } + + if(contains(info.attributes, "value")) + { + value = parser.parse_value(info.attributes.at("value")).at(); + } + + if(contains(info.attributes, "extra_shape")) + { + MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute"); + } + + if(input_as_shape == 1) + { + if(args.size() != 1) + { + MIGRAPHX_THROW("ConstantFill: need an input argument as output shape"); + } + + if(contains(info.attributes, "shape")) + { + MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input " + "at the same time"); + } + + migraphx::argument in = args[0]->eval(); + check_arg_empty(in, "ConstantFill: dynamic shape is not supported"); + + std::vector dims; + in.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); + migraphx::shape s(type, dims); + std::vector values(s.elements(), value); + return info.add_literal(migraphx::literal(s, values)); + } + else if(input_as_shape == 0) + { + if(!contains(info.attributes, "shape")) + { + MIGRAPHX_THROW("ConstantFill: attribute output shape is needed"); + } + + literal ls = parser.parse_value(info.attributes.at("shape")); + std::vector dims; + ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); }); + migraphx::shape s{type, dims}; + std::vector values(s.elements(), value); + return info.add_literal(migraphx::literal(s, values)); + } + else + { + MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape"); + } + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_constant_of_shape.cpp b/src/onnx/parse_constant_of_shape.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9ebe832bbce81b87c26f726428c0a2dedf8ccc36 --- /dev/null +++ b/src/onnx/parse_constant_of_shape.cpp @@ -0,0 +1,75 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_constant_of_shape : op_parser +{ + std::vector operators() const { return {{"ConstantOfShape"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + literal l_val{}; + if(contains(info.attributes, "value")) + { + l_val = parser.parse_value(info.attributes.at("value")); + if(l_val.get_shape().elements() != 1) + { + MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!"); + } + } + else + { + l_val = literal({shape::float_type, {1}, {0}}, {0.0f}); + } + + // input is empty, output is a scalar + auto type = l_val.get_shape().type(); + + if(args.empty()) + { + MIGRAPHX_THROW("ConstantOfShape : must have 1 input!"); + } + else + { + migraphx::shape s; + // empty input tensor, output is a scalar + if(args[0]->get_shape().elements() == 0) + { + s = migraphx::shape{type, {1}, {0}}; + } + else + { + migraphx::argument in = args[0]->eval(); + check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported"); + + std::vector dims; + in.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); + s = migraphx::shape{type, dims}; + } + + literal l_out{}; + l_val.visit([&](auto val) { + using val_type = std::remove_cv_t; + // l_val contains only one element + std::vector out_vec(s.elements(), val.front()); + l_out = literal(s, out_vec); + }); + + return info.add_literal(l_out); + } + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_convolution.cpp b/src/onnx/parse_convolution.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8462d9a6b5712b78b8d70a59e8331343071012e5 --- /dev/null +++ b/src/onnx/parse_convolution.cpp @@ -0,0 +1,93 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_convolution : op_parser +{ + std::vector operators() const + { + return {{"Conv", "convolution"}, {"ConvInteger", "quant_convolution"}}; + } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + auto op = make_op(opd.op_name); + auto values = op.to_value(); + auto l0 = args[0]; + auto weights = args[1]; + auto in_lens = l0->get_shape().lens(); + assert(in_lens.size() > 2); + auto kdims = in_lens.size() - 2; + + // ensure pads availabe only when auto_pad is "NOT_SET" + check_padding_mode(info, "CONV"); + + if(contains(info.attributes, "strides")) + { + values["stride"].clear(); + copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"])); + check_attr_sizes(kdims, values["stride"].size(), "PARSE_CONV: inconsistent strides"); + } + if(contains(info.attributes, "dilations")) + { + values["dilation"].clear(); + copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilation"])); + check_attr_sizes( + kdims, values["dilation"].size(), "PARSE_CONV: inconsistent dilations"); + } + + std::vector padding; + if(contains(info.attributes, "pads")) + { + values["padding"].clear(); + copy(info.attributes["pads"].ints(), std::back_inserter(padding)); + check_attr_sizes(kdims, padding.size() / 2, "PARSE_CONV: inconsistent paddings"); + } + + if(contains(info.attributes, "auto_pad")) + { + auto weight_lens = weights->get_shape().lens(); + std::vector k_lens(weight_lens.begin() + 2, weight_lens.end()); + cal_auto_padding_size(info, + values, + k_lens, + values["dilation"].to_vector(), + in_lens, + padding); + auto auto_pad = info.attributes["auto_pad"].s(); + if(auto_pad.find("SAME") != std::string::npos) + { + values["padding_mode"] = to_value(op::padding_mode_t::same); + } + } + values["padding"] = std::vector(padding.begin(), padding.end()); + + if(contains(info.attributes, "group")) + { + values["group"] = parser.parse_value(info.attributes.at("group")).at(); + } + + recalc_conv_attributes(values, kdims); + + op.from_value(values); + auto l1 = info.add_instruction(op, l0, args[1]); + return info.add_bias(args, l1, 1); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_deconvolution.cpp b/src/onnx/parse_deconvolution.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed6e9d36d865facd34e739e94795a6548ea2cf84 --- /dev/null +++ b/src/onnx/parse_deconvolution.cpp @@ -0,0 +1,157 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +template +std::vector to_int64_vector(const std::vector& input_vector) +{ + std::vector output_vector(input_vector.begin(), input_vector.end()); + return output_vector; +} + +struct parse_deconvolution : op_parser +{ + std::vector operators() const { return {{"ConvTranspose"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + operation op = make_op("deconvolution"); + value values = op.to_value(); + // op::deconvolution op; + auto l0 = args[0]; + std::vector padding; + bool asym_padding = false; + auto in_lens = l0->get_shape().lens(); + assert(in_lens.size() > 2); + auto kdims = in_lens.size() - 2; + + // ensure pads availabe only when auto_pad is "NOT_SET" + check_padding_mode(info, "CONV_TRANSPOSE"); + + if(contains(info.attributes, "pads")) + { + copy(info.attributes["pads"].ints(), std::back_inserter(padding)); + + asym_padding = is_asym_padding(padding); + + if(not asym_padding) + { + size_t pad_ndims = padding.size() / 2; + check_attr_sizes(kdims, pad_ndims, "PARSE_CONV_TRANSPOSE: inconsistent paddings"); + values["padding"].clear(); + std::transform(padding.begin(), + padding.begin() + pad_ndims, + std::back_inserter(values["padding"]), + [](auto pad_val) { return pad_val; }); + } + } + if(contains(info.attributes, "strides")) + { + values["stride"].clear(); + copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"])); + check_attr_sizes( + kdims, values["stride"].size(), "PARSE_CONV_TRANSPOSE: inconsistent strides"); + } + if(contains(info.attributes, "dilations")) + { + values["dilation"].clear(); + copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilation"])); + check_attr_sizes( + kdims, values["dilation"].size(), "PARSE_CONV_TRANSPOSE: inconsistent dilations"); + } + if(contains(info.attributes, "auto_pad")) + { + auto s = info.attributes["auto_pad"].s(); + if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET") + { + MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: auto_pad and padding cannot be specified " + "simultaneously"); + } + + if(s.find("SAME") != std::string::npos) + { + values["padding_mode"] = to_value(op::padding_mode_t::same); + } + } + + if(contains(info.attributes, "group")) + { + values["group"] = parser.parse_value(info.attributes.at("group")).at(); + } + + recalc_conv_attributes(values, kdims); + + op.from_value(values); + auto l1 = info.add_instruction(op, l0, args[1]); + std::vector dims = to_int64_vector(l1->get_shape().lens()); + std::vector curr_shape(dims.begin() + 2, dims.end()); + if(asym_padding) + { + std::vector axes(kdims); + std::iota(axes.begin(), axes.end(), 2); // ignore first 2 dims + + auto pad_kdim_start = padding.begin() + kdims; + std::vector starts(padding.begin(), pad_kdim_start); + + std::vector ends{}; + std::transform(curr_shape.begin(), + curr_shape.end(), + pad_kdim_start, + std::back_inserter(ends), + [](auto curr_dim, auto pad_dim) { return curr_dim - pad_dim; }); + + l1 = info.add_instruction( + make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), l1); + } + + if(contains(info.attributes, "output_padding")) + { + size_t non_kdims = dims.size() * 2 - kdims; + std::vector output_padding(non_kdims, 0); + copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding)); + check_attr_sizes(kdims, + output_padding.size() - non_kdims, + "PARSE_CONV_TRANSPOSE: inconsistent output padding"); + l1 = info.add_instruction(make_op("pad", {{"pads", output_padding}}), l1); + } + + if(contains(info.attributes, "output_shape")) + { + std::vector output_shape; + copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape)); + check_attr_sizes( + kdims, output_shape.size(), "PARSE_CONV_TRANSPOSE: inconsistent output shape"); + dims = to_int64_vector(l1->get_shape().lens()); + copy(dims.begin() + 2, dims.end(), curr_shape.begin()); + if(curr_shape != output_shape) + { + std::vector target_padding(dims.size() * 2 - kdims, 0); + std::transform(output_shape.begin(), + output_shape.end(), + curr_shape.begin(), + std::back_inserter(target_padding), + [](auto out_dim, auto curr_dim) { return out_dim - curr_dim; }); + l1 = info.add_instruction(make_op("pad", {{"pads", target_padding}}), l1); + } + } + + return info.add_bias(args, l1, 1); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_depthtospace.cpp b/src/onnx/parse_depthtospace.cpp new file mode 100644 index 0000000000000000000000000000000000000000..86cd9462eafeb0636e7085f9864b144b0a44ea9f --- /dev/null +++ b/src/onnx/parse_depthtospace.cpp @@ -0,0 +1,74 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_depthtospace : op_parser +{ + std::vector operators() const { return {{"DepthToSpace"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + auto s = args[0]->get_shape(); + // mode attribute of DepthToSpace + auto mode = std::string("DCR"); + if(contains(info.attributes, "mode")) + { + mode = info.attributes.at("mode").s(); // DCR or CRD? + } + // blocksize attribute of DepthToSpace + int blocksize = 0; + if(contains(info.attributes, "blocksize")) + { + blocksize = info.attributes.at("blocksize").i(); + } + if(blocksize < 1) + { + MIGRAPHX_THROW("DepthToSpace: blocksize is less than 1"); + } + // calculate dimensions + auto lens1 = s.lens(); + auto lens2 = s.lens(); + unsigned long divisor = std::pow(blocksize, 2); + if((lens2[1] % divisor) == 0) + lens2[1] = lens2[1] / divisor; + else + MIGRAPHX_THROW("DepthToSpace: div by blocksize quotient not int "); + lens1.push_back(lens1[2]); + lens1.push_back(lens1[3]); + lens2[2] = lens2[2] * blocksize; + lens2[3] = lens2[3] * blocksize; + lens1[2] = blocksize; + std::vector perm; + if(mode == "DCR") + { + lens1[3] = lens1[1] / divisor; + lens1[1] = blocksize; + perm = {0, 3, 4, 1, 5, 2}; + } + else if(mode == "CRD") + { + lens1[1] = lens1[1] / divisor; + lens1[3] = blocksize; + perm = {0, 1, 4, 2, 5, 3}; + } + else + MIGRAPHX_THROW("DepthToSpace: mode attribute cannot be read."); + + auto temp1 = info.add_instruction(make_op("reshape", {{"dims", lens1}}), args[0]); + auto temp2 = info.add_instruction(make_op("transpose", {{"permutation", perm}}), temp1); + return info.add_instruction(make_op("reshape", {{"dims", lens2}}), + info.make_contiguous(temp2)); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_dequantizelinear.cpp b/src/onnx/parse_dequantizelinear.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4bef48e6304834bc2666dd3ddc4047feb5732ff0 --- /dev/null +++ b/src/onnx/parse_dequantizelinear.cpp @@ -0,0 +1,66 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_dequantizelinear : op_parser +{ + std::vector operators() const { return {{"DequantizeLinear"}}; } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + const std::vector& args) const + { + int axis = 1; + if(contains(info.attributes, "axis")) + axis = info.attributes.at("axis").i(); + + auto input_lens = args[0]->get_shape().lens(); + auto n_dim = input_lens.size(); + + instruction_ref x_scale; + if(args[1]->get_shape().elements() != 1) + { + auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); + x_scale = info.add_instruction( + make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]); + } + else + { + x_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}), + args[1]); + } + + if(args.size() == 3) + { + auto x_zero_point = args[2]; + if(x_zero_point->get_shape().elements() != 1) + { + auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); + x_zero_point = info.add_instruction( + make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), + x_zero_point); + } + else + { + x_zero_point = info.add_instruction( + make_op("multibroadcast", {{"out_lens", input_lens}}), x_zero_point); + } + + return info.add_instruction( + make_op("dequantizelinear"), args[0], x_scale, x_zero_point); + } + + return info.add_instruction(make_op("dequantizelinear"), args[0], x_scale); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_dropout.cpp b/src/onnx/parse_dropout.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0c93dd059e23e150c75ff983786c08fb2abcdf8b --- /dev/null +++ b/src/onnx/parse_dropout.cpp @@ -0,0 +1,31 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_dropout : op_parser +{ + std::vector operators() const { return {{"Dropout"}}; } + + std::vector parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + auto out = info.add_instruction(make_op("identity"), args[0]); + auto s = args[0]->get_shape(); + std::vector vec(s.elements(), 1); + shape mask_s{shape::bool_type, s.lens()}; + auto mask = info.add_literal(literal(mask_s, vec)); + + return {out, mask}; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_expand.cpp b/src/onnx/parse_expand.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bfe575d6d154cbd65c00154a9f5e9393ab03458b --- /dev/null +++ b/src/onnx/parse_expand.cpp @@ -0,0 +1,33 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_expand : op_parser +{ + std::vector operators() const { return {{"Expand"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + auto in_lens = args[0]->get_shape().lens(); + migraphx::argument arg_s = args[1]->eval(); + check_arg_empty(arg_s, "Expand: dynamic shape is not supported"); + std::vector dims; + arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); + auto out_lens = compute_broadcasted_lens(in_lens, dims); + return info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), args[0]); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_eyelike.cpp b/src/onnx/parse_eyelike.cpp new file mode 100644 index 0000000000000000000000000000000000000000..649cbcb5ddd45c310675894404356db46c5def3b --- /dev/null +++ b/src/onnx/parse_eyelike.cpp @@ -0,0 +1,73 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_eyelike : op_parser +{ + std::vector operators() const { return {{"EyeLike"}}; } + + instruction_ref parse(const op_desc&, + const onnx_parser&, + const onnx_parser::node_info& info, + std::vector args) const + { + auto input_shape = args[0]->get_shape(); + auto input_lens = input_shape.lens(); + if(input_lens.size() != 2) + { + MIGRAPHX_THROW("EYELIKE: tensor input not of rank 2"); + } + std::ptrdiff_t num_rows = input_lens.front(); + std::ptrdiff_t num_cols = input_lens.back(); + + shape::type_t output_type = args[0]->get_shape().type(); + if(contains(info.attributes, "dtype")) + { + output_type = get_type(info.attributes.at("dtype").i()); + } + + std::ptrdiff_t k = 0; + if(contains(info.attributes, "k")) + { + k = info.attributes.at("k").i(); + } + if(k >= 0) + { + if(k >= num_cols) + { + std::ostringstream oss; + oss << "EYELIKE: positive k out of bounds, k = " << k << " num_cols = " << num_cols; + MIGRAPHX_THROW(oss.str()); + } + } + else + { + if(std::abs(k) >= num_rows) + { + std::ostringstream oss; + oss << "EYELIKE: negative k out of bounds, k = " << k << " num_rows = " << num_cols; + MIGRAPHX_THROW(oss.str()); + } + } + + std::vector eyelike_mat(num_rows * num_cols, 0); + for(std::ptrdiff_t i = 0; i < num_rows; ++i) + { + auto idx = i + k; + if(idx < num_cols and idx >= 0) + eyelike_mat[(num_cols + 1) * i + k] = char{1}; + } + return info.add_literal( + migraphx::literal{migraphx::shape{output_type, input_lens}, eyelike_mat}); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_gather_elements.cpp b/src/onnx/parse_gather_elements.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e866214ef80a65606ad9d86503ec3905994dc246 --- /dev/null +++ b/src/onnx/parse_gather_elements.cpp @@ -0,0 +1,79 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_gather_elements : op_parser +{ + std::vector operators() const { return {{"GatherElements"}}; } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + int axis = 0; + if(contains(info.attributes, "axis")) + { + axis = parser.parse_value(info.attributes.at("axis")).at(); + } + + // standardize input data and index + auto arg_data = info.make_contiguous(args[0]); + auto arg_ind = info.make_contiguous(args[1]); + + auto data_s = arg_data->get_shape(); + auto ind_s = arg_ind->get_shape(); + + if(data_s.lens().size() != ind_s.lens().size()) + { + MIGRAPHX_THROW("PARSE_GATHER_ELEMENTS: input data and index must have the same rank!"); + } + + int n_rank = static_cast(data_s.lens().size()); + int tuned_axis = tune_axis(n_rank, axis, opd.op_name); + + auto axis_stride = data_s.strides()[tuned_axis]; + int64_t data_elem_num = data_s.elements(); + // reshape the input data as one dimension and used as input data + // to the gather operator + arg_data = info.add_instruction(make_op("reshape", {{"dims", {data_elem_num}}}), arg_data); + + std::size_t elem_num = ind_s.elements(); + std::vector ind_index(elem_num); + std::iota(ind_index.begin(), ind_index.end(), 0); + + // convert index in input indices to that in input data + std::vector data_indices(elem_num); + std::transform(ind_index.begin(), ind_index.end(), data_indices.begin(), [&](auto i) { + return data_s.index(ind_s.multi(i)); + }); + + std::vector vec_axis_ind(elem_num); + std::transform(ind_index.begin(), ind_index.end(), vec_axis_ind.begin(), [&](auto i) { + return ind_s.multi(i)[tuned_axis]; + }); + + auto l_shape_idx = + info.add_literal(literal(ind_s, data_indices.begin(), data_indices.end())); + auto l_dim_idx = info.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end())); + auto l_stride = info.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}}); + l_stride = + info.add_instruction(make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride); + auto dim_diff = info.add_instruction(make_op("sub"), arg_ind, l_dim_idx); + auto delta = info.add_instruction(make_op("mul"), dim_diff, l_stride); + auto ind = info.add_instruction(make_op("add"), l_shape_idx, delta); + + auto op = make_op("gather", {{"axis", 0}}); + return info.add_instruction(op, arg_data, ind); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_gemm.cpp b/src/onnx/parse_gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9934a83180125ec9c1043a6a87b886c05f9b1028 --- /dev/null +++ b/src/onnx/parse_gemm.cpp @@ -0,0 +1,97 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_gemm : op_parser +{ + std::vector operators() const { return {{"Gemm"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + float alpha = 1.0f; + float beta = 1.0f; + bool transa = false; + bool transb = false; + if(contains(info.attributes, "alpha")) + { + alpha = parser.parse_value(info.attributes.at("alpha")).at(); + } + if(contains(info.attributes, "beta")) + { + beta = parser.parse_value(info.attributes.at("beta")).at(); + } + if(contains(info.attributes, "transA")) + { + transa = parser.parse_value(info.attributes.at("transA")).at(); + } + if(contains(info.attributes, "transB")) + { + transb = parser.parse_value(info.attributes.at("transB")).at(); + } + + std::vector perm(args[0]->get_shape().lens().size()); + std::iota(perm.begin(), perm.end(), int64_t{0}); + // swap the last two elements + std::swap(*perm.rbegin(), *(perm.rbegin() + 1)); + + auto l1 = args[0]; + auto dot_type = l1->get_shape().type(); + + if(alpha != 1.0f) + { + auto alpha_literal = info.add_literal(alpha); + l1 = info.add_broadcastable_binary_op("mul", alpha_literal, l1); + if(l1->get_shape().type() != dot_type) + { + l1 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), l1); + } + } + + l1 = + (transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), l1) : l1; + auto l2 = (transb) + ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1]) + : args[1]; + + auto ret = info.add_instruction(make_op("dot"), l1, l2); + + if(args.size() == 3) + { + if(not float_equal(beta, 0.0f) && args[2]->get_shape().elements() > 0) + { + auto out_lens = l1->get_shape().lens(); + out_lens.back() = l2->get_shape().lens().back(); + auto l3 = args[2]; + auto l3_lens = l3->get_shape().lens(); + if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end())) + { + l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), + args[2]); + } + auto beta_literal = info.add_literal(beta); + auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal); + if(beta_l3->get_shape().type() != dot_type) + { + beta_l3 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), + beta_l3); + } + + return info.add_instruction(make_op("add"), ret, beta_l3); + } + } + + return ret; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_generic_op.cpp b/src/onnx/parse_generic_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0bfb4296f7c3f7103699df4a1f64c0ef919c02c5 --- /dev/null +++ b/src/onnx/parse_generic_op.cpp @@ -0,0 +1,76 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_generic_op : op_parser +{ + std::vector operators() const + { + // clang-format off + return {{"Abs", "abs"}, + {"Acos", "acos"}, + {"Acosh", "acosh"}, + {"Asin", "asin"}, + {"Asinh", "asinh"}, + {"Atan", "atan"}, + {"Atanh", "atanh"}, + {"Ceil", "ceil"}, + {"Concat", "concat"}, + {"Cos", "cos"}, + {"Cosh", "cosh"}, + {"Elu", "elu"}, + {"Erf", "erf"}, + {"Exp", "exp"}, + {"Flatten", "flatten"}, + {"Floor", "floor"}, + {"Gather", "gather"}, + {"GatherND", "gathernd"}, + {"Identity", "identity"}, + {"IsNaN", "isnan"}, + {"LeakyRelu", "leaky_relu"}, + {"Log", "log"}, + {"LRN", "lrn"}, + {"Neg", "neg"}, + {"NonMaxSuppression", "nonmaxsuppression"}, + {"Reciprocal", "recip"}, + {"Relu", "relu"}, + {"Round", "round"}, + {"Sigmoid", "sigmoid"}, + {"Sign", "sign"}, + {"Sin", "sin"}, + {"Sinh", "sinh"}, + {"Sqrt", "sqrt"}, + {"Tan", "tan"}, + {"Tanh", "tanh"}, + {"Not", "not"}}; + // clang-format on + } + + bool needs_contiguous(const std::string& op_name) const + { + return contains({"flatten", "gather", "nonmaxsuppression", "scatter"}, op_name); + } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + auto op = parser.load(opd.op_name, info); + if(needs_contiguous(opd.op_name)) + { + std::transform(args.begin(), args.end(), args.begin(), [&](auto arg) { + return info.make_contiguous(arg); + }); + } + return info.add_instruction(op, args); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_greaterorequal.cpp b/src/onnx/parse_greaterorequal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..67319364bd5f4edb52dd708e409384de885f6c08 --- /dev/null +++ b/src/onnx/parse_greaterorequal.cpp @@ -0,0 +1,31 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_greaterorequal : op_parser +{ + std::vector operators() const { return {{"GreaterOrEqual"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + auto in_res = info.add_broadcastable_binary_op("less", args[0], args[1]); + if(in_res->get_shape().type() != shape::bool_type) + { + in_res = info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), + in_res); + } + return info.add_instruction(make_op("not"), in_res); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_gru.cpp b/src/onnx/parse_gru.cpp new file mode 100644 index 0000000000000000000000000000000000000000..abddc1779ef91d69ae215a44bba46a0de3cf1b54 --- /dev/null +++ b/src/onnx/parse_gru.cpp @@ -0,0 +1,151 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_gru : op_parser +{ + std::vector operators() const { return {{"GRU"}}; } + + std::vector parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + migraphx::shape input_shape = args[0]->get_shape(); + std::size_t hidden_size = args[2]->get_shape().lens()[2]; + + if(contains(info.attributes, "hidden_size")) + { + std::size_t hidden_size_att = + parser.parse_value(info.attributes.at("hidden_size")).at(); + if(hidden_size != hidden_size_att) + { + MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute"); + } + } + + // Handling of direction to be added later + std::string direction{"forward"}; + if(contains(info.attributes, "direction")) + { + direction = info.attributes.at("direction").s(); + } + + op::rnn_direction dirct = op::rnn_direction::forward; + if(direction == "bidirectional") + { + dirct = op::rnn_direction::bidirectional; + } + else if(direction == "reverse") + { + dirct = op::rnn_direction::reverse; + } + + std::vector vec_names = {"sigmoid", "tanh"}; + if(contains(info.attributes, "activations")) + { + auto names = info.attributes.at("activations").strings(); + vec_names.clear(); + vec_names.resize(names.size()); + std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) { + return to_lower(name); + }); + } + + // need 4 activation functions + if(dirct == op::rnn_direction::bidirectional) + { + // 4 activation functions are used in the bidirectional + // scenario. No spec is provided in onnx::operator. we + // use the algorithm that: if 1 actv function is provided, + // repeat 1 four times. If 2 actv functins are provided, + // assume forward and reverse use the same pair of actv + // functions. For the case of 3 actv functions provided, + // assume the 3rd one is repeated once and used by the + // reverse direction. + // This may need change later + if(vec_names.size() == 1) + { + vec_names.insert(vec_names.end(), 3, vec_names.at(0)); + } + else if(vec_names.size() == 2) + { + // repeat the activation functions + vec_names.push_back(vec_names.at(0)); + vec_names.push_back(vec_names.at(1)); + } + else if(vec_names.size() == 3) + { + vec_names.push_back(vec_names.at(2)); + } + } + else + { + if(vec_names.size() == 1) + { + vec_names.push_back(vec_names.at(0)); + } + } + + auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) { + return (map_activation_functions().count(name) == 0); + }); + if(name_it != vec_names.end()) + { + MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported"); + } + + std::vector vec_actv_funcs(vec_names.size()); + std::transform(vec_names.begin(), + vec_names.end(), + vec_actv_funcs.begin(), + [&](const auto& name) { return map_activation_functions().at(name); }); + + float clip = 0.0; + if(contains(info.attributes, "clip")) + { + clip = parser.parse_value(info.attributes.at("clip")).at(); + } + + int linear_before_reset = 0; + if(contains(info.attributes, "linear_before_reset")) + { + linear_before_reset = + parser.parse_value(info.attributes.at("linear_before_reset")).at(); + } + + // append undefined opeator to make 6 arguments + if(args.size() < 6) + { + auto ins = info.add_instruction(make_op("undefined")); + args.insert(args.end(), 6 - args.size(), ins); + } + + // first output for concatenation of hidden states + auto hidden_states = + info.add_instruction(make_op("gru", + {{"hidden_size", hidden_size}, + {"actv_func", to_value(vec_actv_funcs)}, + {"direction", dirct}, + {"clip", clip}, + {"linear_before_reset", linear_before_reset}}), + args); + + // second output for last gru output + auto last_output = info.add_instruction(make_op("rnn_last_hs_output"), hidden_states); + + return {hidden_states, last_output}; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_hardsigmoid.cpp b/src/onnx/parse_hardsigmoid.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c1df9a70ee9abe523e0f9116d1fc66f3cd6c243 --- /dev/null +++ b/src/onnx/parse_hardsigmoid.cpp @@ -0,0 +1,62 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_hardsigmoid : op_parser +{ + std::vector operators() const { return {{"HardSigmoid"}, {"HardSwish"}}; } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + float alpha = 0.2; + float beta = 0.5; + if(opd.onnx_name == "HardSwish") + { + alpha = 1.0 / 6.0; + } + else + { + if(contains(info.attributes, "alpha")) + alpha = info.attributes.at("alpha").f(); + + if(contains(info.attributes, "beta")) + beta = info.attributes.at("beta").f(); + } + + auto input_lens = args[0]->get_shape().lens(); + auto input_type = args[0]->get_shape().type(); + auto mb_alpha = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + info.add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}})); + auto mb_beta = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + info.add_literal(migraphx::literal{migraphx::shape{input_type}, {beta}})); + auto mb_zero = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0}})); + auto mb_one = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1}})); + + auto mul = info.add_instruction(migraphx::make_op("mul"), mb_alpha, args[0]); + auto add = info.add_instruction(migraphx::make_op("add"), mb_beta, mul); + auto hardsigmoid = info.add_instruction(migraphx::make_op("clip"), add, mb_zero, mb_one); + if(opd.onnx_name == "HardSwish") + return info.add_instruction(migraphx::make_op("mul"), args[0], hardsigmoid); + + return hardsigmoid; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_if.cpp b/src/onnx/parse_if.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fb10f09b18244d7481f23ed50c3923e065da225e --- /dev/null +++ b/src/onnx/parse_if.cpp @@ -0,0 +1,70 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_if : op_parser +{ + std::vector operators() const { return {{"If"}}; } + + std::vector parse(const op_desc& /*opd*/, + onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + const auto& then_graph = info.attributes.at("then_branch").g(); + const auto& else_graph = info.attributes.at("else_branch").g(); + + if(args.front()->get_shape().elements() != 1) + { + MIGRAPHX_THROW("PARSE_IF: condition input can have only one element!"); + } + + std::string then_name = info.name + "_if"; + module_ref then_mdl = parser.prog.create_module(then_name); + + std::string else_name = info.name + "_else"; + module_ref else_mdl = parser.prog.create_module(else_name); + + // parse the then sub_graph + parser.parse_graph(then_mdl, then_graph); + + // parse_the else sub_graph + parser.parse_graph(else_mdl, else_graph); + + auto then_out_shapes = then_mdl->get_output_shapes(); + auto else_out_shapes = else_mdl->get_output_shapes(); + if(not std::equal(then_out_shapes.begin(), + then_out_shapes.end(), + else_out_shapes.begin(), + else_out_shapes.end())) + { + MIGRAPHX_THROW("PARSE_IF: then and else sub_grahps must have same output shapes!"); + } + + auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl}); + auto out_s = if_ret->get_shape(); + assert(out_s.type() == shape::tuple_type); + + const auto& vec_shapes = out_s.sub_shapes(); + std::vector out_inss; + for(std::size_t i = 0; i < vec_shapes.size(); ++i) + { + auto ret = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), if_ret); + out_inss.push_back(ret); + } + + return out_inss; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_imagescalar.cpp b/src/onnx/parse_imagescalar.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e236b11770c1614c7ebbab507f8ed2ac70b1a289 --- /dev/null +++ b/src/onnx/parse_imagescalar.cpp @@ -0,0 +1,50 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_imagescalar : op_parser +{ + std::vector operators() const { return {{"ImageScaler"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + float scale = 1.0; + std::vector bias{}; + if(contains(info.attributes, "scale")) + { + scale = parser.parse_value(info.attributes.at("scale")).at(); + } + + if(contains(info.attributes, "bias")) + { + auto&& bias_floats = info.attributes["bias"].floats(); + bias = std::vector(bias_floats.begin(), bias_floats.end()); + } + auto input_shape = args.front()->get_shape(); + auto const& input_lens = input_shape.lens(); + auto input_type = input_shape.type(); + + auto scale_val = info.add_literal(literal{shape{input_type}, {scale}}); + auto bias_vals = info.add_literal(literal{shape{input_type, {bias.size()}}, bias}); + + auto scale_tensor = info.add_instruction( + migraphx::make_op("scalar", {{"scalar_bcst_dims", input_lens}}), scale_val); + auto img_scaled = + info.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor); + auto bias_bcast = info.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", input_lens}}), bias_vals); + return info.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_instancenorm.cpp b/src/onnx/parse_instancenorm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..01507bf5f32721d9d2a17341e9c856aeecb06d3d --- /dev/null +++ b/src/onnx/parse_instancenorm.cpp @@ -0,0 +1,65 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_instancenorm : op_parser +{ + std::vector operators() const { return {{"InstanceNormalization"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias + // mean = reduce_mean({D1, D2, ... Dk}, x) + // variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2) + + float epsilon = 1e-5f; + if(contains(info.attributes, "epsilon")) + { + epsilon = parser.parse_value(info.attributes.at("epsilon")).at(); + } + auto x = args[0]; + auto scale = args[1]; + auto bias = args[2]; + auto dims = x->get_shape().lens(); + auto ndims = dims.size(); + assert(ndims >= 2); + auto kdims = ndims - 2; + + std::vector axes(kdims); + std::iota(axes.begin(), axes.end(), 2); + + auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); + auto mean_bcast = + info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean); + auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast); + auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0); + auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast); + auto epsilon_literal = info.add_literal(epsilon); + auto epsilon_bcast = + info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal); + auto variance_bcast = + info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), variance); + auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast); + auto l3 = info.add_instruction(make_op("rsqrt"), l2); + auto l4 = info.add_instruction(make_op("mul"), l1, l3); + auto scale_bcast = + info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale); + ; + auto bias_bcast = + info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias); + auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast); + return info.add_instruction(make_op("add"), l5, bias_bcast); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_lessorequal.cpp b/src/onnx/parse_lessorequal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d4279f071ab03da1de2a36d14814f35a085678d4 --- /dev/null +++ b/src/onnx/parse_lessorequal.cpp @@ -0,0 +1,31 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_lessorequal : op_parser +{ + std::vector operators() const { return {{"LessOrEqual"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + auto in_res = info.add_broadcastable_binary_op("greater", args[0], args[1]); + if(in_res->get_shape().type() != shape::bool_type) + { + in_res = info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), + in_res); + } + return info.add_instruction(make_op("not"), in_res); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_loop.cpp b/src/onnx/parse_loop.cpp new file mode 100644 index 0000000000000000000000000000000000000000..446c9065549869c9c9db5a2f77981a0783104bdf --- /dev/null +++ b/src/onnx/parse_loop.cpp @@ -0,0 +1,72 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_loop : op_parser +{ + std::vector operators() const { return {{"Loop"}}; } + + std::vector parse(const op_desc& /*opd*/, + onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + // default value of the max_iter_num + int64_t max_iterations = parser.max_loop_iterations; + // iteration input is empty + if(args.at(0)->name() == "undefined") + { + shape iter_s{shape::int64_type}; + args[0] = info.add_literal(literal(iter_s, {max_iterations})); + } + else + { + auto arg_iters = args.at(0)->eval(); + if(not arg_iters.empty()) + { + max_iterations = arg_iters.at(); + } + } + + // condition input is empty + if(args.at(1)->name() == "undefined") + { + shape cond_s{shape::bool_type}; + args[1] = info.add_literal(literal(cond_s, {true})); + } + + // retrieve the subgraph + const auto& sub_graph = info.attributes.at("body").g(); + std::string mod_name = info.name + "_loop"; + module_ref sub_mod = parser.prog.create_module(mod_name); + + // parse the sub_graph + parser.parse_graph(sub_mod, sub_graph); + + auto ret = info.add_instruction( + make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod}); + auto out_s = ret->get_shape(); + assert(out_s.type() == shape::tuple_type); + + const auto& vec_shapes = out_s.sub_shapes(); + std::vector out_inss; + for(std::size_t i = 0; i < vec_shapes.size(); ++i) + { + auto r = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), ret); + out_inss.push_back(r); + } + + return out_inss; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_lpnormalization.cpp b/src/onnx/parse_lpnormalization.cpp new file mode 100644 index 0000000000000000000000000000000000000000..efb12e7d24ff6a549d3084cf31b7aa2defc71b5c --- /dev/null +++ b/src/onnx/parse_lpnormalization.cpp @@ -0,0 +1,85 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +//! Parser for LpNormalization ONNX operator. +/*! + Normalizes a tensor by the L1 or L2 norms along a given axis. + Norms that evaluate to 0 are changed to 1 to prevent division by zero. +*/ +struct parse_lpnormalization : op_parser +{ + std::vector operators() const { return {{"LpNormalization"}}; } + + instruction_ref parse(const op_desc&, + const onnx_parser&, + const onnx_parser::node_info& info, + std::vector args) const + { + int p = 2; + if(contains(info.attributes, "p")) + { + p = info.attributes.at("p").i(); + } + if(p != 1 and p != 2) + { + MIGRAPHX_THROW("LPNORMALIZATION: only L1 and L2 norm supported"); + } + auto input = args.front(); + auto input_shape = input->get_shape(); + const auto& input_lens = input_shape.lens(); + auto input_type = input_shape.type(); + std::ptrdiff_t num_axes = input_lens.size(); + std::ptrdiff_t axis = -1; + if(contains(info.attributes, "axis")) + { + axis = info.attributes.at("axis").i(); + if(axis < -num_axes or axis >= num_axes) + { + // handled in normalize_attributes but throwing here might be clearer + MIGRAPHX_THROW("LPNORMALIZATION: selected axis out of bounds"); + } + } + migraphx::instruction_ref p_val; + if(p == 1) + { + p_val = info.add_instruction(migraphx::make_op("abs"), input); + } + else + { + p_val = info.add_instruction(migraphx::make_op("mul"), input, input); + } + + // need to check for zeros from lp norm to prevent division by zero + // change them to 1 for the element-wise division + auto norms = + info.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {axis}}}), p_val); + if(p == 2) + { + norms = info.add_instruction(migraphx::make_op("sqrt"), norms); + } + // broadcast back to initial shape, negative axis option doesn't work with unidirectional + norms = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), norms); + auto zero_mb = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}})); + auto one_mb = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}})); + auto is_zero = info.add_instruction(migraphx::make_op("equal"), norms, zero_mb); + auto norms_zeros_to_one = + info.add_instruction(migraphx::make_op("where"), is_zero, one_mb, norms); + return info.add_instruction(migraphx::make_op("div"), input, norms_zeros_to_one); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_lstm.cpp b/src/onnx/parse_lstm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15086d9b99b56f033a21a5a31906691f20ad46c5 --- /dev/null +++ b/src/onnx/parse_lstm.cpp @@ -0,0 +1,210 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +void lstm_actv_functions(op::rnn_direction dirct, std::vector& actv_func_names) +{ + // need 6 activation functions for bidirectional directions + if(dirct == op::rnn_direction::bidirectional) + { + // 6 activation functions are used in the bidirectional + // scenario. No spec is provided in onnx::operator. we + // use the algorithm that: if 1 actv function is provided, + // repeat 1st six times. If 2 actv functins are provided, + // repeat 2nd once, then repeat all three once + // if 3 actv funcs are provide, repeat all three once. + // the same algorithm is used for 4, 5, and 6 actv funcions + // provided. This may need change later + switch(actv_func_names.size()) + { + case 1: + actv_func_names = {actv_func_names.at(0), + actv_func_names.at(0), + actv_func_names.at(0), + actv_func_names.at(0), + actv_func_names.at(0), + actv_func_names.at(0)}; + break; + + case 2: + // repeat the 2nd actv func once, then repeat all three another time + actv_func_names = {actv_func_names.at(0), + actv_func_names.at(1), + actv_func_names.at(1), + actv_func_names.at(0), + actv_func_names.at(1), + actv_func_names.at(1)}; + break; + + case 3: + // repeat all three actv funcs once + actv_func_names = {actv_func_names.at(0), + actv_func_names.at(1), + actv_func_names.at(2), + actv_func_names.at(0), + actv_func_names.at(1), + actv_func_names.at(2)}; + break; + + case 4: + actv_func_names = {actv_func_names.at(0), + actv_func_names.at(1), + actv_func_names.at(2), + actv_func_names.at(3), + actv_func_names.at(3), + actv_func_names.at(3)}; + break; + + case 5: + actv_func_names = {actv_func_names.at(0), + actv_func_names.at(1), + actv_func_names.at(2), + actv_func_names.at(3), + actv_func_names.at(4), + actv_func_names.at(4)}; + break; + + default: break; + } + } + else + { + switch(actv_func_names.size()) + { + case 1: + actv_func_names = {actv_func_names.at(0), actv_func_names.at(0), actv_func_names.at(0)}; + break; + + case 2: + // repeat the 2nd actv func once, so we have 3 actv funcs + actv_func_names = {actv_func_names.at(0), actv_func_names.at(1), actv_func_names.at(1)}; + break; + + default: break; + } + } +} + +struct parse_lstm : op_parser +{ + std::vector operators() const { return {{"LSTM"}}; } + + std::vector parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + migraphx::shape input_shape = args[0]->get_shape(); + std::size_t hidden_size = args[2]->get_shape().lens()[2]; + + if(contains(info.attributes, "hidden_size")) + { + std::size_t hidden_size_att = + parser.parse_value(info.attributes.at("hidden_size")).at(); + if(hidden_size != hidden_size_att) + { + MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute"); + } + } + + // Handling of direction to be added later + std::string direction{"forward"}; + if(contains(info.attributes, "direction")) + { + direction = info.attributes.at("direction").s(); + } + + op::rnn_direction dirct = op::rnn_direction::forward; + if(direction == "bidirectional") + { + dirct = op::rnn_direction::bidirectional; + } + else if(direction == "reverse") + { + dirct = op::rnn_direction::reverse; + } + else if(direction == "forward") + { + dirct = op::rnn_direction::forward; + } + else + { + MIGRAPHX_THROW("LSTM: incorrect direction attribute"); + } + + std::vector vec_names = {"sigmoid", "tanh", "tanh"}; + if(contains(info.attributes, "activations")) + { + auto names = info.attributes.at("activations").strings(); + vec_names.clear(); + vec_names.resize(names.size()); + std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) { + return to_lower(name); + }); + } + + lstm_actv_functions(dirct, vec_names); + + auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) { + return (map_activation_functions().count(name) == 0); + }); + if(name_it != vec_names.end()) + { + MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported"); + } + + std::vector vec_actv_funcs(vec_names.size()); + std::transform(vec_names.begin(), + vec_names.end(), + vec_actv_funcs.begin(), + [&](const auto& name) { return map_activation_functions().at(name); }); + + float clip = 0.0; + if(contains(info.attributes, "clip")) + { + clip = parser.parse_value(info.attributes.at("clip")).at(); + } + + int input_forget = 0; + if(contains(info.attributes, "input_forget")) + { + input_forget = parser.parse_value(info.attributes.at("input_forget")).at(); + } + + // append undefined opeator to make 6 arguments + if(args.size() < 8) + { + auto ins = info.add_instruction(make_op("undefined")); + args.insert(args.end(), 8 - args.size(), ins); + } + + // first output for concatenation of hidden states + auto hidden_states = info.add_instruction(make_op("lstm", + {{"hidden_size", hidden_size}, + {"actv_func", to_value(vec_actv_funcs)}, + {"direction", dirct}, + {"clip", clip}, + {"input_forget", input_forget}}), + args); + + auto last_output = info.add_instruction(make_op("rnn_last_hs_output"), hidden_states); + + // third output for last cell output + auto last_cell_output = + info.add_instruction(make_op("rnn_last_cell_output"), hidden_states); + + return {hidden_states, last_output, last_cell_output}; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_matmul.cpp b/src/onnx/parse_matmul.cpp new file mode 100644 index 0000000000000000000000000000000000000000..90796c4b927ab855043c74d40cdedf6c86572a99 --- /dev/null +++ b/src/onnx/parse_matmul.cpp @@ -0,0 +1,87 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_matmul : op_parser +{ + std::vector operators() const + { + return {{"MatMul", "dot"}, {"MatMulInteger", "quant_dot"}}; + } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + auto l0 = args[0]; + auto l1 = args[1]; + auto l0_lens = l0->get_shape().lens(); + auto l1_lens = l1->get_shape().lens(); + + // args[0] is a vector, prepend 1 to the shape + bool is_a_prepended = false; + if(l0_lens.size() == 1) + { + is_a_prepended = true; + l0_lens.insert(l0_lens.begin(), 1); + l0 = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), args[0]); + } + + bool is_b_appended = false; + if(l1_lens.size() == 1) + { + is_b_appended = true; + l1_lens.push_back(1); + l1 = info.add_instruction(make_op("unsqueeze", {{"axes", {1}}}), args[1]); + } + + instruction_ref bl0 = l0; + instruction_ref bl1 = l1; + if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend())) + { + auto l0_it = l0_lens.begin() + l0_lens.size() - 2; + std::vector l0_broadcasted_lens(l0_lens.begin(), l0_it); + auto l1_it = l1_lens.begin() + l1_lens.size() - 2; + std::vector l1_broadcasted_lens(l1_lens.begin(), l1_it); + auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens); + l0_broadcasted_lens = output_lens; + l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end()); + l1_broadcasted_lens = output_lens; + l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end()); + if(l0_lens != l0_broadcasted_lens) + { + bl0 = info.add_instruction( + make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), l0); + } + if(l1_lens != l1_broadcasted_lens) + { + bl1 = info.add_instruction( + make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1); + } + } + instruction_ref dot_res = info.add_instruction(make_op(opd.op_name), bl0, bl1); + int64_t num_axis = static_cast(dot_res->get_shape().lens().size()); + if(is_a_prepended) + { + dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res); + --num_axis; + } + if(is_b_appended) + { + dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 1}}}), dot_res); + } + + return dot_res; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_mean.cpp b/src/onnx/parse_mean.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02d7ecede98cce17e8660663039a6630105b9ee4 --- /dev/null +++ b/src/onnx/parse_mean.cpp @@ -0,0 +1,59 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_mean : op_parser +{ + const std::set float_types = { + shape::float_type, shape::half_type, shape::double_type}; + + std::vector operators() const { return {{"Mean"}}; } + + /// Calculates the element-wise mean of n>=1 input tensors + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + auto num_data = args.size(); + if(num_data == 1) + return args[0]; + + auto divisor = info.add_literal( + migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}}); + + if(contains(float_types, args[0]->get_shape().type())) + { + return std::accumulate(args.begin() + 1, + args.end(), + info.add_broadcastable_binary_op("div", args[0], divisor), + [&](auto mean, auto data_i) { + // Pre-divide each tensor element-wise by n to reduce risk of + // overflow during summation + auto div = + info.add_broadcastable_binary_op("div", data_i, divisor); + return info.add_broadcastable_binary_op("add", mean, div); + }); + } + else + { + // Compute sum before division for integral types + auto sum = std::accumulate( + args.begin() + 1, args.end(), args[0], [&](auto accum, auto data_i) { + return info.add_broadcastable_binary_op("add", accum, data_i); + }); + + return info.add_broadcastable_binary_op("div", sum, divisor); + } + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_multinomial.cpp b/src/onnx/parse_multinomial.cpp new file mode 100644 index 0000000000000000000000000000000000000000..29067d8212fbe157ca0db5b9ebcc2c995288c15d --- /dev/null +++ b/src/onnx/parse_multinomial.cpp @@ -0,0 +1,63 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_multinomial : op_parser +{ + std::vector operators() const { return {{"Multinomial"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + int dtype = 6; + if(contains(info.attributes, "dtype")) + dtype = info.attributes.at("dtype").i(); + shape::type_t output_type = get_type(dtype); + + size_t sample_size = 1; + if(contains(info.attributes, "sample_size")) + sample_size = info.attributes.at("sample_size").i(); + + // Subtract the per-batch maximum log-probability, making the per-batch max 0 + auto maxes = + info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]); + auto mb_maxes = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}), + maxes); + auto cdf = info.add_instruction(migraphx::make_op("sub"), args[0], mb_maxes); + // Take the element-wise exponent to get probabilities in the range (0, 1] + cdf = info.add_instruction(migraphx::make_op("exp"), cdf); + // Compute the cumulative density function + cdf = info.add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf); + + // Pre-compute random distribution + std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + if(contains(info.attributes, "seed")) + gen.seed(info.attributes.at("seed").f()); + + std::uniform_real_distribution<> dis(0.0, 1.0); + size_t batch_size = args[0]->get_shape().lens().front(); + migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}}; + + std::vector random_dist(batch_size * sample_size); + std::generate(random_dist.begin(), random_dist.end(), [&]() { return dis(gen); }); + auto dist_lit = info.add_literal(migraphx::literal{dist_shape, random_dist}); + + return info.add_instruction( + migraphx::make_op("multinomial", {{"dtype", output_type}}), cdf, dist_lit); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_nonzero.cpp b/src/onnx/parse_nonzero.cpp new file mode 100644 index 0000000000000000000000000000000000000000..488adfae739efcf734b2e84ccd6d16941b5ab5de --- /dev/null +++ b/src/onnx/parse_nonzero.cpp @@ -0,0 +1,68 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +template +static std::vector nonzero_indices(const std::vector& data) +{ + std::vector indices; + for(std::size_t i = 0; i < data.size(); ++i) + { + if(!float_equal(data[i], 0)) + indices.push_back(i); + } + + return indices; +} + +struct parse_nonzero : op_parser +{ + std::vector operators() const { return {{"NonZero"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + migraphx::argument data_arg = args.back()->eval(); + if(data_arg.empty()) + { + return info.add_instruction(make_op("nonzero"), args); + } + else + { + std::vector indices; + data_arg.visit([&](auto val) { + using val_type = std::remove_cv_t; + std::vector vec_data; + vec_data.assign(val.begin(), val.end()); + indices = nonzero_indices(vec_data); + }); + + shape in_s = args[0]->get_shape(); + shape out_s{shape::int64_type, {in_s.lens().size(), indices.size()}}; + + std::vector out_data(out_s.elements()); + for(std::size_t i = 0; i < indices.size(); ++i) + { + auto idx = in_s.multi(indices[i]); + for(std::size_t j = 0; j < in_s.lens().size(); ++j) + { + out_data[out_s.index({j, i})] = idx[j]; + } + } + + return info.add_literal(literal(out_s, out_data)); + } + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_onehot.cpp b/src/onnx/parse_onehot.cpp new file mode 100644 index 0000000000000000000000000000000000000000..174809185fc5fefdeec616b7f383bfc4cc2a35af --- /dev/null +++ b/src/onnx/parse_onehot.cpp @@ -0,0 +1,68 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_onehot : op_parser +{ + std::vector operators() const { return {{"OneHot"}}; } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& /*parser*/, + onnx_parser::node_info info, + std::vector args) const + { + migraphx::argument depth_arg = args[1]->eval(); + check_arg_empty(depth_arg, "PARSE_ONEHOT: depth - dynamic shape not supported"); + size_t depth = depth_arg.at(); + + int64_t axis = -1; + if(contains(info.attributes, "axis")) + { + axis = info.attributes.at("axis").i(); + } + + std::vector depth_input(depth * depth, 0.0f); + for(int i = 0; i < depth; i++) + { + depth_input[depth * i + i] = 1.0f; + } + + auto type = args[2]->get_shape().type(); + shape s{type, {depth, depth}}; + auto l_val = info.add_literal({s, depth_input}); + auto gather_out = info.add_instruction(make_op("gather", {{"axis", 0}}), {l_val, args[0]}); + + // Finally, we need a transpose to move the inner most dim to the axis dim + int n_rank = gather_out->get_shape().lens().size(); + int64_t tuned_axis = tune_axis(n_rank, axis, opd.op_name); + std::vector perm(n_rank - 1); + std::iota(perm.begin(), perm.end(), 0); + perm.insert(perm.begin() + tuned_axis, n_rank - 1); + auto tr_out = + info.add_instruction(make_op("transpose", {{"permutation", perm}}), gather_out); + auto lens = tr_out->get_shape().lens(); + + auto off_val = info.add_instruction( + make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]); + auto on_val = info.add_instruction( + make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]); + auto diff = info.add_instruction(make_op("sub"), on_val, off_val); + auto unsq_off_val = + info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), off_val); + auto unsq_diff_val = + info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), diff); + auto l_mul = info.add_instruction(make_op("mul"), tr_out, unsq_diff_val); + return info.add_instruction(make_op("add"), l_mul, unsq_off_val); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_pad.cpp b/src/onnx/parse_pad.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c4a07ce5353b7ce50efc6e0d7cd3ea6c42af37ff --- /dev/null +++ b/src/onnx/parse_pad.cpp @@ -0,0 +1,163 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +void calc_reflect_indices(std::vector& indices, const int64_t num_dims) +{ + int k = 0; + bool reversed = false; + // in reflect padding, if the num_pads > num_dims, + // compute the extra pad indices periodically, ex. ( 1, 2, 3, 2, 1, 0) + for(int& idx : indices) + { + if(k == num_dims - 1) + reversed = true; + if(k == 0) + reversed = false; + if(reversed) + k--; + else + k++; + idx = k; + } +} + +instruction_ref reflect_pad(const onnx_parser::node_info& info, + const std::vector& pads, + instruction_ref input) +{ + size_t num_dims = pads.size() / 2; + std::vector ldims(pads.begin(), pads.begin() + num_dims); + std::vector rdims(pads.begin() + num_dims, pads.end()); + assert(ldims.size() == rdims.size()); + + std::vector axes(num_dims); + std::iota(axes.begin(), axes.end(), int64_t{0}); + + // iterate over dimensions, starting from lowest dimension + for(int64_t i = num_dims - 1; i >= 0; i--) + { + auto axis = i; + auto lcount = ldims.at(i); + auto rcount = rdims.at(i); + if(lcount == 0 and rcount == 0) // no padding for current dim + continue; + + // calculate starts and ends for each iteration since shape may change + std::vector dims = input->get_shape().lens(); + std::vector starts(axes.size(), 0); + std::vector ends(dims.begin(), dims.end()); + std::vector slices; + + auto starts_it = starts.begin() + i; + auto ends_it = ends.begin() + i; + auto dims_it = dims.begin() + i; + + std::vector l_indices(lcount); + std::vector r_indices(rcount); + + // compute slice indices in a periodic fashion + calc_reflect_indices(l_indices, *dims_it); + calc_reflect_indices(r_indices, *dims_it); + + for(int idx : l_indices) + { + *starts_it = idx; + *ends_it = *starts_it + 1; + slices.push_back(info.add_instruction( + make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input)); + } + // when padding on the left side, the outermost pad should be at the beginning + std::reverse(slices.begin(), slices.end()); + slices.push_back(input); + for(int idx : r_indices) + { + *starts_it = *dims_it - idx - 1; + *ends_it = *starts_it + 1; + slices.push_back(info.add_instruction( + make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input)); + } + input = info.add_instruction(make_op("concat", {{"axis", axis}}), slices); + } + return input; +} + +struct parse_pad : op_parser +{ + std::vector operators() const { return {{"Pad"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + std::vector pads{}; + if(args.size() >= 2) + { + auto pad_arg = args.at(1)->eval(); + check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant"); + pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); }); + } + else if(contains(info.attributes, "pads")) + { + auto&& pad_vals = info.attributes["pads"].ints(); + pads = std::vector(pad_vals.begin(), pad_vals.end()); + } + else + { + MIGRAPHX_THROW("PARSE_PAD: pad must be available"); + } + + // check if padding is actually being done (at least one value is nonzero) + if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; })) + { + return info.add_instruction(make_op("identity"), args.front()); + } + + if(contains(info.attributes, "mode")) + { + auto mode = info.attributes.at("mode").s(); + if(mode == "reflect") + return reflect_pad(info, pads, args.front()); + if(mode != "constant") + { + MIGRAPHX_THROW( + "PARSE_PAD: migraphx currently only supports constant and reflect padding"); + } + } + + float value = 0.0f; + // third input is the value + if(args.size() == 3) + { + auto val_ins = args.at(2); + if(!val_ins->can_eval()) + { + MIGRAPHX_THROW("PARSE_PAD: input value must be constant"); + } + auto val_arg = val_ins->eval(); + if(val_arg.get_shape().elements() != 1) + { + MIGRAPHX_THROW("PARSE_PAD: value should contain only one element"); + } + value = val_arg.at(); + } + else if(contains(info.attributes, "value")) + { + value = parser.parse_value(info.attributes.at("value")).at(); + } + + return info.add_instruction(migraphx::make_op("pad", {{"pads", pads}, {"value", value}}), + args.front()); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_pooling.cpp b/src/onnx/parse_pooling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b53f4c4903c486527687798c4ddc7fdd962843ad --- /dev/null +++ b/src/onnx/parse_pooling.cpp @@ -0,0 +1,175 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_pooling : op_parser +{ + std::vector operators() const + { + return {{"AveragePool", "average"}, + {"GlobalAveragePool", "average"}, + {"GlobalMaxPool", "max"}, + {"MaxPool", "max"}, + {"LpPool", "lpnorm"}, + {"GlobalLpPool", "lpnorm"}}; + } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& /*parser*/, + onnx_parser::node_info info, + std::vector args) const + { + const std::unordered_map mode_map = { + {"max", op::pooling_mode::max}, + {"average", op::pooling_mode::average}, + {"lpnorm", op::pooling_mode::lpnorm}}; + std::string mode = opd.op_name; + if(not contains(mode_map, mode)) + { + MIGRAPHX_THROW("onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]"); + } + operation op = make_op("pooling", {{"mode", mode_map.at(mode)}}); + value values = op.to_value(); + auto l0 = args[0]; + auto in_lens = l0->get_shape().lens(); + assert(in_lens.size() > 2); + auto kdims = in_lens.size() - 2; + + if(starts_with(opd.onnx_name, "Global")) + { + values["lengths"] = std::vector(in_lens.begin() + 2, in_lens.end()); + } + + // does not support ceil_mode + if(contains(info.attributes, "ceil_mode")) + { + values["ceil_mode"] = static_cast(info.attributes.at("ceil_mode").i()); + } + + // count include padding, if count include pad is 1, we always use + // explicit pad + int count_include_pad = 0; + if(contains(info.attributes, "count_include_pad")) + { + count_include_pad = info.attributes.at("count_include_pad").i(); + } + + if(contains(info.attributes, "strides")) + { + values["stride"].clear(); + copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"])); + check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides"); + } + if(contains(info.attributes, "kernel_shape")) + { + values["lengths"].clear(); + copy(info.attributes["kernel_shape"].ints(), std::back_inserter(values["lengths"])); + check_attr_sizes( + kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths"); + } + + // lp_order attribute + if(contains(info.attributes, "p")) + { + values["lp_order"] = info.attributes.at("p").i(); + } + + // ensure pads availabe only when auto_pad is "NOT_SET" + check_padding_mode(info, "POOLING"); + + std::vector paddings; + float pad_val = ((mode == "max") ? std::numeric_limits::lowest() : 0.0f); + + if(contains(info.attributes, "pads")) + { + values["padding"].clear(); + copy(info.attributes["pads"].ints(), std::back_inserter(paddings)); + check_attr_sizes( + kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings"); + } + + if(contains(info.attributes, "auto_pad")) + { + values["padding"].clear(); + // return paddings could be empty, then setting to 0 for no padding + cal_auto_padding_size(info, + values, + values["lengths"].to_vector(), + {1, 1}, + in_lens, + paddings); + } + + if(paddings.size() != 2 * kdims) + { + paddings.resize(kdims * 2); + std::fill_n(paddings.begin(), 2 * kdims, 0); + } + + if(values["padding"].size() != kdims) + { + values["padding"].resize(kdims); + std::fill_n(values["padding"].begin(), kdims, 0); + } + + if(values["stride"].size() != kdims) + { + values["stride"].resize(kdims); + std::fill_n(values["stride"].begin(), kdims, 1); + } + // used to calculate the supposed output shape + std::vector orig_padding = paddings; + + std::vector slice_start; + std::vector slice_end; + tune_padding_size(values, paddings, count_include_pad, slice_start); + + if(!slice_start.empty()) + { + // calculate expected output shape + orig_padding.insert(orig_padding.begin() + kdims, 2, 0); + orig_padding.insert(orig_padding.begin(), 2, 0); + op::pad pad{orig_padding, 0.0f}; + shape padded_shape = pad.compute_shape({l0->get_shape()}); + auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens(); + + // compute slice_end information + slice_end.resize(slice_start.size()); + std::transform(out_lens.begin() + 2, + out_lens.end(), + slice_start.begin(), + slice_end.begin(), + [](auto i, auto j) { return i + j; }); + } + values["padding"] = std::vector(paddings.begin(), paddings.end()); + + check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val); + op.from_value(values); + + auto l1 = info.add_instruction(op, l0); + if(!slice_start.empty()) + { + std::vector axes(kdims); + std::iota(axes.begin(), axes.end(), 2); + l1 = info.add_instruction( + make_op("slice", {{"axes", axes}, {"starts", slice_start}, {"ends", slice_end}}), + l1); + } + + return l1; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_pow.cpp b/src/onnx/parse_pow.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3b919de467a372457863452b8a8586caa05fbd73 --- /dev/null +++ b/src/onnx/parse_pow.cpp @@ -0,0 +1,71 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +auto compute_type(shape::type_t t1, shape::type_t t2) +{ + const static std::unordered_map op_order = {{shape::int8_type, 1}, + {shape::uint8_type, 2}, + {shape::int16_type, 3}, + {shape::uint16_type, 4}, + {shape::int32_type, 5}, + {shape::uint32_type, 6}, + {shape::int64_type, 7}, + {shape::uint64_type, 8}, + {shape::half_type, 9}, + {shape::float_type, 10}, + {shape::double_type, 11}}; + + int it1 = t1; + int it2 = t2; + if(!contains(op_order, it1) or !contains(op_order, it2)) + { + MIGRAPHX_THROW("PARSE_POW: Input data type not supported!"); + } + + return ((op_order.at(it1) >= op_order.at(it2)) ? t1 : t2); +} + +struct parse_pow : op_parser +{ + std::vector operators() const { return {{"Pow"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + auto type_base = args[0]->get_shape().type(); + auto type_exponent = args[1]->get_shape().type(); + + auto type_compute = compute_type(type_base, type_exponent); + if(type_compute != type_base) + { + args[0] = + info.add_instruction(make_op("convert", {{"target_type", type_compute}}), args[0]); + } + + if(type_compute != type_exponent) + { + args[1] = + info.add_instruction(make_op("convert", {{"target_type", type_compute}}), args[1]); + } + + auto ret = info.add_broadcastable_binary_op("pow", args[0], args[1]); + if(type_compute != type_base) + { + ret = info.add_instruction(make_op("convert", {{"target_type", type_base}}), ret); + } + + return ret; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_prefix_scan.cpp b/src/onnx/parse_prefix_scan.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c4a8cd9f121829da772eaa3cfaa8eb962f8c4387 --- /dev/null +++ b/src/onnx/parse_prefix_scan.cpp @@ -0,0 +1,55 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +instruction_ref parse_prefix_scan_oper(const std::string& op_name, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) +{ + migraphx::argument in = args[1]->eval(); + check_arg_empty(in, "PARSE_PREFIX_SCAN: axis - dynamic shape not supported"); + std::vector axis_in; + in.visit([&](auto input) { axis_in.assign(input.begin(), input.end()); }); + int64_t axis = axis_in[0]; + + bool exclusive = false; + bool reverse = false; + + if(contains(info.attributes, "exclusive")) + { + exclusive = parser.parse_value(info.attributes.at("exclusive")).at(); + } + + if(contains(info.attributes, "reverse")) + { + reverse = parser.parse_value(info.attributes.at("reverse")).at(); + } + + return info.add_instruction( + make_op(op_name, {{"axis", axis}, {"exclusive", exclusive}, {"reverse", reverse}}), + args[0]); +} + +struct parse_prefix_scan_op : op_parser +{ + std::vector operators() const { return {{"CumSum", "prefix_scan_sum"}}; } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + return parse_prefix_scan_oper(opd.op_name, parser, std::move(info), std::move(args)); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_quantizelinear.cpp b/src/onnx/parse_quantizelinear.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3ee1bcdb425e04cafcef1dd619b95b7489e9eab4 --- /dev/null +++ b/src/onnx/parse_quantizelinear.cpp @@ -0,0 +1,65 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_quantizelinear : op_parser +{ + std::vector operators() const { return {{"QuantizeLinear"}}; } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + const std::vector& args) const + { + int axis = 1; + if(contains(info.attributes, "axis")) + axis = info.attributes.at("axis").i(); + + auto input_lens = args[0]->get_shape().lens(); + auto n_dim = input_lens.size(); + + instruction_ref y_scale; + if(args[1]->get_shape().elements() != 1) + { + auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); + y_scale = info.add_instruction( + make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]); + } + else + { + y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}), + args[1]); + } + + if(args.size() == 3) + { + auto y_zero_point = args[2]; + if(y_zero_point->get_shape().elements() != 1) + { + auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); + y_zero_point = info.add_instruction( + make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), + y_zero_point); + } + else + { + y_zero_point = info.add_instruction( + make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point); + } + + return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point); + } + + return info.add_instruction(make_op("quantizelinear"), args[0], y_scale); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_randomnormal_ops.cpp b/src/onnx/parse_randomnormal_ops.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c0b0b6498b5250bac1ddd63621715d88b52dd51 --- /dev/null +++ b/src/onnx/parse_randomnormal_ops.cpp @@ -0,0 +1,87 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_randomnormal_ops : op_parser +{ + const std::set valid_types = { + shape::float_type, shape::half_type, shape::double_type}; + + std::vector operators() const { return {{"RandomNormal"}, {"RandomNormalLike"}}; } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + int dtype = 1; + bool use_dtype = false; + if(contains(info.attributes, "dtype")) + { + dtype = info.attributes.at("dtype").i(); + use_dtype = true; + } + shape::type_t out_type = get_type(dtype); + if(not contains(valid_types, out_type)) + MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) + + ". Valid types are 1 (float), 10 (half), and 11 (double)."); + + float mean = 0.0; + if(contains(info.attributes, "mean")) + mean = info.attributes.at("mean").f(); + + float scale = 1.0; + if(contains(info.attributes, "scale")) + scale = info.attributes.at("scale").f(); + + shape out_shape; + if(contains(info.attributes, "shape")) + { + // RandomNormal: + // output type and shape must come from attributes + std::vector out_lens; + literal ls = parser.parse_value(info.attributes.at("shape")); + ls.visit([&](auto s) { out_lens.assign(s.begin(), s.end()); }); + out_shape = shape{out_type, out_lens}; + } + else if(args.size() == 1) + { + // RandomNormalLike: + // output type and shape are the same as the input's by default + // dtype is used instead when attribute is set + if(not contains(valid_types, args[0]->get_shape().type())) + MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + + std::to_string(args[0]->get_shape().type()) + + ". Valid types are float, half, and double."); + out_shape = + use_dtype ? shape{out_type, args[0]->get_shape().lens()} : args[0]->get_shape(); + } + else + { + MIGRAPHX_THROW(opd.op_name + + ": cannot deduce shape without shape attribute or argument."); + } + + std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + if(contains(info.attributes, "seed")) + gen.seed(info.attributes.at("seed").f()); + + std::normal_distribution<> d(mean, scale); + std::vector rand_vals(out_shape.elements()); + std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); }); + + return info.add_literal(literal{out_shape, rand_vals}); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_randomuniform_ops.cpp b/src/onnx/parse_randomuniform_ops.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f37c8e4de7f6f3bc09ea96c2bbdc287e82937b41 --- /dev/null +++ b/src/onnx/parse_randomuniform_ops.cpp @@ -0,0 +1,87 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_randomuniform_ops : op_parser +{ + const std::set valid_types = { + shape::float_type, shape::half_type, shape::double_type}; + + std::vector operators() const { return {{"RandomUniform"}, {"RandomUniformLike"}}; } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + int dtype = 1; + bool use_dtype = false; + if(contains(info.attributes, "dtype")) + { + dtype = info.attributes.at("dtype").i(); + use_dtype = true; + } + shape::type_t out_type = get_type(dtype); + if(not contains(valid_types, out_type)) + MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) + + ". Valid types are 1 (float), 10 (half), and 11 (double)."); + + float high = 1.0; + if(contains(info.attributes, "high")) + high = info.attributes.at("high").f(); + + float low = 0.0; + if(contains(info.attributes, "low")) + low = info.attributes.at("low").f(); + + shape out_shape; + if(contains(info.attributes, "shape")) + { + // RandomUniform: + // output type and shape must come from attributes + std::vector out_lens; + literal ls = parser.parse_value(info.attributes.at("shape")); + ls.visit([&](auto s) { out_lens.assign(s.begin(), s.end()); }); + out_shape = shape{out_type, out_lens}; + } + else if(args.size() == 1) + { + // RandomUniformLike: + // output type and shape are the same as the input by default + // dtype is used instead when attribute is set + if(not contains(valid_types, args[0]->get_shape().type())) + MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + + std::to_string(args[0]->get_shape().type()) + + ". Valid types are float, half, and double."); + out_shape = + use_dtype ? shape{out_type, args[0]->get_shape().lens()} : args[0]->get_shape(); + } + else + { + MIGRAPHX_THROW(opd.op_name + + ": cannot deduce shape without shape attribute or argument."); + } + + std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + if(contains(info.attributes, "seed")) + gen.seed(info.attributes.at("seed").f()); + + std::uniform_real_distribution<> d(high, low); + std::vector rand_vals(out_shape.elements()); + std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); }); + + return info.add_literal(literal{out_shape, rand_vals}); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_range.cpp b/src/onnx/parse_range.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fd6fb1787ba46249e99ba4b10f43bf8c2a7ce1b9 --- /dev/null +++ b/src/onnx/parse_range.cpp @@ -0,0 +1,60 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_range : op_parser +{ + std::vector operators() const { return {{"Range"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + onnx_parser::node_info info, + std::vector args) const + { + auto start_arg = args[0]->eval(); + check_arg_empty(start_arg, "PARSE_RANGE: start arg dynamic shape is not supported"); + auto limit_arg = args[1]->eval(); + check_arg_empty(limit_arg, "PARSE_RANGE: limit arg dynamic shape is not supported"); + auto delta_arg = args[2]->eval(); + check_arg_empty(delta_arg, "PARSE_RANGE: delta arg dynamic shape is not supported"); + + assert(args[0]->get_shape().elements() == 1 and args[1]->get_shape().elements() == 1 and + args[2]->get_shape().elements() == 1); + + instruction_ref l0; + + visit_all(start_arg, limit_arg, delta_arg)([&](auto start, auto limit, auto delta) { + auto start_val = start.front(); + auto limit_val = limit.front(); + auto delta_val = delta.front(); + + size_t num_elements = static_cast( + ceil(static_cast(limit_val - start_val) / static_cast(delta_val))); + + assert(num_elements > 0); + + using type = decltype(start_val); + + std::vector range_vals(num_elements); + + std::generate(range_vals.begin(), range_vals.end(), [&]() { + auto result = start_val; + start_val += delta_val; + return result; + }); + + l0 = info.add_literal({shape{args[0]->get_shape().type(), {num_elements}}, range_vals}); + }); + return l0; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_reduce_op.cpp b/src/onnx/parse_reduce_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d1986c965172b10255f77af388a61a795ad3ae6 --- /dev/null +++ b/src/onnx/parse_reduce_op.cpp @@ -0,0 +1,165 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +instruction_ref parse_reduce_oper(const std::string& op_name, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) +{ + // default to reduce over all dimensions + std::vector axes; + if(args.size() == 2) + { + auto arg_axes = args.at(1)->eval(); + check_arg_empty(arg_axes, "PARSE_" + op_name + ": cannot handle variable axes!"); + axes.clear(); + arg_axes.visit([&](auto s) { axes.assign(s.begin(), s.end()); }); + } + else if(contains(info.attributes, "axes")) + { + axes.clear(); + auto&& attr_axes = info.attributes["axes"].ints(); + axes = std::vector(attr_axes.begin(), attr_axes.end()); + } + + bool noop_with_empty_axes = false; + if(contains(info.attributes, "noop_with_empty_axes")) + { + noop_with_empty_axes = static_cast( + parser.parse_value(info.attributes.at("noop_with_empty_axes")).at()); + } + + // empty axes behavior + if(axes.empty()) + { + if(noop_with_empty_axes) + { + return args.at(0); + } + else + { + std::size_t n_dim = args.front()->get_shape().lens().size(); + axes.resize(n_dim); + std::iota(axes.begin(), axes.end(), 0); + } + } + + int keep_dims = 1; + if(contains(info.attributes, "keepdims")) + { + keep_dims = parser.parse_value(info.attributes.at("keepdims")).at(); + } + + if(keep_dims == 1) + { + return info.add_instruction(make_op(op_name, {{"axes", axes}}), args.front()); + } + else + { + auto ins = info.add_instruction(make_op(op_name, {{"axes", axes}}), args.front()); + return info.add_instruction(make_op("squeeze", {{"axes", axes}}), ins); + } +} + +struct parse_reduce_op : op_parser +{ + std::vector operators() const + { + return {{"ReduceMax", "reduce_max"}, + {"ReduceMean", "reduce_mean"}, + {"ReduceMin", "reduce_min"}, + {"ReduceProd", "reduce_prod"}, + {"ReduceSum", "reduce_sum"}}; + } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + return parse_reduce_oper(opd.op_name, parser, std::move(info), std::move(args)); + } +}; + +struct parse_reduce_l1 : op_parser +{ + std::vector operators() const { return {{"ReduceL1"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + auto abs_ins = info.add_instruction(make_op("abs"), args[0]); + return parse_reduce_oper("reduce_sum", parser, std::move(info), {abs_ins}); + } +}; + +struct parse_reduce_l2 : op_parser +{ + std::vector operators() const { return {{"ReduceL2"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + auto square_ins = info.add_instruction(make_op("mul"), args[0], args[0]); + auto sum_ins = parse_reduce_oper("reduce_sum", parser, info, {square_ins}); + return info.add_instruction(make_op("sqrt"), sum_ins); + } +}; + +struct parse_reduce_log_sum : op_parser +{ + std::vector operators() const { return {{"ReduceLogSum"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + auto sum_ins = parse_reduce_oper("reduce_sum", parser, info, std::move(args)); + return info.add_instruction(make_op("log"), sum_ins); + } +}; + +struct parse_reduce_log_sum_exp : op_parser +{ + std::vector operators() const { return {{"ReduceLogSumExp"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + auto exp_ins = info.add_instruction(make_op("exp"), args[0]); + auto sum_ins = parse_reduce_oper("reduce_sum", parser, info, {exp_ins}); + return info.add_instruction(make_op("log"), sum_ins); + } +}; + +struct parse_reduce_sum_square : op_parser +{ + std::vector operators() const { return {{"ReduceSumSquare"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + auto square_ins = info.add_instruction(make_op("mul"), args[0], args[0]); + return parse_reduce_oper("reduce_sum", parser, std::move(info), {square_ins}); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_reshape.cpp b/src/onnx/parse_reshape.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e366c07952c01f9032dddd410f2bb9dd42ad709a --- /dev/null +++ b/src/onnx/parse_reshape.cpp @@ -0,0 +1,40 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_reshape : op_parser +{ + std::vector operators() const { return {{"Reshape"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + std::vector dims; + if(args.size() == 1) + { + literal s = parser.parse_value(info.attributes.at("shape")); + s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); + } + if(args.size() == 2) + { + auto s = args[1]->eval(); + check_arg_empty(s, "Reshape: dynamic shape is not supported"); + s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); + } + + return info.add_instruction(make_op("reshape", {{"dims", dims}}), + info.make_contiguous(args[0])); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_resize.cpp b/src/onnx/parse_resize.cpp new file mode 100644 index 0000000000000000000000000000000000000000..94f96e6024b461812a1d48d61a3d815ca7cc699c --- /dev/null +++ b/src/onnx/parse_resize.cpp @@ -0,0 +1,363 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +const auto& get_nearest_op(const std::string& mode) +{ + using nearest_op = std::function; + static std::unordered_map const nearest_ops = { + {"round_prefer_floor", + [=](std::size_t d_in, double val) { + val = std::max(0.0, std::min(d_in - 1.0, val)); + return static_cast(std::ceil((val - 0.5))); + }}, + {"round_prefer_ceil", + [=](std::size_t d_in, double val) { + val = std::max(0.0, std::min(d_in - 1.0, val)); + return static_cast(std::round((val))); + }}, + {"floor", + [=](std::size_t d_in, double val) { + val = std::max(0.0, std::min(d_in - 1.0, val)); + return static_cast(std::floor((val))); + }}, + {"ceil", [=](std::size_t d_in, double val) { + val = std::max(0.0, std::min(d_in - 1.0, val)); + return static_cast(std::ceil((val))); + }}}; + + if(!contains(nearest_ops, mode)) + { + MIGRAPHX_THROW("PARSE_RESIZE: nearest_mode " + mode + " not supported!"); + } + + return nearest_ops.at(mode); +} + +const auto& get_original_idx_op(const std::string& mode) +{ + using original_idx_op = std::function; + static std::unordered_map const idx_ops = { + {"half_pixel", + [=](std::size_t, std::size_t, std::size_t idx, double scale) { + return (idx + 0.5) / scale - 0.5; + }}, + {"pytorch_half_pixel", + [=](std::size_t, std::size_t l_out, std::size_t idx, double scale) { + return l_out > 1 ? (idx + 0.5) / scale - 0.5 : 0.0; + }}, + {"align_corners", + [=](std::size_t l_in, std::size_t l_out, std::size_t idx, double) { + return (l_out == 1) ? 0.0 : (1.0 * idx * (l_in - 1.0) / (l_out - 1.0)); + }}, + {"asymmetric", + [=](std::size_t, std::size_t, std::size_t idx, double scale) { return idx / scale; }}, + {"tf_half_pixel_for_nn", [=](std::size_t, std::size_t, std::size_t idx, double scale) { + return (idx + 0.5) / scale; + }}}; + + if(!contains(idx_ops, mode)) + { + MIGRAPHX_THROW("PARSE_RESIZE: coordinate_transformation_mode " + mode + " not supported!"); + } + + return idx_ops.at(mode); +} + +static std::vector +calc_neighbor_points(const std::vector>>& vvv_ind, + int i_dim, + const std::vector>& vec_dims, + const shape& in_s) +{ + if(i_dim == vvv_ind.size()) + { + std::vector vec_ind; + vec_ind.resize(vec_dims.size()); + std::transform(vec_dims.begin(), vec_dims.end(), vec_ind.begin(), [&](auto idx) { + return static_cast(in_s.index(idx)); + }); + + return vec_ind; + } + + const auto& vv_ind = vvv_ind[i_dim]; + const auto& vv_lo = vv_ind.at(0); + std::vector> vec_dims1; + for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size()) + { + std::transform(vv_lo.begin(), + vv_lo.end(), + vec_dims.begin() + start, + std::back_inserter(vec_dims1), + [](auto i, auto dim) { + dim.push_back(i); + return dim; + }); + } + + const auto& vv_hi = vv_ind.at(1); + for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size()) + { + std::transform(vv_hi.begin(), + vv_hi.end(), + vec_dims.begin() + start, + std::back_inserter(vec_dims1), + [](auto i, auto dim) { + dim.push_back(i); + return dim; + }); + } + + return calc_neighbor_points(vvv_ind, i_dim + 1, vec_dims1, in_s); +} + +static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr) +{ + std::string coord_trans_mode = "half_pixel"; + if(contains(attr, "coordinate_transformation_mode")) + { + coord_trans_mode = attr.at("coordinate_transformation_mode").s(); + // does not support transformation mode "tf_crop_and_resize" + if(coord_trans_mode == "tf_crop_and_resize") + { + MIGRAPHX_THROW("PARSE_RESIZE: \"tf_crop_and_resize\" mode is not supported!"); + } + } + + return coord_trans_mode; +} + +static std::string get_mode(const onnx_parser::attribute_map& attr) +{ + std::string mode = "nearest"; + if(contains(attr, "mode")) + { + mode = attr.at("mode").s(); + if(mode != "nearest" and mode != "linear") + { + MIGRAPHX_THROW("PARSE_RESIZE: only nearest and linear modes are supported!"); + } + } + + return mode; +} + +static std::string get_nearest_mode(const onnx_parser::attribute_map& attr) +{ + std::string nearest_mode = "round_prefer_floor"; + if(contains(attr, "nearest_mode")) + { + nearest_mode = attr.at("nearest_mode").s(); + } + + return nearest_mode; +} + +struct parse_resize : op_parser +{ + std::vector operators() const { return {{"Resize"}, {"Upsample"}}; } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& /*parser*/, + onnx_parser::node_info info, + std::vector args) const + { + // coord transform mode + std::string coord_trans_mode = get_coord_trans_mode(info.attributes); + + // mode: only nearest and linear modes are supported for now + std::string mode = get_mode(info.attributes); + + // nearest mode + std::string nearest_mode = get_nearest_mode(info.attributes); + + // check exclude_outside, only support 0 + if(contains(info.attributes, "exclude_outside") and + info.attributes.at("exclude_outside").i() == 1) + { + MIGRAPHX_THROW("PARSE_" + opd.op_name + ": exclude_outside 1 is not supported!"); + } + + // input data shape info + auto in_s = args[0]->get_shape(); + auto in_lens = in_s.lens(); + + // output shape is explicitly specified + std::vector out_lens(in_lens.size()); + + // scale + std::vector vec_scale; + + for(const auto& arg : args) + { + if(arg->name() == "undefined" or arg == args.front()) + { + continue; + } + + // skipped empty input + auto lens = arg->get_shape().lens(); + if(lens.empty()) + { + continue; + } + + auto type = arg->get_shape().type(); + // output size + if(type == shape::int64_type) + { + auto arg_out_s = arg->eval(); + check_arg_empty(arg_out_s, + "PARSE_" + opd.op_name + ": dynamic output size is not supported!"); + arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); }); + + if(out_lens.size() != in_lens.size()) + { + MIGRAPHX_THROW("PARSE_" + opd.op_name + + ": specified output size does not match input size"); + } + + // compute the scale + vec_scale.resize(in_lens.size()); + std::transform(in_lens.begin(), + in_lens.end(), + out_lens.begin(), + vec_scale.begin(), + [](auto iss, auto oss) { return 1.0 * oss / iss; }); + } + else + { + + // scale input + if(lens[0] == in_lens.size()) + { + auto arg_scale = arg->eval(); + check_arg_empty(arg_scale, + "PARSE_" + opd.op_name + + ": dynamic input scale is not supported!"); + + arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); }); + if(in_lens.size() != vec_scale.size()) + { + MIGRAPHX_THROW("PARSE_" + opd.op_name + + ": ranks of input and scale are different!"); + } + + std::transform(in_lens.begin(), + in_lens.end(), + vec_scale.begin(), + out_lens.begin(), + [&](auto idx, auto scale) { + return static_cast(idx * scale); + }); + } + } + } + + shape out_s{in_s.type(), out_lens}; + std::size_t out_elements = out_s.elements(); + auto idx_op = get_original_idx_op(coord_trans_mode); + + // reshape input to one-dimension + std::vector rsp_lens = {static_cast(in_s.elements())}; + args[0] = info.make_contiguous(args[0]); + auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]); + + if(mode == "nearest") + { + std::vector ind(out_elements); + + // map out_idx to in_idx + auto nearest_op = get_nearest_op(nearest_mode); + shape_for_each(out_s, [&](auto idx) { + auto in_idx = idx; + for(auto ii = 0; ii < in_lens.size(); ++ii) + { + auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]); + in_idx[ii] = nearest_op(in_lens[ii], idx_val); + } + + ind[out_s.index(idx)] = static_cast(in_s.index(in_idx)); + }); + + shape ind_s{shape::int32_type, out_lens}; + auto ins_ind = info.add_literal(literal(ind_s, ind)); + return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind); + } + // linear mode + else + { + auto nearest_floor = get_nearest_op("floor"); + auto nearest_ceil = get_nearest_op("ceil"); + + // get the number of dimensions + std::size_t n_dim = out_lens.size(); + std::vector> vv_ind(2, std::vector(out_elements)); + std::vector>> vvv_ind(n_dim, vv_ind); + std::vector> delta(n_dim, std::vector(out_elements)); + + shape_for_each(out_s, [&](auto idx) { + auto in_idx = idx; + auto out_idx = out_s.index(idx); + for(auto ii = 0; ii < in_lens.size(); ++ii) + { + auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]); + vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[ii], idx_val); + vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[ii], idx_val); + delta[ii][out_idx] = idx_val - vvv_ind[ii][0][out_idx]; + } + }); + + std::vector> vec_dims(out_elements); + auto ind = calc_neighbor_points(vvv_ind, 0, vec_dims, in_s); + auto ind_lens = out_lens; + ind_lens[0] *= (std::size_t{1} << n_dim); + shape ind_s{shape::int32_type, ind_lens}; + auto ins_ind = info.add_literal(literal(ind_s, ind)); + auto data = info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind); + + auto dim_lens = out_lens; + dim_lens[0] *= (std::size_t{1} << (n_dim - 1)); + for(std::size_t i = 0; i < n_dim; ++i) + { + shape dim_s{shape::float_type, dim_lens}; + const auto& dim_delta = delta[n_dim - i - 1]; + std::vector delta_data; + for(std::size_t j = 0; j < dim_lens[0] / out_lens[0]; ++j) + { + delta_data.insert(delta_data.begin(), dim_delta.begin(), dim_delta.end()); + } + auto ins_delta = info.add_literal(dim_s, delta_data); + + // slice the data + int64_t slc_stride = dim_lens[0]; + auto low = info.add_instruction( + make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {slc_stride}}}), + data); + auto hi = info.add_instruction( + make_op("slice", + {{"axes", {0}}, {"starts", {slc_stride}}, {"ends", {2 * slc_stride}}}), + data); + auto diff = info.add_instruction(make_op("sub"), hi, low); + auto ddf = info.add_instruction(make_op("mul"), diff, ins_delta); + data = info.add_instruction(make_op("add"), ddf, low); + dim_lens[0] /= 2; + } + + return data; + } + } +}; + +} // namespace onnx + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_reversesequence.cpp b/src/onnx/parse_reversesequence.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f1e9fced7c26b5a2c3dad4555ae483b96c37aff4 --- /dev/null +++ b/src/onnx/parse_reversesequence.cpp @@ -0,0 +1,125 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +//! Parser for ReverseSequence ONNX operator. +/*! + Reverses the data along the time axis for the batches along the batch axis. + The sequence lengths can be given to reverse up to the given length for each batch, keeping the + rest of the sequence in the original order. Variable sequence_lens is not supported in this + version of MIGraphX. You can pass the sequence_lens either as a constant node or an attribute. The + batch axis and time axis must be [0, 1] and not the same. +*/ +struct parse_reversesequence : op_parser +{ + std::vector operators() const { return {{"ReverseSequence"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + int batch_axis = 1; + if(contains(info.attributes, "batch_axis")) + { + batch_axis = info.attributes.at("batch_axis").i(); + } + if(batch_axis != 0 and batch_axis != 1) + { + MIGRAPHX_THROW("REVERSESEQUENCE: batch axis not 0 or 1"); + } + + int time_axis = 0; + if(contains(info.attributes, "time_axis")) + { + time_axis = info.attributes.at("time_axis").i(); + } + if(time_axis != 0 and time_axis != 1) + { + MIGRAPHX_THROW("REVERSESEQUENCE: time axis not 0 or 1"); + } + + if(time_axis == batch_axis) + { + MIGRAPHX_THROW("REVERSESEQUENCE: time axis and batch axis are the same"); + } + + auto input = args[0]; + auto input_lens = input->get_shape().lens(); + if(input_lens.size() < 2) + { + MIGRAPHX_THROW("REVERSESEQUENCE: input tensor must have rank >= 2"); + } + + std::vector sequence_lens; + if(args.size() == 2) + { + migraphx::argument seq_lens_arg = args.back()->eval(); + check_arg_empty(seq_lens_arg, "REVERSESEQUENCE: cannot handle variable sequence_lens"); + seq_lens_arg.visit([&](auto s) { sequence_lens.assign(s.begin(), s.end()); }); + } + else if(contains(info.attributes, "sequence_lens")) + { + literal s = parser.parse_value(info.attributes.at("sequence_lens")); + s.visit([&](auto v) { sequence_lens.assign(v.begin(), v.end()); }); + } + auto batch_size = input_lens[batch_axis]; + auto time_size = input_lens[time_axis]; + + // this condition may still work if sequence_len's shape was incorrect + if(sequence_lens.size() != batch_size) + { + MIGRAPHX_THROW("REVERSESEQUENCE: sequence_lens has incorrect shape"); + } + + instruction_ref ret; + + auto add_slice = [&info, &input, batch_axis, time_axis](int b, int t_start, int t_end) { + return info.add_instruction(make_op("slice", + {{"axes", {batch_axis, time_axis}}, + {"starts", {b, t_start}}, + {"ends", {b + 1, t_end}}}), + input); + }; + + for(int b = 0; b < batch_size; ++b) + { + instruction_ref s0; + if(sequence_lens[b] > 1) + { + s0 = add_slice(b, 0, sequence_lens[b]); + s0 = info.add_instruction(make_op("reverse", {{"axes", {time_axis}}}), s0); + + // if reversed less than whole batch, concat rest of batch + if(sequence_lens[b] < time_size) + { + auto s1 = add_slice(b, sequence_lens[b], time_size); + s0 = info.add_instruction(make_op("concat", {{"axis", time_axis}}), s0, s1); + } + } + else + { // cases where nothing changes + s0 = add_slice(b, 0, time_size); + } + if(b == 0) + { + ret = s0; + } + else + { + ret = info.add_instruction(make_op("concat", {{"axis", batch_axis}}), ret, s0); + } + } + return ret; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_rnn.cpp b/src/onnx/parse_rnn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..efdfa46fde61c130451f8f0e72ad27edb59c2bcb --- /dev/null +++ b/src/onnx/parse_rnn.cpp @@ -0,0 +1,121 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_rnn : op_parser +{ + std::vector operators() const { return {{"RNN"}}; } + + std::vector parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + migraphx::shape input_shape = args[0]->get_shape(); + std::size_t hidden_size = args[1]->get_shape().lens()[1]; + + if(contains(info.attributes, "hidden_size")) + { + std::size_t hidden_size_att = + parser.parse_value(info.attributes.at("hidden_size")).at(); + if(hidden_size != hidden_size_att) + { + MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute"); + } + } + + // Handling of direction to be added later + std::string direction{"forward"}; + if(contains(info.attributes, "direction")) + { + direction = info.attributes.at("direction").s(); + } + + op::rnn_direction dirct = op::rnn_direction::forward; + if(direction == "bidirectional") + { + dirct = op::rnn_direction::bidirectional; + } + else if(direction == "reverse") + { + dirct = op::rnn_direction::reverse; + } + + std::vector vec_names{"tanh"}; + if(contains(info.attributes, "activations")) + { + auto names = info.attributes.at("activations").strings(); + vec_names.clear(); + vec_names.resize(names.size()); + std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) { + return to_lower(name); + }); + } + + auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) { + return (map_activation_functions().count(name) == 0); + }); + if(name_it != vec_names.end()) + { + MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported"); + } + + // bidirectional case should have two activation functions. + // one is for forward, and the other is for reverse. + // if only one actv function is provided, we use it in both + // forward and reverse direction + if(dirct == op::rnn_direction::bidirectional) + { + if(vec_names.size() == 1) + { + vec_names.push_back(vec_names.at(0)); + } + } + + std::vector vec_actv_funcs(vec_names.size()); + std::transform(vec_names.begin(), + vec_names.end(), + vec_actv_funcs.begin(), + [&](const auto& fn) { return map_activation_functions().at(fn); }); + + // To be added later + float clip = 0.0; + if(contains(info.attributes, "clip")) + { + clip = parser.parse_value(info.attributes.at("clip")).at(); + } + + // if the number of arguments is less than 6, append + // undefined operator to have 6 arguments + if(args.size() < 6) + { + auto ins = info.add_instruction(make_op("undefined")); + args.insert(args.end(), (6 - args.size()), ins); + } + + // first output for the concatenation of hidden states + auto hidden_states = info.add_instruction(make_op("rnn", + {{"hidden_size", hidden_size}, + {"actv_func", to_value(vec_actv_funcs)}, + {"direction", dirct}, + {"clip", clip}}), + args); + + // second output for the last hidden state + auto last_output = info.add_instruction(make_op("rnn_last_hs_output"), hidden_states); + + return {hidden_states, last_output}; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_roialign.cpp b/src/onnx/parse_roialign.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c4ec0b9c78817efa9d398a37c1e01cff5f178949 --- /dev/null +++ b/src/onnx/parse_roialign.cpp @@ -0,0 +1,78 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_roialign : op_parser +{ + std::vector operators() const { return {{"RoiAlign"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + onnx_parser::node_info info, + const std::vector& args) const + { + std::string coord_trans_mode = "half_pixel"; + if(contains(info.attributes, "coordinate_transformation_mode")) + { + coord_trans_mode = info.attributes.at("coordinate_transformation_mode").s(); + } + if(not contains({"half_pixel", "output_half_pixel"}, coord_trans_mode)) + { + MIGRAPHX_THROW("coordinate_transformation_mode \"" + coord_trans_mode + + "\": invalid value!"); + } + + migraphx::op::pooling_mode rmode(migraphx::op::pooling_mode::average); + if(contains(info.attributes, "mode")) + { + // read mode; default is "avg" + if(info.attributes.at("mode").s() == "max") + { + rmode = migraphx::op::pooling_mode::max; + } + } + + int64_t output_height = 1; + if(contains(info.attributes, "output_height")) + { + output_height = info.attributes.at("output_height").i(); + } + + int64_t output_width = 1; + if(contains(info.attributes, "output_width")) + { + output_width = info.attributes.at("output_width").i(); + } + + int64_t sampling_ratio = 0; + if(contains(info.attributes, "sampling_ratio")) + { + sampling_ratio = info.attributes.at("sampling_ratio").i(); + } + + float spatial_scale = 1.0f; + if(contains(info.attributes, "spatial_scale")) + { + spatial_scale = info.attributes.at("spatial_scale").f(); + } + return info.add_instruction(make_op("roialign", + {{"coordinate_transformation_mode", coord_trans_mode}, + {"mode", rmode}, + {"output_height", output_height}, + {"output_width", output_width}, + {"sampling_ratio", sampling_ratio}, + {"spatial_scale", spatial_scale}}), + args); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_scatter.cpp b/src/onnx/parse_scatter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b1c33eea405ba7e1ef7019819a4619e21636363 --- /dev/null +++ b/src/onnx/parse_scatter.cpp @@ -0,0 +1,44 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_scatter : op_parser +{ + std::vector operators() const { return {{"ScatterElements"}, {"Scatter"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + const std::vector& args) const + { + operation op; + + std::string op_name = "scatter_none"; + int axis = 0; + + if(contains(info.attributes, "axis")) + axis = info.attributes.at("axis").i(); + if(contains(info.attributes, "reduction")) + { + std::string reduction_att(info.attributes.at("reduction").s()); + // check for a valid reduction attribute. We have an operator for each one. + if(not contains({"none", "add", "mul"}, reduction_att)) + MIGRAPHX_THROW("PARSE_SCATTER: unsupported reduction mode " + reduction_att); + // merge scatter with reduction attribute to specify which scatter operation. Future + // reduction op names should follow this pattern and should also be added to the check + // above. + op_name = std::string("scatter_") + reduction_att; + } + op = migraphx::make_op(op_name, {{"axis", axis}}); + return info.add_instruction(op, args); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_scatternd.cpp b/src/onnx/parse_scatternd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a00370cf90f8830253009a7faf211666aa0f5975 --- /dev/null +++ b/src/onnx/parse_scatternd.cpp @@ -0,0 +1,33 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_scatternd : op_parser +{ + std::vector operators() const { return {{"ScatterND"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector& args) const + { + if(contains(info.attributes, "reduction")) + { + if(info.attributes.at("reduction").s() == "add") + return info.add_instruction(migraphx::make_op("scatternd_add"), args); + if(info.attributes.at("reduction").s() == "mul") + return info.add_instruction(migraphx::make_op("scatternd_mul"), args); + } + + return info.add_instruction(migraphx::make_op("scatternd_none"), args); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_selu.cpp b/src/onnx/parse_selu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..422d9ade43942e045b7150ad1342eaee78ede52c --- /dev/null +++ b/src/onnx/parse_selu.cpp @@ -0,0 +1,61 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_selu : op_parser +{ + std::vector operators() const { return {{"Selu"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + onnx_parser::node_info info, + std::vector args) const + { + auto type = args[0]->get_shape().type(); + auto lens = args[0]->get_shape().lens(); + float alpha = 1.67326f; + if(contains(info.attributes, "alpha")) + { + alpha = info.attributes.at("alpha").f(); + } + + float gamma = 1.0507f; + if(contains(info.attributes, "gamma")) + { + gamma = info.attributes.at("gamma").f(); + } + + auto l_alpha = info.add_literal({{type, {1}}, {alpha}}); + auto l_gamma = info.add_literal({{type, {1}}, {gamma / 2.0f}}); + if(lens != std::vector{1}) + { + l_alpha = + info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_alpha); + l_gamma = + info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_gamma); + } + + auto sign_x = info.add_instruction(make_op("sign"), args[0]); + auto exp_x = info.add_instruction(make_op("exp"), args[0]); + + auto alpha_ex = info.add_instruction(make_op("mul"), l_alpha, exp_x); + auto aex_alpha = info.add_instruction(make_op("sub"), alpha_ex, l_alpha); + + auto ins1 = info.add_instruction(make_op("add"), aex_alpha, args[0]); + auto ins2 = info.add_instruction(make_op("sub"), aex_alpha, args[0]); + + auto sign2 = info.add_instruction(make_op("mul"), sign_x, ins2); + auto ins_sub = info.add_instruction(make_op("sub"), ins1, sign2); + + return info.add_instruction(make_op("mul"), ins_sub, l_gamma); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_shape.cpp b/src/onnx/parse_shape.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3db641412d06ad31168f171c97c319666d3f21d3 --- /dev/null +++ b/src/onnx/parse_shape.cpp @@ -0,0 +1,35 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +// Use a literal instruction to replace the shape since, output of +// shape operator are literals in migraphx +struct parse_shape : op_parser +{ + std::vector operators() const { return {{"Shape"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + if(args.size() != 1) + MIGRAPHX_THROW("Shape: operator should have 1 operand"); + std::vector arg_shape = args[0]->get_shape().lens(); + std::vector vec_shape(arg_shape.size()); + migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()}); + std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { + return int64_t(i); + }); + return info.add_literal(migraphx::literal{s, vec_shape}); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_size.cpp b/src/onnx/parse_size.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed352dc7b11a3bdc01599266c4df6ff53b02e94d --- /dev/null +++ b/src/onnx/parse_size.cpp @@ -0,0 +1,27 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_size : op_parser +{ + std::vector operators() const { return {{"Size"}}; } + + instruction_ref parse(const op_desc&, + const onnx_parser&, + const onnx_parser::node_info& info, + std::vector args) const + { + return info.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type}, + {args[0]->get_shape().elements()}}); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_slice.cpp b/src/onnx/parse_slice.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d6b34c9549aabef23f08b16727546535790520e7 --- /dev/null +++ b/src/onnx/parse_slice.cpp @@ -0,0 +1,114 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_slice : op_parser +{ + std::vector operators() const { return {{"Slice"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + op::slice op; + + std::vector steps; + + // slice can have up to 5 inputs, we first check the 5th one + // to decide whether MIGRAPHX can handle this slice + if(args.size() == 5) + { + migraphx::argument step_arg = args.back()->eval(); + check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice"); + step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); }); + } + + if(args.size() >= 4) + { + migraphx::argument axes_arg = args.at(3)->eval(); + check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice"); + axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); }); + } + else if(contains(info.attributes, "axes")) + { + literal s = parser.parse_value(info.attributes.at("axes")); + s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); }); + } + + if(args.size() >= 3) + { + migraphx::argument end_arg = args.at(2)->eval(); + check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice"); + end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); }); + } + else if(contains(info.attributes, "ends")) + { + literal s = parser.parse_value(info.attributes.at("ends")); + s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); }); + } + + if(args.size() >= 2) + { + migraphx::argument start_arg = args.at(1)->eval(); + check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice"); + start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); }); + } + else if(contains(info.attributes, "starts")) + { + literal s = parser.parse_value(info.attributes.at("starts")); + s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); }); + } + + if(op.axes.empty()) + { + std::vector axes(args[0]->get_shape().lens().size()); + std::iota(axes.begin(), axes.end(), int64_t{0}); + op.axes = axes; + } + + std::vector raxes; + + assert(steps.empty() or steps.size() == op.axes.size()); + assert(op.axes.size() == op.starts.size()); + assert(op.axes.size() == op.ends.size()); + + for(auto i : range(steps.size())) + { + if(steps[i] >= 0) + continue; + op.starts[i] += 1; + if(op.starts[i] == 0) + op.starts[i] = INT_MAX; + op.ends[i] += 1; + raxes.push_back(op.axes[i]); + std::swap(op.starts[i], op.ends[i]); + } + + auto ins = info.add_instruction(op, args[0]); + if(not raxes.empty()) + ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins); + if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; })) + { + std::vector nsteps; + std::transform(steps.begin(), steps.end(), std::back_inserter(nsteps), [](auto s) { + return std::abs(s); + }); + return ins = info.add_instruction( + make_op("step", {{"axes", op.axes}, {"steps", nsteps}}), ins); + } + else + return ins; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_softmax.cpp b/src/onnx/parse_softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..635871339af7e3dc7230440cafcc07d16f6d3b96 --- /dev/null +++ b/src/onnx/parse_softmax.cpp @@ -0,0 +1,43 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_softmax : op_parser +{ + std::vector operators() const + { + return {{"Softmax", "softmax"}, {"LogSoftmax", "logsoftmax"}}; + } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& parser, + const onnx_parser::node_info& info, + const std::vector& args) const + { + // default axis value is -1 for opset 13 + int64_t axis = -1; + + // axis value is 1 for previous opset versions + if(parser.opset_version < 13) + { + axis = 1; + } + + if(contains(info.attributes, "axis")) + { + axis = parser.parse_value(info.attributes.at("axis")).at(); + } + + return info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_softplus.cpp b/src/onnx/parse_softplus.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9267c02fcfde2934f37fcb47ea426414cda617bb --- /dev/null +++ b/src/onnx/parse_softplus.cpp @@ -0,0 +1,31 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_softplus : op_parser +{ + std::vector operators() const { return {{"Softplus"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + // Apply pointwise formula: y = ln(exp(x) + 1) + auto mb_ones = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}), + info.add_literal(migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {1}})); + auto exp = info.add_instruction(migraphx::make_op("exp"), args[0]); + auto add = info.add_instruction(migraphx::make_op("add"), exp, mb_ones); + return info.add_instruction(migraphx::make_op("log"), add); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_softsign.cpp b/src/onnx/parse_softsign.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1f59699fa63caff857f04f5523ebaf8a3dfb1382 --- /dev/null +++ b/src/onnx/parse_softsign.cpp @@ -0,0 +1,31 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_softsign : op_parser +{ + std::vector operators() const { return {{"Softsign"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + // Apply pointwise formula: y = x / (1 + |x|) + auto mb_ones = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}), + info.add_literal(migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {1}})); + auto abs = info.add_instruction(migraphx::make_op("abs"), args[0]); + auto add = info.add_instruction(migraphx::make_op("add"), abs, mb_ones); + return info.add_instruction(migraphx::make_op("div"), args[0], add); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_spacetodepth.cpp b/src/onnx/parse_spacetodepth.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4572a3dc7aa5bdca6840787fd1f7c68cbebac007 --- /dev/null +++ b/src/onnx/parse_spacetodepth.cpp @@ -0,0 +1,60 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_spacetodepth : op_parser +{ + std::vector operators() const { return {{"SpaceToDepth"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + auto s = args[0]->get_shape(); + // blocksize attribute of SpaceToDepth + int blocksize = 1; // if blockSize of 1 then, this is a no-op + if(contains(info.attributes, "blocksize")) + { + blocksize = info.attributes.at("blocksize").i(); + } + if(blocksize < 1) + { + // blockSize less than 1 would rather result in DepthToSpace instead of SpaceToDepth + MIGRAPHX_THROW("SpaceToDepth: blocksize is less than 1"); + } + // calculate dimensions + auto res_lens = s.lens(); // {N, C, H, W} + if(((res_lens[2] % blocksize) == 0) and ((res_lens[3] % blocksize) == 0)) + { + // Co = C * (blocksize ^ 2) + res_lens[1] = res_lens[1] * blocksize * blocksize; + // Ho = (H / blocksize) + res_lens[2] = res_lens[2] / blocksize; + // Wo = (W / blocksize) + res_lens[3] = res_lens[3] / blocksize; + } // res_shape = (N, Co, Ho, Wo) + else + MIGRAPHX_THROW("SpaceToDepth: div by blocksize quotient not int "); + + auto trans_lens = s.lens(); // {N, C, H, W} + trans_lens[2] = res_lens[2]; + trans_lens[3] = blocksize; + trans_lens.push_back(res_lens[3]); + trans_lens.push_back(blocksize); // {N, C, Ho, blocksize, Wo, blocksize} + std::vector perm = {0, 3, 5, 1, 2, 4}; + auto temp1 = info.add_instruction(make_op("reshape", {{"dims", trans_lens}}), args[0]); + auto temp2 = info.add_instruction(make_op("transpose", {{"permutation", perm}}), temp1); + return info.add_instruction(make_op("reshape", {{"dims", res_lens}}), + info.make_contiguous(temp2)); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_split.cpp b/src/onnx/parse_split.cpp new file mode 100644 index 0000000000000000000000000000000000000000..334401b6ca9a16bda3ad83892033ac89ace21854 --- /dev/null +++ b/src/onnx/parse_split.cpp @@ -0,0 +1,70 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_split : op_parser +{ + std::vector operators() const { return {{"Split"}}; } + + std::vector parse(const op_desc& opd, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + int64_t axis = 0; + if(contains(info.attributes, "axis")) + { + axis = parser.parse_value(info.attributes.at("axis")).at(); + } + + auto lens = args[0]->get_shape().lens(); + int64_t n_rank = lens.size(); + int64_t tuned_axis = tune_axis(n_rank, axis, opd.op_name); + + std::vector vec_splits; + if(contains(info.attributes, "split")) + { + literal s = parser.parse_value(info.attributes.at("split")); + s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); }); + + if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) != + static_cast(lens[tuned_axis])) + { + MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!"); + } + } + // no split attribute, input is equally divided + else + { + if((lens[tuned_axis] % info.num_outputs) != 0) + { + MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " + + std::to_string(info.num_outputs) + " splits!"); + } + auto dl = lens[tuned_axis] / info.num_outputs; + vec_splits.resize(info.num_outputs, dl); + } + + std::vector ret_ins; + int64_t start = 0; + for(auto sl : vec_splits) + { + ret_ins.push_back(info.add_instruction( + make_op("slice", {{"axes", {axis}}, {"starts", {start}}, {"ends", {start + sl}}}), + args[0])); + start += sl; + } + + return ret_ins; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_squeeze.cpp b/src/onnx/parse_squeeze.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2268547b9730b66d62aed956d1a4ce6d471a42ec --- /dev/null +++ b/src/onnx/parse_squeeze.cpp @@ -0,0 +1,49 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_squeeze : op_parser +{ + std::vector operators() const + { + return {{"Squeeze", "squeeze"}, {"Unsqueeze", "unsqueeze"}}; + } + + operation assign_axes(operation& op, const std::vector& axes) const + { + auto v = op.to_value(); + v["axes"] = axes; + op.from_value(v); + + return op; + } + + instruction_ref parse(const op_desc& opd, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + auto op = parser.load(opd.op_name, info); + if(args.size() == 2) + { + auto arg_axes = args.at(1)->eval(); + check_arg_empty(arg_axes, "PARSE_" + opd.op_name + ": cannot handle variable axes!"); + std::vector axes; + arg_axes.visit([&](auto s) { axes.assign(s.begin(), s.end()); }); + op = assign_axes(op, axes); + } + + auto arg = info.make_contiguous(args.front()); + return info.add_instruction(op, arg); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_thresholdedrelu.cpp b/src/onnx/parse_thresholdedrelu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1d9709dffbc822d8690b6d6a0f7ec37b2bb6ac4c --- /dev/null +++ b/src/onnx/parse_thresholdedrelu.cpp @@ -0,0 +1,41 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_thresholdedrelu : op_parser +{ + std::vector operators() const { return {{"ThresholdedRelu"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + float alpha = 1.0; + if(contains(info.attributes, "alpha")) + alpha = parser.parse_value(info.attributes.at("alpha")).at(); + + auto x_shape = args[0]->get_shape(); + + auto lit_zero = info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {0}}); + auto lit_alpha = + info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {alpha}}); + auto mb_zero = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_zero); + auto mb_alpha = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_alpha); + auto condition = info.add_instruction(migraphx::make_op("greater"), args[0], mb_alpha); + + return info.add_instruction(migraphx::make_op("where"), condition, args[0], mb_zero); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_tile.cpp b/src/onnx/parse_tile.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5d6d466bafdbc28c15c533104572a10434dafbdd --- /dev/null +++ b/src/onnx/parse_tile.cpp @@ -0,0 +1,40 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_tile : op_parser +{ + std::vector operators() const { return {{"Tile"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + migraphx::argument arg_s = args[1]->eval(); + check_arg_empty(arg_s, "PARSE_TILE: dynamic shape is not supported"); + std::vector repeats; + arg_s.visit([&](auto input) { repeats.assign(input.begin(), input.end()); }); + + auto l0 = args[0]; + for(int i = 0; i < repeats.size(); i++) + { + auto l1 = l0; + for(int j = 1; j < repeats[i]; j++) + { + l0 = info.add_instruction(make_op("concat", {{"axis", i}}), l0, l1); + } + } + return l0; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_topk.cpp b/src/onnx/parse_topk.cpp new file mode 100644 index 0000000000000000000000000000000000000000..da183a303fb270137df549abf80a55e48833621f --- /dev/null +++ b/src/onnx/parse_topk.cpp @@ -0,0 +1,56 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_topk : op_parser +{ + std::vector operators() const { return {{"TopK"}}; } + + std::vector parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + int64_t k = 0; + if(args.size() == 2) + { + auto arg_k = args.at(1)->eval(); + check_arg_empty(arg_k, "PARSE_TopK: k input must be constant"); + k = arg_k.at(); + } + else if(contains(info.attributes, "k")) + { + k = info.attributes.at("k").i(); + } + + bool largest = true; + if(contains(info.attributes, "largest")) + { + largest = static_cast(info.attributes.at("largest").i()); + } + + int64_t axis = -1; + if(contains(info.attributes, "axis")) + { + axis = parser.parse_value(info.attributes.at("axis")).at(); + } + + auto topk_ret = info.add_instruction( + make_op("topk", {{"k", k}, {"axis", axis}, {"largest", largest}}), args.at(0)); + + auto ret_val = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), topk_ret); + auto ret_ind = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), topk_ret); + + return {ret_val, ret_ind}; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_transpose.cpp b/src/onnx/parse_transpose.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a21f6763d847d7595125a0fe519cfe68578db6a3 --- /dev/null +++ b/src/onnx/parse_transpose.cpp @@ -0,0 +1,45 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_transpose : op_parser +{ + std::vector operators() const { return {{"Transpose"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + onnx_parser::node_info info, + std::vector args) const + { + std::vector perm{}; + if(contains(info.attributes, "perm")) + { + auto&& perm_vals = info.attributes["perm"].ints(); + perm = std::vector(perm_vals.begin(), perm_vals.end()); + } + + // if perm is empty, use the default value + auto n_dim = args.front()->get_shape().lens().size(); + if(perm.empty()) + { + perm.resize(n_dim); + std::iota(perm.rbegin(), perm.rend(), 0); + } + + if(perm.size() != n_dim) + { + MIGRAPHX_THROW("PARSE_TRANSPOSE: perm and input have diffferent number of dims!"); + } + + return info.add_instruction(make_op("transpose", {{"permutation", perm}}), args.front()); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_variadic_op.cpp b/src/onnx/parse_variadic_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..954a74537b2125574623aad8aa04382eff8623c9 --- /dev/null +++ b/src/onnx/parse_variadic_op.cpp @@ -0,0 +1,32 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_variadic_op : op_parser +{ + std::vector operators() const + { + return {{"Sum", "add"}, {"Max", "max"}, {"Min", "min"}}; + } + + instruction_ref parse(const op_desc& opd, + const onnx_parser&, + onnx_parser::node_info info, + std::vector args) const + { + return std::accumulate(std::next(args.begin()), + args.end(), + args.front(), + [&](instruction_ref a, instruction_ref b) { + return info.add_broadcastable_binary_op(opd.op_name, a, b); + }); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_where.cpp b/src/onnx/parse_where.cpp new file mode 100644 index 0000000000000000000000000000000000000000..94bbed372e64601b6c5d7431edfb233e38f5c303 --- /dev/null +++ b/src/onnx/parse_where.cpp @@ -0,0 +1,47 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_where : op_parser +{ + std::vector operators() const { return {{"Where"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + const onnx_parser::node_info& info, + std::vector args) const + { + auto lens = + compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens()); + lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); + if(args[0]->get_shape().lens() != lens) + { + args[0] = + info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]); + } + + if(args[1]->get_shape().lens() != lens) + { + args[1] = + info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]); + } + + if(args[2]->get_shape().lens() != lens) + { + args[2] = + info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]); + } + + return info.add_instruction(make_op("where"), args[0], args[1], args[2]); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/softmax.hpp b/src/onnx/softmax.hpp deleted file mode 100644 index 7c28e35faa7f837ba446c0c8fa9b81bbfb2dbe4f..0000000000000000000000000000000000000000 --- a/src/onnx/softmax.hpp +++ /dev/null @@ -1,14 +0,0 @@ -#include -#include -#include - -template -std::vector softmax(const std::vector& p) -{ - size_t n = p.size(); - std::vector result(n); - std::transform(p.begin(), p.end(), result.begin(), [](auto x) { return std::exp(x); }); - T s = std::accumulate(result.begin(), result.end(), 0.0f, std::plus()); - std::transform(result.begin(), result.end(), result.begin(), [=](auto x) { return x / s; }); - return result; -} diff --git a/src/op_enums.cpp b/src/op_enums.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1bbf57932819f095155cdaa1e575079edfee1027 --- /dev/null +++ b/src/op_enums.cpp @@ -0,0 +1,32 @@ +// +// Supporting functions for enum values used in operator parameters. +// These values are declared as "enum class" and should include << streaming operators +// to be able to write their values in human-readable format so users can +// save and edit model files. +// +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +std::ostream& operator<<(std::ostream& os, pooling_mode v) +{ + // the strings for the enum are the same as the values used for onnx parsing + // but this enum is not onnx-specific: strings must be converted when parsing tf + static const std::vector pooling_mode_str = {"average", "max", "lpnorm"}; + os << pooling_mode_str[static_cast::type>(v)]; + return os; +} +std::ostream& operator<<(std::ostream& os, rnn_direction v) +{ + static const std::vector rnn_direction_str = { + "forward", "reverse", "bidirectional"}; + os << rnn_direction_str[static_cast::type>(v)]; + return os; +} + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/operation.cpp b/src/operation.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0fc72180033bdf728c048402a7b5e69a033b9318 --- /dev/null +++ b/src/operation.cpp @@ -0,0 +1,18 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +void migraphx_to_value(value& v, const operation& op) +{ + v["name"] = op.name(); + v["operator"] = op.to_value(); +} +void migraphx_from_value(const value& v, operation& op) +{ + op = make_op(v.at("name").to(), v.at("operator")); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/opt/memory_coloring.cpp b/src/opt/memory_coloring.cpp index 453d2b34dd77611b59428cec5bc2946149d0e0fe..f13e3846767d9c6c2af4cb617129bae18aded5b4 100644 --- a/src/opt/memory_coloring.cpp +++ b/src/opt/memory_coloring.cpp @@ -4,11 +4,11 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -void memory_coloring::apply(program& p) const +void memory_coloring::apply(module& m) const { if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{})) { - memory_coloring_impl opt(&p, allocation_op, verify); + memory_coloring_impl opt(&m, allocation_op, verify); opt.run(); } } diff --git a/src/opt/memory_coloring_impl.cpp b/src/opt/memory_coloring_impl.cpp index 5c68f5b7bffabdce74743597dadc25dfd3012938..4dee9e0175d0d98d271a92d54a43177be5f07e51 100644 --- a/src/opt/memory_coloring_impl.cpp +++ b/src/opt/memory_coloring_impl.cpp @@ -1,4 +1,7 @@ -#include +#include + +#include + #include "memory_coloring_impl.hpp" namespace migraphx { @@ -6,8 +9,11 @@ inline namespace MIGRAPHX_INLINE_NS { void memory_coloring_impl::run() { + // calc implicit depdendencies + mod_implicit_deps = p_mod->calc_implicit_deps(); + MIGRAPHX_DEBUG(dump("---Before memory coloring---")); - MIGRAPHX_DEBUG(dump_program()); + MIGRAPHX_DEBUG(dump_module()); build(); if(num_of_lives != 0) { @@ -19,7 +25,10 @@ void memory_coloring_impl::run() allocate(interval); alloc_queue.pop(); } + + // rewrite happens after all modules are processed rewrite(); + if(enable_verify) verify(); } @@ -31,7 +40,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval) std::size_t size = s.bytes(); if(size == 0) return false; - std::size_t element_size = size / s.elements(); + std::size_t element_size = (s.elements() == 0 ? 4 : (size / s.elements())); live_range& segment = interval->segment; int vn = segment.vn; std::priority_queue, ordering> conflict_queue; @@ -41,7 +50,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval) if(conflict_table.find(vn) != conflict_table.end()) { std::set& vn_set = conflict_table[vn]; - for(auto& iter : vn_set) + for(const auto& iter : vn_set) { live_range* range = live_ranges[iter]; long long offset = range->offset; @@ -96,13 +105,13 @@ bool memory_coloring_impl::allocate(interval_ptr interval) void memory_coloring_impl::build() { - std::size_t num_of_instrs = p_program->size(); + std::size_t num_of_instrs = p_mod->size(); if(num_of_instrs == 0) return; auto cur_points = num_of_instrs * 2; - instruction_ref iter = p_program->end(); - instruction_ref begin = p_program->begin(); + instruction_ref iter = p_mod->end(); + instruction_ref begin = p_mod->begin(); std::vector dead_instrs; std::set live_set; // Build live intervals. @@ -134,8 +143,19 @@ void memory_coloring_impl::build() { is_dead = true; } - for(auto&& arg : iter->inputs()) + + auto inputs = iter->inputs(); + if(contains(mod_implicit_deps, iter)) { + const auto& impl_deps = mod_implicit_deps.at(iter); + inputs.insert(inputs.end(), impl_deps.begin(), impl_deps.end()); + } + + for(auto&& arg : inputs) + { + if(not p_mod->has_instruction(arg)) + continue; + if(is_param(arg) || is_outline(arg)) { if(is_output_param(arg)) @@ -182,8 +202,8 @@ void memory_coloring_impl::rewrite() std::vector dims; dims.push_back((required_bytes + sizeof(float) - 1) / sizeof(float)); shape s = {shape::float_type, dims}; - instruction_ref scratch_param = p_program->add_parameter("scratch", s); - for(auto ins : iterator_for(*p_program)) + instruction_ref scratch_param = p_mod->add_parameter("scratch", s); + for(auto ins : iterator_for(*p_mod)) { const instruction* p_iter = &(*ins); if(instr2_live.find(p_iter) != instr2_live.end()) @@ -207,13 +227,15 @@ void memory_coloring_impl::rewrite() if(is_allocate(ins)) { - p_program->replace_instruction( - ins, op::load{ins->get_shape(), offset}, scratch_param); + p_mod->replace_instruction( + ins, + make_op("load", {{"shape", to_value(ins->get_shape())}, {"offset", offset}}), + scratch_param); } } } MIGRAPHX_DEBUG(dump("---After rewrite---")); - MIGRAPHX_DEBUG(dump_program()); + MIGRAPHX_DEBUG(dump_module()); } void memory_coloring_impl::verify() @@ -227,8 +249,8 @@ void memory_coloring_impl::verify() if(segment.begin == invalid_offset) { - if(!interval.is_live_on_entry) - MIGRAPHX_THROW("interval is not live on entry"); + // if(!interval.is_live_on_entry) + // MIGRAPHX_THROW("interval is not live on entry"); continue; } @@ -240,7 +262,7 @@ void memory_coloring_impl::verify() if(conflict_table.find(vn) != conflict_table.end()) { std::set& vn_set = conflict_table[vn]; - for(auto& iter : vn_set) + for(const auto& iter : vn_set) { live_range* range = live_ranges[iter]; if(range->offset == invalid_offset) @@ -257,7 +279,7 @@ void memory_coloring_impl::verify() void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; } -void memory_coloring_impl::dump_program() { std::cout << *p_program << std::endl; } +void memory_coloring_impl::dump_module() { std::cout << *p_mod << std::endl; } void memory_coloring_impl::dump_intervals() { diff --git a/src/opt/memory_coloring_impl.hpp b/src/opt/memory_coloring_impl.hpp old mode 100644 new mode 100755 index b564c6f72d48cb94697eca202948e3e06b29946e..4f264796ec287e6f54c8358b1df9dc9b2b76a006 --- a/src/opt/memory_coloring_impl.hpp +++ b/src/opt/memory_coloring_impl.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -39,10 +40,6 @@ struct live_interval { live_interval() : segment({invalid_offset, invalid_offset, invalid_offset, invalid_offset, 0}) { - id = invalid_offset; - def_point = invalid_offset; - is_literal = false; - is_live_on_entry = false; } void add_use(std::size_t use) { use_points.push_front(use); } @@ -55,35 +52,27 @@ struct live_interval #endif live_range segment; - std::size_t id; - std::list use_points; - std::size_t def_point; - shape result; - bool is_literal; - bool is_live_on_entry; + std::size_t id = invalid_offset; + std::list use_points{}; + std::size_t def_point = invalid_offset; + shape result{}; + bool is_literal = false; + bool is_live_on_entry = false; }; using interval_ptr = live_interval*; struct memory_coloring_impl { - memory_coloring_impl(program* p, std::string alloc_op, bool p_verify) - : p_program(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify) + memory_coloring_impl(module* p, std::string alloc_op, bool p_verify) + : p_mod(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify) { - instr2_live.clear(); - live_ranges.clear(); - conflict_table.clear(); - num_of_lives = 0; - max_value_number = -1; - required_bytes = 0; - earliest_end_point = -1; - latest_end_point = -1; - unify_literals = false; } + bool allocate(interval_ptr); - void add_conflicts(std::set& live_set, int val) + void add_conflicts(const std::set& live_set, int val) { - for(auto& iter : live_set) + for(const auto& iter : live_set) { conflict_table[iter].insert(val); conflict_table[val].insert(iter); @@ -97,7 +86,11 @@ struct memory_coloring_impl static bool is_param(const instruction_ref ins) { return ins->name() == "@param"; } static bool is_output_param(const instruction_ref ins) { - return is_param(ins) && any_cast(ins->get_operator()).parameter == "output"; + if(not is_param(ins)) + return false; + + auto param_name = any_cast(ins->get_operator()).parameter; + return contains(param_name, "#output_"); } bool is_allocate(const instruction_ref ins) const { return ins->name() == allocation_op; } static bool is_outline(const instruction_ref ins) { return ins->name() == "@outline"; } @@ -118,12 +111,12 @@ struct memory_coloring_impl void verify(); #ifdef MIGRAPHX_DEBUG_OPT void dump(const std::string&); - void dump_program(); + void dump_module(); void dump_intervals(); #endif struct ordering { - bool operator()(const interval_ptr i1, const interval_ptr i2) const + bool operator()(const interval_ptr& i1, const interval_ptr& i2) const { auto len1 = i1->get_end() - i1->get_begin(); auto len2 = i2->get_end() - i2->get_begin(); @@ -145,28 +138,31 @@ struct memory_coloring_impl return (i1->offset > i2->offset); } }; - program* p_program; + + module* p_mod; std::unordered_map instr2_live; // universe of live intervals. - std::vector live_intervals; + std::vector live_intervals = {}; // Map live range value number to live range. - std::unordered_map live_ranges; + std::unordered_map live_ranges = {}; // Map live range value number to a set of conflicting live ranges' value numbers. - std::unordered_map> conflict_table; + std::unordered_map> conflict_table = {}; // Priority queue for coloring. - std::priority_queue, ordering> alloc_queue; + std::priority_queue, ordering> alloc_queue{}; - int num_of_lives; - int max_value_number; - std::size_t required_bytes; + int num_of_lives = 0; + int max_value_number = -1; + std::size_t required_bytes = 0; // The earliest program point where an live interval ends. - int earliest_end_point; + int earliest_end_point = -1; // The latest program point where an live interval ends. - int latest_end_point; + int latest_end_point = -1; // Whether to unify literals into coloring. - bool unify_literals; + bool unify_literals = false; std::string allocation_op{}; bool enable_verify; + + ins_dep_map mod_implicit_deps; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/pass_manager.cpp b/src/pass_manager.cpp old mode 100644 new mode 100755 index 0516913221bb615bfc92f908d27934d6adb97049..1a51f0a665b59e89743711c3c8ca06350f41b664 --- a/src/pass_manager.cpp +++ b/src/pass_manager.cpp @@ -15,25 +15,97 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -void run_passes(program& prog, const std::vector& passes, tracer trace) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PASSES); + +void validate_pass(module& mod, const pass& p, tracer trace) { - for(auto& p : passes) + (void)mod; + (void)p; + (void)trace; +#ifndef NDEBUG + trace("Validate ..."); + auto invalid = mod.validate(); + if(invalid != mod.end()) { - trace("Pass: ", p.name()); - p.apply(prog); - trace(prog); + auto index = std::distance(mod.begin(), invalid); + MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " + + std::to_string(index) + ": " + invalid->name()); + } + trace(); +#endif +} +void run_pass(program& prog, const pass& p, tracer trace) +{ + trace("Pass: ", p.name()); + p.apply(prog); + trace(prog); +} -#ifndef NDEBUG - trace("Validate ..."); - auto invalid = prog.validate(); - if(invalid != prog.end()) +struct module_pm : module_pass_manager +{ + module* mod; + program* prog; + tracer* t; + + module_pm(module* pmod = nullptr, program* pprog = nullptr, tracer* pt = nullptr) + : mod(pmod), prog(pprog), t(pt) + { + } + + template + void trace(Ts&&... xs) const + { + assert(t); + (*t)(xs...); + } + + virtual module& get_module() override + { + assert(mod); + return *mod; + } + virtual module* create_module(const std::string& name) override + { + assert(prog); + return prog->create_module(name); + } + virtual void run_pass(const pass& p) override + { + assert(mod); + trace("Module: ", mod->name(), ", Pass: ", p.name()); + assert(mod->validate() == mod->end()); + p.apply(*this); + trace(*mod); + validate_pass(*mod, p, *t); + } +}; + +module& get_module(module_pass_manager& mpm) { return mpm.get_module(); } + +void run_passes(module& mod, const std::vector& passes, tracer trace) +{ + if(enabled(MIGRAPHX_TRACE_PASSES{})) + trace = tracer{std::cout}; + for(const auto& p : passes) + { + module_pm{&mod, nullptr, &trace}.run_pass(p); + } +} + +void run_passes(program& prog, const std::vector& passes, tracer trace) +{ + if(enabled(MIGRAPHX_TRACE_PASSES{})) + trace = tracer{std::cout}; + for(const auto& p : passes) + { + auto mods = prog.get_modules(); + for(const auto& mod : reverse(mods)) { - auto index = std::distance(prog.begin(), invalid); - MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " + - std::to_string(index) + ": " + invalid->name()); + if(mod->bypass()) + continue; + module_pm{mod, &prog, &trace}.run_pass(p); } - trace(); -#endif + run_pass(prog, p, trace); } } diff --git a/src/permutation.cpp b/src/permutation.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5a77407e034a84f8d899bccc7dbeb65eaa4cd96a --- /dev/null +++ b/src/permutation.cpp @@ -0,0 +1,55 @@ + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +shape reorder_shape(const shape& s, const std::vector& permutation) +{ + return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)}; +} + +std::vector invert_permutation(const std::vector& permutation) +{ + return sort_permutation(permutation, std::less<>{}); +} + +std::vector find_permutation(const shape& s) +{ + std::vector result(s.lens().size()); + std::iota(result.begin(), result.end(), 0); + std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) { + return std::make_tuple(s.strides()[x], s.lens()[x]); + })); + return result; +} + +std::vector find_permutation(const std::vector& shapes) +{ + if(shapes.empty()) + return {}; + std::map, std::size_t> count; + for(auto&& s : shapes) + { + if(s.broadcasted()) + continue; + count[find_permutation(s)]++; + } + if(count.empty()) + { + std::vector r(shapes.front().lens().size()); + std::iota(r.begin(), r.end(), 0); + return r; + } + auto it = std::max_element( + count.begin(), count.end(), by(std::less<>{}, [](auto&& p) { return p.second; })); + assert(it != count.end()); + return it->first; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/preallocate_param.cpp b/src/preallocate_param.cpp new file mode 100755 index 0000000000000000000000000000000000000000..086c1f435cab7f20c949a535842d8b4c29a6a173 --- /dev/null +++ b/src/preallocate_param.cpp @@ -0,0 +1,29 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +void preallocate_param::apply(module& m) const +{ + auto last = std::prev(m.end()); + for(auto ins : iterator_for(m)) + { + if(ins->name() != "@param") + continue; + if(param != any_cast(ins->get_operator()).parameter) + continue; + std::string id = m.name() + ":" + param; + auto r = m.insert_instruction(ins, model.preallocate(ins->get_shape(), id)); + m.replace_instruction(ins, r); + m.move_instruction(ins, m.end()); + } + m.remove_instructions(std::next(last), m.end()); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/process.cpp b/src/process.cpp new file mode 100755 index 0000000000000000000000000000000000000000..69fec5e9b62f4c737e9f68f719c23c91fd3389cc --- /dev/null +++ b/src/process.cpp @@ -0,0 +1,84 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_CMD_EXECUTE) + +std::function redirect_to(std::ostream& os) +{ + return [&](const char* x) { os << x; }; +} + +int exec(const std::string& cmd, const std::function& std_out) +{ + int ec = 0; + if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{})) + std::cout << cmd << std::endl; + auto closer = [&](FILE* stream) { + auto status = pclose(stream); + ec = WIFEXITED(status) ? 0 : WEXITSTATUS(status); // NOLINT + }; + { + // TODO: Use execve instead of popen + std::unique_ptr pipe(popen(cmd.c_str(), "r"), closer); // NOLINT + if(!pipe) + MIGRAPHX_THROW("popen() failed: " + cmd); + std::array buffer; + while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) + std_out(buffer.data()); + } + return ec; +} + +struct process_impl +{ + std::string command{}; + fs::path cwd{}; + + std::string get_command() const + { + std::string result; + if(not cwd.empty()) + result += "cd " + cwd.string() + "; "; + result += command; + return result; + } +}; + +process::process(const std::string& cmd) : impl(std::make_unique()) +{ + impl->command = cmd; +} + +process::process(process&&) noexcept = default; + +process& process::operator=(process rhs) +{ + std::swap(impl, rhs.impl); + return *this; +} + +process::~process() noexcept = default; + +process& process::cwd(const fs::path& p) +{ + impl->cwd = p; + return *this; +} + +void process::exec() +{ + auto ec = migraphx::exec(impl->get_command(), redirect_to(std::cout)); + if(ec != 0) + MIGRAPHX_THROW("Command " + impl->get_command() + " exited with status " + + std::to_string(ec)); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/program.cpp b/src/program.cpp index 1d5d5db02d9a6d3fbc68168e6d8501ed789be981..965f5057b917efd61f4102ccbdf3c881e20de396 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -6,87 +6,38 @@ #include #include #include -#include #include +#include +#include +#include +#include +#include +#include +#include #include #include #include +#include #include +#include +#include +#include + namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +using milliseconds = std::chrono::duration; + struct program_impl { - // A list is used to keep references to an instruction stable - std::list instructions; + // A map is used to keep references to modules of the program + std::unordered_map modules; context ctx; + std::string target_name; }; -const operation& get_operation(instruction_ref ins) { return ins->get_operator(); } - -static void print_instruction(std::ostream& os, - instruction_ref ins, - const std::unordered_map& names) -{ - os << names.at(ins) << " = "; - - os << ins->get_operator(); - - if(ins->name() == "@literal") - { - if(ins->get_literal().get_shape().elements() > 10) - os << "{ ... }"; - else - os << "{" << ins->get_literal() << "}"; - } - - if(!ins->inputs().empty()) - { - char delim = '('; - for(auto&& arg : ins->inputs()) - { - os << delim << names.at(arg); - delim = ','; - } - os << ")"; - } - - os << " -> " << ins->get_shape(); -} - -template -static void print_program(const program& p, F print_func) -{ - std::unordered_map names; - int count = 0; - - for(auto ins : iterator_for(p)) - { - std::string var_name; - if(ins->name() == "@param") - { - var_name = any_cast(ins->get_operator()).parameter; - } - else - { - var_name = "@" + std::to_string(count); - count++; - } - names.emplace(ins, var_name); - - // TODO: Use all_of - for(auto&& arg : ins->inputs()) - { - assert(p.has_instruction(arg) && "Instruction not found"); - (void)arg; - } - - print_func(ins, names); - } -} - -program::program() : impl(std::make_unique()) {} +program::program() : impl(std::make_unique()) { this->create_module("main"); } program::program(program&&) noexcept = default; program::~program() noexcept = default; @@ -103,287 +54,204 @@ program& program::operator=(program p) void program::assign(const program& p) { - // clean the current program if(!impl) { impl = std::make_unique(); } - else if(!impl->instructions.empty()) + else if(!impl->modules.empty()) { - impl->instructions.clear(); + impl->modules.clear(); } - impl->ctx = p.impl->ctx; - std::unordered_map ins_map; - for(auto ins : iterator_for(p)) - { - instruction_ref copy_ins{}; - if(ins->name() == "@literal") - { - auto l = ins->get_literal(); - copy_ins = impl->instructions.insert(impl->instructions.end(), instruction{l}); - } - else if(ins->name() == "@param") - { - auto&& name = any_cast(ins->get_operator()).parameter; - auto s = ins->get_shape(); - copy_ins = impl->instructions.insert(impl->instructions.end(), - {builtin::param{name}, std::move(s), {}}); - } - else if(ins->name() == "@outline") - { - auto s = ins->get_shape(); - copy_ins = - impl->instructions.insert(impl->instructions.end(), {builtin::outline{s}, s, {}}); - } - else - { - // retrieve its mapped input - auto inputs = ins->inputs(); - // ensure all inputs have its corresponding copy instructions - assert(std::all_of( - inputs.begin(), inputs.end(), [&](auto i) { return ins_map.count(i) > 0; })); - std::vector copy_inputs(inputs.size()); - std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) { - return ins_map[i]; - }); - copy_ins = add_instruction(ins->get_operator(), copy_inputs); - } - - ins_map[ins] = copy_ins; - } -} - -instruction_ref program::add_instruction(const operation& op, std::vector args) -{ - return insert_instruction(impl->instructions.end(), op, std::move(args)); -} -instruction_ref program::insert_instruction(instruction_ref ins, - const operation& op, - std::vector args) -{ - assert(std::all_of( - args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) && - "Argument is not an exisiting instruction"); - assert(not starts_with(op.name(), "@")); - shape r = compute_shape(op, args); - auto result = impl->instructions.insert(ins, {op, r, std::move(args)}); - instruction::backreference(result); - assert(result->valid(begin())); - return result; -} + impl->ctx = p.impl->ctx; + impl->target_name = p.impl->target_name; + impl->modules = p.impl->modules; -instruction_ref program::replace_instruction(instruction_ref ins, - const operation& op, - std::vector args) -{ - assert(std::all_of( - args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) && - "Argument is not an exisiting instruction"); - assert(not starts_with(op.name(), "@")); - - shape r = compute_shape(op, args); - instruction::replace(ins, op, r, std::move(args)); - assert(ins->valid(begin())); - return ins; -} + // build a map from old ins to new ins + // Build a map from old module to new module + std::unordered_map mod_map; + std::transform( + impl->modules.begin(), + impl->modules.end(), + std::inserter(mod_map, mod_map.begin()), + [&](auto&& xp) { return std::make_pair(&p.impl->modules.at(xp.first), &xp.second); }); -instruction_ref program::replace_instruction(instruction_ref ins, instruction_ref rep) -{ - assert(has_instruction(ins)); - assert(has_instruction(rep)); - assert(ins != rep); - - if(ins == std::prev(this->end())) + std::unordered_map ins_map; + for(auto&& pp : mod_map) { - return replace_instruction(ins, op::identity{}, rep); + auto old_ins = iterator_for(*pp.first); + auto new_ins = iterator_for(*pp.second); + std::transform(old_ins.begin(), + old_ins.end(), + new_ins.begin(), + std::inserter(ins_map, ins_map.begin()), + [](auto x, auto y) { return std::make_pair(x, y); }); } - // TODO: Should it be an error if the output is empty? - if(ins->outputs().empty()) - { - return rep; - } - // Make a copy of outputs which can be changed when calling replace_argument - auto outputs = ins->outputs(); - for(auto out : outputs) + // Update all references from all modules + for(auto&& mp : impl->modules) { - // TODO: Check for possible cycles - if(out != rep) - { - instruction::replace_argument(out, ins, rep); - } - assert(out->valid(begin())); + for(auto ins : iterator_for(mp.second)) + instruction::replace_refs(ins, ins_map, mod_map); } - // Replacement should not be dead code unless its the last instruction - assert(!rep->outputs().empty() or rep == std::prev(end())); - // Output of the original instruction should only be the replacement or empty - assert(ins->outputs().empty() or std::all_of(ins->outputs().begin(), - ins->outputs().end(), - [&](auto i) { return i == rep; })); - assert(ins->valid(begin())); - assert(rep->valid(begin())); - return rep; -} - -instruction_ref program::remove_instruction(instruction_ref ins) -{ - assert(has_instruction(ins)); - assert(ins->outputs().empty()); - ins->clear_arguments(); - return impl->instructions.erase(ins); -} - -instruction_ref program::remove_instructions(instruction_ref first, instruction_ref last) -{ - if(first == last) - return first; - // TODO: Check every element - assert(has_instruction(first)); - std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); }); - assert(std::all_of(first, last, [&](const instruction& ins) { return ins.outputs().empty(); })); - return impl->instructions.erase(first, last); } -instruction_ref program::move_instruction(instruction_ref src, instruction_ref dst) -{ - impl->instructions.splice(dst, impl->instructions, src); - return src; -} - -instruction_ref program::add_literal(literal l) -{ - impl->instructions.emplace_front(std::move(l)); - return impl->instructions.begin(); -} - -instruction_ref program::add_outline(const shape& s) +shape program::get_parameter_shape(std::string name) const { - impl->instructions.push_front({builtin::outline{s}, s, {}}); - return impl->instructions.begin(); + const auto* mm = this->get_main_module(); + return mm->get_parameter_shape(std::move(name)); } -instruction_ref program::add_parameter(std::string name, shape s) +std::vector program::get_parameter_names() const { - assert(get_parameter_shape(name) == shape{}); - impl->instructions.push_front({builtin::param{std::move(name)}, std::move(s), {}}); - return impl->instructions.begin(); -} - -shape program::get_parameter_shape(std::string name) const -{ - auto ins = std::find_if( - impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) { - if(x.name() == "@param") - { - return any_cast(x.get_operator()).parameter == name; - } - else - { - return false; - } - }); - if(ins != this->end()) - return ins->get_shape(); - else - return {}; + const auto* mm = this->get_main_module(); + return mm->get_parameter_names(); } instruction_ref program::get_parameter(std::string name) const { - auto ins = std::find_if( - impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) { - if(x.name() == "@param") - { - return any_cast(x.get_operator()).parameter == name; - } - else - { - return false; - } - }); - if(ins != this->end()) - return ins; - else - return this->end(); + const auto* mm = this->get_main_module(); + return mm->get_parameter(std::move(name)); } std::unordered_map program::get_parameter_shapes() const { - std::unordered_map result; - for(auto&& ins : impl->instructions) - { - if(ins.name() == "@param") - { - auto&& name = any_cast(ins.get_operator()).parameter; - result[name] = ins.get_shape(); - } - } - return result; + const auto* mm = this->get_main_module(); + return mm->get_parameter_shapes(); } -bool program::has_instruction(instruction_ref ins) const +std::size_t program::size() const { return impl->modules.size(); } + +std::vector program::get_output_shapes() const { - return std::find_if( - impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) { - return std::addressof(*ins) == std::addressof(x); - }) != impl->instructions.end(); + const auto* mm = this->get_main_module(); + return mm->get_output_shapes(); } -std::size_t program::size() const { return impl->instructions.size(); } -instruction_ref program::begin() const { return impl->instructions.begin(); } -instruction_ref program::end() const { return impl->instructions.end(); } - -shape program::get_shape() const { return impl->instructions.back().get_shape(); } - context& program::get_context() const { return impl->ctx; } instruction_ref program::validate() const { - return std::find_if(impl->instructions.begin(), - impl->instructions.end(), - [&](const instruction& i) { return !i.valid(impl->instructions.begin()); }); + const auto* mm = this->get_main_module(); + return mm->validate(); } +bool program::is_compiled() const { return not this->impl->target_name.empty(); } + void program::compile(const target& t, compile_options options) { - assert(this->validate() == impl->instructions.end()); - this->impl->ctx = t.get_context(); + assert(not this->is_compiled()); + this->impl->target_name = t.name(); + this->impl->ctx = t.get_context(); if(enabled(MIGRAPHX_TRACE_COMPILE{})) options.trace = tracer{std::cout}; + options.trace(*this); options.trace(); - run_passes(*this, t.get_passes(this->impl->ctx, options), options.trace); - auto invalid = this->validate(); - if(invalid != impl->instructions.end()) + + auto&& passes = t.get_passes(this->impl->ctx, options); + run_passes(*this, passes, options.trace); + + auto mods = this->get_modules(); + + // Validate and finalize + for(const auto& mod : reverse(mods)) { - auto index = std::distance(impl->instructions.begin(), invalid); - MIGRAPHX_THROW("Invalid program from compilation at instruction " + std::to_string(index)); + auto invalid = mod->validate(); + if(invalid != mod->end()) + { + MIGRAPHX_THROW("Invalid module " + mod->name() + " from compilation at instruction " + + std::to_string(std::distance(mod->begin(), invalid))); + } + auto dangling = mod->find_dangling_reference(); + if(dangling != mod->end()) + { + auto index = std::distance(mod->begin(), dangling); + MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " + + std::to_string(index)); + } + mod->finalize(this->impl->ctx); } - this->finalize(); } void program::finalize() { - for(auto ins : iterator_for(*this)) + auto* mm = this->get_main_module(); + mm->finalize(this->impl->ctx); +} + +template +std::string classify(T x) +{ + switch(std::fpclassify(x)) { - ins->finalize(this->impl->ctx); + case FP_INFINITE: return "inf"; + case FP_NAN: return "nan"; + case FP_NORMAL: return "normal"; + case FP_SUBNORMAL: return "subnormal"; + case FP_ZERO: return "zero"; + default: return "unknown"; } } +std::unordered_set classify_argument(const argument& a) +{ + std::unordered_set result; + a.visit( + [&](auto t) { + for(const auto& x : t) + result.insert(classify(x)); + }, + [&](const auto& xs) { + for(const auto& x : xs) + { + auto r = classify_argument(x); + result.insert(r.begin(), r.end()); + } + }); + return result; +} + +void preview_argument(std::ostream& os, const argument& a) +{ + a.visit( + [&](auto t) { + if(t.size() <= 10) + { + os << t; + } + else + { + os << to_string_range(t.begin(), t.begin() + 5); + os << ", ..., "; + os << to_string_range(t.end() - 5, t.end()); + } + }, + [&](const auto& xs) { + for(const auto& x : xs) + { + os << '{'; + preview_argument(os, x); + os << '}'; + } + }); +} + template -argument generic_eval(const program& p, - context& ctx, - std::unordered_map params, - F trace) -{ - assert(p.validate() == p.end()); - std::unordered_map results; - results.reserve(p.size() * 2); +std::vector generic_eval(const module* mod, + context& ctx, + std::unordered_map params, + std::unordered_map results, + F make_trace) +{ + assert(mod->validate() == mod->end()); + results.reserve(mod->size() * 2); std::vector values; values.reserve(16); - for(auto ins : iterator_for(p)) + auto trace = make_trace(mod); + for(auto ins : iterator_for(*mod)) { + assert(results.find(ins) == results.end()); const auto& name = ins->name(); if(name == "@literal") { @@ -407,6 +275,19 @@ argument generic_eval(const program& p, { results.emplace(ins, trace(ins, [&] { return argument{ins->get_shape(), nullptr}; })); } + else if(name == "@return") + { + std::vector prog_outputs; + std::transform(ins->inputs().begin(), + ins->inputs().end(), + std::back_inserter(prog_outputs), + [&](instruction_ref i) { + assert(results.find(i) != results.end()); + return results[i]; + }); + + return prog_outputs; + } else { values.resize(ins->inputs().size()); @@ -415,52 +296,280 @@ argument generic_eval(const program& p, assert(results.find(i) != results.end()); return results[i]; }); + + const auto& mod_args = ins->module_inputs(); + auto module_eval = [&](module_ref smod, + const std::unordered_map& inputs) { + auto ssctx = ctx; + return generic_eval(smod, ssctx, inputs, results, make_trace); + }; + results.emplace(ins, trace(ins, [&] { - return ins->get_operator().compute(ctx, ins->get_shape(), values); + return ins->normalized_operator().compute( + ctx, ins->get_shape(), values, mod_args, module_eval); })); } assert(results.find(ins) != results.end()); + assert(results.at(ins).get_shape() == ins->get_shape()); } - return results.at(std::prev(p.end())); + return {results.at(std::prev(mod->end()))}; } -argument program::eval(std::unordered_map params) const +template +std::vector generic_eval(const program& p, + context& ctx, + std::unordered_map params, + F make_trace) +{ + const module* mm = p.get_main_module(); + return generic_eval(mm, ctx, params, {}, make_trace); +} + +std::vector program::eval(parameter_map params) const { auto& ctx = this->impl->ctx; #ifndef NDEBUG - auto sctx = ctx; - auto check_context = [&](auto f) { - assert(is_shared(ctx, sctx)); - auto x = f(); - sctx = ctx; - return x; + auto with_check_context = [&](auto f) { + return [=, &ctx](auto&&) { + auto sctx = std::make_shared(ctx); + auto check_context = [=, &ctx](auto g) { + assert(is_shared(ctx, *sctx)); + auto x = g(); + *sctx = ctx; + return x; + }; + return [=](auto&&... xs) { return f(xs..., check_context); }; + }; }; #else - auto check_context = [](auto f) { return f(); }; + auto with_check_context = [](auto f) { + return [=](auto&&) { + return [=](auto&&... xs) { return f(xs..., [](auto g) { return g(); }); }; + }; + }; #endif auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{}); if(trace_level > 0) { - return generic_eval(*this, ctx, std::move(params), [&](auto& ins, auto f) { - ctx.finish(); - std::cout << "Run instruction: "; - this->debug_print(ins); - auto result = check_context(f); - ctx.finish(); - if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load") - std::cout << "Ouput: " << result << std::endl; - return result; + std::unordered_map ins_out; + // get instruction names + this->print([&](auto x, auto ins_names) { + std::stringstream ss; + instruction::print(ss, x, ins_names); + ins_out[x] = ss.str(); }); + + return generic_eval(*this, + ctx, + std::move(params), + with_check_context([&](auto& ins, auto f, auto&& check_context) { + ctx.finish(); + std::cout << "Run instruction: " << ins_out.at(ins) << std::endl; + timer t{}; + auto result = check_context(f); + double t1 = t.record(); + ctx.finish(); + double t2 = t.record(); + std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl; + if(trace_level > 1 and ins->name().front() != '@' and + ins->name() != "load" and not result.empty()) + { + target tgt = make_target(this->impl->target_name); + auto buffer = tgt.copy_from(result); + if(trace_level == 2) + { + std::cout << "Output has " + << to_string_range(classify_argument(buffer)) + << std::endl; + std::cout << "Output: "; + preview_argument(std::cout, buffer); + std::cout << std::endl; + } + else + { + std::cout << "Output: " << buffer << std::endl; + } + } + return result; + })); } else { - return generic_eval( - *this, ctx, std::move(params), [&](auto&, auto f) { return check_context(f); }); + return generic_eval(*this, + ctx, + std::move(params), + with_check_context([&](auto&, auto f, auto&& check_context) { + return check_context(f); + })); } } +const int program_file_version = 5; + +value program::to_value() const +{ + value result; + result["version"] = program_file_version; + result["target"] = this->impl->target_name; + if(not this->impl->target_name.empty()) + result["context"] = this->impl->ctx.to_value(); + + value module_vals = value::object{}; + std::unordered_map names; + for(auto& mod : this->get_modules()) + { + value mod_val; + value nodes; + mod_val["name"] = mod->name(); + names = mod->print( + [&](auto ins, auto ins_names) { + value node; + node["output"] = ins_names.at(ins); + node["name"] = ins->name(); + node["shape"] = migraphx::to_value(ins->get_shape()); + node["normalized"] = ins->is_normalized(); + if(ins->name() == "@literal") + node["literal"] = migraphx::to_value(ins->get_literal()); + node["operator"] = ins->get_operator().to_value(); + std::vector inputs; + std::transform(ins->inputs().begin(), + ins->inputs().end(), + std::back_inserter(inputs), + [&](auto i) { + assert(contains(ins_names, i)); + return ins_names.at(i); + }); + node["inputs"] = inputs; + auto module_args = ins->module_inputs(); + if(not module_args.empty()) + { + std::vector module_inputs; + std::transform(module_args.begin(), + module_args.end(), + std::back_inserter(module_inputs), + [&](auto mod_ref) { return mod_ref->name(); }); + node["module_inputs"] = module_inputs; + } + + nodes.push_back(node); + }, + names); + mod_val["nodes"] = nodes; + + module_vals[mod->name()] = mod_val; + } + + result["modules"] = module_vals; + + return result; +} + +static void mod_from_val(module_ref mod, + const value& v, + std::unordered_map& instructions, + const std::unordered_map& map_mods) +{ + const auto& module_val = v.at(mod->name()); + for(const value& node : module_val.at("nodes")) + { + instruction_ref output; + auto name = node.at("name").to(); + auto fields = node.at("operator"); + auto normalized = node.at("normalized").to(); + + if(name == "@param") + { + output = mod->add_parameter(fields["parameter"].to(), + migraphx::from_value(node.at("shape"))); + } + else if(name == "@literal") + { + output = mod->add_literal(migraphx::from_value(node.at("literal"))); + } + else + { + auto op = make_op(name, fields); + std::vector inputs; + std::transform(node.at("inputs").begin(), + node.at("inputs").end(), + std::back_inserter(inputs), + [&](const value& i) { + auto i_name = i.to(); + assert(contains(instructions, i_name)); + return instructions.at(i_name); + }); + + std::vector module_inputs; + if(node.contains("module_inputs")) + { + std::transform(node.at("module_inputs").begin(), + node.at("module_inputs").end(), + std::back_inserter(module_inputs), + [&](const value& i) { return map_mods.at(i.to()); }); + + for(auto& smod : module_inputs) + { + mod_from_val(smod, v, instructions, map_mods); + } + } + + if(name == "@return") + { + output = mod->add_return(inputs); + } + else if(module_inputs.empty()) + { + output = mod->add_instruction(op, inputs); + } + else + { + output = mod->add_instruction(op, inputs, module_inputs); + } + } + output->set_normalized(normalized); + instructions[node.at("output").to()] = output; + } +} + +void program::from_value(const value& v) +{ + auto version = v.at("version").to(); + if(version != program_file_version) + { + MIGRAPHX_THROW("Warning: Program version mismatch"); + } + + this->impl->target_name = v.at("target").to(); + if(not this->impl->target_name.empty()) + { + target t = make_target(this->impl->target_name); + this->impl->ctx = t.get_context(); + this->impl->ctx.from_value(v.at("context")); + } + + auto module_vals = v.at("modules"); + for(const auto& vv : module_vals) + { + const auto& name = vv.get_key(); + if(name == "main") + continue; + impl->modules.emplace(name, name); + } + std::unordered_map map_mods; + std::transform(impl->modules.begin(), + impl->modules.end(), + std::inserter(map_mods, map_mods.end()), + [&](auto&& pp) { return std::make_pair(pp.first, &pp.second); }); + + std::unordered_map map_insts; + auto* mm = get_main_module(); + mod_from_val(mm, module_vals, map_insts, map_mods); + + this->finalize(); +} + double common_average(const std::vector& v) { std::size_t n = v.size() / 4; @@ -468,10 +577,38 @@ double common_average(const std::vector& v) return total / std::distance(v.begin() + n, v.end() - n); } -void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) const +std::string perf_group(const operation& op) +{ + auto attr = op.attributes(); + if(attr.contains("group")) + return attr.at("group").to(); + return op.name(); +} + +void program::mark(const parameter_map& params, marker&& m) { - using milliseconds = std::chrono::duration; - auto& ctx = this->impl->ctx; + auto& ctx = this->impl->ctx; + // Run once by itself + eval(params); + ctx.finish(); + // Start marking + m.mark_start(*this); + generic_eval(*this, ctx, params, always([&](auto ins, auto f) { + argument result; + m.mark_start(ins); + result = f(); + m.mark_stop(ins); + return result; + })); + m.mark_stop(*this); +} + +void program::perf_report(std::ostream& os, + std::size_t n, + parameter_map params, + std::size_t batch) const +{ + auto& ctx = this->impl->ctx; // Run once by itself eval(params); ctx.finish(); @@ -488,21 +625,22 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) std::sort(total_vec.begin(), total_vec.end()); std::unordered_map> ins_vec; // Fill the map - generic_eval(*this, ctx, params, [&](auto ins, auto) { + generic_eval(*this, ctx, params, always([&](auto ins, auto) { ins_vec[ins].reserve(n); - return argument{}; - }); + return argument{ins->get_shape(), nullptr}; + })); + // Run and time each instruction for(std::size_t i = 0; i < n; i++) { - generic_eval(*this, ctx, params, [&](auto ins, auto f) { + generic_eval(*this, ctx, params, always([&](auto ins, auto f) { argument result; ins_vec[ins].push_back(time([&] { result = f(); ctx.finish(); })); return result; - }); + })); } for(auto&& p : ins_vec) std::sort(p.second.begin(), p.second.end()); @@ -523,14 +661,20 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) for(auto&& p : ins_vec) { double avg = common_average(p.second); - op_times[p.first->name()] += avg; + op_times[perf_group(p.first->get_operator())] += avg; total_instruction_time += avg; } double calculate_overhead_time = total_time - total_instruction_time; double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time; - print_program(*this, [&](auto ins, const auto& names) { - print_instruction(std::cout, ins, names); + std::unordered_map names; + this->print(names, [&](auto ins, auto ins_names) { + instruction::print(std::cout, ins, ins_names); + + // skip return instruction + if(ins->name() == "@return") + return; + double avg = common_average(ins_vec[ins]); double percent = std::ceil(100.0 * avg / total_instruction_time); os << ": " << avg << "ms, " << percent << "%"; @@ -555,7 +699,8 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) os << std::endl; - os << "Rate: " << rate << "/sec" << std::endl; + os << "Batch size: " << batch << std::endl; + os << "Rate: " << rate * batch << "/sec" << std::endl; os << "Total time: " << total_time << "ms" << std::endl; os << "Total instructions time: " << total_instruction_time << "ms" << std::endl; os << "Overhead time: " << overhead_time << "ms" @@ -567,86 +712,231 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) void program::debug_print() const { std::cout << *this << std::endl; } void program::debug_print(instruction_ref ins) const { - if(ins == this->end()) + std::unordered_map names; + if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) { + return is_end(pp.second.end(), ins); + })) { std::cout << "End instruction" << std::endl; return; } - if(not has_instruction(ins)) + else if(std::none_of(this->impl->modules.begin(), + this->impl->modules.end(), + [&](const auto& pp) { return pp.second.has_instruction(ins); })) { std::cout << "Instruction not part of program" << std::endl; return; } + std::stringstream ss; - print_program(*this, [&](auto x, const auto& names) { + this->print(names, [&](auto x, auto ins_names) { if(x == ins) { - print_instruction(std::cout, x, names); + instruction::print(std::cout, x, ins_names); std::cout << std::endl; } }); } -void program::debug_print(const std::vector& inss) const + +void program::print( + std::unordered_map& names, + const std::function)>& + print_func) const { - for(auto ins : inss) - debug_print(ins); - std::cout << std::endl; + for(const auto& pp : this->impl->modules) + { + names = pp.second.print(print_func, names); + } } -static std::string enclose_name(const std::string& name) +void program::print( + const std::function)>& print_func) const { - return '"' + replace_string(name, "\"", "\\\"") + '"'; + std::unordered_map names; + this->print(names, print_func); } void program::print_graph(std::ostream& os, bool brief) const { - os << "digraph {" << std::endl; - os << "\trankdir=LR;" << std::endl; - print_program(*this, [&](auto ins, const auto& names) { - std::string label; - if(brief) - label = ins->name(); - else - label = to_string(ins->get_operator()); - os << "\t" << enclose_name(names.at(ins)) << "[label=" << enclose_name(label) << "]"; - os << ";" << std::endl; - if(!ins->inputs().empty()) + const auto* mm = this->get_main_module(); + mm->print_graph(os, brief); +} + +void program::print_cpp(std::ostream& os) const +{ + auto vec_modules = this->get_modules(); + std::unordered_map names; + for(auto& mod : vec_modules) + { + os << "module: \"" << mod->name() << "\"" << std::endl; + names = mod->print_cpp(os, names); + os << std::endl; + } +} + +void program::dry_run(std::unordered_map params) const +{ + auto& ctx = this->impl->ctx; + generic_eval(*this, ctx, std::move(params), always([](auto ins, auto&&...) { + return argument{ins->get_shape(), nullptr}; + })); +} + +void program::annotate(std::ostream& os, const std::function& a) const +{ + for(auto& pp : this->impl->modules) + { + std::cout << pp.first << ":" << std::endl; + pp.second.annotate(os, a); + } +} + +const module* program::get_module(const std::string& name) const { return &impl->modules.at(name); } + +module* program::create_module(const std::string& name) +{ + assert(not contains(impl->modules, name)); + auto r = impl->modules.emplace(name, name); + return &(r.first->second); +} + +module* program::get_module(const std::string& name) { return &impl->modules.at(name); } + +module* program::get_main_module() { return get_module("main"); } + +const module* program::get_main_module() const { return get_module("main"); } + +template +std::vector generic_get_modules(T* mm) +{ + std::vector vec_modules; + vec_modules.push_back(mm); + auto sub_modules = mm->get_sub_modules(); + vec_modules.insert(vec_modules.end(), sub_modules.begin(), sub_modules.end()); + return vec_modules; +} + +template +void generic_get_unused_modules(Map& m, const std::vector& mods, OutputIterator out) +{ + std::unordered_set used; + std::transform(mods.begin(), mods.end(), std::inserter(used, used.end()), [](auto&& mod) { + return mod->name(); + }); + transform_if( + m.begin(), + m.end(), + out, + [&](auto&& pp) { return not contains(used, pp.first); }, + [](auto&& pp) { return &pp.second; }); +} + +std::vector program::get_modules() const +{ + auto result = generic_get_modules(this->get_main_module()); + generic_get_unused_modules(impl->modules, result, std::back_inserter(result)); + return result; +} + +std::vector program::get_modules() +{ + auto result = generic_get_modules(this->get_main_module()); + generic_get_unused_modules(impl->modules, result, std::back_inserter(result)); + return result; +} + +template +bool is_unused_module(Map& m, const std::vector& mods, const std::string& name) +{ + bool is_unused = false; + generic_get_unused_modules(m, mods, make_function_output_iterator([&](auto* mod) { + if(mod->name() == name) + is_unused = true; + })); + return is_unused; +} + +template +bool references_instruction(Map& m, const instruction& ins, const std::string& name) +{ + return std::any_of(m.begin(), m.end(), [&](auto&& p) { + if(p.first == name) + return false; + return std::any_of(p.second.begin(), p.second.end(), [&](auto&& i) { + return std::any_of(i.inputs().begin(), i.inputs().end(), [&](auto&& j) { + return std::addressof(*j) == std::addressof(ins); + }); + }); + }); +} + +void program::remove_module(const std::string& name) +{ + // cppcheck-suppress assertWithSideEffect + assert(is_unused_module(impl->modules, generic_get_modules(this->get_main_module()), name) && + "Module used in program"); + assert(std::none_of( + impl->modules.at(name).begin(), + impl->modules.at(name).end(), + [&](auto&& ins) { return references_instruction(impl->modules, ins, name); }) && + "Instruction referenced in another module"); + + // if an instruction has an input out side of the current module, need to remove + // the instruction from its input's outputs + auto& mod = impl->modules.at(name); + for(auto ins : iterator_for(mod)) + { + auto inputs = ins->inputs(); + for(auto in : inputs) { - for(auto&& arg : ins->inputs()) + if(not mod.has_instruction(in)) { - os << "\t" << enclose_name(names.at(arg)) << " -> " << enclose_name(names.at(ins)); - if(not brief) - os << "[label=" << enclose_name(to_string(ins->get_shape())) << "]"; - os << ";" << std::endl; + in->remove_output(ins); } } - }); - os << "}" << std::endl; + } + + impl->modules.erase(name); } -void program::dry_run(std::unordered_map params) const +void program::remove_unused_modules() { - auto& ctx = this->impl->ctx; - generic_eval(*this, ctx, std::move(params), [](auto&&...) { return argument{}; }); + std::vector unused; + generic_get_unused_modules( + impl->modules, generic_get_modules(this->get_main_module()), std::back_inserter(unused)); + for(auto* m : unused) + this->remove_module(m->name()); } -void program::annotate(std::ostream& os, std::function a) const +program& program::sort() { - print_program(*this, [&](auto ins, const auto& names) { - print_instruction(os, ins, names); - a(ins); - os << std::endl; - }); + for(auto& pp : this->impl->modules) + { + pp.second.sort(); + } + + return *this; } bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); } std::ostream& operator<<(std::ostream& os, const program& p) { - print_program(p, [&](auto ins, const auto& names) { - print_instruction(os, ins, names); + auto vec_modules = p.get_modules(); + std::unordered_map names; + for(auto& mod : vec_modules) + { + os << "module: \"" << mod->name() << "\"" << std::endl; + names = mod->print( + [&](auto ins, auto ins_names) { + instruction::print(os, ins, ins_names); + os << std::endl; + }, + names); os << std::endl; - }); + } + return os; } diff --git a/src/propagate_constant.cpp b/src/propagate_constant.cpp index ab3285b6d77ed1cd37b81f84c61616505e623875..c65283b1e5f3dd18c75129ee6dfb738c75ca0fa7 100644 --- a/src/propagate_constant.cpp +++ b/src/propagate_constant.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace migraphx { @@ -20,33 +21,42 @@ bool skip_propogate(instruction_ref ins) return false; } -void propagate_constant::apply(program& p) const +bool is_const(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); } + +void propagate_constant::apply(module& m) const { - for(auto i : iterator_for(p)) + std::unordered_set const_instrs; + auto last = std::prev(m.end()); + + // Find instructions that can be evaluated to a literal + for(auto i : iterator_for(m)) { - if(i->name() != "@literal") + if(is_const(i) and i != last) continue; - if(i->outputs().empty()) - continue; - fix([&](auto self, auto ins) { - std::unordered_set children(ins->outputs().begin(), - ins->outputs().end()); - for(auto child : children) - { - if(child->name() == "@literal" or skip_propogate(child)) - { - self(child); - continue; - } - auto r = child->eval(); - if(not r.empty()) - { - assert(r.get_shape() == child->get_shape()); - auto l = p.add_literal(r.get_shape(), r.data()); - self(p.replace_instruction(child, l)); - } - } - })(i); + + std::copy_if( + i->inputs().begin(), + i->inputs().end(), + std::inserter(const_instrs, const_instrs.begin()), + [&](const instruction_ref ins) { return is_const(ins) and ins->name() != "@literal"; }); + } + + // Compute literals in parallel + std::vector const_instrs_vec{const_instrs.begin(), const_instrs.end()}; + std::vector literals(const_instrs_vec.size()); + par_for(const_instrs_vec.size(), 1, [&](const auto i) { + literals[i] = const_instrs_vec[i]->eval(); + }); + + // Replace instructions in m + for(size_t i = 0; i < const_instrs_vec.size(); i++) + { + if(not literals[i].empty()) + { + assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape()); + auto l = m.add_literal(literals[i].get_shape(), literals[i].data()); + m.replace_instruction(const_instrs_vec[i], l); + } } } diff --git a/src/py/CMakeLists.txt b/src/py/CMakeLists.txt index 69f376c38a89866871c8c9edc08cabfc9153c8a5..dac747838320e680f1d4b9c754520352bf22778a 100644 --- a/src/py/CMakeLists.txt +++ b/src/py/CMakeLists.txt @@ -1,21 +1,14 @@ option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON) if(MIGRAPHX_ENABLE_PYTHON) - find_program(DEFAULT_PYTHON_EXE python) - if(DEFAULT_PYTHON_EXE) - set(PYTHON_EXECUTABLE ${DEFAULT_PYTHON_EXE} CACHE PATH "Path to python executable") - endif() - find_package(pybind11 REQUIRED) - pybind11_add_module(migraphx_py migraphx_py.cpp) - set_target_properties(migraphx_py PROPERTIES - OUTPUT_NAME migraphx - C_VISIBILITY_PRESET hidden - CXX_VISIBILITY_PRESET hidden - ) - target_link_libraries(migraphx_py PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_cpu) - if(MIGRAPHX_ENABLE_GPU) - target_link_libraries(migraphx_py PRIVATE migraphx_gpu) - target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU) - endif() - rocm_install_targets(TARGETS migraphx_py) + include(PythonModules) + + add_custom_target(migraphx_py) + + foreach(PYTHON_VERSION ${PYTHON_VERSIONS}) + py_add_module(migraphx_py_${PYTHON_VERSION} migraphx_py.cpp PYTHON_VERSION ${PYTHON_VERSION} PYTHON_MODULE migraphx) + target_link_libraries(migraphx_py_${PYTHON_VERSION} PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets) + rocm_install_targets(TARGETS migraphx_py_${PYTHON_VERSION}) + add_dependencies(migraphx_py migraphx_py_${PYTHON_VERSION}) + endforeach() endif() diff --git a/src/py/backend/__init__.py b/src/py/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/py/backend/__init__.py @@ -0,0 +1 @@ + diff --git a/src/py/backend/backend.py b/src/py/backend/backend.py new file mode 100755 index 0000000000000000000000000000000000000000..a2271f1ed6cfb435cb9df34e16fec5a06654ceda --- /dev/null +++ b/src/py/backend/backend.py @@ -0,0 +1,135 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Advanced Micro Devices. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +""" +Implements ONNX's backend API. +""" +import sys +if sys.version_info < (3, 0): + sys.exit() + +from onnx import ModelProto +from onnx.checker import check_model +from onnx.backend.base import Backend +import migraphx +from onnx_migraphx.backend_rep import MIGraphXBackendRep + + +def get_device(): + return ("CPU", "GPU") + + +class MIGraphXBackend(Backend): + _device = "GPU" + _input_names = [] + _prog_string = "" + + @classmethod + def set_device(cls, device): + cls._device = device + """ + Implements + `ONNX's backend API `_ + with *ONNX Runtime*. + The backend is mostly used when you need to switch between + multiple runtimes with the same API. + `Importing models from ONNX to Caffe2 `_ + shows how to use *caffe2* as a backend for a converted model. + Note: This is not the official Python API. + """ # noqa: E501 + + @classmethod + def get_program(cls): + return cls._prog_string + + @classmethod + def is_compatible(cls, model, device=None, **kwargs): + """ + Return whether the model is compatible with the backend. + + :param model: unused + :param device: None to use the default device or a string (ex: `'CPU'`) + :return: boolean + """ + device = cls._device + return cls.supports_device(device) + + @classmethod + def supports_device(cls, device): + """ + Check whether the backend is compiled with particular device support. + In particular it's used in the testing suite. + """ + return device in get_device() + + @classmethod + def prepare(cls, model, device=None, **kwargs): + """ + Load the model and creates a :class:`migraphx.program` + ready to be used as a backend. + + :param model: ModelProto (returned by `onnx.load`), + string for a filename or bytes for a serialized model + :param device: requested device for the computation, + None means the default one which depends on + the compilation settings + :param kwargs: see :class:`onnxruntime.SessionOptions` + :return: :class:`migraphx.program` + """ + if isinstance(model, MIGraphXBackendRep): + return model + elif isinstance(model, migraphx.program): + return MIGraphXBackendRep(model, cls._input_names) + elif isinstance(model, (str, bytes)): + if device is not None and not cls.supports_device(device): + raise RuntimeError( + "Incompatible device expected '{0}', got '{1}'".format( + device, get_device())) + inf = migraphx.parse_onnx_buffer(model) + cls._prog_string = str("\nProgram =\n{}".format(inf)) + device = cls._device + cls._input_names = inf.get_parameter_names() + inf.compile(migraphx.get_target(device.lower())) + cls._prog_string = cls._prog_string + str( + "\nCompiled program =\n{}".format(inf)) + + return cls.prepare(inf, device, **kwargs) + else: + # type: ModelProto + check_model(model) + bin = model.SerializeToString() + return cls.prepare(bin, device, **kwargs) + + @classmethod + def run_model(cls, model, inputs, device=None, **kwargs): + """ + Compute the prediction. + + :param model: :class:`migraphx.program` returned + by function *prepare* + :param inputs: inputs + :param device: requested device for the computation, + None means the default one which depends on + the compilation settings + :param kwargs: see :class:`migraphx.program` + :return: predictions + """ + rep = cls.prepare(model, device, **kwargs) + return rep.run(inputs, **kwargs) + + @classmethod + def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): + ''' + This method is not implemented as it is much more efficient + to run a whole model than every node independently. + ''' + raise NotImplementedError( + "It is much more efficient to run a whole model than every node independently." + ) + + +is_compatible = MIGraphXBackend.is_compatible +prepare = MIGraphXBackend.prepare +run = MIGraphXBackend.run_model +supports_device = MIGraphXBackend.supports_device diff --git a/src/py/backend/backend_rep.py b/src/py/backend/backend_rep.py new file mode 100755 index 0000000000000000000000000000000000000000..64a5701dd186bc08e3cbed6078519be514e525df --- /dev/null +++ b/src/py/backend/backend_rep.py @@ -0,0 +1,54 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Advanced Micro Device Inc. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +""" +Implements ONNX's backend API. +""" +import sys +if sys.version_info < (3, 0): + sys.exit() + +import migraphx +from onnx.backend.base import BackendRep +import numpy as np +from typing import Any, Tuple + + +class MIGraphXBackendRep(BackendRep): + """ + Computes the prediction for a pipeline converted into + an :class:`onnxruntime.InferenceSession` node. + """ + def __init__(self, prog, input_names): + """ + :param session: :class:`migraphx.program` + """ + self._program = prog + self._input_names = input_names + + def run(self, inputs, **kwargs): # type: (Any, **Any) -> Tuple[Any, ...] + """ + Computes the prediction. + See :meth:`migraphx.program.run`. + """ + + if isinstance(inputs, list): + inps = {} + for i, name in enumerate(self._input_names): + inps[name] = migraphx.argument(inputs[i]) + mgx_outputs = self._program.run(inps) + outs = [] + for out in mgx_outputs: + outs.append(np.array(out)) + return outs + else: + inp = self._program.get_parameter_shapes().keys() + if len(inp) != 1: + raise RuntimeError("Model expect {0} inputs".format(len(inp))) + inps = {inp[0]: migraphx.argument(inputs)} + mgx_outputs = self._program.run(inps) + outs = [] + for out in mgx_outputs: + outs.append(np.array(out)) + return self._program.run(inps) diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 91ceb3fa7a3e89366ff45ee67b4ace28d733f6a7..eba9c085c245bab99071b24be8e20c25c86a0a0e 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -1,76 +1,140 @@ #include #include +#include #include +#include +#include #include #include -#include +#include +#include #include #include #include #include +#include +#include +#include +#include #ifdef HAVE_GPU -#include #include #endif +using half = half_float::half; namespace py = pybind11; -template -struct throw_half -{ - F f; +#ifdef __clang__ +#define MIGRAPHX_PUSH_UNUSED_WARNING \ + _Pragma("clang diagnostic push") \ + _Pragma("clang diagnostic ignored \"-Wused-but-marked-unused\"") +#define MIGRAPHX_POP_WARNING _Pragma("clang diagnostic pop") +#else +#define MIGRAPHX_PUSH_UNUSED_WARNING +#define MIGRAPHX_POP_WARNING +#endif +#define MIGRAPHX_PYBIND11_MODULE(...) \ + MIGRAPHX_PUSH_UNUSED_WARNING \ + PYBIND11_MODULE(__VA_ARGS__) \ + MIGRAPHX_POP_WARNING - template - void operator()(A a) const +namespace migraphx { + +migraphx::value to_value(py::kwargs kwargs); +migraphx::value to_value(py::list lst); + +template +void visit_py(T x, F f) +{ + if(py::isinstance(x)) { - f(a); + f(to_value(x.template cast())); } - - void operator()(migraphx::shape::as) const + else if(py::isinstance(x)) + { + f(to_value(x.template cast())); + } + else if(py::isinstance(x)) + { + f(x.template cast()); + } + else if(py::isinstance(x)) { - throw std::runtime_error("Half not supported in python yet."); + f(x.template cast()); } + else if(py::isinstance(x)) + { + f(x.template cast()); + } + else if(py::isinstance(x)) + { + f(x.template cast()); + } + else + { + MIGRAPHX_THROW("VISIT_PY: Unsupported data type!"); + } +} - void operator()(migraphx::tensor_view) const +migraphx::value to_value(py::list lst) +{ + migraphx::value v = migraphx::value::array{}; + for(auto val : lst) { - throw std::runtime_error("Half not supported in python yet."); + visit_py(val, [&](auto py_val) { v.push_back(py_val); }); } -}; -template -struct skip_half + return v; +} + +migraphx::value to_value(py::kwargs kwargs) { - F f; + migraphx::value v = migraphx::value::object{}; - template - void operator()(A a) const + for(auto arg : kwargs) { - f(a); + auto&& key = py::str(arg.first); + auto&& val = arg.second; + visit_py(val, [&](auto py_val) { v[key] = py_val; }); } + return v; +} +} // namespace migraphx - void operator()(migraphx::shape::as) const {} +namespace pybind11 { +namespace detail { - void operator()(migraphx::tensor_view) const {} +template <> +struct npy_format_descriptor +{ + static std::string format() + { + // following: https://docs.python.org/3/library/struct.html#format-characters + return "e"; + } + static constexpr auto name() { return _("half"); } }; +} // namespace detail +} // namespace pybind11 + template void visit_type(const migraphx::shape& s, F f) { - s.visit_type(throw_half{f}); + s.visit_type(f); } template void visit(const migraphx::raw_data& x, F f) { - x.visit(throw_half{f}); + x.visit(f); } template void visit_types(F f) { - migraphx::shape::visit_types(skip_half{f}); + migraphx::shape::visit_types(f); } template @@ -82,12 +146,26 @@ py::buffer_info to_buffer_info(T& x) strides.begin(), strides.end(), strides.begin(), [&](auto i) { return i * s.type_size(); }); py::buffer_info b; visit_type(s, [&](auto as) { - b = py::buffer_info(x.data(), - as.size(), - py::format_descriptor::format(), - s.lens().size(), - s.lens(), - strides); + // migraphx use int8_t data to store bool type, we need to + // explicitly specify the data type as bool for python + if(s.type() == migraphx::shape::bool_type) + { + b = py::buffer_info(x.data(), + as.size(), + py::format_descriptor::format(), + s.lens().size(), + s.lens(), + strides); + } + else + { + b = py::buffer_info(x.data(), + as.size(), + py::format_descriptor::format(), + s.lens().size(), + s.lens(), + strides); + } }); return b; } @@ -97,34 +175,59 @@ migraphx::shape to_shape(const py::buffer_info& info) migraphx::shape::type_t t; std::size_t n = 0; visit_types([&](auto as) { - if(info.format == py::format_descriptor::format()) + if(info.format == py::format_descriptor::format() or + (info.format == "l" and py::format_descriptor::format() == "q") or + (info.format == "L" and py::format_descriptor::format() == "Q")) { t = as.type_enum(); n = sizeof(as()); } + else if(info.format == "?" and py::format_descriptor::format() == "b") + { + t = migraphx::shape::bool_type; + n = sizeof(bool); + } }); if(n == 0) { - MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type" + info.format); + MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type " + info.format); } auto strides = info.strides; std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t { return n > 0 ? i / n : 0; }); - return migraphx::shape{t, info.shape, strides}; + + // scalar support + if(info.shape.empty()) + { + return migraphx::shape{t}; + } + else + { + return migraphx::shape{t, info.shape, strides}; + } } -PYBIND11_MODULE(migraphx, m) +MIGRAPHX_PYBIND11_MODULE(migraphx, m) { py::class_(m, "shape") - .def(py::init<>()) + .def(py::init([](py::kwargs kwargs) { + auto v = migraphx::to_value(kwargs); + auto t = migraphx::shape::parse_type(v.get("type", "float")); + auto lens = v.get("lens", {1}); + if(v.contains("strides")) + return migraphx::shape(t, lens, v.at("strides").to_vector()); + else + return migraphx::shape(t, lens); + })) .def("type", &migraphx::shape::type) .def("lens", &migraphx::shape::lens) .def("strides", &migraphx::shape::strides) .def("elements", &migraphx::shape::elements) .def("bytes", &migraphx::shape::bytes) + .def("type_string", &migraphx::shape::type_string) .def("type_size", &migraphx::shape::type_size) .def("packed", &migraphx::shape::packed) .def("transposed", &migraphx::shape::transposed) @@ -155,41 +258,183 @@ PYBIND11_MODULE(migraphx, m) py::class_(m, "target"); + py::class_(m, "instruction_ref"); + + py::class_>(m, "module") + .def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; }) + .def( + "add_instruction", + [](migraphx::module& mm, + const migraphx::operation& op, + std::vector& args, + std::vector& mod_args) { + return mm.add_instruction(op, args, mod_args); + }, + py::arg("op"), + py::arg("args"), + py::arg("mod_args") = std::vector{}) + .def( + "add_literal", + [](migraphx::module& mm, py::buffer data) { + py::buffer_info info = data.request(); + auto literal_shape = to_shape(info); + return mm.add_literal(literal_shape, reinterpret_cast(info.ptr)); + }, + py::arg("data")) + .def( + "add_parameter", + [](migraphx::module& mm, const std::string& name, const migraphx::shape shape) { + return mm.add_parameter(name, shape); + }, + py::arg("name"), + py::arg("shape")) + .def( + "add_return", + [](migraphx::module& mm, std::vector& args) { + return mm.add_return(args); + }, + py::arg("args")) + .def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); }); + py::class_(m, "program") - .def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); }) + .def(py::init([]() { return migraphx::program(); })) + .def("get_parameter_names", &migraphx::program::get_parameter_names) .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) - .def("get_shape", &migraphx::program::get_shape) - .def("compile", - [](migraphx::program& p, const migraphx::target& t, bool offload_copy) { - migraphx::compile_options options; - options.offload_copy = offload_copy; - p.compile(t, options); - }, - py::arg("t"), - py::arg("offload_copy") = true) - .def("run", &migraphx::program::eval) + .def("get_output_shapes", &migraphx::program::get_output_shapes) + .def( + "compile", + [](migraphx::program& p, const migraphx::target& t, bool offload_copy, bool fast_math) { + migraphx::compile_options options; + options.offload_copy = offload_copy; + options.fast_math = fast_math; + p.compile(t, options); + }, + py::arg("t"), + py::arg("offload_copy") = true, + py::arg("fast_math") = true) + .def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); }) + .def( + "create_module", + [](migraphx::program& p, const std::string& name) { return p.create_module(name); }, + py::arg("name")) + .def("run", + [](migraphx::program& p, py::dict params) { + migraphx::parameter_map pm; + for(auto x : params) + { + std::string key = x.first.cast(); + py::buffer b = x.second.cast(); + py::buffer_info info = b.request(); + pm[key] = migraphx::argument(to_shape(info), info.ptr); + } + return p.eval(pm); + }) + .def("sort", &migraphx::program::sort) + .def("print", [](const migraphx::program& p) { std::cout << p << std::endl; }) .def("__eq__", std::equal_to{}) .def("__ne__", std::not_equal_to{}) .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); }); - m.def("parse_tf", - &migraphx::parse_tf, - "Parse tf protobuf (default format is nhwc)", - py::arg("filename"), - py::arg("is_nhwc") = true); - m.def("parse_onnx", &migraphx::parse_onnx); + py::class_(m, "op") + .def(py::init([](const std::string& name, py::kwargs kwargs) { + migraphx::value v = migraphx::value::object{}; + if(kwargs) + { + v = migraphx::to_value(kwargs); + } + return migraphx::make_op(name, v); + })) - m.def("get_target", [](const std::string& name) -> migraphx::target { - if(name == "cpu") - return migraphx::cpu::target{}; -#ifdef HAVE_GPU - if(name == "gpu") - return migraphx::gpu::target{}; -#endif - throw std::runtime_error("Target not found: " + name); - }); + .def("name", &migraphx::operation::name); + + m.def( + "parse_tf", + [](const std::string& filename, + bool is_nhwc, + unsigned int batch_size, + std::unordered_map> map_input_dims, + std::vector output_names) { + return migraphx::parse_tf( + filename, migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names}); + }, + "Parse tf protobuf (default format is nhwc)", + py::arg("filename"), + py::arg("is_nhwc") = true, + py::arg("batch_size") = 1, + py::arg("map_input_dims") = std::unordered_map>(), + py::arg("output_names") = std::vector()); + + m.def( + "parse_onnx", + [](const std::string& filename, + unsigned int default_dim_value, + std::unordered_map> map_input_dims, + bool skip_unknown_operators, + bool print_program_on_error, + int64_t max_loop_iterations) { + migraphx::onnx_options options; + options.default_dim_value = default_dim_value; + options.map_input_dims = map_input_dims; + options.skip_unknown_operators = skip_unknown_operators; + options.print_program_on_error = print_program_on_error; + options.max_loop_iterations = max_loop_iterations; + return migraphx::parse_onnx(filename, options); + }, + "Parse onnx file", + py::arg("filename"), + py::arg("default_dim_value") = 1, + py::arg("map_input_dims") = std::unordered_map>(), + py::arg("skip_unknown_operators") = false, + py::arg("print_program_on_error") = false, + py::arg("max_loop_iterations") = 10); + + m.def( + "parse_onnx_buffer", + [](const std::string& onnx_buffer, + unsigned int default_dim_value, + std::unordered_map> map_input_dims, + bool skip_unknown_operators, + bool print_program_on_error) { + migraphx::onnx_options options; + options.default_dim_value = default_dim_value; + options.map_input_dims = map_input_dims; + options.skip_unknown_operators = skip_unknown_operators; + options.print_program_on_error = print_program_on_error; + return migraphx::parse_onnx_buffer(onnx_buffer, options); + }, + "Parse onnx file", + py::arg("filename"), + py::arg("default_dim_value") = 1, + py::arg("map_input_dims") = std::unordered_map>(), + py::arg("skip_unknown_operators") = false, + py::arg("print_program_on_error") = false); + + m.def( + "load", + [](const std::string& name, const std::string& format) { + migraphx::file_options options; + options.format = format; + return migraphx::load(name, options); + }, + "Load MIGraphX program", + py::arg("filename"), + py::arg("format") = "msgpack"); + + m.def( + "save", + [](const migraphx::program& p, const std::string& name, const std::string& format) { + migraphx::file_options options; + options.format = format; + return migraphx::save(p, name, options); + }, + "Save MIGraphX program", + py::arg("p"), + py::arg("filename"), + py::arg("format") = "msgpack"); + m.def("get_target", &migraphx::make_target); m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0); + m.def("fill_argument", &migraphx::fill_argument, py::arg("s"), py::arg("value")); m.def("quantize_fp16", &migraphx::quantize_fp16, py::arg("prog"), @@ -198,14 +443,14 @@ PYBIND11_MODULE(migraphx, m) &migraphx::quantize_int8, py::arg("prog"), py::arg("t"), - py::arg("calibration") = std::vector{}, + py::arg("calibration") = std::vector{}, py::arg("ins_names") = std::vector{"dot", "convolution"}); #ifdef HAVE_GPU m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); m.def("to_gpu", &migraphx::gpu::to_gpu, py::arg("arg"), py::arg("host") = false); m.def("from_gpu", &migraphx::gpu::from_gpu); - m.def("gpu_sync", &migraphx::gpu::gpu_sync); + m.def("gpu_sync", [] { migraphx::gpu::gpu_sync(); }); #endif #ifdef VERSION_INFO diff --git a/src/quantization.cpp b/src/quantization.cpp index 1fef91efd1eaa7847dc910af9e9b486f794f98a0..acb0ad71e4f789e0ec497d8f2eebc75ee1a48186 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -1,100 +1,28 @@ +#include +#include #include +#include +#include +#include +#include +#include +#include #include #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include +#include #include #include -#include +#include +#include #include -#include -#include -#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS) -instruction_ref insert_quant_ins(program& prog, - instruction_ref& ins, - shape::type_t type, - std::unordered_map& map_ins, - float scale = 1.0f, - float shift = 0.0f) -{ - if(map_ins.count(ins) > 0) - { - return map_ins[ins]; - } - - if(ins->name() == "undefined") - { - return ins; - } - - assert(ins->get_shape().type() == shape::float_type or - ins->get_shape().type() == shape::double_type or - ins->get_shape().type() == shape::int32_type or - ins->get_shape().type() == shape::half_type); - instruction_ref quant_ins{}; - auto insert_loc = std::next(ins); - if(type == shape::int8_type) - { - auto scaled_ins = ins; - if(scale != 1.0f) - { - auto float_ins = scaled_ins; - if(scaled_ins->get_shape().type() != shape::float_type) - { - float_ins = - prog.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins); - } - std::vector vec_scale(scaled_ins->get_shape().elements(), scale); - auto l_scale = prog.add_literal(literal(float_ins->get_shape(), vec_scale)); - scaled_ins = prog.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins); - } - - auto shifted_ins = scaled_ins; - if(shift != 0.0f) - { - auto float_ins = shifted_ins; - if(shifted_ins->get_shape().type() != shape::float_type) - { - float_ins = prog.insert_instruction( - insert_loc, op::convert{shape::float_type}, shifted_ins); - } - std::vector vec_shift(shifted_ins->get_shape().elements(), shift); - auto l_shift = prog.add_literal(literal(float_ins->get_shape(), vec_shift)); - shifted_ins = prog.insert_instruction(insert_loc, op::add{}, l_shift, float_ins); - } - - auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, shifted_ins); - auto clipped_ins = - prog.insert_instruction(insert_loc, op::clip{127.0f, -128.0f}, rounded_ins); - quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, clipped_ins); - } - else - { - quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, ins); - } - - map_ins[ins] = quant_ins; - - return quant_ins; -} - // This function is to convert any instructions specified in the input // from double or float to float16 by inserting a convert operator. // For the conversion, there could be cases of overflowing, but it @@ -102,402 +30,29 @@ instruction_ref insert_quant_ins(program& prog, // truncate of the input to get the fp16. void quantize_fp16(program& prog, const std::vector& ins_names) { - std::unordered_map map_fp16; - for(auto ins : iterator_for(prog)) - { - // all indicates every instruction is converted - if((not contains(ins_names, "all")) and (not contains(ins_names, ins->name()))) - { - continue; - } - - shape::type_t orig_type = ins->get_shape().type(); - // process all inputs, if input is a fp32 or fp64, convert it - // to a fp16 by adding a convert operator. - auto inputs = ins->inputs(); - std::vector converted_inputs; - for(auto input : inputs) - { - auto s = input->get_shape(); - if(s.type() == shape::float_type || s.type() == shape::double_type) - { - // if the input is a convert operator, uses its input - // as its current input - instruction_ref input_fp16{}; - if(input->name() == "convert" and - input->inputs().front()->get_shape().type() == shape::half_type) - { - input_fp16 = input->inputs().front(); - } - else - { - input_fp16 = insert_quant_ins(prog, input, shape::half_type, map_fp16); - } - converted_inputs.push_back(input_fp16); - } - else - { - converted_inputs.push_back(input); - } - } - - // no change for the input, go to the next instruction - if(inputs == converted_inputs) - { - continue; - } - - auto op = ins->get_operator(); - auto ins_shape = compute_shape(op, converted_inputs); - if(ins_shape.type() != orig_type) - { - // check the dead code case to avoid assert - bool output_empty = ins->outputs().empty(); - auto ins_orig_type = - prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins); - if(!output_empty) - { - prog.replace_instruction(ins, ins_orig_type); - } - } - - prog.replace_instruction(ins, op, converted_inputs); - } -} - -static void ins_quantize_int8(program& prog, - instruction_ref ins, - std::vector& converted_inputs, - const std::vector>& ins_quant_params) -{ - auto orig_type = ins->get_shape().type(); - auto inputs = ins->inputs(); - if(ins->name() == "dot") - { - auto dot_op = any_cast(ins->get_operator()); - float new_alpha = dot_op.alpha / (ins_quant_params[0].first * ins_quant_params[1].first); - float new_beta = dot_op.beta; - // We need additional checking about the quant_alpha value. If - // abs(quant_alpha) > 50 (some tmp value set here), we can convert - // it to an integer as the new_alpha in the quant_dot - float threshold = 50.0f; - if(fabs(new_alpha) >= threshold && fabs(new_beta) >= threshold) - { - int32_t quant_alpha = static_cast(std::round(new_alpha)); - int32_t quant_beta = static_cast(std::round(new_beta)); - if(shape::int32_type == orig_type) - { - prog.replace_instruction( - ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); - } - else - { - auto quant_dot = prog.insert_instruction( - ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); - prog.replace_instruction(ins, op::convert{orig_type}, quant_dot); - } - } - // either alpha or beta cannot be quantized because of too big - // relative rounding error - else - { - if(converted_inputs.size() == 3) - { - converted_inputs.pop_back(); - } - auto q_dot = prog.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs); - auto f_dot = prog.insert_instruction(ins, op::convert{shape::float_type}, q_dot); - auto c_shape = q_dot->get_shape(); - std::vector vec_alpha(c_shape.elements(), new_alpha); - auto l_alpha = - prog.add_literal(literal({shape::float_type, c_shape.lens()}, vec_alpha)); - - if(inputs.size() == 3 and dot_op.beta != 0.0f) - { - auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, f_dot); - std::vector vec_beta(c_shape.elements(), dot_op.beta); - auto l_beta = - prog.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta)); - instruction_ref beta_c{}; - if(orig_type != shape::float_type) - { - auto fp32_c = - prog.insert_instruction(ins, op::convert{shape::float_type}, inputs.back()); - beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, fp32_c); - } - else - { - beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back()); - } - - if(orig_type == shape::float_type) - { - prog.replace_instruction(ins, op::add{}, alpha_ab, beta_c); - } - else - { - auto f_res = prog.insert_instruction(ins, op::add{}, alpha_ab, beta_c); - prog.replace_instruction(ins, op::convert{orig_type}, f_res); - } - } - else - { - if(orig_type == shape::float_type) - { - prog.replace_instruction(ins, op::mul{}, l_alpha, f_dot); - } - else - { - auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, f_dot); - prog.replace_instruction(ins, op::convert{orig_type}, alpha_ab); - } - } - } - } - else if(ins->name() == "convolution") - { - // Current MIOpen convolution does not support alpha and beta, - // so we need a separate multiply to adjust the output - auto conv_op = any_cast(ins->get_operator()); - auto padding = conv_op.padding; - auto stride = conv_op.stride; - auto dilation = conv_op.dilation; - auto padding_mode = conv_op.padding_mode; - auto group = conv_op.group; - auto adjust_factor = 1.0f / (ins_quant_params[0].first * ins_quant_params[1].first); - - auto quant_conv = prog.insert_instruction( - ins, - op::quant_convolution{padding, stride, dilation, padding_mode, group}, - converted_inputs); - float threshold = 50.0f; - std::vector vec_factor(quant_conv->get_shape().elements(), adjust_factor); - if(quant_conv->get_shape().type() == orig_type and adjust_factor >= threshold) - { - auto l_factor = prog.add_literal( - literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end())); - prog.replace_instruction(ins, op::mul{}, quant_conv, l_factor); - } - // convert quant_conv output to float type, multiply the factor and - // conver back to original type - else - { - auto float_conv = - prog.insert_instruction(ins, op::convert{shape::float_type}, quant_conv); - auto l_factor = prog.add_literal(literal(float_conv->get_shape(), vec_factor)); - if(orig_type == shape::float_type) - { - prog.replace_instruction(ins, op::mul{}, l_factor, float_conv); - } - else - { - auto adjusted_conv = prog.insert_instruction(ins, op::mul{}, l_factor, float_conv); - prog.replace_instruction(ins, op::convert{orig_type}, adjusted_conv); - } - } - } - else - { - MIGRAPHX_THROW("QUANTIZE_INT8: does not support operator " + ins->name()); - } -} - -// int8 quantization is different from fp16 since int8 can only handle value -// -128 ~ 127. To convert the float or double to int8, we need a scale and -// a shift, then the convert can be done as v_int8 = fp * scale + shift. -// To simplify the changes, we consider shift as 0.0f for now. -void quantize_int8_impl(program& prog, - const std::vector>& quant_params, - const std::vector& ins_names) -{ - if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{})) - { - for(std::size_t i = 0; i < quant_params.size(); ++i) - { - auto param = quant_params.at(i); - std::cout << "ins_index = " << i << ", scale = " << param.first - << ", shift = " << param.second << std::endl; - } - std::cout << std::endl; - } - - // For now, we only support the int8 quantization of gemm and convolution - std::set op_names = {"convolution", "dot"}; - std::set input_ins_names(ins_names.begin(), ins_names.end()); - if(!std::includes( - op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end())) - { - MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation"); - } - - std::size_t quant_param_index = 0; - std::unordered_map map_quant_ins; - std::unordered_map map_ins_index; - for(auto ins : iterator_for(prog)) - { - if(not contains(ins_names, ins->name())) - { - continue; - } - - // for the dot operator, there could be 2 or 3 input arguments - // if the 3rd argument is available, convert it to an int32. - std::vector converted_inputs; - - // process all inputs, if input is a fp32 or fp64, convert it - // to a int8 type by adding a convert operator and replace - // the operator with the corresponding int8 version - auto inputs = ins->inputs(); - std::vector> ins_quant_params; - for(auto input : inputs) - { - // calculate the index of each instruction to be quantized - std::size_t ins_index = - (map_ins_index.count(input) > 0) ? map_ins_index[input] : quant_param_index++; - map_ins_index[input] = ins_index; - - auto param = quant_params[map_ins_index[input]]; - ins_quant_params.push_back(param); - - // In general, the target_type is int8, but for the dot - // operation, if it has 3 inputs, then the last one should - // be converted to int32_type - shape::type_t quant_type = shape::int8_type; - if((ins->name() == "dot") and (inputs.size() == 3) and (input == inputs.back())) - { - quant_type = shape::int32_type; - } - - auto s = input->get_shape(); - if((s.type() == shape::float_type or s.type() == shape::double_type or - s.type() == shape::half_type or s.type() == shape::int32_type) and - s.type() != quant_type) - { - // if the input is a convert operator, uses its input - // as its current input - instruction_ref quant_input{}; - if(input->name() == "convert" and - input->inputs().front()->get_shape().type() == quant_type) - { - quant_input = input->inputs().front(); - // the scale in this case is not used, so tune the scale - // to 1.0f for this parameter - ins_quant_params.back() = std::pair(1.0f, 0.0f); - } - else - { - quant_input = insert_quant_ins( - prog, input, quant_type, map_quant_ins, param.first, param.second); - } - converted_inputs.push_back(quant_input); - } - else - { - converted_inputs.push_back(input); - } - } - - // no change for the input, go to the next instruction - if(inputs == converted_inputs) - { - continue; - } - - ins_quantize_int8(prog, ins, converted_inputs, ins_quant_params); - } - - if(quant_param_index != quant_params.size()) - { - MIGRAPHX_THROW("QUANTIZE_INT8: number of scales does not match"); - } + run_passes(prog, + {quantize_fp16_pass{ins_names}, + eliminate_common_subexpression{}, + dead_code_elimination{}, + simplify_reshapes{}, + dead_code_elimination{}, + simplify_qdq{}, + dead_code_elimination{}}); } void quantize_int8(program& prog, const target& t, - const std::vector& calibration, + const std::vector& calibration, const std::vector& ins_names) { - // insert capture operator - auto cap_prog = prog; - auto int8_quant_params = capture_arguments(cap_prog, t, ins_names); - - // use the calibration data to compute the quantization scale - cap_prog.compile(t); - - // use all calibration data to run the program to calculate the - // quantization scale and shift - for(auto&& arg : calibration) - { - program::parameter_map m; - for(auto&& x : cap_prog.get_parameter_shapes()) - { - if(arg.count(x.first) > 0) - { - assert(x.second == arg.at(x.first).get_shape()); - m[x.first] = t.copy_to(arg.at(x.first)); - } - else - { - m[x.first] = t.allocate(x.second); - } - } - cap_prog.eval(m); - } - - quantize_int8_impl(prog, *int8_quant_params, ins_names); -} - -// For the input of each input argument, we need to insert a -// capture operator to compute the scale and shift -std::size_t capture_arguments(program& prog, - const std::vector& ins_names, - const std::function)>& func) -{ - - size_t num_quant_params = 0; - // the int8 quantization only support dot and convolution - std::set op_names = {"dot", "convolution"}; + std::set op_names = {"convolution", "dot"}; std::set input_ins_names(ins_names.begin(), ins_names.end()); if(!std::includes( op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end())) { - MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported"); - } - - std::unordered_map ins_map; - for(auto ins : iterator_for(prog)) - { - if(not contains(ins_names, ins->name())) - { - continue; - } - - auto inputs = ins->inputs(); - std::vector new_args; - for(auto input : inputs) - { - instruction_ref new_ins{}; - if(ins_map.count(input) > 0) - { - new_ins = ins_map[input]; - } - else - { - new_ins = prog.insert_instruction( - std::next(input), op::capture{num_quant_params++, func}, input); - ins_map[input] = new_ins; - } - new_args.push_back(new_ins); - } - instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args); + MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation"); } - return num_quant_params; -} - -std::shared_ptr>> -capture_arguments_impl(program& prog, const target& t, const std::vector& ins_names) -{ std::shared_ptr>> int8_quant_params = std::make_shared>>(); std::shared_ptr> max_abs_vals = std::make_shared>(); @@ -505,7 +60,6 @@ capture_arguments_impl(program& prog, const target& t, const std::vector args) { std::pair param_pair{64.0f, 0.0f}; - // scale and shift is need for only int8 type, and we do not // consider shift, so set shift to 0 std::vector vec_val; @@ -528,12 +82,56 @@ capture_arguments_impl(program& prog, const target& t, const std::vectorat(ins_index) = param_pair; }; - auto num_params = capture_arguments(prog, ins_names, calc_quant_params); + // pass to add capture argument op + std::size_t param_num = 0; + run_passes(prog, {capture_arguments_pass{ins_names, calc_quant_params, ¶m_num}}); + int8_quant_params->resize(param_num, std::pair(64.0f, 0.0f)); + max_abs_vals->resize(param_num, 0.0f); + + // use the calibration data to compute the quantization scale + auto capture_prog = prog; + capture_prog.compile(t); + + // use all calibration data to run the program to calculate the + // quantization scale and shift + for(auto&& arg : calibration) + { + parameter_map m; + for(auto&& x : capture_prog.get_parameter_shapes()) + { + if(arg.count(x.first) > 0) + { + assert(x.second == arg.at(x.first).get_shape()); + m[x.first] = t.copy_to(arg.at(x.first)); + } + else + { + m[x.first] = t.allocate(x.second); + } + } + capture_prog.eval(m); + } - int8_quant_params->resize(num_params, std::pair(64.0f, 0.0f)); - max_abs_vals->resize(num_params, 0.0f); + // print the quantization parameters in only the main module + if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{})) + { + for(std::size_t i = 0; i < int8_quant_params->size(); ++i) + { + auto param = int8_quant_params->at(i); + std::cout << "ins_index = " << i << ", scale = " << param.first + << ", shift = " << param.second << std::endl; + } + std::cout << std::endl; + } - return int8_quant_params; + run_passes(prog, + {quantize_int8_pass{ins_names, *int8_quant_params}, + eliminate_common_subexpression{}, + dead_code_elimination{}, + simplify_reshapes{}, + dead_code_elimination{}, + simplify_qdq{}, + dead_code_elimination{}}); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/quantize_fp16.cpp b/src/quantize_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..772b60191ac690602cb1a601de914019bf7e4522 --- /dev/null +++ b/src/quantize_fp16.cpp @@ -0,0 +1,58 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +static void quantize_module(module& m, const std::vector& ins_names) +{ + for(auto ins : iterator_for(m)) + { + // instructions are not in the set to be quantized + if(not(contains(ins_names, ins->name()) or contains(ins_names, "all"))) + continue; + + // skip return and convert instructions + if(contains({"@return", "convert"}, ins->name())) + continue; + + if(ins->inputs().empty()) + continue; + + auto mod_inputs = ins->module_inputs(); + auto s = ins->get_shape(); + // Convert back to original type before quantizing the inputs + if(mod_inputs.empty()) + { + auto r = m.insert_instruction( + std::next(ins), make_op("convert", {{"target_type", s.type()}}), ins); + m.replace_instruction(ins, r); + } + + // Convert each of the inputs that are floating point to fp16 + auto inputs = ins->inputs(); + std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { + auto input_type = input->get_shape().type(); + if(input_type != shape::float_type and input_type != shape::double_type) + return input; + return m.insert_instruction( + ins, make_op("convert", {{"target_type", shape::half_type}}), input); + }); + + // Replace inputs + m.replace_instruction(ins, ins->get_operator(), inputs, mod_inputs); + } +} + +void quantize_fp16_pass::apply(module& m) const { quantize_module(m, ins_names); } + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/quantize_int8.cpp b/src/quantize_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..37f2348d87b5e97d7c44d31ea4a2345447056cfa --- /dev/null +++ b/src/quantize_int8.cpp @@ -0,0 +1,86 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS) + +static std::vector& get_quantizable_type() +{ + static std::vector quantable_types = { + shape::float_type, shape::double_type, shape::half_type}; + return quantable_types; +} + +void quantize_int8_pass::apply(module& m) const // NOLINT +{ + const auto& quantizable_types = get_quantizable_type(); + for(auto ins : iterator_for(m)) + { + if(ins->name() != "capture") + continue; + + auto op_val = ins->get_operator().to_value(); + assert(op_val.contains("ins_index")); + + auto param_index = op_val.at("ins_index").to(); + auto param = quant_params[param_index]; + + auto input = ins->inputs().front(); + auto s = input->get_shape(); + if(contains(quantizable_types, s.type()) and s.type() != shape::int8_type) + { + auto zero_point = m.add_literal(static_cast(param.second)); + auto scale = m.add_literal(literal({s.type()}, {1.0f / param.first})); + const auto& lens = s.lens(); + scale = + m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), scale); + zero_point = m.insert_instruction( + ins, make_op("multibroadcast", {{"out_lens", lens}}), zero_point); + auto q_in = + m.insert_instruction(ins, make_op("quantizelinear"), input, scale, zero_point); + auto dq_in = + m.insert_instruction(ins, make_op("dequantizelinear"), q_in, scale, zero_point); + m.replace_instruction(ins, dq_in); + } + } +} + +void capture_arguments_pass::apply(module& m) const // NOLINT +{ + assert(param_index != nullptr); + for(auto ins : iterator_for(m)) + { + if(not contains(ins_names, ins->name())) + { + continue; + } + + auto inputs = ins->inputs(); + std::vector new_args; + for(auto input : inputs) + { + auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input); + new_args.push_back(new_in); + } + m.replace_instruction(ins, ins->get_operator(), new_args); + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/reduce_dims.cpp b/src/reduce_dims.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eac19a661cdf63a72d2841dd9a4969fb2195c476 --- /dev/null +++ b/src/reduce_dims.cpp @@ -0,0 +1,127 @@ +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +bool reduce_dim(std::vector& shapes, std::size_t n) +{ + std::vector new_lens; + for(const auto& s : shapes) + { + assert(n < s.lens().size()); + if((n + 1) >= s.lens().size()) + return false; + auto astride = s.strides()[n]; + auto alen = s.lens()[n]; + auto bstride = s.strides()[n + 1]; + auto blen = s.lens()[n + 1]; + + if(astride == bstride * blen or alen == 1) + new_lens.push_back(alen * blen); + } + if(new_lens.size() != shapes.size()) + return false; + std::size_t i = 0; + for(auto& s : shapes) + { + auto lens = s.lens(); + auto strides = s.strides(); + lens.erase(lens.begin() + n); + strides.erase(strides.begin() + n); + lens[n] = new_lens[i]; + s = shape{s.type(), lens, strides}; + i++; + } + return true; +} + +void reduce_dim1(std::vector& shapes) +{ + if(std::any_of(shapes.begin(), shapes.end(), [&](const auto& s) { + return s.lens().size() < 2 or s.lens().back() != 1; + })) + return; + for(auto& s : shapes) + { + auto lens = s.lens(); + auto strides = s.strides(); + lens.pop_back(); + strides.pop_back(); + s = shape{s.type(), lens, strides}; + } +} + +std::size_t reduce_dim_all(std::vector& shapes, std::size_t n) +{ + while(reduce_dim(shapes, n) and n < shapes.size()) {} + return n + 1; +} +void reduce_dim_all(std::vector& shapes) +{ + std::size_t n = 0; + while(n < shapes.front().lens().size() - 1) + n = reduce_dim_all(shapes, n); + reduce_dim1(shapes); +} + +std::vector base_lens(const std::vector& shapes) +{ + return std::accumulate( + shapes.begin() + 1, shapes.end(), shapes.front().lens(), [](auto&& lens, auto&& s) { + std::vector result; + const auto* x = &s.lens(); + const auto* y = &lens; + if(x->size() > y->size()) + std::swap(x, y); + std::transform( + x->begin(), x->end(), y->begin(), std::back_inserter(result), [&](auto a, auto b) { + return std::max(a, b); + }); + return result; + }); +} + +shape mask_shape(const shape& s, const std::vector& lens) +{ + assert(s.lens().size() == lens.size()); + std::vector rstrides(lens.size()); + std::size_t stride = 1; + for(std::size_t i = lens.size() - 1; i < lens.size(); i--) + { + if(lens[i] == s.lens()[i]) + { + rstrides[i] = stride; + stride *= lens[i]; + } + else if(lens[i] != 1 and s.lens()[i] != 1) + { + return shape{}; + } + } + return shape{s.type(), lens, rstrides}; +} + +std::vector reduce_dims(const std::vector& shapes) +{ + if(shapes.empty()) + return {}; + auto result = shapes; + auto base = base_lens(shapes); + for(auto&& s : shapes) + { + if(s.lens().size() != base.size()) + return shapes; + if(s.lens() == base) + continue; + auto mshape = mask_shape(s, base); + if(mshape.lens().size() != base.size()) + return shapes; + result.push_back(mshape); + } + reduce_dim_all(result); + result.erase(result.begin() + shapes.size(), result.end()); + return result; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/register_op.cpp b/src/register_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ebaa20001ae0cc130bd50f06e29a8e469eb9255b --- /dev/null +++ b/src/register_op.cpp @@ -0,0 +1,32 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::unordered_map& op_map() +{ + static std::unordered_map m; // NOLINT + return m; +} +void register_op(const operation& op) { op_map()[op.name()] = op; } +operation load_op(const std::string& name) +{ + return at(op_map(), name, "Operator not found: " + name); +} + +bool has_op(const std::string& name) { return op_map().count(name) == 1; } + +std::vector get_operators() +{ + std::vector result; + std::transform(op_map().begin(), op_map().end(), std::back_inserter(result), [&](auto&& p) { + return p.first; + }); + std::sort(result.begin(), result.end()); + return result; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/register_target.cpp b/src/register_target.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b059674eb3caa165fbc2a89a5020ce2c4c6a135c --- /dev/null +++ b/src/register_target.cpp @@ -0,0 +1,37 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::unordered_map& target_map() +{ + static std::unordered_map m; // NOLINT + return m; +} + +void register_target(const target& t) { target_map()[t.name()] = t; } + +target make_target(const std::string& name) +{ + const auto it = target_map().find(name); + if(it == target_map().end()) + { + MIGRAPHX_THROW("Requested target '" + name + "' is not enabled or not supported"); + } + return it->second; +} + +std::vector get_targets() +{ + std::vector result; + std::transform(target_map().begin(), + target_map().end(), + std::back_inserter(result), + [&](auto&& p) { return p.first; }); + std::sort(result.begin(), result.end()); + return result; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/replace_allocate.cpp b/src/replace_allocate.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d79a8d9279ab04c38cdf37bba34eae943fab70d2 --- /dev/null +++ b/src/replace_allocate.cpp @@ -0,0 +1,101 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::unordered_map create_output_names(const module& mod) +{ + std::unordered_map mod_output_names{}; + auto last = std::prev(mod.end()); + if(last->name() == "@return") + { + const auto& prog_outputs = last->inputs(); + std::vector outputs_alias(prog_outputs.size()); + + std::transform(prog_outputs.begin(), + prog_outputs.end(), + outputs_alias.begin(), + [](const auto& i) { return instruction::get_output_alias(i); }); + + std::size_t index = 0; + for(auto ins : outputs_alias) + { + mod_output_names[ins] = mod.name() + ":#output_" + std::to_string(index++); + } + } + else + { + auto ins = instruction::get_output_alias(last); + mod_output_names[ins] = "output"; + } + return mod_output_names; +} + +void insert_submod_allocations(instruction_ref ins, module& mod, const allocation_model& model) +{ + std::vector inputs = ins->inputs(); + std::vector mod_args = ins->module_inputs(); + + std::map name_shapes; + for(const auto& smod : mod_args) + { + auto ps = smod->get_parameter_shapes(); + name_shapes.insert(ps.begin(), ps.end()); + } + + for(auto& pn : name_shapes) + { + const auto& s = pn.second; + instruction_ref output{}; + output = mod.insert_instruction(ins, model.allocate(s)); + inputs.push_back(output); + } + + mod.replace_instruction(ins, ins->get_operator(), inputs, mod_args); +} + +void replace_allocate::apply(module& m) const +{ + auto mod_output_names = create_output_names(m); + bool main_offload_copy = m.name() == "main" ? this->offload_copy : false; + for(auto ins : iterator_for(m)) + { + auto op = ins->get_operator(); + auto op_name = op.name(); + + // check if allocations from submodules need to be inserted + // for now, only the "if" operator is affected + if(op_name == "if") + { + insert_submod_allocations(ins, m, model); + continue; + } + if(op_name != "allocate") + continue; + + auto s = ins->get_shape(); + + if(not main_offload_copy and model.needs_out_params() and contains(mod_output_names, ins)) + { + + auto out_param = m.add_parameter(mod_output_names[ins], s); + m.replace_instruction(ins, out_param); + continue; + } + + m.replace_instruction( + ins, + m.insert_instruction(ins, + make_op(model.name(), migraphx::value{{"shape", to_value(s)}}))); + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/rewrite_batchnorm.cpp b/src/rewrite_batchnorm.cpp index c76fce0b4c78a2ddef9dac95c4bf4779140fdc2d..a613e955a39fe12537bc08a177ee5fbda3553398 100644 --- a/src/rewrite_batchnorm.cpp +++ b/src/rewrite_batchnorm.cpp @@ -1,20 +1,22 @@ #include #include #include -#include +#include #include #include #include #include #include +#include + #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -void rewrite_batchnorm::apply(program& p) const +void rewrite_batchnorm::apply(module& m) const { - for(auto ins : iterator_for(p)) + for(auto ins : iterator_for(m)) { if(ins->name() != "batch_norm_inference") continue; @@ -26,7 +28,8 @@ void rewrite_batchnorm::apply(program& p) const if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); })) continue; - auto s = shape{ins->get_shape().type(), {ins->get_shape().lens()[1]}}; + std::vector lens = ins->inputs()[1]->get_shape().lens(); + shape s{ins->get_shape().type(), lens}; // Get epsilon auto bn_op = any_cast(ins->get_operator()); auto epsilon = bn_op.epsilon; @@ -43,13 +46,13 @@ void rewrite_batchnorm::apply(program& p) const }); auto broadcast = op::broadcast{1, ins->get_shape().lens()}; - auto a_ins = p.add_literal({a.get_shape(), a.data()}); - auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins); - auto mul = p.insert_instruction(ins, op::mul{}, ins->inputs().front(), a_broadcast); - auto b_ins = p.add_literal({b.get_shape(), b.data()}); - auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins); - auto add = p.insert_instruction(ins, op::add{}, mul, b_broadcast); - p.replace_instruction(ins, add); + auto a_ins = m.add_literal({a.get_shape(), a.data()}); + auto a_broadcast = m.insert_instruction(ins, broadcast, a_ins); + auto mul = m.insert_instruction(ins, make_op("mul"), ins->inputs().front(), a_broadcast); + auto b_ins = m.add_literal({b.get_shape(), b.data()}); + auto b_broadcast = m.insert_instruction(ins, broadcast, b_ins); + auto add = m.insert_instruction(ins, make_op("add"), mul, b_broadcast); + m.replace_instruction(ins, add); } } diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 5e0b663a20e9d6da64d139f638dac59a3d3933af..7084f6881f93eaaf34d7321f01b0def76db3e0fe 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -4,39 +4,54 @@ #include #include #include +#include +#include + #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -void rewrite_pooling::apply(program& prog) const +void rewrite_pooling::apply(module& m) const { - for(auto ins : iterator_for(prog)) + for(auto ins : iterator_for(m)) { if(ins->name() != "pooling") continue; - if(ins->get_shape().lens().size() != 4) - continue; if(ins->inputs().empty()) continue; auto&& s = ins->inputs().front()->get_shape(); if(not s.standard()) continue; auto&& op = any_cast(ins->get_operator()); - if(op.mode != "average") - continue; - if(op.padding[0] != 0 and op.padding[1] != 0) + if(!std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; })) continue; - if(op.stride[0] != 1 and op.stride[1] != 1) + if(!std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; })) continue; - if(s.lens()[2] != op.lengths[0] and s.lens()[3] != op.lengths[1]) + auto lens = s.lens(); + if(!std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end())) continue; std::int64_t n = s.lens()[0]; std::int64_t c = s.lens()[1]; - auto reshape = - prog.insert_instruction(ins, op::reshape{{n * c, -1}}, ins->inputs().front()); - auto pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape); - prog.replace_instruction(ins, op::reshape{{n, c, 1, 1}}, pooling); + auto reshape = m.insert_instruction( + ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front()); + instruction_ref pooling{}; + + // average pooling + if(op.mode == op::pooling_mode::average) + { + pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape); + } + // max pooling + else + { + pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape); + } + + std::vector rsp_lens(lens.size(), 1); + rsp_lens[0] = n; + rsp_lens[1] = c; + m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling); } } diff --git a/src/rewrite_quantization.cpp b/src/rewrite_quantization.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d409b88d438da55a07f6d5c899f74b918ba5e919 --- /dev/null +++ b/src/rewrite_quantization.cpp @@ -0,0 +1,83 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +void apply_quantizelinear(module& m, instruction_ref ins) +{ + assert(ins->name() == "quantizelinear"); + auto x = ins->inputs()[0]; + auto y_scale = ins->inputs()[1]; + + if(x->get_shape().type() != y_scale->get_shape().type()) + { + x = m.insert_instruction(ins, make_op("convert", {{"target_type", shape::float_type}}), x); + } + auto div = m.insert_instruction(ins, make_op("div"), x, y_scale); + auto add_zero_point = m.insert_instruction(ins, make_op("round"), div); + + if(ins->inputs().size() == 3) + { + auto zero_point = m.insert_instruction( + ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]); + add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point); + } + + int64_t max_quant = 0; + int64_t min_quant = 0; + ins->get_shape().visit_type([&](auto qt) { + max_quant = qt.max(); + min_quant = qt.min(); + }); + auto s = add_zero_point->get_shape(); + std::vector min_data(s.elements(), min_quant); + std::vector max_data(s.elements(), max_quant); + auto min_arg = m.add_literal(literal(s, min_data)); + auto max_arg = m.add_literal(literal(s, max_data)); + + auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg); + m.replace_instruction( + ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate); +} + +void apply_dequantizelinear(module& m, instruction_ref ins) +{ + assert(ins->name() == "dequantizelinear"); + auto x = m.insert_instruction( + ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[0]); + auto x_scale = ins->inputs()[1]; + + if(ins->inputs().size() == 3) + { + auto x_zero_point = m.insert_instruction( + ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]); + x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point); + } + + m.replace_instruction(ins, make_op("mul"), x, x_scale); +} + +void rewrite_quantization::apply(module& m) const +{ + for(auto ins : iterator_for(m)) + { + if(ins->name() == "quantizelinear") + { + apply_quantizelinear(m, ins); + } + + else if(ins->name() == "dequantizelinear") + { + apply_dequantizelinear(m, ins); + } + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/rewrite_rnn.cpp b/src/rewrite_rnn.cpp index 32f0769871666ff55ba7cf1769cc950e70a6f416..437399c1b575f0ba3042089f3303d6f03f470013 100644 --- a/src/rewrite_rnn.cpp +++ b/src/rewrite_rnn.cpp @@ -9,45 +9,54 @@ #include #include #include -#include #include #include #include #include #include +#include +#include +#include +#include +#include + #include #include +#include #include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -void rewrite_rnn::apply(program& prog) const +void rewrite_rnn::apply(module& m) const { - for(auto ins : iterator_for(prog)) + for(auto ins : iterator_for(m)) { if(ins->name() == "rnn") { - apply_vanilla_rnn(prog, ins); + apply_vanilla_rnn(m, ins); } else if(ins->name() == "gru") { - apply_gru(prog, ins); + apply_gru(m, ins); } else if(ins->name() == "lstm") { - apply_lstm(prog, ins); + apply_lstm(m, ins); } } } -void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const +// NOLINTNEXTLINE(readability-function-cognitive-complexity) +void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const { assert(ins->name() == "rnn"); // could be 3 to 6 inputs, but the parse_rnn function will // append undefined operators to make 6 arguments when parsing // an onnx file. Another case is user can have num of arguments - // when writing their program. + // when writing their module. auto args = ins->inputs(); shape seq_shape = args[0]->get_shape(); @@ -59,25 +68,41 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const auto actv_funcs = vanilla_rnn_actv_funcs(ins); auto rnn_op = any_cast(ins->get_operator()); - op::rnn_direction dicrt = rnn_op.direction; + op::rnn_direction dirct = rnn_op.direction; + + // process sequence length + instruction_ref seq_lens = m.end(); + if((args.size() >= 5) && args[4]->name() != "undefined") + { + seq_lens = args[4]; + } + + bool variable_seq_len = is_variable_seq_lens(m, seq_lens); + instruction_ref last_output{}; - if(dicrt == op::rnn_direction::bidirectional) + if(dirct == op::rnn_direction::bidirectional) { // input weight matrix - auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); - auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]); + auto w_forward = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]); + auto w_reverse = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]); // hidden state weight matrix - auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]); - auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]); + auto r_forward = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]); + auto r_reverse = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]); // process bias - instruction_ref bias_forward = prog.end(); - instruction_ref bias_reverse = prog.end(); + instruction_ref bias_forward = m.end(); + instruction_ref bias_reverse = m.end(); if(args.size() >= 4 && args[3]->name() != "undefined") { - bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); - bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]); + bias_forward = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); + bias_reverse = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]); } // process intial hidden state, it could be the 6th argument @@ -86,57 +111,62 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const instruction_ref ih_reverse{}; if(args.size() == 6 && args[5]->name() != "undefined") { - ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]); - ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]); + ih_forward = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); + ih_reverse = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]); } else { - ih_forward = prog.add_literal(migraphx::literal{ih_shape, data}); - ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data}); - } - - auto ret_forward = vanilla_rnn_cell(true, - prog, - ins, - args[0], - w_forward, - r_forward, - bias_forward, - ih_forward, - actv_funcs.at(0)); - auto ret_reverse = vanilla_rnn_cell(false, - prog, - ins, - args[0], - w_reverse, - r_reverse, - bias_reverse, - ih_reverse, - actv_funcs.at(1)); - - auto concat_output = - prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); - last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output); + ih_forward = m.add_literal(migraphx::literal{ih_shape, data}); + ih_reverse = m.add_literal(migraphx::literal{ih_shape, data}); + } + + auto ret_forward = + vanilla_rnn_cell(true, + m, + ins, + {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward}, + actv_funcs.at(0)); + + if(variable_seq_len) + { + args[0] = + m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); + } + + auto ret_reverse = + vanilla_rnn_cell(false, + m, + ins, + {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse}, + actv_funcs.at(1)); + + auto concat_output = m.insert_instruction( + ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]); + last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output); // The following logic is to ensure the last instruction rewritten from // rnn operator is a concat instruction // sequence len is 1 - if(ret_forward[0] == prog.end()) + if(ret_forward[0] == m.end()) { - prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); + m.replace_instruction( + ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]); } else { - ret_forward[0] = - prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); - ret_reverse[0] = - prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); - prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); + ret_forward[0] = m.insert_instruction( + ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]); + ret_reverse[0] = m.insert_instruction( + ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]); + m.replace_instruction( + ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]}); } } else { - bool is_forward = (dicrt == op::rnn_direction::forward); + bool is_forward = (dirct == op::rnn_direction::forward); // input weight matrix auto w = args[1]; @@ -144,7 +174,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const auto r = args[2]; // process bias and initial hidden state - instruction_ref bias = prog.end(); + instruction_ref bias = m.end(); if(args.size() >= 4 && args[3]->name() != "undefined") { bias = args[3]; @@ -158,104 +188,110 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const } else { - ih = prog.add_literal(migraphx::literal{ih_shape, data}); + ih = m.add_literal(migraphx::literal{ih_shape, data}); } - auto ret = - vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0)); - last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); + if(!is_forward and variable_seq_len) + { + args[0] = + m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); + } + + auto ret = vanilla_rnn_cell( + is_forward, m, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0)); + last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]); // following logic is to ensure the last instruction is a // concat instruction // sequence len is 1 - if(ret[0] == prog.end()) + if(ret[0] == m.end()) { - prog.replace_instruction(ins, op::concat{0}, ret[1]); + m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]); } else { auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg1 = is_forward ? ret[1] : ret[0]; - prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1); + m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1); } } - // search its output to find if there are rnn_last_output operator - // while loop to handle case of multiple rnn_last_output operators - auto last_output_it = ins->outputs().begin(); - while(last_output_it != ins->outputs().end()) - { - last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) { - return i->name() == "rnn_last_output"; - }); - - if(last_output_it != ins->outputs().end()) - { - prog.replace_instruction(*last_output_it, last_output); - last_output_it++; - } - } + // in case of all sequences are of the same lengths and shorter than the + // max sequence length, need to pad 0's at the end for output hidden states + ins = pad_hidden_states(m, args[0], seq_lens, ins); + replace_last_hs_output(m, ins, seq_lens, last_output, dirct); } std::vector rewrite_rnn::vanilla_rnn_cell(bool is_forward, - program& prog, + module& m, instruction_ref ins, - instruction_ref input, - instruction_ref w, - instruction_ref r, - instruction_ref bias, - instruction_ref ih, + std::vector inputs, operation& actv_func) const { + assert(inputs.size() == 6); + auto seq = inputs.at(0); + auto w = inputs.at(1); + auto r = inputs.at(2); + auto bias = inputs.at(3); + auto seq_lens = inputs.at(4); + auto ih = inputs.at(5); + // squeeze and transpose w std::vector perm{1, 0}; - auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); - auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw); + auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w); + auto tran_sw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw); // squeeze and transpose r - auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); - auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr); + auto sr = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r); + auto tran_sr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr); // initial hidden state - auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); + auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih); auto sih_lens = sih->get_shape().lens(); // bias instruction_ref bb{}; - if(bias != prog.end()) + if(bias != m.end()) { long hs = static_cast(r->get_shape().lens()[2]); - auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); - auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); - auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); - auto wrb = prog.insert_instruction(ins, op::add{}, wb, rb); - bb = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wrb); + auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias); + auto wb = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), sbias); + auto rb = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias); + auto wrb = m.insert_instruction(ins, make_op("add"), wb, rb); + bb = m.insert_instruction( + ins, make_op("broadcast", {{"axis", 1}, {"out_lens", sih_lens}}), wrb); } - instruction_ref hidden_out = prog.end(); + instruction_ref hidden_out = m.end(); instruction_ref last_out{}; - last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih); - std::size_t seq_len = input->get_shape().lens()[0]; - for(std::size_t i = 0; i < seq_len; i++) + last_out = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih); + long seq_len = get_seq_len(m, seq, seq_lens); + for(long i = 0; i < seq_len; i++) { long seq_index = is_forward ? i : (seq_len - 1 - i); - auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input); - xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); - auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw); - auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr); - if(bias != prog.end()) + auto xt = m.insert_instruction( + ins, + make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}), + seq); + auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt); + xt = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt); + auto xt_wi = m.insert_instruction(ins, make_op("dot"), xt, tran_sw); + auto ht_ri = m.insert_instruction(ins, make_op("dot"), sih, tran_sr); + if(bias != m.end()) { - xt_wi = prog.insert_instruction(ins, op::add{}, xt_wi, bb); + xt_wi = m.insert_instruction(ins, make_op("add"), xt_wi, bb); } - auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri); + auto xt_ht = m.insert_instruction(ins, make_op("add"), xt_wi, ht_ri); // apply activation function - auto ht = prog.insert_instruction(ins, actv_func, xt_ht); + auto ht = m.insert_instruction(ins, actv_func, xt_ht); sih = ht; // add the dimensions of sequence length (axis 0 for sequence length, // axis 1 for num_directions - last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht); + last_out = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht); // concatenation for the last last_out is performed in the apply() // function to ensure the last instruction is concat, then we have @@ -264,17 +300,17 @@ std::vector rewrite_rnn::vanilla_rnn_cell(bool is_forward, { if(is_forward) { - hidden_out = - (seq_index == 0) - ? last_out - : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out); + hidden_out = (seq_index == 0) + ? last_out + : m.insert_instruction( + ins, make_op("concat", {{"axis", 0}}), hidden_out, last_out); } else { - hidden_out = - (seq_index == seq_len - 1) - ? last_out - : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out); + hidden_out = (seq_index == seq_len - 1) + ? last_out + : m.insert_instruction( + ins, make_op("concat", {{"axis", 0}}), last_out, hidden_out); } } } @@ -294,7 +330,7 @@ std::vector rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) if(rnn_op.actv_funcs.empty()) { // default is tanh - return {op::tanh{}, op::tanh{}}; + return {make_op("tanh"), make_op("tanh")}; } else if(rnn_op.actv_funcs.size() == 1) { @@ -310,7 +346,7 @@ std::vector rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) if(rnn_op.actv_funcs.empty()) { // default is tanh - return {op::tanh{}}; + return {make_op("tanh")}; } else { @@ -319,7 +355,8 @@ std::vector rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) } } -void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const +// NOLINTNEXTLINE(readability-function-cognitive-complexity) +void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const { assert(ins->name() == "gru"); const auto actv_funcs = gru_actv_funcs(ins); @@ -337,25 +374,41 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const std::vector data(ih_shape.elements(), 0.0); auto gru_op = any_cast(ins->get_operator()); - op::rnn_direction dicrt = gru_op.direction; + op::rnn_direction dirct = gru_op.direction; + + // process sequence length + instruction_ref seq_lens = m.end(); + if((args.size() >= 5) && args[4]->name() != "undefined") + { + seq_lens = args[4]; + } + + bool variable_seq_len = is_variable_seq_lens(m, seq_lens); + instruction_ref last_output{}; - if(dicrt == op::rnn_direction::bidirectional) + if(dirct == op::rnn_direction::bidirectional) { // w weight matrix - auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); - auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]); + auto w_forward = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]); + auto w_reverse = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]); // r weight matrix - auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]); - auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]); + auto r_forward = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]); + auto r_reverse = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]); // bias - instruction_ref bias_forward = prog.end(); - instruction_ref bias_reverse = prog.end(); + instruction_ref bias_forward = m.end(); + instruction_ref bias_reverse = m.end(); if(args.size() >= 4 && args[3]->name() != "undefined") { - bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); - bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]); + bias_forward = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); + bias_reverse = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]); } // intial hidden state @@ -363,59 +416,71 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const instruction_ref ih_reverse{}; if(args.size() == 6 && args[5]->name() != "undefined") { - ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]); - ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]); + ih_forward = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); + ih_reverse = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]); } else { - ih_forward = prog.add_literal(migraphx::literal{ih_shape, data}); - ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data}); + ih_forward = m.add_literal(migraphx::literal{ih_shape, data}); + ih_reverse = m.add_literal(migraphx::literal{ih_shape, data}); } - auto ret_forward = gru_cell(true, - prog, - ins, - {args[0], w_forward, r_forward, bias_forward, ih_forward}, - gru_op.linear_before_reset, - actv_funcs.at(0), - actv_funcs.at(1)); + auto ret_forward = + gru_cell(true, + m, + ins, + {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward}, + gru_op.linear_before_reset, + actv_funcs.at(0), + actv_funcs.at(1)); - auto ret_reverse = gru_cell(false, - prog, - ins, - {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse}, - gru_op.linear_before_reset, - actv_funcs.at(2), - actv_funcs.at(3)); + if(variable_seq_len) + { + args[0] = + m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); + } + + auto ret_reverse = + gru_cell(false, + m, + ins, + {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse}, + gru_op.linear_before_reset, + actv_funcs.at(2), + actv_funcs.at(3)); - auto concat_output = - prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); - last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output); + auto concat_output = m.insert_instruction( + ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]); + last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output); // The following logic is to ensure the last instruction rewritten // from gru operator is a concat - if(ret_forward[0] == prog.end()) + if(ret_forward[0] == m.end()) { - prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); + m.replace_instruction( + ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]); } else { - ret_forward[0] = - prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); - ret_reverse[0] = - prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); - prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); + ret_forward[0] = m.insert_instruction( + ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]); + ret_reverse[0] = m.insert_instruction( + ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]); + m.replace_instruction( + ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]}); } } else { - bool is_forward = (dicrt == op::rnn_direction::forward); + bool is_forward = (dirct == op::rnn_direction::forward); // weight matrix auto w = args[1]; auto r = args[2]; // bias - instruction_ref bias = prog.end(); + instruction_ref bias = m.end(); if(args.size() >= 4 && args[3]->name() != "undefined") { bias = args[3]; @@ -429,167 +494,189 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const } else { - ih = prog.add_literal(migraphx::literal{ih_shape, data}); + ih = m.add_literal(migraphx::literal{ih_shape, data}); + } + + if(!is_forward and variable_seq_len) + { + args[0] = + m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); } auto ret = gru_cell(is_forward, - prog, + m, ins, - {args[0], w, r, bias, ih}, + {args[0], w, r, bias, seq_lens, ih}, gru_op.linear_before_reset, actv_funcs.at(0), actv_funcs.at(1)); - last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); + last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]); - if(ret[0] == prog.end()) + if(ret[0] == m.end()) { - prog.replace_instruction(ins, op::concat{0}, ret[1]); + m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]); } else { auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg1 = is_forward ? ret[1] : ret[0]; - prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1); + m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1); } } - // replace the corresponding rnn_last_output instruction - // with the last_output, if rnn_last_output exists - // while loop to handle case of multiple rnn_last_output operators - auto last_output_it = ins->outputs().begin(); - while(last_output_it != ins->outputs().end()) - { - last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) { - return i->name() == "rnn_last_output"; - }); - - if(last_output_it != ins->outputs().end()) - { - prog.replace_instruction(*last_output_it, last_output); - last_output_it++; - } - } + // in case of all sequences are of the same lengths and shorter than the + // max sequence length, need to pad 0's at the end for output hidden states + ins = pad_hidden_states(m, args[0], seq_lens, ins); + replace_last_hs_output(m, ins, seq_lens, last_output, dirct); } +// NOLINTNEXTLINE(readability-function-cognitive-complexity) std::vector rewrite_rnn::gru_cell(bool is_forward, - program& prog, + module& m, instruction_ref ins, std::vector inputs, int linear_before_reset, const operation& actv_func1, const operation& actv_func2) const { - assert(inputs.size() == 5); - auto seq = inputs.at(0); - auto w = inputs.at(1); - auto r = inputs.at(2); - auto bias = inputs.at(3); - auto ih = inputs.at(4); - - instruction_ref hidden_states = prog.end(); + assert(inputs.size() == 6); + auto seq = inputs.at(0); + auto w = inputs.at(1); + auto r = inputs.at(2); + auto bias = inputs.at(3); + auto seq_lens = inputs.at(4); + auto ih = inputs.at(5); + + instruction_ref hidden_states = m.end(); instruction_ref last_output{}; migraphx::shape seq_shape = seq->get_shape(); migraphx::shape r_shape = r->get_shape(); - long seq_len = static_cast(seq_shape.lens()[0]); - long hs = static_cast(r_shape.lens()[2]); + long hs = r_shape.lens()[2]; - migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]}); - std::vector data(s.elements(), 1.0f); - auto l1 = prog.add_literal(migraphx::literal{s, data}); + migraphx::shape ss(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]}); + std::vector data(ss.elements(), 1.0f); + auto l1 = m.add_literal(migraphx::literal{ss, data}); // w matrix squeeze to 2-dim and do a transpose std::vector perm{1, 0}; - auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); - auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw); + auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w); + auto tw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw); // r slide to two part, zr and h - auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); - auto rzr = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr); - auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr); + auto sr = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r); + auto rzr = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr); + auto trzr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr); - auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr); - auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh); + auto rh = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr); + auto trh = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh); // initial states - auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); + auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih); size_t bs = ih->get_shape().lens()[1]; // bias instruction_ref bwb{}; instruction_ref brb_zr{}; instruction_ref brb_h{}; - if(bias != prog.end()) + if(bias != m.end()) { - auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); - auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias); - bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast(3 * hs)}}, wb); - - auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias); - auto rb_h = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias); - brb_zr = prog.insert_instruction( - ins, op::broadcast{1, {bs, static_cast(2 * hs)}}, rb_zr); - brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast(hs)}}, rb_h); + auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias); + auto wb = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias); + bwb = m.insert_instruction( + ins, + make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast(3 * hs)}}}), + wb); + + auto rb_zr = m.insert_instruction( + ins, + make_op("slice", {{"axes", {0}}, {"starts", {3 * hs}}, {"ends", {5 * hs}}}), + sbias); + auto rb_h = m.insert_instruction( + ins, + make_op("slice", {{"axes", {0}}, {"starts", {5 * hs}}, {"ends", {6 * hs}}}), + sbias); + brb_zr = m.insert_instruction( + ins, + make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast(2 * hs)}}}), + rb_zr); + brb_h = m.insert_instruction( + ins, + make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast(hs)}}}), + rb_h); } + long seq_len = get_seq_len(m, seq, seq_lens); for(long i = 0; i < seq_len; i++) { long seq_index = is_forward ? i : (seq_len - 1 - i); - auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); - xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); - - auto xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw); - auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr); - if(bias != prog.end()) + auto xt = m.insert_instruction( + ins, + make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}), + seq); + auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt); + xt = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt); + + auto xt_w = m.insert_instruction(ins, make_op("dot"), xt, tw); + auto ih1_rzr = m.insert_instruction(ins, make_op("dot"), sih, trzr); + if(bias != m.end()) { - xt_w = prog.insert_instruction(ins, op::add{}, xt_w, bwb); - ih1_rzr = prog.insert_instruction(ins, op::add{}, ih1_rzr, brb_zr); + xt_w = m.insert_instruction(ins, make_op("add"), xt_w, bwb); + ih1_rzr = m.insert_instruction(ins, make_op("add"), ih1_rzr, brb_zr); } - auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w); - auto xw_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_w); - auto xw_h = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_w); + auto xw_z = m.insert_instruction( + ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_w); + auto xw_r = m.insert_instruction( + ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_w); + auto xw_h = m.insert_instruction( + ins, make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), xt_w); - auto hr_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, ih1_rzr); - auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr); + auto hr_z = m.insert_instruction( + ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), ih1_rzr); + auto hr_r = m.insert_instruction( + ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), ih1_rzr); - auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z); - auto zt = prog.insert_instruction(ins, actv_func1, xw_hr_z); + auto xw_hr_z = m.insert_instruction(ins, make_op("add"), xw_z, hr_z); + auto zt = m.insert_instruction(ins, actv_func1, xw_hr_z); - auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r); - auto rt = prog.insert_instruction(ins, actv_func1, xw_hr_r); + auto xw_hr_r = m.insert_instruction(ins, make_op("add"), xw_r, hr_r); + auto rt = m.insert_instruction(ins, actv_func1, xw_hr_r); instruction_ref hr_h{}; if(linear_before_reset == 0) { // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) - auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih); - hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh); - if(bias != prog.end()) + auto rt_ht1 = m.insert_instruction(ins, make_op("mul"), rt, sih); + hr_h = m.insert_instruction(ins, make_op("dot"), rt_ht1, trh); + if(bias != m.end()) { - hr_h = prog.insert_instruction(ins, op::add{}, hr_h, brb_h); + hr_h = m.insert_instruction(ins, make_op("add"), hr_h, brb_h); } } else { // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) - auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh); - if(bias != prog.end()) + auto ht1_rh = m.insert_instruction(ins, make_op("dot"), sih, trh); + if(bias != m.end()) { - ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brb_h); + ht1_rh = m.insert_instruction(ins, make_op("add"), ht1_rh, brb_h); } - hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh); + hr_h = m.insert_instruction(ins, make_op("mul"), rt, ht1_rh); } - auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h); - auto ht = prog.insert_instruction(ins, actv_func2, xw_hr_h); + auto xw_hr_h = m.insert_instruction(ins, make_op("add"), xw_h, hr_h); + auto ht = m.insert_instruction(ins, actv_func2, xw_hr_h); // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1 - auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt); - auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht); - auto zt_ht1 = prog.insert_instruction(ins, op::mul{}, zt, sih); - sih = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1); - last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih); + auto one_minus_zt = m.insert_instruction(ins, make_op("sub"), l1, zt); + auto one_minus_zt_ht = m.insert_instruction(ins, make_op("mul"), one_minus_zt, ht); + auto zt_ht1 = m.insert_instruction(ins, make_op("mul"), zt, sih); + sih = m.insert_instruction(ins, make_op("add"), one_minus_zt_ht, zt_ht1); + last_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih); if(i < seq_len - 1) { @@ -598,14 +685,16 @@ std::vector rewrite_rnn::gru_cell(bool is_forward, hidden_states = (seq_index == 0) ? last_output - : prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output); + : m.insert_instruction( + ins, make_op("concat", {{"axis", 0}}), hidden_states, last_output); } else { hidden_states = (seq_index == seq_len - 1) ? last_output - : prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states); + : m.insert_instruction( + ins, make_op("concat", {{"axis", 0}}), last_output, hidden_states); } } } @@ -623,7 +712,7 @@ std::vector rewrite_rnn::gru_actv_funcs(instruction_ref ins) const if(gru_op.direction == op::rnn_direction::bidirectional) { if(gru_op.actv_funcs.empty()) - return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}}; + return {make_op("sigmoid"), make_op("tanh"), make_op("sigmoid"), make_op("tanh")}; else if(gru_op.actv_funcs.size() == 1) return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0), @@ -645,7 +734,7 @@ std::vector rewrite_rnn::gru_actv_funcs(instruction_ref ins) const else { if(gru_op.actv_funcs.empty()) - return {op::sigmoid{}, op::tanh{}}; + return {make_op("sigmoid"), make_op("tanh")}; else if(gru_op.actv_funcs.size() == 1) return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)}; else @@ -654,7 +743,8 @@ std::vector rewrite_rnn::gru_actv_funcs(instruction_ref ins) const } // for lstm operators -void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const +// NOLINTNEXTLINE(readability-function-cognitive-complexity) +void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const { assert(ins->name() == "lstm"); auto args = ins->inputs(); @@ -672,26 +762,43 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const auto lstm_op = any_cast(ins->get_operator()); op::rnn_direction dirct = lstm_op.direction; - instruction_ref last_output{}; + // process sequence length + instruction_ref seq_lens = m.end(); + if((args.size() >= 5) && args[4]->name() != "undefined") + { + seq_lens = args[4]; + } + + bool variable_seq_len = is_variable_seq_lens(m, seq_lens); + + instruction_ref last_hs_output{}; instruction_ref last_cell_output{}; + instruction_ref hidden_state{}; + instruction_ref cell_outputs{}; if(dirct == op::rnn_direction::bidirectional) { // input weight matrix // input weight matrix - auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); - auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]); + auto w_forward = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]); + auto w_reverse = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]); // hidden state weight matrix - auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]); - auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]); + auto r_forward = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]); + auto r_reverse = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]); // process bias - instruction_ref bias_forward = prog.end(); - instruction_ref bias_reverse = prog.end(); + instruction_ref bias_forward = m.end(); + instruction_ref bias_reverse = m.end(); if(args.size() >= 4 && args[3]->name() != "undefined") { - bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); - bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]); + bias_forward = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); + bias_reverse = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]); } // process intial hidden state, it is the 6th argument @@ -699,13 +806,15 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const instruction_ref ih_reverse{}; if(args.size() >= 6 && args[5]->name() != "undefined") { - ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]); - ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]); + ih_forward = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); + ih_reverse = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]); } else { - ih_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data}); - ih_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data}); + ih_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data}); + ih_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data}); } // process initial cell value @@ -713,63 +822,94 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const instruction_ref ic_reverse{}; if(args.size() >= 7 && args[6]->name() != "undefined") { - ic_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[6]); - ic_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[6]); + ic_forward = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]); + ic_reverse = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[6]); } else { - ic_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data}); - ic_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data}); + ic_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data}); + ic_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data}); } // process weight of the peephole - instruction_ref pph_forward = prog.end(); - instruction_ref pph_reverse = prog.end(); + instruction_ref pph_forward = m.end(); + instruction_ref pph_reverse = m.end(); if(args.size() == 8 && args[7]->name() != "undefined") { - pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]); - pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]); + pph_forward = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]); + pph_reverse = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[7]); } - auto ret_forward = lstm_cell( - true, - prog, - ins, - {args[0], w_forward, r_forward, bias_forward, ih_forward, ic_forward, pph_forward}, - actv_funcs.at(0), - actv_funcs.at(1), - actv_funcs.at(2)); - - auto ret_reverse = lstm_cell( - false, - prog, - ins, - {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse, ic_reverse, pph_reverse}, - actv_funcs.at(3), - actv_funcs.at(4), - actv_funcs.at(5)); - - auto concat_output = - prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); - last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output); - - // last cell output + auto ret_forward = lstm_cell(true, + m, + ins, + {args[0], + w_forward, + r_forward, + bias_forward, + seq_lens, + ih_forward, + ic_forward, + pph_forward}, + actv_funcs.at(0), + actv_funcs.at(1), + actv_funcs.at(2)); + + if(variable_seq_len) + { + args[0] = + m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); + } + auto ret_reverse = lstm_cell(false, + m, + ins, + {args[0], + w_reverse, + r_reverse, + bias_reverse, + seq_lens, + ih_reverse, + ic_reverse, + pph_reverse}, + actv_funcs.at(3), + actv_funcs.at(4), + actv_funcs.at(5)); + + auto concat_hs_output = m.insert_instruction( + ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]); + auto concat_cell_output = m.insert_instruction( + ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]); + last_hs_output = + m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_hs_output); last_cell_output = - prog.insert_instruction(ins, op::concat{0}, ret_forward[2], ret_reverse[2]); + m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_cell_output); // the following logic is to ensure the last instruction is a concat - if(ret_forward[0] == prog.end()) + if(ret_forward[0] == m.end()) { - prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); + cell_outputs = concat_cell_output; } else { - ret_forward[0] = - prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); - ret_reverse[0] = - prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); - prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); + ret_forward[1] = m.insert_instruction( + ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]); + ret_reverse[1] = m.insert_instruction( + ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]); + + ret_forward[3] = m.insert_instruction( + ins, make_op("concat", {{"axis", 0}}), ret_forward[2], ret_forward[3]); + ret_reverse[3] = m.insert_instruction( + ins, make_op("concat", {{"axis", 0}}), ret_reverse[3], ret_reverse[2]); + cell_outputs = m.insert_instruction( + ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]); } + + hidden_state = m.replace_instruction( + ins, make_op("concat", {{"axis", 1}}), {ret_forward[1], ret_reverse[1]}); } else { @@ -779,7 +919,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const auto r = args[2]; // bias - instruction_ref bias = prog.end(); + instruction_ref bias = m.end(); if(args.size() >= 4 && args[3]->name() != "undefined") { bias = args[3]; @@ -793,7 +933,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const } else { - ih = prog.add_literal(migraphx::literal{ihc_shape, ihc_data}); + ih = m.add_literal(migraphx::literal{ihc_shape, ihc_data}); } // initial cell value @@ -804,74 +944,65 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const } else { - ic = prog.add_literal(migraphx::literal{ihc_shape, ihc_data}); + ic = m.add_literal(migraphx::literal{ihc_shape, ihc_data}); } // process weight of the peephole - instruction_ref pph = prog.end(); + instruction_ref pph = m.end(); if(args.size() == 8 && args[7]->name() != "undefined") { pph = args[7]; } + if(!is_forward and variable_seq_len) + { + args[0] = + m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); + } auto ret = lstm_cell(is_forward, - prog, + m, ins, - {args[0], w, r, bias, ih, ic, pph}, + {args[0], w, r, bias, seq_lens, ih, ic, pph}, actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2)); - last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); - last_cell_output = ret[2]; - if(ret[0] == prog.end()) + last_hs_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]); + last_cell_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[3]); + + if(ret[0] == m.end()) { - prog.replace_instruction(ins, op::concat{0}, ret[1]); + cell_outputs = ret[3]; + hidden_state = m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]); } else { + auto concat_cell_arg0 = is_forward ? ret[2] : ret[3]; + auto concat_cell_arg1 = is_forward ? ret[3] : ret[2]; + cell_outputs = m.insert_instruction( + ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1); + auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg1 = is_forward ? ret[1] : ret[0]; - prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1); + hidden_state = m.replace_instruction( + ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1); } } - // replace the corresponding lstm_last_output instruction - // with the last_output, and the lstm_last_cell_output with - // the last_cell_output. The while loop is to handle the case - // of multiple lstm_last_output and lstm_last_cell_output - // operators - auto last_output_it = ins->outputs().begin(); - while(last_output_it != ins->outputs().end()) - { - last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) { - return i->name() == "rnn_last_output"; - }); - - if(last_output_it != ins->outputs().end()) - { - prog.replace_instruction(*last_output_it, last_output); - last_output_it++; - } - } + // in case of all sequences are of the same lengths and shorter than the + // max sequence length, need to pad 0's at the end for output hidden states + hidden_state = pad_hidden_states(m, args[0], seq_lens, hidden_state); - auto last_cell_output_it = ins->outputs().begin(); - while(last_cell_output_it != ins->outputs().end()) - { - last_cell_output_it = std::find_if(last_cell_output_it, ins->outputs().end(), [](auto i) { - return i->name() == "lstm_last_cell_output"; - }); + // replace last hidden states with corresponding instructions + ins = replace_last_hs_output(m, hidden_state, seq_lens, last_hs_output, dirct); - if(last_cell_output_it != ins->outputs().end()) - { - prog.replace_instruction(*last_cell_output_it, last_cell_output); - last_cell_output_it++; - } - } + // replace last cell outputs with corresponding instructions + replace_last_cell_output(m, ins, seq_lens, cell_outputs, last_cell_output, dirct); } +// NOLINTNEXTLINE(readability-function-cognitive-complexity) std::vector rewrite_rnn::lstm_cell(bool is_forward, - program& prog, + module& m, instruction_ref ins, std::vector inputs, const operation& actv_func1, @@ -879,146 +1010,175 @@ std::vector rewrite_rnn::lstm_cell(bool is_forward, const operation& actv_func3) const { // must have 7 args in the input vector - assert(inputs.size() == 7); - auto seq = inputs.at(0); - auto w = inputs.at(1); - auto r = inputs.at(2); - auto bias = inputs.at(3); - auto ih = inputs.at(4); - auto ic = inputs.at(5); - auto pph = inputs.at(6); - - instruction_ref hidden_states = prog.end(); - instruction_ref last_output{}; + assert(inputs.size() == 8); + auto seq = inputs.at(0); + auto w = inputs.at(1); + auto r = inputs.at(2); + auto bias = inputs.at(3); + auto seq_lens = inputs.at(4); + auto ih = inputs.at(5); + auto ic = inputs.at(6); + auto pph = inputs.at(7); + + instruction_ref hidden_states = m.end(); + instruction_ref cell_outputs = m.end(); + + instruction_ref last_hs_output{}; instruction_ref last_cell_output{}; - migraphx::shape seq_shape = seq->get_shape(); - migraphx::shape r_shape = r->get_shape(); - long seq_len = static_cast(seq_shape.lens()[0]); - long hs = static_cast(r_shape.lens()[2]); - auto bs = ih->get_shape().lens()[1]; + migraphx::shape r_shape = r->get_shape(); + long hs = r_shape.lens()[2]; + auto bs = ih->get_shape().lens()[1]; std::vector perm{1, 0}; // w matrix, squeeze and transpose - auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); - auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw); + auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w); + auto tsw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw); // r matrix, squeeze and transpose - auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); - auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr); + auto sr = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r); + auto tsr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr); // initial hidden state - auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); + auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih); // initial cell state - auto sic = prog.insert_instruction(ins, op::squeeze{{0}}, ic); + auto sic = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ic); auto ic_lens = sic->get_shape().lens(); // bias instruction_ref wrb{}; - if(bias != prog.end()) + if(bias != m.end()) { - auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); - auto ub_wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias); - auto ub_rb = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias); - auto ub_wrb = prog.insert_instruction(ins, op::add{}, ub_wb, ub_rb); + auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias); + auto ub_wb = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4 * hs}}}), sbias); + auto ub_rb = m.insert_instruction( + ins, + make_op("slice", {{"axes", {0}}, {"starts", {4 * hs}}, {"ends", {8 * hs}}}), + sbias); + auto ub_wrb = m.insert_instruction(ins, make_op("add"), ub_wb, ub_rb); - wrb = prog.insert_instruction( - ins, op::broadcast{1, {bs, 4 * static_cast(hs)}}, ub_wrb); + wrb = m.insert_instruction( + ins, + make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast(hs)}}}), + ub_wrb); } // peep hole instruction_ref pphi_brcst{}; instruction_ref ppho_brcst{}; instruction_ref pphf_brcst{}; - if(pph != prog.end()) + if(pph != m.end()) { - auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph); - auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph); - pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphi); - - auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph); - ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, ppho); - - auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph); - pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphf); + auto spph = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), pph); + auto pphi = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph); + pphi_brcst = m.insert_instruction( + ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphi); + + auto ppho = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph); + ppho_brcst = m.insert_instruction( + ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), ppho); + + auto pphf = m.insert_instruction( + ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph); + pphf_brcst = m.insert_instruction( + ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf); } + long seq_len = get_seq_len(m, seq, seq_lens); for(long i = 0; i < seq_len; ++i) { long seq_index = is_forward ? i : (seq_len - 1 - i); - auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); - xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); - - auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw); - auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr); - auto xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr); - if(bias != prog.end()) + auto xt = m.insert_instruction( + ins, + make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}), + seq); + auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt); + xt = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt); + + auto xt_tsw = m.insert_instruction(ins, make_op("dot"), xt, tsw); + auto sih_tsr = m.insert_instruction(ins, make_op("dot"), sih, tsr); + auto xt_sih = m.insert_instruction(ins, make_op("add"), xt_tsw, sih_tsr); + if(bias != m.end()) { - xt_sih = prog.insert_instruction(ins, op::add{}, xt_sih, wrb); + xt_sih = m.insert_instruction(ins, make_op("add"), xt_sih, wrb); } - auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih); - auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih); - auto ft_before_actv = - prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih); - auto ct_before_actv = - prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih); + auto it_before_actv = m.insert_instruction( + ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_sih); + auto ot_before_actv = m.insert_instruction( + ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_sih); + auto ft_before_actv = m.insert_instruction( + ins, + make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), + xt_sih); + auto ct_before_actv = m.insert_instruction( + ins, + make_op("slice", {{"axes", {1}}, {"starts", {3 * hs}}, {"ends", {4 * hs}}}), + xt_sih); - if(pph != prog.end()) + if(pph != m.end()) { - auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic); - it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct); + auto pphi_ct = m.insert_instruction(ins, make_op("mul"), pphi_brcst, sic); + it_before_actv = m.insert_instruction(ins, make_op("add"), it_before_actv, pphi_ct); - auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic); - ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct); + auto pphf_ct = m.insert_instruction(ins, make_op("mul"), pphf_brcst, sic); + ft_before_actv = m.insert_instruction(ins, make_op("add"), ft_before_actv, pphf_ct); } - auto it = prog.insert_instruction(ins, actv_func1, it_before_actv); - auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv); - auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv); + auto it = m.insert_instruction(ins, actv_func1, it_before_actv); + auto ft = m.insert_instruction(ins, actv_func1, ft_before_actv); + auto ct = m.insert_instruction(ins, actv_func2, ct_before_actv); // equation Ct = ft (.) Ct-1 + it (.) ct - auto ft_cell = prog.insert_instruction(ins, op::mul{}, ft, sic); - auto it_ct = prog.insert_instruction(ins, op::mul{}, it, ct); - auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct); - last_cell_output = cellt; + auto ft_cell = m.insert_instruction(ins, make_op("mul"), ft, sic); + auto it_ct = m.insert_instruction(ins, make_op("mul"), it, ct); + auto cellt = m.insert_instruction(ins, make_op("add"), ft_cell, it_ct); - if(pph != prog.end()) + if(pph != m.end()) { - auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt); - ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt); + auto ppho_cellt = m.insert_instruction(ins, make_op("mul"), ppho_brcst, cellt); + ot_before_actv = m.insert_instruction(ins, make_op("add"), ot_before_actv, ppho_cellt); } - auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv); + auto ot = m.insert_instruction(ins, actv_func1, ot_before_actv); // Ht = ot (.) h(Ct) - auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt); - auto ht = prog.insert_instruction(ins, op::mul{}, ot, h_cellt); + auto h_cellt = m.insert_instruction(ins, actv_func3, cellt); + auto ht = m.insert_instruction(ins, make_op("mul"), ot, h_cellt); sic = cellt; sih = ht; - last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht); + last_hs_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht); + last_cell_output = + m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), cellt); if(i < seq_len - 1) { if(i == 0) { - hidden_states = last_output; + hidden_states = last_hs_output; + cell_outputs = last_cell_output; } else { - auto concat_arg0 = is_forward ? hidden_states : last_output; - auto concat_arg1 = is_forward ? last_output : hidden_states; - hidden_states = - prog.insert_instruction(ins, op::concat{0}, concat_arg0, concat_arg1); + auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output; + auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states; + hidden_states = m.insert_instruction( + ins, make_op("concat", {{"axis", 0}}), concat_hs_arg0, concat_hs_arg1); + + auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output; + auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs; + cell_outputs = m.insert_instruction( + ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1); } } } - last_cell_output = prog.insert_instruction(ins, op::unsqueeze{{0}}, last_cell_output); - - return {hidden_states, last_output, last_cell_output}; + return {hidden_states, last_hs_output, cell_outputs, last_cell_output}; } std::vector rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const @@ -1035,7 +1195,12 @@ std::vector rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const switch(num_actv_funcs) { case 0: - return {op::sigmoid{}, op::tanh{}, op::tanh{}, op::sigmoid{}, op::tanh{}, op::tanh{}}; + return {make_op("sigmoid"), + make_op("tanh"), + make_op("tanh"), + make_op("sigmoid"), + make_op("tanh"), + make_op("tanh")}; case 1: return {actv_funcs.at(0), @@ -1084,7 +1249,7 @@ std::vector rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const { switch(num_actv_funcs) { - case 0: return {op::sigmoid{}, op::tanh{}, op::tanh{}}; + case 0: return {make_op("sigmoid"), make_op("tanh"), make_op("tanh")}; case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)}; @@ -1095,14 +1260,164 @@ std::vector rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const } } -namespace op { -std::ostream& operator<<(std::ostream& os, rnn_direction v) +bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens) const { - std::vector rnn_direction_str = {"forward", "reverse", "bidirectional"}; - os << rnn_direction_str[static_cast::type>(v)]; - return os; + bool is_var_lens = false; + if(seq_lens != m.end()) + { + if(seq_lens->can_eval()) + { + auto arg_lens = seq_lens->eval(); + std::vector vec_lens; + arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); }); + int64_t l = 0; + if(!vec_lens.empty()) + { + l = vec_lens[0]; + } + if(!std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; })) + { + is_var_lens = true; + } + } + else + { + is_var_lens = true; + } + } + + return is_var_lens; +} + +std::size_t +rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const +{ + bool is_var_lens = is_variable_seq_lens(m, seq_lens); + auto input_shape = input->get_shape(); + auto length = input_shape.lens()[0]; + if(!is_var_lens and seq_lens != m.end()) + { + auto arg_len = seq_lens->eval(); + std::vector vec_lens; + arg_len.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); }); + length = vec_lens.empty() ? length : vec_lens[0]; + } + + return length; +} + +instruction_ref rewrite_rnn::replace_last_hs_output(module& m, + instruction_ref ins, + instruction_ref seq_lens, + instruction_ref last_hs_output, + op::rnn_direction dirct) const +{ + bool variable_seq_len = is_variable_seq_lens(m, seq_lens); + instruction_ref result_ins{}; + if(variable_seq_len) + { + result_ins = + m.insert_instruction(std::next(ins), + make_op("rnn_var_sl_shift_output", + {{"output_name", "hidden_states"}, {"direction", dirct}}), + ins, + seq_lens); + m.replace_instruction(ins, result_ins); + auto hs_outputs = find_all(result_ins->outputs(), + [&](auto i) { return i->name() == "rnn_last_hs_output"; }); + + for(auto& hs_out : hs_outputs) + { + auto inputs = hs_out->inputs(); + m.replace_instruction(hs_out, + make_op("rnn_var_sl_last_output", {{"direction", dirct}}), + inputs.front(), + seq_lens); + } + } + else + { + auto hs_outputs = + find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_hs_output"; }); + + for(auto& hs_out : hs_outputs) + { + m.replace_instruction(hs_out, last_hs_output); + } + + result_ins = ins; + } + + return result_ins; +} + +void rewrite_rnn::replace_last_cell_output(module& m, + instruction_ref ins, + instruction_ref seq_lens, + instruction_ref cell_outputs, + instruction_ref last_cell_output, + op::rnn_direction dirct) const +{ + bool variable_seq_len = is_variable_seq_lens(m, seq_lens); + auto ins_outputs = + find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_cell_output"; }); + + if(variable_seq_len) + { + if(!ins_outputs.empty()) + { + cell_outputs = m.insert_instruction( + std::next(ins), + make_op("rnn_var_sl_shift_output", + {{"output_name", "cell_outputs"}, {"direction", dirct}}), + cell_outputs, + seq_lens); + } + + for(auto co : ins_outputs) + { + m.replace_instruction(co, + make_op("rnn_var_sl_last_output", {{"direction", dirct}}), + cell_outputs, + seq_lens); + } + } + // replace the rnn_last_cell_output with the last_cell_output. The while + // loop is to handle the case of multiple rnn_last_cell_output operators + else + { + for(auto co : ins_outputs) + { + m.replace_instruction(co, last_cell_output); + } + } +} + +instruction_ref rewrite_rnn::pad_hidden_states(module& m, + instruction_ref seq, + instruction_ref seq_lens, + instruction_ref hs) const +{ + auto max_seq_len = seq->get_shape().lens()[0]; + auto seq_len = get_seq_len(m, seq, seq_lens); + + // condition of all sequence are of the same length and + // less than max_seq_len, we need to append the hs outputs + auto hs_padded = hs; + if(seq_len < max_seq_len) + { + auto s = hs->get_shape(); + auto pad_lens = s.lens(); + pad_lens[0] = static_cast(max_seq_len - seq_len); + shape pad_s{s.type(), pad_lens}; + std::vector pad_data(pad_s.elements(), 0.0f); + auto pl = m.add_literal(pad_s, pad_data.begin(), pad_data.end()); + hs_padded = m.insert_instruction(std::next(hs), make_op("concat", {{"axis", 0}}), hs, pl); + m.replace_instruction(hs, hs_padded); + } + + return hs_padded; } -} // namespace op } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/schedule.cpp b/src/schedule.cpp index 69de68b4c95542b03169abb13384438974a9e8d4..0572b3952474350d69003d74a4b58c749c3401db 100644 --- a/src/schedule.cpp +++ b/src/schedule.cpp @@ -1,20 +1,24 @@ #include #include #include -#include #include +#include #include #include #include #include +#include #include #include #include #include #include +#include + #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -36,6 +40,9 @@ struct stream_info std::unordered_map ins2stream; std::unordered_map weights; std::unordered_map iweights; + ins_dep_map mod_implicit_deps; + + void calc_implicit_deps(const module& m) { mod_implicit_deps = m.calc_implicit_deps(); } void accumulate_weights(instruction_ref last, const schedule_model& model) { @@ -46,12 +53,21 @@ struct stream_info auto&& op = ins->get_operator(); if(not is_context_free(op) and op.name()[0] != '@') weight = model.weight(op); + // This will ensure a stream will be assigned to return + if(op.name() == "@return") + weight = 1; iweights[ins] = weight; - weights[ins] = - std::accumulate(ins->inputs().begin(), - ins->inputs().end(), - weight, - [&](std::size_t w, instruction_ref i) { return w + self(i); }); + auto inputs = ins->inputs(); + if(contains(mod_implicit_deps, ins)) + { + const auto& impl_deps = mod_implicit_deps.at(ins); + inputs.insert(inputs.end(), impl_deps.begin(), impl_deps.end()); + } + + weights[ins] = std::accumulate( + inputs.begin(), inputs.end(), weight, [&](std::size_t w, instruction_ref i) { + return w + self(i); + }); } return weights[ins]; })(last); @@ -75,7 +91,7 @@ struct stream_info return args.end(); } - const std::size_t min_partition_threshold = 1; + const std::size_t min_partition_threshold = 2; sort_args_by_weight(args, std::greater<>{}); auto it = std::lower_bound(std::next(args.begin()), @@ -100,17 +116,19 @@ struct stream_info } }; - std::size_t assign_streams(program& p, std::size_t n) + std::size_t assign_streams(module& m, std::size_t n) { assert(n > 0); partition critical; std::unordered_map> partitions; partitions.reserve(weights.size()); fix([&](auto self, auto ins, auto& part) { - assert(ins != p.end()); + assert(not is_end(ins, m.end())); + if(not m.has_instruction(ins)) + return; if(contains(partitions, ins)) return; - assert(p.has_instruction(ins)); + // Add an entry so we know the instruction was visited partitions[ins]; part.add(ins, this->iweights[ins]); @@ -133,8 +151,8 @@ struct stream_info } } // Sort instructions - p.move_instruction(ins, p.end()); - })(std::prev(p.end()), critical); + m.move_instruction(ins, m.end()); + })(std::prev(m.end()), critical); // Set the critical partition to stream 0 set_stream(critical, 0); @@ -179,13 +197,13 @@ struct stream_info } }; - void sort(program& p, std::size_t) const + void sort(module& m, std::size_t) { std::set children; std::unordered_map visited; - auto last = std::prev(p.end()); + auto last = std::prev(m.end()); auto mw = this->weights.at(last); - auto nw = mw / (p.size() + 1); + auto nw = mw / (m.size() + 1); auto add_child = [&](auto ins) { auto x = 1 + (mw - this->weights.at(ins)) / (nw + 1); auto w = x * this->iweights.at(ins); @@ -204,12 +222,34 @@ struct stream_info // Pop the first element auto top = children.begin()->second; children.erase(children.begin()); - - p.move_instruction(top, p.begin()); + m.move_instruction(top, m.begin()); for(auto ins : top->inputs()) { + if(not m.has_instruction(ins)) + continue; add_child(ins); } + + if(contains(mod_implicit_deps, top)) + { + for(auto ins : mod_implicit_deps.at(top)) + { + assert(m.has_instruction(ins)); + add_child(ins); + } + } + } + + // move dangling parameter to the front so as not be removed + auto ins = std::next(last); + while(ins != m.end()) + { + auto next = std::next(ins); + if(ins->name() == "@param") + { + m.move_instruction(ins, m.begin()); + } + ins = next; } } @@ -263,20 +303,12 @@ struct stream_info { return [=](auto f) { return fix([&](auto self, auto ins) { - for(auto i : select(ins)) - { + return all_of(select(ins), [&](auto i) { if(iweights.at(i) == 0) - { - if(not self(i)) - return false; - } + return self(i); else - { - if(not f(this->get_stream(i))) - return false; - } - } - return true; + return f(this->get_stream(i)); + }); })(start); }; } @@ -332,25 +364,33 @@ struct stream_info } std::unordered_map>> - find_concurrent_instructions(program& p) + find_concurrent_instructions(module& m) const { std::unordered_map>> result; std::unordered_map> merge_from; - result.reserve(p.size()); - merge_from.reserve(p.size()); - for(auto ins : reverse_iterator_for(p)) + dominator_info di = compute_dominator(m); + result.reserve(m.size()); + merge_from.reserve(m.size()); + for(auto ins : reverse_iterator_for(m)) { for(auto&& arg : ins->outputs()) { + if(not m.has_instruction(arg)) + continue; if(is_merge_point(arg)) merge_from[ins].insert(arg); merge_from[ins].insert(merge_from[arg].begin(), merge_from[arg].end()); } - auto streams = this->get_streams(ins); + if(is_split_point(ins)) + { + erase_if(merge_from[ins], + [&](auto merge) { return di.strictly_dominate(ins, merge); }); + } + auto streams = this->get_streams(ins); // Collect concur instructions for each merge point. - for(auto& merge : merge_from[ins]) + for(const auto& merge : merge_from[ins]) { for(auto stream : streams) { @@ -375,12 +415,19 @@ struct stream_info } std::unordered_map> - get_conflicts(program& p) + get_conflicts(module& m) { + using conflict_table_type = std::unordered_map>; conflict_table_type conflict_table; - auto concur_ins = this->find_concurrent_instructions(p); + auto concur_ins = this->find_concurrent_instructions(m); + + // Compute an index for each instruction + std::unordered_map ins2index; + std::size_t index_total = 0; + for(auto ins : iterator_for(m)) + ins2index[ins] = index_total++; std::vector thread_conflict_tables( std::thread::hardware_concurrency()); @@ -423,14 +470,13 @@ struct stream_info for(auto ins1 : ins1_set) { - auto p1 = std::distance(ins1, merge_first); + auto p1 = ins2index.at(ins1); for(auto ins2 : ins2_set) { if(ins1 == ins2) continue; - auto p2 = std::distance(ins2, merge_first); - // The smaller distance means the instruction occurs later - if(p1 > p2) + auto p2 = ins2index.at(ins2); + if(p2 > p1) thrd_table[ins2].insert(ins1); else thrd_table[ins1].insert(ins2); @@ -461,19 +507,24 @@ struct stream_info } }; -void schedule::apply(program& p) const +void schedule::apply(module& m) const { if(not enable) return; + stream_info si; - auto last = std::prev(p.end()); + si.calc_implicit_deps(m); + auto last = std::prev(m.end()); si.accumulate_weights(last, model); - auto nstreams = si.assign_streams(p, model.concurrency()); - si.sort(p, model.concurrency()); + auto nstreams = si.assign_streams(m, model.concurrency()); + si.sort(m, model.concurrency()); if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{})) { - p.annotate(std::cout, [&](auto ins) { + m.annotate(std::cout, [&](auto ins) { + if(ins->name() == "@param" and not contains(si.weights, ins)) + return; + std::cout << ":"; std::cout << " weight=" << si.weights.at(ins); std::cout << " input={"; @@ -497,9 +548,9 @@ void schedule::apply(program& p) const std::unordered_map ins2wait; std::unordered_map> waited_for; std::unordered_map> ins2waited; - ins2wait.reserve(p.size()); - ins2waited.reserve(p.size()); - for(auto ins : iterator_for(p)) + ins2wait.reserve(m.size()); + ins2waited.reserve(m.size()); + for(auto ins : iterator_for(m)) { // Only schedule instructions that have a stream if(not si.has_stream(ins)) @@ -508,29 +559,27 @@ void schedule::apply(program& p) const // Schedule instruction on the stream auto stream = si.get_stream(ins); assert(stream < model.concurrency()); - model.sched(p, ins, stream); + model.sched(m, ins, stream); // Insert wait instructions if(si.is_merge_point(ins, stream)) { for(auto i : si.get_recorded_instructions(ins)) { - if(not si.has_stream(i)) - continue; - auto istream = si.get_stream(i); - if(stream == istream) + if(not si.has_stream(i) or si.get_stream(i) == stream) continue; + // Create a new event if it hasn't been recorded if(not contains(ins2wait, i)) { ins2wait[i] = wait_id; - model.record(p, i, wait_id); + model.record(m, i, wait_id); wait_id++; } auto w = ins2wait.at(i); // If we already waited for the event on this stream then dont // insert another wait event if(not contains(waited_for[stream], w)) - model.wait(p, ins, w); + model.wait(m, ins, w); // Store the event as waited waited_for[stream].insert(w); // Store all wait events that have been waited on prior to the recorded instruction @@ -545,7 +594,7 @@ void schedule::apply(program& p) const } // Add memory conflicts - auto conflict_table = si.get_conflicts(p); + auto conflict_table = si.get_conflicts(m); for(auto&& ip : conflict_table) { if(ip.second.empty()) @@ -553,7 +602,7 @@ void schedule::apply(program& p) const std::vector args; args.push_back(ip.first); args.insert(args.end(), ip.second.begin(), ip.second.end()); - p.insert_instruction(std::next(ip.first), op::identity{}, args); + m.insert_instruction(std::next(ip.first), make_op("identity"), args); } } diff --git a/src/serialize.cpp b/src/serialize.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b53cc2ba2aeb6f000a9bc5299f884fc701384352 --- /dev/null +++ b/src/serialize.cpp @@ -0,0 +1,43 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +void raw_data_to_value(value& v, const RawData& rd) +{ + value result; + result["shape"] = migraphx::to_value(rd.get_shape()); + if(rd.get_shape().type() == shape::tuple_type) + result["sub"] = migraphx::to_value(rd.get_sub_objects()); + else + result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes()); + v = result; +} + +void migraphx_to_value(value& v, const literal& l) { raw_data_to_value(v, l); } +void migraphx_from_value(const value& v, literal& l) +{ + auto s = migraphx::from_value(v.at("shape")); + l = literal(s, v.at("data").get_binary().data()); +} + +void migraphx_to_value(value& v, const argument& a) { raw_data_to_value(v, a); } +void migraphx_from_value(const value& v, argument& a) +{ + if(v.contains("data")) + { + literal l = migraphx::from_value(v); + a = l.get_argument(); + } + else + { + a = migraphx::from_value>(v.at("sub")); + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/shape.cpp b/src/shape.cpp index aa86e23471391c8b374c217f475ac4393d3e2477..9daf451674d6c86652c2d392fe0e5e9655db33b8 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -1,9 +1,12 @@ #include #include +#include +#include #include #include #include +#include #include namespace migraphx { @@ -13,34 +16,40 @@ struct shape_impl { static std::shared_ptr default_shape() { - static std::shared_ptr result = std::make_shared(); + static const std::shared_ptr result = std::make_shared(); return result; } - shape_impl() : m_type(shape::float_type), m_standard(false) {} + shape_impl() : m_type(shape::float_type) {} - shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true) {} + shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true) + { + assert(t != shape::tuple_type); + } shape_impl(shape::type_t t, std::vector l) : m_type(t), m_lens(std::move(l)), m_standard(true) { + assert(t != shape::tuple_type); this->calculate_strides(); assert(m_lens.size() == m_strides.size()); } shape_impl(shape::type_t t, std::vector l, std::vector s) : m_type(t), m_lens(std::move(l)), m_strides(std::move(s)) { + assert(t != shape::tuple_type); assert(m_lens.size() == m_strides.size()); // assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and // "At least one stride must be non-zero"); - m_standard = - this->elements() == this->element_space() and - std::is_sorted(m_strides.rbegin(), m_strides.rend()) and - std::none_of(m_strides.begin(), m_strides.end(), [](auto x) { return x == 0; }); + m_standard = this->elements() == this->element_space() and + std::is_sorted(m_strides.rbegin(), m_strides.rend()); } + + shape_impl(const std::vector& subs) : m_type(shape::tuple_type), m_shapes(subs) {} shape::type_t m_type; - std::vector m_lens; - std::vector m_strides; - bool m_standard; + std::vector m_lens = {}; + std::vector m_strides = {}; + std::vector m_shapes = {}; + bool m_standard = false; void calculate_strides() { @@ -77,8 +86,43 @@ struct shape_impl return std::accumulate( m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies()); } + + std::shared_ptr copy() const { return std::make_shared(*this); } }; +const std::vector& shape::types() +{ + static const std::vector result = { +#define MIGRAPHX_GENERATE_TYPE_VECTOR(x, t) x, + MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR) tuple_type}; + return result; +} + +std::string shape::name(shape::type_t t) +{ + switch(t) + { + case tuple_type: return "tuple_type"; +#define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \ + case x: return #x; + MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE) +#undef MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE + } + MIGRAPHX_THROW("Invalid type"); +} +std::string shape::cpp_type(shape::type_t t) +{ + switch(t) + { + case tuple_type: MIGRAPHX_THROW("No C++ type for tuple"); +#define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \ + case x: return #t; + MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE) +#undef MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE + } + MIGRAPHX_THROW("Invalid type"); +} + shape::shape() : impl(shape_impl::default_shape()) {} shape::shape(type_t t) : impl(std::make_shared(t)) {} @@ -91,20 +135,45 @@ shape::shape(type_t t, std::vector l, std::vector s) { } +shape::shape(const std::vector& subs) : impl(std::make_shared(subs)) {} + +shape::shape(std::shared_ptr pimpl) : impl(std::move(pimpl)) {} + +shape shape::from_permutation(type_t t, + const std::vector& l, + const std::vector& perm) +{ + auto new_lens = reorder_dims(l, perm); + shape result = reorder_shape({t, new_lens}, invert_permutation(perm)); + assert(result.lens() == l); + return result; +} + shape::type_t shape::type() const { return impl->m_type; } const std::vector& shape::lens() const { return impl->m_lens; } const std::vector& shape::strides() const { return impl->m_strides; } std::size_t shape::elements() const { return impl->elements(); } std::size_t shape::bytes() const { - std::size_t n = 0; - this->visit_type([&](auto as) { n = as.size(); }); - return n * this->element_space(); + if(this->sub_shapes().empty()) + { + std::size_t n = 0; + this->visit_type([&](auto as) { n = as.size(); }); + return n * this->element_space(); + } + else + { + return std::accumulate(this->sub_shapes().begin(), + this->sub_shapes().end(), + std::size_t{0}, + [&](auto x, auto y) { return x + y.bytes(); }); + } } std::size_t shape::type_size() const { std::size_t n = 0; - this->visit_type([&](auto as) { n = as.size(); }); + if(this->sub_shapes().empty()) + this->visit_type([&](auto as) { n = as.size(); }); return n; } std::size_t shape::index(std::initializer_list l) const @@ -146,19 +215,30 @@ std::vector shape::multi(std::size_t i) const assert(this->standard()); std::vector indices(lens().size()); + multi_copy(i, indices.data(), indices.data() + lens().size()); + + return indices; +} + +void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const +{ + assert(this->standard()); + (void)end; + assert(lens().size() <= (end - start)); std::transform(strides().begin(), strides().end(), lens().begin(), - indices.begin(), + start, [&](std::size_t stride, std::size_t len) { assert(len > 0 and stride > 0); return (i / stride) % len; }); - - return indices; } -bool shape::packed() const { return this->elements() == this->element_space(); } +bool shape::packed() const +{ + return this->sub_shapes().empty() and this->elements() == this->element_space(); +} bool shape::transposed() const { @@ -192,38 +272,100 @@ bool shape::scalar() const { assert(this->lens().size() == this->strides().size()); // if any stride > 0, then accumulate will return false - return std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0; + return this->sub_shapes().empty() and + std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0; } bool shape::standard() const { return impl->m_standard; } -std::size_t shape::element_space() const { return impl->element_space(); } +shape shape::normalize_standard() const +{ + if(this->standard()) + return {this->type(), this->lens()}; + else + return *this; +} -std::string shape::type_string() const +shape shape::with_lens(type_t t, const std::vector& l) const { - switch(this->type()) - { -#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE(x, t) \ - case x: return #x; - MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE) -#undef MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE - } - MIGRAPHX_THROW("Invalid type"); + assert(l.size() == this->lens().size()); + auto perm = find_permutation(*this); + return shape::from_permutation(t, l, perm); } +shape shape::with_lens(const std::vector& l) const +{ + return this->with_lens(this->type(), l); +} + +shape shape::with_type(type_t t) const +{ + auto c = impl->copy(); + c->m_type = t; + return {c}; +} + +std::size_t shape::element_space() const { return impl->element_space(); } + +std::string shape::type_string() const { return name(this->type()); } + bool operator==(const shape& x, const shape& y) { - return x.type() == y.type() && x.lens() == y.lens() && x.strides() == y.strides(); + return x.impl == y.impl or (x.type() == y.type() and x.lens() == y.lens() and + x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes()); } bool operator!=(const shape& x, const shape& y) { return !(x == y); } std::ostream& operator<<(std::ostream& os, const shape& x) { - os << x.type_string() << ", "; - os << "{" << to_string_range(x.lens()) << "}, "; - os << "{" << to_string_range(x.strides()) << "}"; + if(x.sub_shapes().empty()) + { + os << x.type_string() << ", "; + os << "{" << to_string_range(x.lens()) << "}, "; + os << "{" << to_string_range(x.strides()) << "}"; + } + else + { + os << "[" << to_string_range(x.sub_shapes()) << "]"; + } return os; } +shape::type_t shape::parse_type(const std::string& s) +{ + static const std::unordered_map m = { +#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP(x, t) {#x, x}, {#t, x}, + MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP){"tuple_type", + tuple_type}, + {"tuple", tuple_type}}; + return m.at(s); +} + +const std::vector& shape::sub_shapes() const { return impl->m_shapes; } + +void migraphx_to_value(value& v, const shape& s) +{ + value result; + result["type"] = migraphx::to_value(s.type_string()); + result["lens"] = migraphx::to_value(s.lens()); + result["strides"] = migraphx::to_value(s.strides()); + result["sub_shapes"] = migraphx::to_value(s.sub_shapes()); + v = result; +} +void migraphx_from_value(const value& v, shape& s) +{ + auto t = v.at("type").get_string(); + if(t == "tuple_type") + { + s = shape{migraphx::from_value>(v.at("sub_shapes"))}; + } + else + { + s = shape{shape::parse_type(t), + v.at("lens").to_vector(), + v.at("strides").to_vector()}; + } +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 3fbd5d71b893e5b58d09c4a26b313dd75254e8d8..8db6afe915a3a287d6d3ba8ceb56ffde055d8ddc 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1,14 +1,19 @@ #include #include #include -#include -#include #include +#include #include -#include #include +#include +#include #include #include +#include +#include + +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -27,6 +32,8 @@ auto conv_const_weights() match::args(match::any(), match::is_constant().bind("w"))); } +auto reduction() { return match::name_contains("reduce"); } + struct find_mul_conv { auto matcher() const @@ -35,7 +42,7 @@ struct find_mul_conv match::name("broadcast").bind("a"))); } - void apply(program& p, match::matcher_result r) const + void apply(module& m, const match::matcher_result& r) const { auto ins = r.result; auto conv_ins = r.instructions["conv"]; @@ -46,12 +53,106 @@ struct find_mul_conv if(broadcast_op.axis != 1) return; - auto new_a = p.insert_instruction( - ins, op::broadcast{0, w_ins->get_shape().lens()}, a_ins->inputs().front()); - auto new_mul = p.insert_instruction(ins, op::mul{}, new_a, w_ins); - auto new_conv = p.insert_instruction( + auto new_a = m.insert_instruction( + ins, + make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}), + a_ins->inputs().front()); + auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, w_ins); + auto new_conv = m.insert_instruction( ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul); - p.replace_instruction(ins, new_conv); + m.replace_instruction(ins, new_conv); + } +}; + +struct find_mul_slice_conv +{ + static auto conv() + { + return match::name("convolution")( + match::all_of[match::outputs()](match::name("slice")), + match::args(match::any(), match::is_constant().bind("w"))); + } + auto matcher() const + { + return match::name("mul")(match::either_arg(0, 1)( + match::name("slice")(match::used_once(), match::arg(0)(conv().bind("conv"))) + .bind("slice"), + match::name("broadcast")(match::is_constant()).bind("a"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto slice_ins = r.instructions["slice"]; + auto conv_ins = r.instructions["conv"]; + auto a_ins = r.instructions["a"]; + auto w_ins = r.instructions["w"]; + + auto broadcast_op = any_cast(a_ins->get_operator()); + if(broadcast_op.axis != 1) + return; + + auto slice_op = any_cast(slice_ins->get_operator()); + if(slice_op.axes.size() != 1) + return; + if(slice_op.axes.front() != 1) + return; + + auto slice_idx = std::distance(conv_ins, slice_ins); + if(std::any_of(conv_ins->outputs().begin(), conv_ins->outputs().end(), [&](auto i) { + if(i == slice_ins) + return false; + if(std::distance(conv_ins, i) < slice_idx) + return true; + auto sop = any_cast(i->get_operator()); + if(sop.axes != slice_op.axes) + return true; + if(std::max(sop.starts.front(), slice_op.starts.front()) < + std::min(sop.ends.front(), slice_op.ends.front())) + return true; + return false; + })) + return; + + auto w_slice_op = slice_op; + w_slice_op.axes = {0}; + auto slice_w_ins = m.insert_instruction(ins, w_slice_op, w_ins); + + auto new_a = m.insert_instruction( + ins, + make_op("broadcast", {{"axis", 0}, {"out_lens", slice_w_ins->get_shape().lens()}}), + a_ins->inputs().front()); + auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins); + + std::vector sliced_weights; + if(slice_op.starts.front() != 0) + sliced_weights.push_back(m.insert_instruction( + ins, + make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", slice_op.starts}}), + w_ins)); + sliced_weights.push_back(new_mul); + int64_t end_axis = w_ins->get_shape().lens().at(0); + if(slice_op.ends.front() != end_axis) + sliced_weights.push_back(m.insert_instruction( + ins, + make_op("slice", {{"axes", {0}}, {"starts", slice_op.ends}, {"ends", {end_axis}}}), + w_ins)); + + auto new_weights = + m.insert_instruction(ins, make_op("concat", {{"axis", 0}}), sliced_weights); + + auto new_conv = m.insert_instruction( + ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights); + assert(conv_ins->get_shape() == new_conv->get_shape()); + + auto slice1 = m.insert_instruction(ins, slice_op, new_conv); + assert(ins->get_shape().lens() == slice1->get_shape().lens()); + m.replace_instruction(ins, slice1); + // TODO: Check each slice doesn't overlap and that it occurs after slice_ins + auto outputs = conv_ins->outputs(); + for(auto output : outputs) + if(output != slice_ins) + instruction::replace_argument(output, conv_ins, new_conv); } }; @@ -70,7 +171,7 @@ struct find_mul_add match::is_constant().bind("a"))); } - void apply(program& p, match::matcher_result r) const + void apply(module& m, const match::matcher_result& r) const { auto ins = r.result; auto a_ins = r.instructions["a"]; @@ -78,9 +179,9 @@ struct find_mul_add auto x_ins = r.instructions["x"]; assert(x_ins != b_ins); - auto ax_ins = p.insert_instruction(ins, op::mul{}, a_ins, x_ins); - auto ab_ins = p.insert_instruction(ins, op::mul{}, a_ins, b_ins); - p.replace_instruction(ins, op::add{}, ax_ins, ab_ins); + auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins); + auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins); + m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins); } }; @@ -92,15 +193,15 @@ struct find_add_lit_broadcast match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b"))); } - void apply(program& p, match::matcher_result r) const + void apply(module& m, const match::matcher_result& r) const { auto ins = r.result; auto x_ins = r.instructions["x"]; auto a_ins = r.instructions["a"]; auto b_ins = r.instructions["b"]; - auto sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins); - p.replace_instruction(ins, op::add{}, x_ins, sumab); + auto sumab = m.insert_instruction(ins, make_op("add"), a_ins, b_ins); + m.replace_instruction(ins, make_op("add"), x_ins, sumab); } }; @@ -112,7 +213,7 @@ struct find_double_add_lit_broadcast match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y"))); } - void apply(program& p, match::matcher_result r) const + void apply(module& m, const match::matcher_result& r) const { auto ins = r.result; auto x_ins = r.instructions["x"]; @@ -126,18 +227,18 @@ struct find_double_add_lit_broadcast { if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape()) return; - auto op = a_ins->get_operator(); - auto presum = - p.insert_instruction(ins, op::add{}, a_ins->inputs().at(0), b_ins->inputs().at(0)); - sumab = p.insert_instruction(ins, op, presum); + auto op = a_ins->get_operator(); + auto presum = m.insert_instruction( + ins, make_op("add"), a_ins->inputs().at(0), b_ins->inputs().at(0)); + sumab = m.insert_instruction(ins, op, presum); } else { - sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins); + sumab = m.insert_instruction(ins, make_op("add"), a_ins, b_ins); } - auto sumxy = p.insert_instruction(ins, op::add{}, x_ins, y_ins); - p.replace_instruction(ins, op::add{}, sumxy, sumab); + auto sumxy = m.insert_instruction(ins, make_op("add"), x_ins, y_ins); + m.replace_instruction(ins, make_op("add"), sumxy, sumab); } }; @@ -145,11 +246,12 @@ struct find_inner_broadcast { auto matcher() const { - return match::name("mul", "add")( + return pointwise( + match::nargs(2), match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y"))); } - void apply(program& p, match::matcher_result r) const + void apply(module& m, const match::matcher_result& r) const { auto ins = r.result; auto x_ins = r.instructions["x"]; @@ -161,9 +263,365 @@ struct find_inner_broadcast if(xbroadcast.axis != ybroadcast.axis) return; - auto op = p.insert_instruction( + auto op = m.insert_instruction( ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front()); - p.replace_instruction(ins, xbroadcast, op); + m.replace_instruction(ins, xbroadcast, op); + } +}; + +struct find_concat_op +{ + auto matcher() const + { + return match::name("concat")(match::any_of[match::inputs()]( + match::any_of(match::pointwise(), match::name("broadcast")), match::used_once())); + } + + template + static std::vector get_output_lens(Iterator start, Iterator last, std::size_t axis) + { + assert(start != last); + std::size_t dim = 0; + for(auto ins : range(start, last)) + { + dim += ins->get_shape().lens().at(axis); + } + auto lens = (*start)->get_shape().lens(); + lens[axis] = dim; + return lens; + } + + static bool is_valid_op(const operation& op) + { + return op.name() == "broadcast" or op.attributes().contains("pointwise"); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto axis = any_cast(ins->get_operator()).axis; + + auto each = [&](auto start, auto last) -> std::vector { + if(std::distance(start, last) < 2) + return {start, last}; + auto x = *start; + if(x->inputs().size() > 2 or x->inputs().empty() or x->outputs().size() > 1) + return {start, last}; + auto op = x->get_operator(); + if(not is_valid_op(op)) + return {start, last}; + auto iaxis = axis; + // Adjust broadcast lens + if(op.name() == "broadcast") + { + auto b = any_cast(op); + if(b.axis != iaxis) + return {start, last}; + b.broadcast_lens = get_output_lens(start, last, iaxis); + op = b; + iaxis = 0; + } + + std::vector concats; + for(std::size_t i = 0; i < x->inputs().size(); i++) + { + std::vector inputs; + std::transform(start, last, std::back_inserter(inputs), [&](auto j) { + return j->inputs().at(i); + }); + auto concat = + m.insert_instruction(ins, make_op("concat", {{"axis", iaxis}}), inputs); + concats.push_back(concat); + } + auto y = m.insert_instruction(ins, op, concats); + return {y}; + }; + + std::vector args; + auto update_args = [&](auto start, auto last) { + auto x = each(start, last); + args.insert(args.end(), x.begin(), x.end()); + }; + auto pred = [](auto i, auto j) { + return i->get_operator() == j->get_operator() and + i->inputs().size() == i->inputs().size() and + i->outputs().size() == i->outputs().size(); + }; + group_unique(ins->inputs().begin(), ins->inputs().end(), update_args, pred); + if(args.size() == 1) + m.replace_instruction(ins, args.front()); + else + m.replace_instruction(ins, make_op("concat", {{"axis", axis}}), args); + } +}; + +std::vector get_splits(instruction_ref ins) +{ + std::vector result; + std::copy_if(ins->outputs().begin(), + ins->outputs().end(), + std::back_inserter(result), + [&](auto i) { return i->name() == "slice"; }); + if(result.size() < 2) + return {}; + auto get_slice = [](auto& i) -> auto& { return any_cast(i->get_operator()); }; + auto&& axes = get_slice(result.front()).axes; + if(std::any_of(result.begin(), result.end(), [&](auto i) { return get_slice(i).axes != axes; })) + return {}; + auto get_start = [&](auto& i) -> auto& { return get_slice(i).starts; }; + auto get_end = [&](auto& i) -> auto& { return get_slice(i).ends; }; + std::sort( + result.begin(), result.end(), [&](auto x, auto y) { return get_start(x) < get_start(y); }); + if(std::any_of(get_start(result.front()).begin(), get_start(result.front()).end(), [&](auto i) { + return i != 0; + })) + return {}; + auto it = std::adjacent_find( + result.begin(), result.end(), [&](auto x, auto y) { return get_end(x) != get_start(y); }); + if(it != result.end()) + return {}; + for(std::size_t i = 0; i < axes.size(); i++) + { + auto axis = axes[i]; + if(ins->get_shape().lens()[axis] != get_slice(result.back()).ends[i]) + return {}; + } + return result; +} + +struct find_splits +{ + auto matcher() const + { + return match::any(match::any_of[match::outputs()](match::name("slice")( + match::any_of[match::outputs()](match::pointwise(), reduction())))); + } + + static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2) + { + + std::unordered_set traversed; + return fix([&](auto self, auto ins) -> bool { + if(ins == ins2) + return true; + + if(contains(traversed, ins)) + return false; + + traversed.insert(ins); + const auto& inputs = ins->inputs(); + return std::any_of(inputs.begin(), inputs.end(), [&](auto in) { + return m.has_instruction(in) and self(in); + }); + })(ins1); + } + + static std::vector> + get_split_groups(const module& m, const std::vector& splits) + { + std::vector> groups; + for(auto out : splits.front()->outputs()) + { + if(out->name() == "slice") + continue; + std::vector group; + for(auto split : splits) + { + auto it = + std::find_if(split->outputs().begin(), split->outputs().end(), [&](auto i) { + return i->get_operator() == out->get_operator(); + }); + if(it == split->outputs().end()) + break; + assert((*it)->name() != "slice"); + + // If there is a duplicate bail + // there are should be no dependency between instructions in the group + if(std::any_of(group.begin(), group.end(), [&](auto i) { + return is_dependent(m, *it, i) or is_dependent(m, i, *it); + })) + { + return {}; + } + + group.push_back(*it); + } + if(group.size() != splits.size()) + continue; + groups.push_back(group); + } + return groups; + } + + bool is_fusable(instruction_ref start, instruction_ref split_front) const + { + auto op = start->get_operator(); + if(contains(op.name(), "reduce")) + { + auto slc = any_cast(split_front->get_operator()); + auto slc_axes = slc.axes; + auto reduce_axes = start->get_operator().to_value()["axes"].to_vector(); + // axes of slice and reduce op cannot have overlap + if(std::any_of(slc_axes.begin(), slc_axes.end(), [&](auto axis) { + return (std::find(reduce_axes.begin(), reduce_axes.end(), axis) != + reduce_axes.end()); + })) + { + return false; + } + } + else if(not op.attributes().contains("pointwise")) + { + return false; + } + + return true; + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto splits = get_splits(ins); + if(splits.empty()) + return; + + for(const auto& group : get_split_groups(m, splits)) + { + auto start = group.front(); + auto split_front = splits.front(); + auto op = start->get_operator(); + if(not is_fusable(start, split_front)) + { + continue; + } + + // Make sure there is no duplicates + assert(std::none_of( + std::next(group.begin()), group.end(), [&](auto i) { return i == start; })); + + auto split_idx = 0; + instruction_ref c = m.end(); + if(start->inputs().size() == 1) + { + c = m.insert_instruction(std::next(ins), op, ins); + } + else if(start->inputs().size() == 2) + { + assert(not std::none_of(start->inputs().begin(), start->inputs().end(), [](auto i) { + return i->name() == "slice"; + }) && "one argument must be a split"); + auto data_idx = 1; + if(start->inputs().back()->name() == "slice") + { + split_idx = 1; + data_idx = 0; + } + + std::vector data_args; + std::transform(group.begin(), + group.end(), + std::back_inserter(data_args), + [&](auto i) { return i->inputs()[data_idx]; }); + + // Data arguments must be a constant + if(std::any_of(data_args.begin(), data_args.end(), [](auto i) { + return not i->can_eval(); + })) + return; + + for(auto data : data_args) + m.move_instructions(data, ins); + + auto slice_op = any_cast(splits.front()->get_operator()); + assert(not slice_op.axes.empty()); + if(slice_op.axes.size() > 1) + return; + auto concat_axis = slice_op.axes.front(); + // TODO: Check if axises match + auto concat = m.insert_instruction( + ins, make_op("concat", {{"axis", concat_axis}}), data_args); + + std::vector args; + args.resize(2); + args[split_idx] = ins; + args[data_idx] = concat; + c = m.insert_instruction(std::next(ins), op, args); + } + if(c != m.end()) + { + for(auto i : group) + { + auto split = i->inputs()[split_idx]; + assert(split->name() == "slice"); + // Insert contiguous for reshapes + auto outputs = i->outputs(); + for(auto output : outputs) + { + if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name())) + continue; + auto x = + m.insert_instruction(output, make_op("contiguous"), output->inputs()); + m.replace_instruction(output, output->get_operator(), x); + } + + m.replace_instruction(i, split->get_operator(), c); + } + } + } + } +}; + +struct find_split_concat +{ + auto matcher() const + { + return match::any(match::any_of[match::outputs()]( + match::name("slice")(match::all_of[match::outputs()](match::name("concat"))))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + + auto splits = get_splits(ins); + if(splits.empty()) + return; + if(std::any_of( + splits.begin(), splits.end(), [](auto i) { return i->outputs().size() != 1; })) + return; + // Check for concat operator + auto concat = splits.front()->outputs().front(); + if(std::any_of(splits.begin(), splits.end(), [&](auto i) { + return i->outputs().front() != concat; + })) + return; + // Check axis match + auto concat_op = any_cast(concat->get_operator()); + auto split_op = any_cast(splits.front()->get_operator()); + if(split_op.axes.size() != 1) + return; + if(split_op.axes.front() != concat_op.axis) + return; + // Replace args + auto args = concat->inputs(); + auto it = + std::find_if(args.begin(), args.end(), [&](auto i) { return i == splits.front(); }); + if(std::distance(it, args.end()) < splits.size()) + return; + // If the slices are not in order then stop + if(not std::is_sorted(it, it + splits.size(), [](instruction_ref x, instruction_ref y) { + auto xop = any_cast(x->get_operator()); + auto yop = any_cast(y->get_operator()); + return std::tie(xop.starts, xop.ends) < std::tie(yop.starts, yop.ends); + })) + return; + *it = splits.front()->inputs().front(); + args.erase(std::next(it), it + splits.size()); + + if(args.size() == 1) + m.replace_instruction(concat, args.front()); + else + m.replace_instruction(concat, concat->get_operator(), args); } }; @@ -206,17 +664,7 @@ struct find_add_convs return x.stride[0] / y.stride[0]; } - static shape compute_stride_shape(const shape& input, std::size_t n) - { - return {input.type(), - {input.lens()[0], input.lens()[1], input.lens()[2] / n, input.lens()[3] / n}, - {input.strides()[0], - input.strides()[1], - input.strides()[2] * n, - input.strides()[3] * n}}; - } - - void apply(program& p, match::matcher_result r) const + void apply(module& m, const match::matcher_result& r) const { auto ins = r.result; auto a_conv = r.instructions["a"]; @@ -245,8 +693,8 @@ struct find_add_convs if(n == 0) return; new_op = a_op; - b_input = p.insert_instruction( - ins, op::as_shape{compute_stride_shape(b_input->get_shape(), n)}, b_input); + b_input = m.insert_instruction( + ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), b_input); } else if(b_op.stride < a_op.stride) { @@ -254,8 +702,8 @@ struct find_add_convs if(n == 0) return; new_op = b_op; - a_input = p.insert_instruction( - ins, op::as_shape{compute_stride_shape(a_input->get_shape(), n)}, a_input); + a_input = m.insert_instruction( + ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), a_input); } else return; @@ -264,25 +712,328 @@ struct find_add_convs return; } - auto concat_input = p.insert_instruction(ins, op::concat{1}, a_input, b_input); - auto concat_weights = p.insert_instruction(ins, op::concat{1}, a_weights, b_weights); - p.replace_instruction(ins, new_op, concat_input, concat_weights); + auto concat_input = + m.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_input, b_input); + auto concat_weights = + m.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_weights, b_weights); + m.replace_instruction(ins, new_op, concat_input, concat_weights); + } +}; + +MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins) +{ + auto pred = [&](auto name) { + return [=](auto i) { + return i->name() == name and i->inputs().front() == ins and + i->inputs().at(1)->can_eval(); + }; + }; + auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")); + auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution")); + return !(dots < 2 and convs < 2); +} + +struct find_conv_dot_horiz_fusion +{ + auto matcher() const { return horiz_conv_dot(); } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + + auto pred = [](auto i, auto j) { + if(i->get_operator() != j->get_operator()) + return false; + if(not contains({"dot", "convolution"}, i->name())) + return true; + auto x = i->inputs()[1]->get_shape().lens(); + auto y = j->inputs()[1]->get_shape().lens(); + if(x.size() != y.size()) + return false; + // Check that non-axises match + int axis = 1; + if(i->name() == "dot") + { + axis = x.size() - 1; + } + return axis_equal(x, y, axis); + }; + + auto each = [&](auto start, auto last) { + if(std::distance(start, last) < 2) + return; + auto&& name = (*start)->name(); + if(not contains({"dot", "convolution"}, name)) + return; + auto op = (*start)->get_operator(); + int group = 1; + if(name == "convolution") + group = any_cast(op).group; + // Skip group convolution + if(group != 1) + return; + auto input = (*start)->inputs().front(); + std::vector args; + std::transform( + start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); }); + int axis = 1; + int concat_axis = 0; + if(name == "dot") + { + axis = int(args.front()->get_shape().lens().size() - 1); + concat_axis = axis; + } + + for(auto arg : args) + m.move_instructions(arg, input); + // TODO: Check if axises match + auto concat = + m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args); + auto fused = m.insert_instruction(std::next(input), op, input, concat); + int64_t offset = 0; + for(auto arg : range(start, last)) + { + int64_t len = arg->get_shape().lens()[axis]; + m.replace_instruction( + arg, + make_op("slice", + {{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}), + fused); + offset += len; + } + }; + + auto outputs = ins->outputs(); + group_by(outputs.begin(), outputs.end(), each, pred); + } +}; + +struct find_div_const +{ + auto matcher() const + { + return match::name("div")(match::arg(1)(match::is_constant().bind("c"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto c_ins = r.instructions["c"]; + + auto recip = m.insert_instruction(std::next(c_ins), make_op("recip"), c_ins); + + auto args = ins->inputs(); + + m.replace_instruction(ins, make_op("mul"), args.front(), recip); + } +}; + +struct find_sub_const +{ + auto matcher() const + { + return match::name("sub")(match::arg(1)(match::is_constant().bind("c"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto c_ins = r.instructions["c"]; + + auto neg = m.insert_instruction(std::next(c_ins), make_op("neg"), c_ins); + + auto args = ins->inputs(); + + m.replace_instruction(ins, make_op("add"), args.front(), neg); + } +}; + +struct find_rsqrt +{ + auto matcher() const + { + return match::name("recip")(match::args( + match::name("sqrt")(match::used_once(), match::args(match::any().bind("x"))))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto x_ins = r.instructions["x"]; + + m.replace_instruction(ins, make_op("rsqrt"), x_ins); + } +}; + +static bool same_ops(const std::vector& vec_ins) +{ + return std::all_of(vec_ins.begin(), vec_ins.end(), [&](auto i) { + return i->get_operator() == vec_ins.front()->get_operator(); + }); +} + +struct find_split_reshape +{ + auto matcher() const + { + return match::name("reshape")(match::arg(0)(match::name("contiguous")( + match::arg(0)(match::name("slice").bind("slice"))))) + .bind("reshape"); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto slc = r.instructions["slice"]; + auto rsp = r.instructions["reshape"]; + + auto input = slc->inputs().front(); + auto split_outputs = get_splits(input); + if(split_outputs.empty()) + { + return; + } + + std::vector vec_rsp(split_outputs.size()); + std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) { + assert(i->outputs().size() == 1); + auto cont = i->outputs().front(); + assert(cont->outputs().size() == 1); + return cont->outputs().front(); + }); + + // all outputs are reshape and of the same shape + auto dims = any_cast(rsp->get_operator()).dims; + if(!same_ops(vec_rsp)) + { + return; + } + + // ensure reshape happens after the axis dimension + auto axis = any_cast(slc->get_operator()).axes[0]; + auto slc_lens = slc->get_shape().lens(); + auto slc_dim_size = std::accumulate( + slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies()); + + // search the reshape output (standard shape) to decide which axis are + // in its output corresponding to the slc_dim_size + auto rsp_lens = rsp->get_shape().lens(); + auto rsp_strides = rsp->get_shape().strides(); + rsp_strides.insert(rsp_strides.begin(), rsp_strides[0] * rsp_lens[0]); + auto ait = std::find(rsp_strides.begin(), rsp_strides.end(), slc_dim_size); + if(ait == rsp_strides.end()) + { + return; + } + int rsp_axis = std::distance(rsp_strides.begin(), ait); + + // calculate reshape output shape + std::vector vec_dims(vec_rsp.size()); + std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) { + return is->get_shape().lens()[rsp_axis]; + }); + + std::vector rsp_out_lens(rsp_lens.begin(), rsp_lens.end()); + rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0}); + + // insert the reshape instruction + auto rsp_ins = m.insert_instruction( + std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input); + + // replace the original reshape with slice + int64_t start = 0; + for(std::size_t i = 0; i < vec_rsp.size(); ++i) + { + m.replace_instruction( + vec_rsp[i], + make_op( + "slice", + {{"axes", {rsp_axis}}, {"starts", {start}}, {"ends", {start + vec_dims[i]}}}), + rsp_ins); + start += vec_dims[i]; + } + } +}; + +struct find_split_transpose +{ + auto matcher() const + { + return match::name("transpose")(match::arg(0)(match::name("slice").bind("slice"))) + .bind("trans"); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto slc = r.instructions["slice"]; + auto trans = r.instructions["trans"]; + + auto input = slc->inputs().front(); + auto split_outputs = get_splits(input); + if(split_outputs.empty()) + { + return; + } + + std::vector vec_trans(split_outputs.size()); + std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) { + assert(i->outputs().size() == 1); + return i->outputs().front(); + }); + + // all transpose are the same + auto perm = any_cast(trans->get_operator()).dims; + if(!same_ops(vec_trans)) + { + return; + } + + // insert an transpose instruction + auto tr = m.insert_instruction( + std::next(input), make_op("transpose", {{"permutation", perm}}), input); + + // compute the axis in the slice + auto axis = any_cast(slc->get_operator()).axes.front(); + auto it = std::find(perm.begin(), perm.end(), axis); + assert(it != perm.end()); + int64_t axis_new = std::distance(perm.begin(), it); + + for(auto in : split_outputs) + { + auto oper = any_cast(in->get_operator()); + auto starts = oper.starts; + auto ends = oper.ends; + auto tr_orig = in->outputs().front(); + m.replace_instruction( + tr_orig, + make_op("slice", {{"axes", {axis_new}}, {"starts", starts}, {"ends", ends}}), + tr); + } } }; -void simplify_algebra::apply(program& p) const +void simplify_algebra::apply(module& m) const { // Run simplifications multiple times - for(int i = 0; i < 4; i++) + for(int i = 0; i < 8; i++) { - match::find_matches(p, + match::find_matches(m, find_inner_broadcast{}, find_double_add_lit_broadcast{}, find_add_lit_broadcast{}, find_add_convs{}, + find_conv_dot_horiz_fusion{}, find_mul_conv{}, - find_mul_add{}); - dead_code_elimination{}.apply(p); + find_mul_slice_conv{}, + find_mul_add{}, + find_div_const{}, + find_sub_const{}, + find_rsqrt{}, + find_concat_op{}, + find_split_concat{}, + find_splits{}, + find_split_reshape{}, + find_split_transpose{}); + dead_code_elimination{}.apply(m); } } diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a34adb47ef8bdcb5ebdfb76ddf5689a7a23d615a --- /dev/null +++ b/src/simplify_qdq.cpp @@ -0,0 +1,156 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +std::unordered_set get_quantizable_op_names() +{ + static std::unordered_set s = {"convolution", "dot"}; + return s; +} + +MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins) +{ + if(ins->name() != "@literal") + return false; + bool all_same = false; + ins->get_literal().visit([&](auto s) { + all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) { + return float_equal(scale, s.front()); + }); + }); + return all_same; +} + +struct match_find_quantizable_ops +{ + + static auto dequantizelinear_op(const std::string& name, const std::string& scale) + { + return match::name("dequantizelinear")( + match::arg(0)(match::skip(match::name("quantizelinear"))(match::any().bind(name))), + match::arg(1)(match::skip_broadcasts(has_same_value().bind(scale))), + match::arg(2)(match::skip_broadcasts(match::all_of(match::has_value(0))))); + } + + auto matcher() const + { + return match::name(get_quantizable_op_names())( + match::arg(0)(dequantizelinear_op("x1", "scale1")), + match::arg(1)(dequantizelinear_op("x2", "scale2"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto qop = r.result; + auto q1 = r.instructions["x1"]; + auto q2 = r.instructions["x2"]; + auto scale1 = r.instructions["scale1"]; + auto scale2 = r.instructions["scale2"]; + + // Only INT8 type currently supported + if(q1->get_shape().type() != migraphx::shape::int8_type or + q2->get_shape().type() != migraphx::shape::int8_type) + return; + + double scale; + visit_all(scale1->get_literal(), scale2->get_literal())( + [&](const auto s1, const auto s2) { scale = s1.front() * s2.front(); }); + + auto qop_args = qop->inputs(); + qop_args.at(0) = q1; + qop_args.at(1) = q2; + instruction_ref dq; + instruction_ref dq_scale; + instruction_ref zero_point; + if(qop->name() == "convolution") + { + auto conv_val = qop->get_operator().to_value(); + dq = m.insert_instruction( + qop, migraphx::make_op("quant_convolution", conv_val), qop_args); + } + else if(qop->name() == "dot") + { + dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args); + } + auto ins_type = qop->get_shape().type(); + dq_scale = m.add_literal(literal({ins_type}, {scale})); + + auto lens = dq->get_shape().lens(); + auto scale_mb = + m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale); + dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb); + m.replace_instruction(qop, dq); + } +}; + +bool compare_literals(instruction_ref ins1, instruction_ref ins2) +{ + if(ins1->name() == "broadcast" or ins1->name() == "multibroadcast") + ins1 = ins1->inputs().front(); + auto x = ins1->eval(); + if(x.empty()) + return false; + auto literal1 = ins1->get_literal(); + if(ins2->name() == "broadcast" or ins2->name() == "multibroadcast") + ins2 = ins2->inputs().front(); + auto y = ins2->eval(); + if(y.empty()) + return false; + auto literal2 = ins2->get_literal(); + + bool diff_shapes_equal_vals = false; + visit_all(ins1->get_literal(), ins2->get_literal())([&](const auto l1, const auto l2) { + diff_shapes_equal_vals = + std::all_of( + l1.begin() + 1, l1.end(), [&](auto v) { return float_equal(v, l1.front()); }) and + std::all_of(l2.begin(), l2.end(), [&](auto v) { return float_equal(v, l1.front()); }); + }); + + return (x == y) or diff_shapes_equal_vals; +} + +void remove_qdq_pairs(module& m) +{ + for(auto ins : iterator_for(m)) + { + auto args = ins->inputs(); + for(auto&& arg : args) + { + if(arg->name() == "dequantizelinear") + { + auto q = arg->inputs().front(); + if((q->name() == "quantizelinear") and + compare_literals(arg->inputs().at(1), q->inputs().at(1)) and + compare_literals(arg->inputs().at(2), q->inputs().at(2))) + { + instruction::replace_argument(ins, arg, q->inputs().front()); + } + } + } + } +} + +void simplify_qdq::apply(module& m) const +{ + match::find_matches(m, match_find_quantizable_ops{}); + migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); + remove_qdq_pairs(m); + migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 4e03ec4ed11a6df7b498cd0d3f8c573aa45f96ba..c4ad4260383fc8e6e5d8501b3ce12f8339a03a70 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -1,14 +1,21 @@ +#include #include #include #include #include #include #include +#include #include #include #include #include +#include #include +#include +#include + +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -63,19 +70,19 @@ struct find_reshaper match::any_of[match::outputs()](match::name(reshaper_names()))); } - void apply(program& p, const match::matcher_result& mr) const + void apply(module& m, const match::matcher_result& mr) const { auto ins = mr.result; std::vector reshapes{ins}; while(is_reshaper(reshapes.back())) { assert(!reshapes.back()->inputs().empty()); - assert(p.has_instruction(reshapes.back()->inputs().front())); + assert(m.has_instruction(reshapes.back()->inputs().front())); auto input = reshapes.back()->inputs().front(); reshapes.push_back(input); } - std::pair r{p.end(), p.end()}; + std::pair r{m.end(), m.end()}; for(auto start : iterator_for(reshapes)) { auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) { @@ -89,7 +96,7 @@ struct find_reshaper } if(r.first != r.second) { - p.replace_instruction(r.first, r.second); + m.replace_instruction(r.first, r.second); } } }; @@ -102,6 +109,7 @@ struct find_nop_reshapes reshapes.insert("as_shape"); reshapes.insert("broadcast"); reshapes.insert("concat"); + reshapes.insert("convert"); reshapes.insert("multibroadcast"); reshapes.insert("pad"); reshapes.insert("slice"); @@ -109,10 +117,10 @@ struct find_nop_reshapes return match::name(reshapes)(match::same_shape(match::arg(0))); } - void apply(program& p, const match::matcher_result& mr) const + void apply(module& m, const match::matcher_result& mr) const { auto ins = mr.result; - p.replace_instruction(ins, ins->inputs().front()); + m.replace_instruction(ins, ins->inputs().front()); } }; @@ -124,7 +132,7 @@ struct find_transpose match::skip_output(match::name("contiguous"))(match::name("transpose")))); } - void apply(program& p, const match::matcher_result& mr) const + void apply(module& m, const match::matcher_result& mr) const { auto ins = mr.result; auto x = ins; @@ -141,12 +149,99 @@ struct find_transpose return; if(is_no_transpose(dims)) { - p.replace_instruction(ins, t->inputs().front()); + m.replace_instruction(ins, t->inputs().front()); } else { - p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front()); + m.replace_instruction( + ins, make_op("transpose", {{"permutation", dims}}), t->inputs().front()); + } + } +}; + +struct find_nested_convert +{ + auto matcher() const { return match::name("convert")(match::arg(0)(match::name("convert"))); } + + void apply(module& m, const match::matcher_result& mr) const + { + auto ins = mr.result; + auto x = ins->inputs().front(); + auto input = x->inputs().front(); + + if(ins->get_shape() != input->get_shape()) + return; + + m.replace_instruction(ins, input); + } +}; + +struct find_nested_slice +{ + auto matcher() const { return match::name("slice")(match::arg(0)(match::name("slice"))); } + + using axes_map = std::map>; + + static axes_map get_axes(instruction_ref ins) + { + axes_map result; + auto op = any_cast(ins->get_operator()); + for(std::size_t i = 0; i < op.axes.size(); i++) + { + result[op.axes[i]] = std::make_pair(op.starts[i], op.ends[i]); + } + return result; + } + + static axes_map merge(const axes_map& m1, const axes_map& m2) + { + axes_map result; + // Non overlapping + for(auto&& p : m1) + { + if(contains(m2, p.first)) + continue; + result[p.first] = p.second; + } + for(auto&& p : m2) + { + if(contains(m1, p.first)) + continue; + result[p.first] = p.second; + } + // Overlapping + for(auto&& p1 : m1) + { + if(not contains(m2, p1.first)) + continue; + auto&& v1 = p1.second; + auto&& v2 = m2.at(p1.first); + auto start = v1.first + v2.first; + auto end = start + (v2.second - v2.first); + result[p1.first] = std::make_pair(start, end); + } + return result; + } + + void apply(module& m, const match::matcher_result& mr) const + { + auto ins = mr.result; + auto slice = ins->inputs().front(); + auto input = slice->inputs().front(); + + auto a1 = get_axes(ins); + auto a2 = get_axes(slice); + + auto axes = merge(a2, a1); + + auto op = op::slice{}; + for(auto&& pp : axes) + { + op.axes.push_back(pp.first); + op.starts.push_back(pp.second.first); + op.ends.push_back(pp.second.second); } + m.replace_instruction(ins, op, input); } }; @@ -157,25 +252,41 @@ struct find_concat_transpose return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape())); } - void apply(program& p, const match::matcher_result& mr) const + void apply(module& m, const match::matcher_result& mr) const { - auto ins = mr.result; - auto s = ins->inputs().front()->get_shape(); + auto ins = mr.result; + auto trans_inputs = ins->inputs(); + auto s = trans_inputs.front()->get_shape(); assert(s.transposed()); - auto op = any_cast(ins->get_operator()); - auto permutation = find_permutation(s); + auto op = any_cast(ins->get_operator()); + auto permutation = find_permutation(s); + + // permutation should be the same for all inputs + if(!std::all_of(trans_inputs.begin(), trans_inputs.end(), [&](auto in) { + return (find_permutation(in->get_shape()) == permutation); + })) + { + return; + } + + // axis could be a negative value + int64_t n_dim = static_cast(s.lens().size()); + op.axis = tune_axis(n_dim, op.axis, op.name()); + auto ipermutation = invert_permutation(permutation); op.axis = ipermutation[op.axis]; std::vector inputs; std::transform( ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) { - return p.insert_instruction(ins, op::transpose{permutation}, i); + return m.insert_instruction( + ins, make_op("transpose", {{"permutation", permutation}}), i); }); - auto concat = p.insert_instruction(ins, op, inputs); - auto t = p.insert_instruction(ins, op::transpose{ipermutation}, concat); + auto concat = m.insert_instruction(ins, op, inputs); + auto t = m.insert_instruction( + ins, make_op("transpose", {{"permutation", ipermutation}}), concat); assert(ins->get_shape().lens() == t->get_shape().lens()); - p.replace_instruction(ins, t); + m.replace_instruction(ins, t); } }; @@ -192,7 +303,7 @@ struct find_nested_concat return op.axis; } - void apply(program& p, const match::matcher_result& mr) const + void apply(module& m, const match::matcher_result& mr) const { auto ins = mr.result; auto axis = get_axis(ins); @@ -205,33 +316,286 @@ struct find_nested_concat else args.push_back(i); } - })(ins->inputs()); - p.replace_instruction(ins, ins->get_operator(), args); + m.replace_instruction(ins, ins->get_operator(), args); } }; -void simplify_reshapes::apply(program& p) const +struct find_resize { - for(int i = 0; i < 2; i++) + auto matcher() const + { + return match::name("gather")( + match::args(match::name("reshape").bind("data"), match::is_constant().bind("ind"))); + } + + void apply(module& m, const match::matcher_result& r) const { - auto end = std::prev(p.end()); - for(auto ins : iterator_for(p)) + auto ins = r.result; + auto ins_rsp = r.instructions["data"]; + auto ins_ind = r.instructions["ind"]; + + // resize input shape + if(ins_rsp->get_shape().lens().size() != 1) { - if(ins == end and ins->name() == "contiguous") - continue; - // Skip possible dead instructions - if(ins->outputs().empty() and ins != end) + return; + } + + // resize output shape + const auto& in_shape = ins_rsp->inputs().front()->get_shape(); + const auto& out_shape = ins->get_shape(); + // check if output shape is multiple of input shape + const auto& in_lens = in_shape.lens(); + const auto& out_lens = out_shape.lens(); + if(in_lens.size() != out_lens.size()) + { + return; + } + + // output shape must be multiple of input shape + std::vector is_multi(in_lens.size()); + std::transform( + in_lens.begin(), in_lens.end(), out_lens.begin(), is_multi.begin(), [](auto x, auto y) { + return (y % x == 0); + }); + if(not std::all_of(is_multi.begin(), is_multi.end(), [](auto b) { return b; })) + { + return; + } + + // output must be multiple of inputs + std::vector scales(in_lens.size()); + std::transform( + in_lens.begin(), in_lens.end(), out_lens.begin(), scales.begin(), [](auto x, auto y) { + return y / x; + }); + + // if ind is not constant, cannot optimize + std::vector vec_ind; + auto arg_ind = ins_ind->eval(); + if(arg_ind.empty()) + { + return; + } + arg_ind.visit([&](auto v) { vec_ind.assign(v.begin(), v.end()); }); + if(not all_of(range(out_shape.elements()), [&](auto i) { + auto out_idx = out_shape.multi(i); + auto in_idx = out_idx; + std::transform(out_idx.begin(), + out_idx.end(), + scales.begin(), + in_idx.begin(), + [&](auto io, auto scale) { return io - (io % scale); }); + return vec_ind[i] == vec_ind[out_shape.index(in_idx)]; + })) + { + return; + } + + // wrap up shapes for multibroadcast + std::vector> dim_scales; + std::transform(in_lens.begin(), + in_lens.end(), + out_lens.begin(), + std::back_inserter(dim_scales), + [](auto x, auto y) { return std::make_pair(x, y / x); }); + + std::vector in_dims; + std::vector out_dims; + for(auto& isp : dim_scales) + { + in_dims.push_back(isp.first); + out_dims.push_back(isp.first * isp.second); + if(isp.first == 1 or isp.second == 1) + { continue; - match::find_matches(p, - ins, - find_nop_reshapes{}, - find_reshaper{}, - find_transpose{}, - find_concat_transpose{}, - find_nested_concat{}); + } + + out_dims.back() = isp.first; + in_dims.push_back(1); + out_dims.push_back(isp.second); + } + + auto in_rsp = ins_rsp->inputs().front(); + auto rsp_data = m.insert_instruction( + ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp); + auto mb_rsp = m.insert_instruction( + ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data); + auto std_mb = m.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp); + std::vector rsp_dims(out_lens.begin(), out_lens.end()); + m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb); + } +}; + +struct find_where_op +{ + auto matcher() const + { + return match::name("gather")( + match::args(match::name("reshape")(match::arg(0)(match::name("concat").bind("data"))), + match::is_constant().bind("ind"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto concat = r.instructions["data"]; + auto ins_ind = r.instructions["ind"]; + std::vector vec_ind; + auto arg_ind = ins_ind->eval(); + arg_ind.visit([&](auto v) { vec_ind.assign(v.begin(), v.end()); }); + // ind has to be the same value + auto val = vec_ind.front(); + if(not std::all_of(vec_ind.begin(), vec_ind.end(), [&](auto v) { return (v == val); })) + { + return; + } + + // concat axis must be 0 + auto op = any_cast(concat->get_operator()); + if(op.axis != 0) + { + return; + } + + // check concat inputs, it has to be 2 and have the same shape + const auto& inputs = concat->inputs(); + if(inputs.size() != 2) + { + return; + } + if(inputs.at(0)->get_shape() != inputs.at(1)->get_shape()) + { + return; + } + if(inputs.at(0)->get_shape().lens() != ins_ind->get_shape().lens()) + { + return; + } + + if(val) + { + m.replace_instruction(ins, inputs.at(0)); + } + else + { + m.replace_instruction(ins, inputs.at(1)); } } +}; + +struct find_reshape_cont +{ + auto matcher() const + { + return match::pointwise( + match::nargs(2), + match::either_arg(0, 1)( + match::name("reshape")(match::args(match::name("contiguous").bind("cont"))) + .bind("rsp"), + match::any())); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto ins_cont = r.instructions["cont"]; + auto in_ins = r.instructions["rsp"]; + + auto cont_input = ins_cont->inputs().front(); + auto lens = cont_input->get_shape().lens(); + std::vector dims(lens.begin(), lens.end()); + + if(in_ins->get_shape() != ins->get_shape()) + { + return; + } + + if(not std::all_of(ins->inputs().begin(), ins->inputs().end(), [](auto i) { + return i->get_shape().standard(); + })) + { + return; + } + + auto out_lens = ins->get_shape().lens(); + std::vector out_dims(out_lens.begin(), out_lens.end()); + std::vector inputs; + for(const auto& in : ins->inputs()) + { + if(in == in_ins) + { + inputs.push_back(cont_input); + } + else + { + inputs.push_back( + m.insert_instruction(ins, make_op("reshape", {{"dims", dims}}), in)); + } + } + auto out = m.insert_instruction(ins, ins->get_operator(), inputs); + m.replace_instruction(ins, make_op("reshape", {{"dims", out_dims}}), out); + } +}; + +// match sequence of transpose --> contiguous --> reshaper_op +auto match_transpose_contiguous_reshaper() +{ + return match::name({"reshape", "squeeze", "unsqueeze"})( + match::used_once(), + match::args( + match::name("contiguous")( + match::used_once(), match::args(match::transpose_shape().bind("trans_ins"))) + .bind("cont_ins"))) + .bind("reshaper_ins"); +}; + +// finds the pattern of transpose --> contiguous --> reshaper_op --> unary +// application of this matcher moves the unary operation before the contiguous so it becomes +// transpose --> unary --> contiguous --> reshaper_op. later pointwise sub-module can be created out +// of unary --> contiguous --> reshaper_op. Such pattern appears in depthToSpace or spaceToDepth +// operator. +struct find_transpose_contiguous_reshaper_unary +{ + auto matcher() const + { + return pointwise(match::used_once(), + match::nargs(1), + match::args(match_transpose_contiguous_reshaper())); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto reshaper_ins = r.instructions["reshaper_ins"]; + auto trans_ins = r.instructions["trans_ins"]; + auto cont_ins = r.instructions["cont_ins"]; + auto unary_op_name = ins->get_operator().name(); + auto unary_ins = m.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins); + auto new_cont_ins = m.insert_instruction(cont_ins, make_op("contiguous"), unary_ins); + // older cont and reshape are removed by deadcode elimination + m.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins); + } +}; + +void simplify_reshapes::apply(module& m) const +{ + for(int i = 0; i < 2; i++) + { + match::find_matches(m, + find_where_op{}, + find_resize{}, + find_reshape_cont{}, + find_nop_reshapes{}, + find_reshaper{}, + find_transpose{}, + find_concat_transpose{}, + find_nested_convert{}, + find_nested_slice{}, + find_nested_concat{}, + find_transpose_contiguous_reshaper_unary{}); + dead_code_elimination{}.apply(m); + } } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/cpu/CMakeLists.txt b/src/targets/cpu/CMakeLists.txt old mode 100644 new mode 100755 index 9ad580324e42991aad4f98493df1fcc6af02d98f..87570257e5ea9c8d5c7608af5644f4ef520b86a0 --- a/src/targets/cpu/CMakeLists.txt +++ b/src/targets/cpu/CMakeLists.txt @@ -1,19 +1,69 @@ +include(CheckCXXCompilerFlag) + add_library(migraphx_cpu - target.cpp - lowering.cpp + allocate.cpp + allocation_model.cpp + binary.cpp + concat.cpp + convolution.cpp + copy.cpp + deconvolution.cpp + dnnl.cpp + eltwise.cpp + erf.cpp + fuse_ops.cpp + gather.cpp gemm.cpp + layernorm.cpp + logsoftmax.cpp + lowering.cpp + lrn.cpp + preallocate.cpp + pooling.cpp + reduction.cpp + reorder.cpp + softmax.cpp + sub.cpp + target.cpp + write_literals.cpp ) set_target_properties(migraphx_cpu PROPERTIES EXPORT_NAME cpu) rocm_set_soversion(migraphx_cpu ${MIGRAPHX_SO_VERSION}) -find_path(BLAZE_INCLUDE blaze/Blaze.h) -find_package(Threads) +set(MIGRAPHX_ENABLE_ZENDNN Off CACHE BOOL "") + +if(MIGRAPHX_ENABLE_ZENDNN) + find_path(ZENDNN_INC_PATH zendnn.hpp) + find_library(ZENDNN_LIB amdZenDNN) + find_library(BLIS_LIB blis) +else() + find_package(dnnl REQUIRED) +endif() rocm_clang_tidy_check(migraphx_cpu) -target_link_libraries(migraphx_cpu migraphx Threads::Threads) -target_include_directories(migraphx_cpu PRIVATE ${BLAZE_INCLUDE}) -target_compile_definitions(migraphx_cpu PRIVATE -DBLAZE_USE_CPP_THREADS) +if(MIGRAPHX_ENABLE_ZENDNN) + target_compile_definitions(migraphx_cpu PRIVATE -DMIGRAPHX_ENABLE_ZENDNN) + target_include_directories(migraphx_cpu PRIVATE ${ZENDNN_INC_PATH}) + message(STATUS "ZENDNN_LIB: ${ZENDNN_LIB}") + target_link_libraries(migraphx_cpu PRIVATE ${BLIS_LIB}) + target_link_libraries(migraphx_cpu PRIVATE ${ZENDNN_LIB}) +else() + target_link_libraries(migraphx_cpu PRIVATE DNNL::dnnl) +endif() +target_link_libraries(migraphx_cpu PRIVATE migraphx) + +find_package(OpenMP) +target_link_libraries(migraphx_cpu PUBLIC OpenMP::OpenMP_CXX) +# Add library path to rpath to workaround issues with our broken packages +foreach(LIBRARY ${OpenMP_CXX_LIBRARIES}) + if(LIBRARY MATCHES "libomp") + get_filename_component(LIBRARY_PATH "${LIBRARY}" PATH) + target_link_libraries(migraphx_cpu PUBLIC -Wl,-rpath=${LIBRARY_PATH} -Wl,-rpath-link=${LIBRARY_PATH}) + endif() +endforeach() + +target_link_libraries(migraphx_all_targets INTERFACE migraphx_cpu) rocm_install_targets( TARGETS migraphx_cpu diff --git a/src/targets/cpu/allocate.cpp b/src/targets/cpu/allocate.cpp new file mode 100755 index 0000000000000000000000000000000000000000..90d74b8069cbad2cbfb7a4a1a724700e6ac995e6 --- /dev/null +++ b/src/targets/cpu/allocate.cpp @@ -0,0 +1,37 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct cpu_allocate : auto_register_op +{ + shape s; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.s, "shape")); + } + + std::string name() const { return "cpu::allocate"; } + shape compute_shape(const std::vector& inputs) const + { + check_shapes{inputs, *this}.has(0); + return s; + } + argument compute(context&, const shape& output_shape, const std::vector&) const + { + argument result{output_shape}; + return result; + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/allocation_model.cpp b/src/targets/cpu/allocation_model.cpp new file mode 100755 index 0000000000000000000000000000000000000000..8ae3441695ab04068fda97c800027b0cf3e87f67 --- /dev/null +++ b/src/targets/cpu/allocation_model.cpp @@ -0,0 +1,23 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +std::string cpu_allocation_model::name() const { return "cpu::allocate"; } +operation cpu_allocation_model::allocate(const shape& s) const +{ + return make_op(name(), {{"shape", to_value(s)}}); +} + +operation cpu_allocation_model::preallocate(const shape& s, const std::string& id) const +{ + return make_op("cpu::preallocate", {{"shape", to_value(s)}, {"id", id}}); +} + +std::string cpu_allocation_model::copy() const { return "cpu::copy"; } + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/binary.cpp b/src/targets/cpu/binary.cpp new file mode 100644 index 0000000000000000000000000000000000000000..abefd31563f957a9724be2d09009b18fa8a8813a --- /dev/null +++ b/src/targets/cpu/binary.cpp @@ -0,0 +1,49 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct dnnl_binary : dnnl_op +{ + std::string algo; + template + static auto reflect(Self& self, F f) + { + return pack_join(self.reflect_base(self, f), pack(f(self.algo, "algo"))); + } + + std::string group() const { return this->name() + "::" + algo; } + + std::string name() const { return "dnnl::binary"; } + + shape compute_shape(std::vector inputs) const + { + // Compensate for allocation + inputs.pop_back(); + check_shapes{this->trim_post_op_inputs(inputs), *this}.has(2); + auto s0 = inputs.at(0); + auto s1 = inputs.at(1); + auto r = s0; + if(s0 != s1 or !s0.packed()) + { + r = shape{s0.type(), s0.lens()}; + } + // Call to get_primitive to make sure an algo is available + this->get_primitive(this->to_memory_desc(r, inputs)); + return r; + } + + dnnl::binary::desc get_desc(const std::unordered_map& m) const + { + return {to_dnnl_algo(algo), + m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)), + m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_1)), + m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST))}; + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/concat.cpp b/src/targets/cpu/concat.cpp new file mode 100755 index 0000000000000000000000000000000000000000..aa70963704fb670844322ac84b3d76d063e9c8de --- /dev/null +++ b/src/targets/cpu/concat.cpp @@ -0,0 +1,44 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct dnnl_concat : dnnl_extend_op +{ + std::vector arg_map(int size) const + { + std::vector result(size); + std::iota(result.begin(), result.end(), MIGRAPHX_DNNL_PREFIX(ARG_MULTIPLE_SRC)); + return result; + } + // Custom desc class since its missing in dnnl + struct desc + { + dnnl::memory::desc dst; + std::size_t axis = 1; + std::vector srcs; + }; + desc get_desc(const std::unordered_map& m) const + { + std::vector srcs; + srcs.reserve(m.size() - 1); + + for(auto i = 0; i < m.size() - 1; i++) + { + srcs.push_back(m.at(MIGRAPHX_DNNL_PREFIX(ARG_MULTIPLE_SRC) + i)); + } + return {m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), std::size_t(op.axis), srcs}; + } + + auto get_primitive_desc(const desc& d, const dnnl::primitive_attr& attr) const + { + return dnnl::concat::primitive_desc(d.dst, d.axis, d.srcs, get_dnnl_context().engine, attr); + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/convolution.cpp b/src/targets/cpu/convolution.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a0bc1d8b5753e20edda0a65c8874843d2251e1c --- /dev/null +++ b/src/targets/cpu/convolution.cpp @@ -0,0 +1,63 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct dnnl_convolution + : dnnl_extend_op +{ + std::vector arg_map(int) const + { + return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)}; + } + + shape adjust_shape(const shape& x, int i) const + { + auto s = base_adjust_shape(x); + if(i == 1 and op.group > 1) + { + // TODO: Add support for transposed weights + if(not s.standard()) + MIGRAPHX_THROW("Weights for grouped convolution must be standard"); + auto lens = s.lens(); + lens.insert(lens.begin(), op.group); + lens.at(1) /= op.group; + return shape{s.type(), lens}; + } + return s; + } + + dnnl::convolution_forward::desc + get_desc(const std::unordered_map& m) const + { + // In DNNL dilation is zero-based + auto dilation = op.dilation; + std::transform( + dilation.begin(), dilation.end(), dilation.begin(), [](auto x) { return x - 1; }); + auto kdims = op.kdims(); + std::vector padding_l(op.padding.begin(), op.padding.begin() + kdims); + std::vector padding_r(op.padding.begin() + kdims, op.padding.end()); + return {dnnl::prop_kind::forward_inference, + dnnl::algorithm::convolution_auto, + m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)), + m.at(MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)), + m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), + to_dnnl_dims(op.stride), + to_dnnl_dims(dilation), + to_dnnl_dims(padding_l), + to_dnnl_dims(padding_r)}; + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/copy.cpp b/src/targets/cpu/copy.cpp new file mode 100755 index 0000000000000000000000000000000000000000..bcb1aaeb62a58401c9407d7d0981564ba34a4c68 --- /dev/null +++ b/src/targets/cpu/copy.cpp @@ -0,0 +1,42 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct cpu_copy : reduce_dims_base, auto_register_op +{ + template + static auto reflect(Self&, F) + { + return pack(); + } + + std::string name() const { return "cpu::copy"; } + shape compute_shape(const std::vector& inputs) const + { + check_shapes{inputs, *this}.has(2); + return inputs.at(1); + } + argument + compute(context& ctx, const shape& output_shape, const std::vector& args) const + { + argument result = get_arg(args, args.size() - 1); + + visit_all(result, get_arg(args, 0))([&](auto output, auto input) { + pointwise(output, input)(ctx, output.get_shape(), 1024, [](auto& y, auto x) { y = x; }); + }); + + return result.reshape(output_shape); + } + + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/deconvolution.cpp b/src/targets/cpu/deconvolution.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8fda440943292353a0b70c3975bce57c80dc28ce --- /dev/null +++ b/src/targets/cpu/deconvolution.cpp @@ -0,0 +1,53 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct dnnl_deconvolution + : dnnl_extend_op +{ + std::vector arg_map(int) const + { + return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)}; + } + + shape adjust_shape(const shape& x, int i) const + { + auto s = base_adjust_shape(x); + if(i == 1) + { + // The input and output channels are flipped for dnnl + auto lens = s.lens(); + std::swap(lens[0], lens[1]); + auto strides = s.strides(); + std::swap(strides[0], strides[1]); + return {s.type(), lens, strides}; + } + return s; + } + + dnnl::deconvolution_forward::desc + get_desc(const std::unordered_map& m) const + { + // In DNNL dilation is zero-based + auto dilation = op.dilation; + std::transform( + dilation.begin(), dilation.end(), dilation.begin(), [](auto x) { return x - 1; }); + return {dnnl::prop_kind::forward_inference, + dnnl::algorithm::deconvolution_direct, + m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)), + m.at(MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)), + m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), + to_dnnl_dims(op.stride), + to_dnnl_dims(dilation), + to_dnnl_dims(op.padding), + to_dnnl_dims(op.padding)}; + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/dnnl.cpp b/src/targets/cpu/dnnl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..995c920db143578cd922b77efaa8a008e52e6480 --- /dev/null +++ b/src/targets/cpu/dnnl.cpp @@ -0,0 +1,181 @@ +#include + +#if defined(__GNUC__) && __GNUC__ <= 5 +namespace std { +#ifdef MIGRAPHX_ENABLE_ZENDNN +namespace dnnl = zendnn; +#endif +template <> +struct hash +{ + using argument_type = dnnl::algorithm; + using result_type = std::size_t; + result_type operator()(const argument_type& x) const noexcept + { + return std::hash>{}( + static_cast>(x)); + } +}; + +} // namespace std +#endif + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +dnnl_context& get_dnnl_context() +{ + static dnnl_context ctx{}; // NOLINT + return ctx; +} + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wswitch-enum" +#endif +dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t) +{ + using dt = dnnl::memory::data_type; + using st = shape::type_t; + switch(t) + { + case st::half_type: return dt::f16; + case st::float_type: return dt::f32; + case st::int32_type: return dt::s32; + case st::int8_type: return dt::s8; + case st::uint8_type: return dt::u8; + default: MIGRAPHX_THROW("Unsupported data type"); + } +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +dnnl::memory::format_tag to_dnnl_memory_format_tag(std::size_t n) +{ + switch(n) + { + case 1: return dnnl::memory::format_tag::a; + case 2: return dnnl::memory::format_tag::ab; + case 3: return dnnl::memory::format_tag::abc; + case 4: return dnnl::memory::format_tag::abcd; + case 5: return dnnl::memory::format_tag::abcde; + case 6: return dnnl::memory::format_tag::abcdef; + default: MIGRAPHX_THROW("Unsupported tensor size: " + std::to_string(n)); + } +} + +dnnl::memory::desc to_dnnl_memory_desc(const shape& s) +{ + return {to_dnnl_dims(s.lens()), to_dnnl_memory_data_type(s.type()), to_dnnl_dims(s.strides())}; +} + +dnnl::memory to_dnnl_memory(const dnnl::memory::desc& desc, const argument& a) +{ + return {desc, get_dnnl_context().engine, a.data()}; +} + +dnnl::memory to_dnnl_memory(const argument& a) +{ + return to_dnnl_memory(to_dnnl_memory_desc(a.get_shape()), a); +} + +// clang-format off +#define MIGRAPHX_VISIT_DNNL_ALGO(m) \ + m(undef) \ + m(convolution_auto) \ + m(convolution_direct) \ + m(convolution_winograd) \ + m(deconvolution_direct) \ + m(deconvolution_winograd) \ + m(eltwise_relu) \ + m(eltwise_tanh) \ + m(eltwise_elu) \ + m(eltwise_square) \ + m(eltwise_abs) \ + m(eltwise_sqrt) \ + m(eltwise_swish) \ + m(eltwise_linear) \ + m(eltwise_bounded_relu) \ + m(eltwise_soft_relu) \ + m(eltwise_logistic) \ + m(eltwise_exp) \ + m(eltwise_gelu) \ + m(eltwise_gelu_tanh) \ + m(eltwise_gelu_erf) \ + m(eltwise_log) \ + m(eltwise_clip) \ + m(eltwise_pow) \ + m(eltwise_round) \ + m(eltwise_relu_use_dst_for_bwd) \ + m(eltwise_tanh_use_dst_for_bwd) \ + m(eltwise_elu_use_dst_for_bwd) \ + m(eltwise_sqrt_use_dst_for_bwd) \ + m(eltwise_logistic_use_dst_for_bwd) \ + m(eltwise_exp_use_dst_for_bwd) \ + m(lrn_across_channels) \ + m(lrn_within_channel) \ + m(pooling_max) \ + m(pooling_avg) \ + m(pooling_avg_include_padding) \ + m(pooling_avg_exclude_padding) \ + m(vanilla_rnn) \ + m(vanilla_lstm) \ + m(vanilla_gru) \ + m(lbr_gru) \ + m(binary_add) \ + m(binary_mul) \ + m(binary_max) \ + m(binary_min) \ + m(binary_div) \ + m(resampling_nearest) \ + m(resampling_linear) \ + m(reduction_max) \ + m(reduction_min) \ + m(reduction_sum) \ + m(reduction_mul) \ + m(reduction_mean) \ + m(reduction_norm_lp_max) \ + m(reduction_norm_lp_sum) \ + m(reduction_norm_lp_power_p_max) \ + m(reduction_norm_lp_power_p_sum) +// clang-format on + +const std::unordered_map& dnnl_algo_map() +{ + static const std::unordered_map m = { +#define MIGRAPHX_DNNL_ALGO_GENERATE_VISITOR(x) {#x, dnnl::algorithm::x}, + MIGRAPHX_VISIT_DNNL_ALGO(MIGRAPHX_DNNL_ALGO_GENERATE_VISITOR) +#undef MIGRAPHX_DNNL_ALGO_GENERATE_VISITOR + }; + return m; +} + +dnnl::algorithm to_dnnl_algo(const std::string& name) +{ + if(dnnl_algo_map().count(name) == 0) + MIGRAPHX_THROW("Missing dnnl algo: " + name); + return dnnl_algo_map().at(name); +} + +const std::unordered_map& dnnl_algo_string_map() +{ + static const std::unordered_map m = { +#define MIGRAPHX_DNNL_ALGO_GENERATE_VISITOR(x) {dnnl::algorithm::x, #x}, + MIGRAPHX_VISIT_DNNL_ALGO(MIGRAPHX_DNNL_ALGO_GENERATE_VISITOR) +#undef MIGRAPHX_DNNL_ALGO_GENERATE_VISITOR + }; + return m; +} + +std::string to_string(const dnnl::algorithm& algo) +{ + if(dnnl_algo_string_map().count(algo) == 0) + return "unknown_" + std::to_string(static_cast(algo)); + return dnnl_algo_string_map().at(algo); +} + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/eltwise.cpp b/src/targets/cpu/eltwise.cpp new file mode 100644 index 0000000000000000000000000000000000000000..48af8335cad20e57de3527b81d2b8a8cbd389216 --- /dev/null +++ b/src/targets/cpu/eltwise.cpp @@ -0,0 +1,50 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct dnnl_eltwise : dnnl_op +{ + std::string algo; + float alpha = 0; + float beta = 0; + template + static auto reflect(Self& self, F f) + { + return pack_join(self.reflect_base(self, f), + pack(f(self.algo, "algo"), f(self.alpha, "alpha"), f(self.beta, "beta"))); + } + + std::string group() const { return this->name() + "::" + algo; } + + std::string name() const { return "dnnl::eltwise"; } + + shape compute_shape(std::vector inputs) const + { + // Compensate for allocation + inputs.pop_back(); + check_shapes{this->trim_post_op_inputs(inputs), *this}.has(1).packed(); + auto s = inputs.at(0); + auto r = s; + if(not s.packed()) + r = shape{s.type(), s.lens()}; + // Call to get_primitive to make sure an algo is available + this->get_primitive(this->to_memory_desc(r, inputs)); + return r; + } + + dnnl::eltwise_forward::desc get_desc(const std::unordered_map& m) const + { + return {dnnl::prop_kind::forward_inference, + to_dnnl_algo(algo), + m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)), + alpha, + beta}; + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/erf.cpp b/src/targets/cpu/erf.cpp new file mode 100755 index 0000000000000000000000000000000000000000..5216bba4c649e8f22015322584e6c1538f801d54 --- /dev/null +++ b/src/targets/cpu/erf.cpp @@ -0,0 +1,13 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +template struct cpu_unary; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/fuse_ops.cpp b/src/targets/cpu/fuse_ops.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a54c5153fc539715cdf83a00f2f0f10abd25e4cb --- /dev/null +++ b/src/targets/cpu/fuse_ops.cpp @@ -0,0 +1,110 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_DNNL_POST_OPS_WORKAROUND); + +MIGRAPHX_PRED_MATCHER(has_post_ops, instruction_ref ins) +{ + auto v = ins->get_operator().to_value(); + return v.contains("post_ops"); +} + +MIGRAPHX_PRED_MATCHER(without_post_ops, instruction_ref ins) +{ + auto v = ins->get_operator().to_value(); + return v.contains("post_ops") and v["post_ops"].empty(); +} + +bool workaround_dnnl_broken_post_ops(const operation& op, const operation& post_op) +{ + if(contains({"dnnl::dot", "dnnl::convolution"}, op.name())) + return true; + auto pv = post_op.to_value(); + if(not pv.at("post_ops").empty()) + return true; + auto v = op.to_value(); + auto last_op = v.at("post_ops").empty() ? v : v.at("post_ops").back(); + auto algo = last_op.contains("algo") ? last_op.at("algo").to() : op.name(); + auto post_algo = pv["algo"].to(); + if(starts_with(algo, "eltwise") and starts_with(post_algo, "eltwise")) + return true; + if(algo == post_algo) + return true; + return false; +} + +operation merge_post_ops(const operation& op, const operation& post_op) +{ + auto pv = post_op.to_value(); + auto v = op.to_value(); + v["post_ops"].push_back({{"algo", pv["algo"]}, + {"alpha", pv["alpha"].value_or(0.0f)}, + {"beta", pv["beta"].value_or(0.0f)}}); + auto post_ops = pv.at("post_ops"); + for(const auto& po : post_ops) + v["post_ops"].push_back(po); + return make_op(op.name(), v); +} + +struct find_post_ops +{ + context* ctx = nullptr; + match::any_matcher matcher() const + { + if(enabled(MIGRAPHX_DISABLE_DNNL_POST_OPS_WORKAROUND{})) + return match::name("dnnl::eltwise", + "dnnl::binary")(match::arg(0)(has_post_ops(), match::used_once())); + else + return match::name("dnnl::eltwise")( + without_post_ops(), + match::arg(0)(match::name("dnnl::binary")(without_post_ops(), match::used_once()))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto x_ins = ins->inputs().front(); + auto x = x_ins->get_operator(); + + if(workaround_dnnl_broken_post_ops(x, ins->get_operator())) + return; + + auto op = merge_post_ops(x, ins->get_operator()); + auto inputs = x_ins->inputs(); + inputs.back() = ins->inputs().back(); + if(ins->name() == "dnnl::binary") + inputs.insert(std::prev(inputs.end()), ins->inputs().at(1)); + auto input_shapes = to_shapes(inputs); + auto new_shape = try_compute_shape(op, input_shapes); + if(new_shape.empty() or new_shape.front() != ins->get_shape()) + return; + auto info = compile(op, *ctx, new_shape.front(), input_shapes); + if(info.contains("impl") and starts_with(info.at("impl").to(), "ref:")) + return; + m.replace_instruction(ins, op, inputs); + } +}; + +void fuse_ops::apply(module& m) const +{ + for(std::size_t i = 0; i < 4; i++) + { + match::find_matches(m, find_post_ops{ctx}); + dead_code_elimination{}.apply(m); + } +} + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/gather.cpp b/src/targets/cpu/gather.cpp new file mode 100755 index 0000000000000000000000000000000000000000..94f8f85f97f066b36b98c5ab32370d5b4ae339f2 --- /dev/null +++ b/src/targets/cpu/gather.cpp @@ -0,0 +1,65 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct cpu_gather : auto_register_op +{ + op::gather op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + std::string name() const { return "cpu::" + op.name(); } + shape compute_shape(std::vector inputs) const + { + // Compensate for allocation + inputs.pop_back(); + check_shapes(inputs, *this).standard(); + return migraphx::compute_shape(op, inputs); + } + + argument + compute(context& ctx, const shape& output_shape, const std::vector& args) const + { + std::size_t nelements = output_shape.elements(); + auto lens = args[0].get_shape().lens(); + auto axis_dim_size = lens[op.axis]; + lens[op.axis] = args[1].get_shape().elements(); + shape out_comp{output_shape.type(), lens}; + + visit_all(args.back(), args[0])([&](auto output, auto input) { + args[1].visit([&](auto indices) { + const auto* indices_ptr = indices.data(); + auto* output_ptr = output.data(); + ctx.bulk_execute(nelements, 1024, [=](auto start, auto end) { + for(auto i = start; i < end; i++) + { + auto idx = out_comp.multi(i); + auto in_index = indices_ptr[idx[op.axis]]; + in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; + idx[op.axis] = in_index; + output_ptr[i] = input(idx.begin(), idx.end()); + } + }); + }); + }); + + return args.back(); + } + + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/gemm.cpp b/src/targets/cpu/gemm.cpp index 82f1aa42227557bf507591ec702dfd089bcb5947..97475ffb894502716b0b63890653d41986b1b353 100644 --- a/src/targets/cpu/gemm.cpp +++ b/src/targets/cpu/gemm.cpp @@ -1,131 +1,34 @@ -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace cpu { -template -using matrix = blaze::CustomMatrix; // NOLINT - -template -static auto make_mat(tensor_view x) -{ - const auto& s = x.get_shape(); - // assert(s.lens().size() == 2); - std::size_t n_dims = s.lens().size(); - std::size_t dim_0 = n_dims - 2; - std::size_t dim_1 = n_dims - 1; - if(s.transposed()) - return matrix{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]}; - return matrix{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]}; -} - -template -static void visit_mat(tensor_view x, F f) -{ - auto mat = make_mat(x); - if(x.get_shape().transposed()) - f(blaze::trans(mat)); - else - f(mat); -} - -template -struct is_fast_gemm_type : std::false_type -{ -}; - -template <> -struct is_fast_gemm_type : std::true_type -{ -}; - -template -void migemm_impl( - tensor_view cmat, tensor_view amat, tensor_view bmat, F alpha, F beta, std::true_type) -{ - visit_mat(amat, [&](const auto& a) { - visit_mat(bmat, [&](const auto& b) { - auto c = make_mat(cmat); - c = beta * c; - // This is a simple optimization to avoid - // compute A * B if alpha is 0.0 - if(alpha != 0.0) - { - c = c + alpha * a * b; - } - }); - }); -} - -template -void migemm_impl( - tensor_view cmat, tensor_view amat, tensor_view bmat, F alpha, F beta, std::false_type) +struct dnnl_gemm : dnnl_extend_op { - std::size_t n_dims = cmat.get_shape().lens().size(); - std::size_t dim_0 = n_dims - 2; - std::size_t dim_1 = n_dims - 1; - auto k = amat.get_shape().lens()[dim_1]; - - assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]); - assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]); - assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]); - - shape_for_each(cmat.get_shape(), [&](const auto& c_idx) { - auto a_idx = c_idx; - auto b_idx = c_idx; - double s = 0.0; - dfor(k)([&](auto kk) { - a_idx[dim_1] = b_idx[dim_0] = kk; - s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end()); - }); - cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta; - }); -} - -template -void migemm_impl(tensor_view cmat, tensor_view amat, tensor_view bmat, F alpha, F beta) -{ - auto lens = amat.get_shape().lens(); - bool batch_mul = - std::accumulate( - lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies()) == 1; - if(batch_mul) + std::vector arg_map(int) const { - migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type{}); + return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), + MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS), + MIGRAPHX_DNNL_PREFIX(ARG_BIAS)}; } - else - { - migemm_impl(cmat, amat, bmat, alpha, beta, std::false_type{}); - } -} -template -void migemm_tpl( - const argument& c_arg, const argument& a_arg, const argument& b_arg, F alpha, F beta) -{ - visit_all(c_arg, a_arg, b_arg)( - [&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); }); -} + void required(const check_shapes& cs) const { cs.not_broadcasted(); } -void migemm( - const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta) -{ - migemm_tpl(c_arg, a_arg, b_arg, alpha, beta); -} - -void migemm(const argument& c_arg, - const argument& a_arg, - const argument& b_arg, - int32_t alpha, - int32_t beta) -{ - migemm_tpl(c_arg, a_arg, b_arg, alpha, beta); -} + dnnl::matmul::desc get_desc(const std::unordered_map& m) const + { + return {m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)), + m.at(MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)), + m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST))}; + } +}; } // namespace cpu } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/cpu/include/migraphx/cpu/allocation_model.hpp b/src/targets/cpu/include/migraphx/cpu/allocation_model.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4a302d002e958b602165ccddd466d9732bc77e73 --- /dev/null +++ b/src/targets/cpu/include/migraphx/cpu/allocation_model.hpp @@ -0,0 +1,26 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_ALLOCATION_MODEL_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_ALLOCATION_MODEL_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct cpu_allocation_model +{ + std::string name() const; + std::string copy() const; + operation allocate(const shape& s) const; + operation preallocate(const shape& s, const std::string& id) const; + bool needs_out_params() const { return false; } +}; + +} // namespace cpu + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/cpu/include/migraphx/cpu/context.hpp b/src/targets/cpu/include/migraphx/cpu/context.hpp old mode 100644 new mode 100755 index c2a83db08e74dc0b7fca04745d48ff6fa5703cb1..2df1b62e98ddcc94c0e51dfb3bec86873d329336 --- a/src/targets/cpu/include/migraphx/cpu/context.hpp +++ b/src/targets/cpu/include/migraphx/cpu/context.hpp @@ -2,6 +2,9 @@ #define MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP #include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -10,6 +13,18 @@ namespace cpu { struct context { void finish() const {} + + template + void bulk_execute(std::size_t n, std::size_t min_grain, F f) + { + cpu::parallel_for(n, min_grain, f); + } + + template + void bulk_execute(std::size_t n, F f) + { + this->bulk_execute(n, 256, f); + } }; } // namespace cpu diff --git a/src/targets/cpu/include/migraphx/cpu/dnnl.hpp b/src/targets/cpu/include/migraphx/cpu/dnnl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b14770741656a27abc93277a906da482d7b6bdb3 --- /dev/null +++ b/src/targets/cpu/include/migraphx/cpu/dnnl.hpp @@ -0,0 +1,396 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_DNNL_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_DNNL_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef MIGRAPHX_ENABLE_ZENDNN +#include +#else +#include +#endif + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +#ifdef MIGRAPHX_ENABLE_ZENDNN +namespace dnnl = zendnn; +#define MIGRAPHX_CONCAT_PREFIX(b) ZENDNN_##b // NOLINT +#else +#define MIGRAPHX_CONCAT_PREFIX(b) DNNL_##b // NOLINT +#endif +#define MIGRAPHX_DNNL_PREFIX(b) MIGRAPHX_CONCAT_PREFIX(b) // NOLINT + +struct dnnl_context +{ + dnnl::engine engine; + dnnl::stream stream; + dnnl_context() : engine(dnnl::engine::kind::cpu, 0), stream(engine) {} +}; + +dnnl_context& get_dnnl_context(); + +dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t); + +dnnl::memory::format_tag to_dnnl_memory_format_tag(std::size_t n); + +template +inline dnnl::memory::dims to_dnnl_dims(R&& r) +{ + return {r.begin(), r.end()}; +} + +dnnl::memory::desc to_dnnl_memory_desc(const shape& s); + +dnnl::memory to_dnnl_memory(const dnnl::memory::desc& desc, const argument& a); + +dnnl::memory to_dnnl_memory(const argument& a); + +dnnl::algorithm to_dnnl_algo(const std::string& name); + +std::string to_string(const dnnl::algorithm& algo); + +struct post_op : reflect_equality, reflect_stream +{ + std::string algo; + float alpha = 0; + float beta = 0; + template + static auto reflect(Self& self, F f) + { + return pack(f(self.algo, "algo"), f(self.alpha, "alpha"), f(self.beta, "beta")); + } +}; + +template +struct dnnl_op : auto_register_op +{ + std::vector post_ops; + std::function& args)> execute; + + template + static auto reflect_base(Self& self, F f) + { + return pack(f(self.post_ops, "post_ops")); + } + + template + static auto reflect(Self& self, F f) + { + return reflect_base(self, f); + } + + std::string group() const + { + const auto& self = static_cast(*this); + return self.name(); + } + + value attributes() const + { + std::vector names; + std::transform(post_ops.begin(), post_ops.end(), std::back_inserter(names), [](auto&& op) { + return op.algo; + }); + const auto& self = static_cast(*this); + auto g = self.group(); + if(not names.empty()) + g += "<" + join_strings(names, ",") + ">"; + return {{"group", g}}; + } + + std::size_t get_extra_post_op_args() const + { + return std::count_if(post_ops.begin(), post_ops.end(), [](const auto& po) { + return contains(po.algo, "binary"); + }); + } + + static std::size_t get_binary_post_op_arg(std::size_t pos) + { + return MIGRAPHX_DNNL_PREFIX(ARG_ATTR_MULTIPLE_POST_OP)(pos) | // NOLINT + MIGRAPHX_DNNL_PREFIX(ARG_SRC_1); // NOLINT + } + + static std::vector to_shapes(const std::vector& args) + { + std::vector shapes(args.size()); + std::transform(args.begin(), args.end(), shapes.begin(), [](const argument& a) { + return a.get_shape(); + }); + return shapes; + } + static std::string impl(const Primitive& prim) + { + auto desc = prim.get_primitive_desc(); + const char* str = nullptr; +#ifdef MIGRAPHX_ENABLE_ZENDNN + zendnn_primitive_desc_query(desc, zendnn_query_impl_info_str, 0, &str); +#else + dnnl_primitive_desc_query(desc, dnnl_query_impl_info_str, 0, &str); +#endif + return str == nullptr ? "" : str; + } + // Map arg index to arg in dnnl + std::vector arg_map(int size) const + { + std::vector result(size); + std::iota(result.begin(), result.end(), MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)); + return result; + } + shape base_adjust_shape(const shape& s) const + { + if(s.broadcasted()) + { + auto lens = s.lens(); + auto strides = s.strides(); + std::transform(strides.begin(), + strides.end(), + lens.begin(), + lens.begin(), + [](auto stride, auto len) -> std::size_t { + if(stride == 0) + return 1; + else + return len; + }); + return shape{s.type(), lens}; + } + return s; + } + template + void for_each_post_op(F f) const + { + int i = 0; + for(auto&& op : post_ops) + { + if(contains(op.algo, "binary")) + { + f(op, get_binary_post_op_arg(i)); + } + else + { + f(op, -1); + } + i++; + } + } + shape adjust_shape(const shape& s, int) const { return base_adjust_shape(s); } + std::vector create_arg_map(std::size_t input_size) const + { + const auto& self = static_cast(*this); + auto npost_ops = get_extra_post_op_args(); + auto prim_input_size = input_size - npost_ops; + auto m = self.arg_map(prim_input_size); + for_each_post_op([&](auto&&, auto arg) { + if(arg < 0) + return; + m.push_back(arg); + }); + return m; + } + std::unordered_map + to_memory_desc(const shape& output_shape, const std::vector& inputs) const + { + const auto& self = static_cast(*this); + std::unordered_map result; + result[MIGRAPHX_DNNL_PREFIX(ARG_DST)] = + to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size())); + auto m = create_arg_map(inputs.size()); + assert(m.size() >= inputs.size()); + for(int i = 0; i < inputs.size(); i++) + { + result[m[i]] = to_dnnl_memory_desc(self.adjust_shape(inputs[i], i)); + } + return result; + } + dnnl::primitive_attr + get_primitive_attr(const std::unordered_map& m) const + { + dnnl::primitive_attr result; + dnnl::post_ops po; + for_each_post_op([&](auto&& op, auto arg) { + if(contains(op.algo, "binary_add")) + { + auto desc = m.at(arg); + if(desc == m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST))) + po.append_sum(1.0f); + else + po.append_binary(to_dnnl_algo(op.algo), m.at(arg)); + } + else if(contains(op.algo, "binary")) + { + po.append_binary(to_dnnl_algo(op.algo), m.at(arg)); + } + else if(contains(op.algo, "eltwise")) + po.append_eltwise(1.0f, to_dnnl_algo(op.algo), op.alpha, op.beta); + else + MIGRAPHX_THROW("Unknown post op algo: " + op.algo); + }); + result.set_post_ops(po); + return result; + } + template + auto get_primitive_desc(const T& desc, const dnnl::primitive_attr& attr) const + -> decltype(typename Primitive::primitive_desc(desc, attr, get_dnnl_context().engine)) + { + return typename Primitive::primitive_desc(desc, attr, get_dnnl_context().engine); + } + Primitive get_primitive(const std::unordered_map& m) const + { + const auto& self = static_cast(*this); + auto desc = self.get_desc(m); + auto attr = MIGRAPHX_ASSERT_NO_THROW(this->get_primitive_attr(m)); + auto pd = self.get_primitive_desc(desc, attr); + return Primitive(pd); + } + argument compute(context& ctx, const shape&, const std::vector& args) const + { + return execute(ctx, args); + } + + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } + value compile(context&, const shape& output_shape, std::vector inputs) + { + // Compensate for allocation + inputs.pop_back(); + auto md = to_memory_desc(output_shape, inputs); + auto prim = get_primitive(md); + auto impl_name = impl(prim); + return {{"impl", impl_name}}; + } + + void finalize(context&, const shape& output_shape, std::vector inputs) + { + // Compensate for allocation + inputs.pop_back(); + const auto& self = static_cast(*this); + auto name = self.name(); + auto md = to_memory_desc(output_shape, inputs); + auto prim = get_primitive(md); + auto arg_lookup = create_arg_map(inputs.size()); +#ifndef NDEBUG + auto prim_attr = get_primitive_attr(md); +#endif + execute = [=](context&, const std::vector& args) { +#ifndef NDEBUG + // Check that the memory descriptors have not changed + auto debug_args = args; + debug_args.pop_back(); + auto debug_md = to_memory_desc(output_shape, to_shapes(debug_args)); + for(auto&& p : debug_md) + { + if(md.count(p.first) == 0) + MIGRAPHX_THROW(name + + ": Missing memory descriptor for: " + std::to_string(p.first)); + if(p.second == md.at(p.first)) + continue; + MIGRAPHX_THROW(name + + ": Memory descriptor has changed for: " + std::to_string(p.first)); + } + // Check post_ops args are correct + auto pos = prim_attr.get_post_ops(); + auto prim_input_size = inputs.size() - this->get_extra_post_op_args(); + int j = 0; + for(int i = 0; i < pos.len(); i++) + { + auto arg = j + prim_input_size; + auto kind = pos.kind(i); + std::string mesg = + "Post op " + std::to_string(i) + "@" + std::to_string(arg) + ": "; + try + { + dnnl::algorithm algo; + dnnl::memory::desc mdesc; + float scale = 0; + float alpha = 0; + float beta = 0; + if(kind == dnnl::primitive::kind::binary) + { + pos.get_params_binary(i, algo, mdesc); + if(mdesc != md.at(arg_lookup.at(arg))) + MIGRAPHX_THROW(mesg + + "Memory descriptor doesn't match for binary post op"); + j++; + } + else if(kind == dnnl::primitive::kind::eltwise) + { + pos.get_params_eltwise(i, scale, algo, alpha, beta); + } + else if(kind == dnnl::primitive::kind::sum) + { + pos.get_params_sum(i, scale); + algo = dnnl::algorithm::binary_add; + } + else + { + MIGRAPHX_THROW("Unknown kind"); + } + if(to_dnnl_algo(post_ops[i].algo) != algo) + MIGRAPHX_THROW(mesg + "Algorithm doesn't match for post op " + + post_ops[i].algo + " != " + to_string(algo)); + } + catch(const dnnl::error& e) + { + MIGRAPHX_THROW(mesg + "Failed to get post ops argument " + ": " + e.what()); + } + } +#endif + std::unordered_map m; + m[MIGRAPHX_DNNL_PREFIX(ARG_DST)] = + to_dnnl_memory(md.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), args.back()); + for(int i = 0; i < args.size() - 1; i++) + m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]); + prim.execute(get_dnnl_context().stream, m); + return args.back(); + }; + } + std::vector trim_post_op_inputs(const std::vector& inputs) const + { + auto prim_input_size = inputs.size() - this->get_extra_post_op_args(); + return {inputs.begin(), inputs.begin() + prim_input_size}; + } +}; + +template +struct dnnl_extend_op : dnnl_op +{ + Op op; + + template + static auto reflect(Self& self, F f) + { + return pack_join(self.reflect_base(self, f), migraphx::reflect(self.op, f)); + } + + // dnnl has some issues with non-packed inputs + void required(const check_shapes& cs) const { cs.packed_or_broadcasted(); } + + std::string name() const { return "dnnl::" + op.name(); } + shape compute_shape(std::vector inputs) const + { + const auto& self = static_cast(*this); + // Compensate for allocation + inputs.pop_back(); + self.required(check_shapes(inputs, self)); + auto r = migraphx::compute_shape(op, this->trim_post_op_inputs(inputs)); + // Call to get_primitive to make sure an algo is available + this->get_primitive(this->to_memory_desc(r, inputs)); + return r; + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/cpu/include/migraphx/cpu/fuse_ops.hpp b/src/targets/cpu/include/migraphx/cpu/fuse_ops.hpp new file mode 100755 index 0000000000000000000000000000000000000000..80c5e0bec4960764ec0eda64b27fd664fe4fd54d --- /dev/null +++ b/src/targets/cpu/include/migraphx/cpu/fuse_ops.hpp @@ -0,0 +1,26 @@ +#ifndef MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP +#define MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +namespace cpu { + +struct context; + +struct fuse_ops +{ + context* ctx = nullptr; + std::string name() const { return "cpu::fuse_ops"; } + void apply(module& m) const; +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP diff --git a/src/targets/cpu/include/migraphx/cpu/lowering.hpp b/src/targets/cpu/include/migraphx/cpu/lowering.hpp index b02535f17168ff258e0d3223023c41eb93e398d6..fbccd4c15b9a1c8c8fbe7e511594365964439e3f 100644 --- a/src/targets/cpu/include/migraphx/cpu/lowering.hpp +++ b/src/targets/cpu/include/migraphx/cpu/lowering.hpp @@ -1,17 +1,20 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP #define MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP -#include +#include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { + +struct module; + namespace cpu { struct lowering { std::string name() const { return "cpu::lowering"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace cpu diff --git a/src/targets/cpu/include/migraphx/cpu/parallel.hpp b/src/targets/cpu/include/migraphx/cpu/parallel.hpp new file mode 100755 index 0000000000000000000000000000000000000000..2cf3a89ca5d6b320130d1825631754554cd098ca --- /dev/null +++ b/src/targets/cpu/include/migraphx/cpu/parallel.hpp @@ -0,0 +1,98 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_PARALLEL_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_PARALLEL_HPP + +// #define MIGRAPHX_DISABLE_OMP + +#include +#ifdef MIGRAPHX_DISABLE_OMP +#include +#else + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wreserved-identifier" +#endif +#include +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#endif + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +#ifdef MIGRAPHX_DISABLE_OMP + +inline std::size_t max_threads() { return std::thread::hardware_concurrency(); } + +template +void parallel_for_impl(std::size_t n, std::size_t threadsize, F f) +{ + if(threadsize <= 1) + { + f(std::size_t{0}, n); + } + else + { + std::vector threads(threadsize); +// Using const here causes gcc 5 to ICE +#if(!defined(__GNUC__) || __GNUC__ != 5) + const +#endif + std::size_t grainsize = std::ceil(static_cast(n) / threads.size()); + + std::size_t work = 0; + std::generate(threads.begin(), threads.end(), [=, &work] { + auto result = + joinable_thread([=]() mutable { f(work, std::min(n, work + grainsize)); }); + work += grainsize; + return result; + }); + // cppcheck-suppress unsignedLessThanZero + assert(work >= n); + } +} +#else + +inline std::size_t max_threads() { return omp_get_max_threads(); } + +template +void parallel_for_impl(std::size_t n, std::size_t threadsize, F f) +{ + if(threadsize <= 1) + { + f(std::size_t{0}, n); + } + else + { + std::size_t grainsize = std::ceil(static_cast(n) / threadsize); +#pragma omp parallel for num_threads(threadsize) schedule(static, 1) private(grainsize, n) + for(std::size_t tid = 0; tid < threadsize; tid++) + { + std::size_t work = tid * grainsize; + f(work, std::min(n, work + grainsize)); + } + } +} +#endif +template +void parallel_for(std::size_t n, std::size_t min_grain, F f) +{ + const auto threadsize = std::min(max_threads(), n / min_grain); + parallel_for_impl(n, threadsize, f); +} + +template +void parallel_for(std::size_t n, F f) +{ + const int min_grain = 8; + parallel_for(n, min_grain, f); +} + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/cpu/include/migraphx/cpu/pointwise.hpp b/src/targets/cpu/include/migraphx/cpu/pointwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..050605f88d99b34070125fc793f037c78a138c7a --- /dev/null +++ b/src/targets/cpu/include/migraphx/cpu/pointwise.hpp @@ -0,0 +1,390 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct multi_index +{ + constexpr multi_index() = default; + + multi_index(const shape& s, std::size_t i) : n(s.lens().size()) + { + assert(n < max_size); + std::copy(s.lens().begin(), s.lens().end(), dims); + s.multi_copy(i, index, index + max_size); + } + + constexpr std::size_t size() const { return n; } + + constexpr std::size_t* begin() { return index; } + constexpr const std::size_t* begin() const { return index; } + + constexpr std::size_t* end() { return index + size(); } + constexpr const std::size_t* end() const { return index + size(); } + + std::size_t offset(const shape& s) const { return s.index(begin(), end()); } + + constexpr void carry() + { + std::size_t overflow = 0; + for(std::ptrdiff_t i = size() - 1; i > 0; i--) + { + auto z = index[i] + overflow; + // Reset overflow + overflow = 0; + // Compute overflow using while loop instead of mod + // overflow = z / dims[i]; + // z = z % dims[i]; + while(z >= dims[i]) + { + z -= dims[i]; + overflow += 1; + } + index[i] = z; + // Exit if there is no overflow + if(overflow == 0) + return; + } + index[0] += overflow; + } + + constexpr void increment(std::size_t i) + { + index[size() - 1] += i; + carry(); + } + + constexpr multi_index& operator+=(std::size_t i) + { + increment(i); + return *this; + } + + constexpr multi_index& operator++() + { + increment(1); + return *this; + } + multi_index operator++(int) // NOLINT + { + multi_index result = *this; + increment(1); + return result; + } + + private: + static const std::size_t max_size = 5; + std::size_t index[max_size] = {}; + std::size_t dims[max_size] = {}; + std::size_t n = 0; +}; + +struct reduce_dims_base +{ + std::vector reduce_shapes; + + void finalize(context&, const shape&, const std::vector& inputs) + { + reduce_shapes = reduce_dims(inputs); + } + + argument get_arg(const std::vector& args, std::size_t i) const + { + if(reduce_shapes.empty()) + return args[i]; + return args.at(i).reshape(reduce_shapes.at(i)); + } + + argument get_output() const + { + argument a{reduce_shapes[0]}; + return a; + } +}; + +template +struct vec +{ + using array_type = std::array; + using vector_type __attribute__((vector_size(N * sizeof(T)))) = T; + union + { + array_type array; + vector_type vector; + }; + + static_assert(sizeof(array_type) == sizeof(vector_type), "Not the same size"); +}; + +template +constexpr std::integral_constant vec_size(const T&) +{ + return {}; +} + +template +constexpr std::integral_constant vec_size(const vec&) +{ + return {}; +} + +template +constexpr std::size_t vec_size() +{ + return decltype(vec_size(std::declval())){}; +} + +template () > 0))> +void vec_apply(F f, V& v, Vs... vs) +{ + assert(all_of({vec_size()...}, [&](auto n) { return n == vec_size(); })); + assert(vec_size() == v.array.size()); + for(std::size_t i = 0; i < vec_size(); i++) + f(v.array[i], vs.vector[i]...); +} + +template () == 0))> +void vec_apply(F f, V& v, Vs&... vs) +{ + f(v, vs...); +} + +inline std::size_t find_packed_len(const shape& s) +{ + for(std::size_t i = 0; i < s.lens().size(); i++) + { + if(s.lens()[i] > 1 and s.strides()[i] == 1) + { + return i; + } + } + return -1; +} + +template +shape vectorize(const shape& s) +{ + assert(s.standard() or s.broadcasted()); + auto lens = s.lens(); + if(s.broadcasted()) + { + auto n = find_packed_len(s); + assert(n != -1); + assert((lens[n] % N) == 0); + lens[n] /= N; + return {s.type(), lens, s.strides()}; + } + assert((lens.back() % N) == 0); + lens.back() /= N; + return {s.type(), lens}; +} + +template +tensor_view> vectorize(tensor_view tv) +{ + return {vectorize(tv.get_shape()), reinterpret_cast*>(tv.data())}; +} + +template +struct is_vector_type : std::false_type +{ +}; + +template <> +struct is_vector_type : std::true_type +{ +}; + +template +struct is_vector_tensor_view : and_{}...> +{ +}; + +template +bool is_vectorizable(const Xs&... xs) +{ + return all_of({xs...}, [](const auto& s) { + if(s.standard() and (s.lens().back() % N) == 0) + return true; + if(s.broadcasted()) + { + auto n = std::inner_product(s.lens().begin(), + s.lens().end(), + s.strides().begin(), + 0, + std::plus<>{}, + [&](auto len, auto stride) -> std::size_t { + if(stride > 0 and len == 1) + return 0; + return stride; + }); + if(n == 1) + { + auto i = find_packed_len(s); + assert(i != -1); + return (s.lens()[i] % N) == 0; + } + } + return false; + }); +} + +template {})> +auto auto_vectorize(const shape& base_shape, Ts... xs) +{ + return [=](auto f) { + if(is_vectorizable<32>(base_shape, xs.get_shape()...)) + f(vectorize<32>(base_shape), vectorize<32>(xs)...); + else if(is_vectorizable<8>(base_shape, xs.get_shape()...)) + f(vectorize<8>(base_shape), vectorize<8>(xs)...); + else + f(base_shape, xs...); + }; +} + +template {})> +auto auto_vectorize(const shape& base_shape, Ts... xs) +{ + return [=](auto f) { f(base_shape, xs...); }; +} + +template +bool is_standard_offset(const X& x, const Xs&... xs) +{ + if(all_of({x, xs...}, [](const auto& s) { return s.standard(); })) + return true; + if(all_of({x, xs...}, [](const auto& s) { return s.packed(); }) and + all_of({xs...}, [&](const auto& s) { return s == x; })) + return true; + return false; +} + +template +auto pointwise_apply(Ts... ts) +{ + return [=](context& ctx, const shape& base_shape, std::size_t min_grain, auto f) mutable { + if(is_standard_offset(ts.get_shape()...)) + { + ctx.bulk_execute(base_shape.elements(), min_grain, [=](auto start, auto end) mutable { + for(auto i = start; i < end; i++) + { + vec_apply(f, ts.data()[i]...); + } + }); + } + else + { + assert(base_shape.lens().size() <= 6); + ctx.bulk_execute(base_shape.elements(), min_grain, [=](auto start, auto end) mutable { + multi_index mi(base_shape, start); + for(auto i = start; i < end; i++) + { + vec_apply(f, ts.data()[mi.offset(ts.get_shape())]...); + ++mi; + } + }); + } + }; +} + +template +auto pointwise(Ts... ts) +{ + return [=](context& ctx, const shape& base_shape, std::size_t min_grain, auto f) mutable { + auto_vectorize(base_shape, ts...)( + [&](auto bs, auto... xs) { pointwise_apply(xs...)(ctx, bs, min_grain, f); }); + }; +} + +template +struct cpu_unary : reduce_dims_base, auto_register_op> +{ + Op op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + std::string name() const { return "cpu::" + op.name(); } + shape compute_shape(const std::vector& inputs) const + { + check_shapes{inputs, *this}.has(2); + const auto& s = inputs.at(0); + return {s.type(), s.lens()}; + } + argument + compute(context& ctx, const shape& output_shape, const std::vector& args) const + { + argument result = get_arg(args, args.size() - 1); + + visit_all(result, get_arg(args, 0))([&](auto output, auto input) { + auto op2 = op; + pointwise(output, input)( + ctx, output.get_shape(), 1024, [op2](auto& y, auto x) { y = op2.apply()(x); }); + }); + + return result.reshape(output_shape); + } + + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +template +struct cpu_binary : reduce_dims_base, auto_register_op> +{ + Op op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + std::string name() const { return "cpu::" + op.name(); } + shape compute_shape(const std::vector& inputs) const + { + check_shapes{inputs, *this}.has(3); + const auto& s = inputs.at(0); + return {s.type(), s.lens()}; + } + + argument + compute(context& ctx, const shape& output_shape, const std::vector& args) const + { + argument result = get_arg(args, args.size() - 1); + + visit_all(result, get_arg(args, 0), get_arg(args, 1))( + [&](auto output, auto input1, auto input2) { + auto op2 = op; + pointwise(output, input1, input2)( + ctx, output.get_shape(), 1024, [op2](auto& z, auto x, auto y) { + z = op2.apply()(x, y); + }); + }); + + return result.reshape(output_shape); + } + + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/cpu/include/migraphx/cpu/target.hpp b/src/targets/cpu/include/migraphx/cpu/target.hpp old mode 100644 new mode 100755 index 32210e879a32c60831d76abdb814144397fe49a9..9d2655d29b1a2de5324d0042cf165cef98438587 --- a/src/targets/cpu/include/migraphx/cpu/target.hpp +++ b/src/targets/cpu/include/migraphx/cpu/target.hpp @@ -2,6 +2,7 @@ #define MIGRAPHX_GUARD_MIGRAPHLIB_CPU_TARGET_HPP #include +#include #include #include #include @@ -14,7 +15,7 @@ namespace cpu { struct target { std::string name() const; - std::vector get_passes(migraphx::context& ctx, const compile_options&) const; + std::vector get_passes(migraphx::context& gctx, const compile_options&) const; migraphx::context get_context() const { return context{}; } argument copy_to(const argument& arg) const { return arg; } @@ -22,6 +23,8 @@ struct target argument allocate(const shape& s) const; }; +MIGRAPHX_REGISTER_TARGET(target); + } // namespace cpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/cpu/include/migraphx/cpu/write_literals.hpp b/src/targets/cpu/include/migraphx/cpu/write_literals.hpp new file mode 100755 index 0000000000000000000000000000000000000000..0e4e9ce3086979628c4071887e006828a0849486 --- /dev/null +++ b/src/targets/cpu/include/migraphx/cpu/write_literals.hpp @@ -0,0 +1,22 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_WRITE_LITERALS_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_WRITE_LITERALS_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +struct module; +namespace cpu { + +struct write_literals +{ + std::string name() const { return "cpu::write_literals"; } + void apply(module& m) const; +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/cpu/layernorm.cpp b/src/targets/cpu/layernorm.cpp new file mode 100755 index 0000000000000000000000000000000000000000..572a97e8b1219ea7a304fcc2cbd42de2cd97b4ce --- /dev/null +++ b/src/targets/cpu/layernorm.cpp @@ -0,0 +1,42 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct dnnl_layernorm : dnnl_op +{ + float epsilon = 1e-12f; + template + static auto reflect(Self& self, F f) + { + return pack(f(self.epsilon, "epsilon")); + } + + std::string name() const { return "dnnl::layernorm"; } + + shape compute_shape(std::vector inputs) const + { + // Compensate for allocation + inputs.pop_back(); + check_shapes{this->trim_post_op_inputs(inputs), *this}.has(1); + auto s = inputs.at(0); + // Call to get_primitive to make sure an algo is available + this->get_primitive(this->to_memory_desc(s, inputs)); + return s; + } + + dnnl::layer_normalization_forward::desc + get_desc(const std::unordered_map& m) const + { + return {dnnl::prop_kind::forward_inference, + m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)), + 1e-12f, + dnnl::normalization_flags::none}; + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/logsoftmax.cpp b/src/targets/cpu/logsoftmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5c5af02b4d315173baf43046fce1d4464b27a559 --- /dev/null +++ b/src/targets/cpu/logsoftmax.cpp @@ -0,0 +1,21 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct dnnl_logsoftmax : dnnl_extend_op +{ + dnnl::logsoftmax_forward::desc + get_desc(const std::unordered_map& m) const + { + int axis = this->op.axis; + return {dnnl::prop_kind::forward_inference, m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)), axis}; + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/lowering.cpp b/src/targets/cpu/lowering.cpp old mode 100644 new mode 100755 index c359350610341f882f217b47c85b2c9e4c317221..6fd2d28b7ac70dc832b6d3f8d5a66c8978d59e5d --- a/src/targets/cpu/lowering.cpp +++ b/src/targets/cpu/lowering.cpp @@ -2,8 +2,10 @@ #include #include #include -#include +#include +#include #include +#include #include #include #include @@ -17,12 +19,23 @@ #include #include #include +#include #include #include #include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -42,184 +55,6 @@ typename std::conditional_t{}, std::make_signed, std::ena return x; } -// -// cpu implemenataion of batch norm for inference -// -// inputs are: -// args[0] -> input data buffer -// args[1] -> mini batch mean -// args[2] -> mini batch variance -// args[3] -> gamma -// args[4] -> bias -// -// The equation to compute batch norm for inference is: -// -// output[i] = bias + gamma * (input[i] + mean) / sqrt(variance + epsilon) -// -// the input data format should be nchw -// -struct cpu_batch_norm_inference -{ - op::batch_norm_inference op; - - template - static auto reflect(Self& self, F f) - { - return migraphx::reflect(self.op, f); - } - - std::string name() const { return "cpu::batch_norm_inference"; } - - shape compute_shape(const std::vector& inputs) const { return op.compute_shape(inputs); } - - argument compute(context&, const shape& output_shape, std::vector args) const - { - argument output{output_shape}; - - double epsilon = op.epsilon; - auto input = args[0]; - auto arg_gamma = args[1]; - auto arg_bias = args[2]; - auto mini_batch_mean = args[3]; - auto mini_batch_variance = args[4]; - - auto num_batch = output_shape.lens()[0]; - auto num_channels = output_shape.lens()[1]; - auto image_height = output_shape.lens()[2]; - auto image_width = output_shape.lens()[3]; - - if(op.bn_mode == op::batch_norm_inference::spatial) - { - visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)( - [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) { - - par_dfor(num_batch, num_channels, image_height, image_width)( - [&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) { - assert((variance[c] + epsilon) > 0); - result(n, c, h, w) = gamma[c] * (buffer(n, c, h, w) - mean[c]) / - std::sqrt(variance[c] + epsilon) + - bias[c]; - }); - }); - } - - if(op.bn_mode == op::batch_norm_inference::per_activation) - { - visit_all(output, input, mini_batch_mean, mini_batch_mean, arg_gamma, arg_bias)( - [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) { - - par_dfor(num_batch, num_channels, image_height, image_width)( - [&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) { - assert((variance(c, h, w) + epsilon) > 0); - result(n, c, h, w) = gamma(c, h, w) * - (buffer(n, c, h, w) - mean(c, h, w)) / - std::sqrt(variance(c, h, w) + epsilon) + - bias(c, h, w); - }); - }); - } - - return output; - } -}; - -struct cpu_lrn -{ - op::lrn op; - - template - static auto reflect(Self& self, F f) - { - return migraphx::reflect(self.op, f); - } - - std::string name() const { return "cpu::lrn"; } - shape compute_shape(const std::vector& inputs) const { return op.compute_shape(inputs); } - argument compute(context&, shape output_shape, std::vector args) const - { - argument result{output_shape}; - visit_all(result, args[0])([&](auto output, auto input) { - int n_batch = output_shape.lens()[0]; - int channels = output_shape.lens()[1]; - int height = output_shape.lens()[2]; - int width = output_shape.lens()[3]; - float alphaoverarea = op.alpha / float(op.size); - int radius = (op.size - 1) / 2; - - par_dfor(n_batch, height, width)([&](int b, int h, int w) { - float scale = 0; - dfor(channels)([&](int c) { - auto start = (c - radius) < 0 ? 0 : (c - radius); - auto end = (c + radius) > channels ? channels : (c + radius); - for(auto k = start; k < end; ++k) - { - scale += std::pow(input(b, k, h, w), 2); - } - scale *= alphaoverarea; - scale += op.bias; - scale = std::pow(scale, -op.beta); - output(b, c, h, w) = input(b, c, h, w) * scale; - }); - }); - }); - return result; - } -}; - -template -struct cpu_convolution -{ - Op op; - - template - static auto reflect(Self& self, F f) - { - return migraphx::reflect(self.op, f); - } - - std::string name() const { return "cpu::" + op.name(); } - shape compute_shape(const std::vector& inputs) const { return op.compute_shape(inputs); } - argument compute(context&, shape output_shape, std::vector args) const - { - argument result{output_shape}; - result.visit([&](auto output) { - using type = typename decltype(output)::value_type; - visit_all(args[0], args[1])([&](auto input, auto weights) { - auto in = input.get_shape().lens(); - auto in_h = in[2]; - auto in_w = in[3]; - - auto wei = weights.get_shape().lens(); - auto wei_n = wei[0]; - auto wei_c = wei[1]; - auto wei_h = wei[2]; - auto wei_w = wei[3]; - - par_dfor(output_shape.lens()[0], - output_shape.lens()[1], - output_shape.lens()[2], - output_shape.lens()[3])( - [&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) { - const auto start_x = i * op.stride[0] - op.padding[0]; - const auto start_y = j * op.stride[1] - op.padding[1]; - const auto group_id = w / (wei_n / op.group); - - type acc = type{0}; - dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) { - const auto in_x = start_x + x; - const auto in_y = start_y + y; - const auto in_ch = group_id * wei_c + k; - if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w) - acc += input(o, in_ch, in_x, in_y) * weights(w, k, x, y); - }); - output(o, w, i, j) = acc; - }); - }); - }); - return result; - } -}; - struct cpu_im2col { op::im2col op; @@ -231,7 +66,10 @@ struct cpu_im2col } static std::string name() { return "cpu::im2col"; } - shape compute_shape(const std::vector& inputs) const { return op.compute_shape(inputs); } + shape compute_shape(const std::vector& inputs) const + { + return op.normalize_compute_shape(inputs); + } argument compute(context&, const shape& output_shape, std::vector args) const { @@ -281,104 +119,40 @@ struct cpu_im2col return result; } }; +MIGRAPHX_REGISTER_OP(cpu_im2col) -struct max_pool -{ - static std::string name() { return "max"; } - static double start() { return std::numeric_limits::lowest(); } - - static double apply(double x, double y) - { - double m = std::max(x, y); - return (m); - } - - static double final(double x, double) { return (x); } -}; - -struct avg_pool -{ - static std::string name() { return "average"; } - static double start() { return 0.0; } - - static double apply(double x, double y) { return x + y; } - - static double final(double x, double y) { return x / y; } -}; - -template -struct cpu_pooling +struct cpu_op { - op::pooling op; - + operation op = op::identity{}; template static auto reflect(Self& self, F f) { return migraphx::reflect(self.op, f); } - - std::string name() const { return "cpu::pooling_" + Op::name(); } - shape compute_shape(const std::vector& inputs) const { return op.compute_shape(inputs); } - argument compute(context&, const shape& output_shape, std::vector args) const - { - argument result{output_shape}; - visit_all(result, args[0])([&](auto output, auto input) { - using type = typename decltype(output)::value_type; - auto in_h = input.get_shape().lens()[2]; - auto in_w = input.get_shape().lens()[3]; - - par_dfor(output_shape.lens()[0], - output_shape.lens()[1], - output_shape.lens()[2], - output_shape.lens()[3])( - [&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) { - const int start_x0 = i * op.stride[0] - op.padding[0]; - const int start_y0 = j * op.stride[1] - op.padding[1]; - - const int hend = std::min(start_x0 + op.lengths[0], in_h); - const int wend = std::min(start_y0 + op.lengths[1], in_w); - - const int start_x = std::max(start_x0, 0); - const int start_y = std::max(start_y0, 0); - - const int w_h = (hend - start_x); - const int w_w = (wend - start_y); - const int pool_size = std::max(w_h * w_w, 1); - - double acc = Op::start(); - dfor(w_h, w_w)([&](int x, int y) { - const int in_x = start_x + x; - const int in_y = start_y + y; - if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w) - { - acc = Op::apply(acc, input(o, w, in_x, in_y)); - } - }); - output(o, w, i, j) = type(Op::final(acc, pool_size)); - }); - }); - return result; - } -}; - -struct cpu_op -{ - operation op; - std::string name() const { return "cpu::" + op.name(); } + std::string name() const { return "cpu::op"; } shape compute_shape(const std::vector& inputs) const { return op.compute_shape(inputs); } argument compute(context&, const shape& output_shape, const std::vector& args) const { return op.compute(output_shape, args); } - friend bool operator==(const cpu_op& x, const cpu_op& y) { return x.op == y.op; } - friend bool operator==(const cpu_op& x, const operation& y) + value to_value() const + { + value v; + v["name"] = op.name(); + v["operator"] = op.to_value(); + return v; + } + void from_value(const value& v) { - if(x.name() != y.name()) - return false; - return x == any_cast(y); + op = make_op(v.at("name").to(), v.at("operator")); + } + friend std::ostream& operator<<(std::ostream& os, const cpu_op& x) + { + os << "cpu::" << x.op; + return os; } - friend bool operator==(const operation& x, const cpu_op& y) { return y == x; } }; +MIGRAPHX_REGISTER_OP(cpu_op) struct cpu_pad { @@ -390,13 +164,16 @@ struct cpu_pad return migraphx::reflect(self.op, f); } - std::string name() const { return "cpu::contiguous"; } + std::string name() const { return "cpu::pad"; } shape compute_shape(const std::vector& inputs) const { return op.compute_shape(inputs); } argument compute(context&, const shape& output_shape, std::vector args) const { assert(output_shape.standard()); argument result{output_shape}; - result.visit([&](auto output) { std::fill(output.begin(), output.end(), op.value); }); + result.visit([&](auto output) { + using type = typename decltype(output)::value_type; + std::fill(output.begin(), output.end(), pad_clamp(op.value)); + }); visit_all(result, args[0])([&](auto output, auto input) { shape_for_each(input.get_shape(), [&](const auto& idx) { @@ -412,123 +189,7 @@ struct cpu_pad return result; } }; - -struct cpu_gemm -{ - op::dot op; - - template - static auto reflect(Self& self, F f) - { - return migraphx::reflect(self.op, f); - } - std::string name() const { return "cpu::dot"; } - shape compute_shape(const std::vector& inputs) const - { - if(inputs.size() == 3) - { - auto c_shape = inputs.at(2); - check_shapes{{c_shape}}.not_broadcasted(); - } - return op.compute_shape(inputs); - } - - argument compute(context&, const shape& output_shape, std::vector args) const - { - argument result{output_shape}; - // 3 inputs, it is alpha * A * B + beta * C, then - // A and B are matrices, and C is of the same shape as A * B - if(args.size() == 3) - { - // no need to consider the value of args[2] - if(op.beta == 0.0f) - { - result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); }); - } - else - { - visit_all(result, args[2])([&](auto output, auto input) { - std::copy(input.begin(), input.end(), output.begin()); - }); - } - - migemm(result, args[0], args[1], op.alpha, op.beta); - - return result; - } - - // 2 input arguments - migemm(result, args[0], args[1], op.alpha, 0.0f); - - return result; - } -}; - -struct cpu_quant_gemm -{ - op::quant_dot op; - - template - static auto reflect(Self& self, F f) - { - return migraphx::reflect(self.op, f); - } - - std::string name() const { return "cpu::quant_dot"; } - shape compute_shape(const std::vector& inputs) const - { - if(inputs.size() == 3) - { - auto c_shape = inputs.at(2); - check_shapes{{c_shape}}.not_broadcasted(); - } - return op.compute_shape(inputs); - } - - argument compute(context&, const shape& output_shape, std::vector args) const - { - argument result{output_shape}; - // 3 inputs, it is alpha * A * B + beta * C, then - // A and B are matrices, and C is of the same shape to A * B - - // first, convert the args[0] and args[1] from int8_t to int32_t - argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}}; - argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}}; - arg_0.visit([&](auto output) { - args.at(0).visit( - [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); }); - }); - - arg_1.visit([&](auto output) { - args.at(1).visit( - [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); }); - }); - - if(args.size() == 3) - { - // no need to consider the value of args[2] - if(op.beta == 0) - { - result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); }); - } - else - { - visit_all(result, args[2])([&](auto output, auto input) { - std::copy(input.begin(), input.end(), output.begin()); - }); - } - - migemm(result, arg_0, arg_1, op.alpha, op.beta); - - return result; - } - - // 2 input arguments - migemm(result, arg_0, arg_1, op.alpha, int32_t{0}); - - return result; - } -}; +MIGRAPHX_REGISTER_OP(cpu_pad) struct leaky_relu_op { @@ -541,20 +202,16 @@ struct leaky_relu_op } }; -struct elu_op +template +struct cpu_unary2 : auto_register_op> { - op::elu op; - std::string name() const { return "cpu::elu"; } - auto fcn() const + cpu_unary2() = default; + + template + cpu_unary2(T pop) : op(Op{std::move(pop)}) { - auto a = op.alpha; - return [a](auto x) { return x > 0 ? x : a * std::expm1(x); }; } -}; -template -struct cpu_unary -{ Op op; template @@ -565,8 +222,8 @@ struct cpu_unary std::string name() const { return op.name(); } shape compute_shape(const std::vector& inputs) const { - check_shapes{inputs}.has(1); - auto s = inputs.at(0); + check_shapes{inputs, *this}.has(1); + const auto& s = inputs.at(0); return {s.type(), s.lens()}; } @@ -581,11 +238,11 @@ struct cpu_unary return result; } }; +template struct cpu_unary2; -template -struct cpu_softmax +struct cpu_rnn_var_sl_last_output { - Op op; + op::rnn_var_sl_last_output op; template static auto reflect(Self& self, F f) @@ -593,94 +250,163 @@ struct cpu_softmax return migraphx::reflect(self.op, f); } - std::string name() const { return "cpu::" + op.name(); } - shape compute_shape(const std::vector& inputs) const { return op.compute_shape(inputs); } - argument compute(context&, const shape& output_shape, std::vector args) const + std::string name() const { return "cpu::rnn_var_sl_last_output"; } + + shape compute_shape(std::vector inputs) const + { + return op.compute_shape(std::move(inputs)); + } + + argument compute(const shape& output_shape, std::vector args) const { argument result{output_shape}; - auto batch_lens = output_shape.lens(); - std::size_t n_dims = batch_lens[op.axis]; - batch_lens[op.axis] = 1; - shape batch_shape{shape::int32_type, batch_lens}; + auto out_comp_lens = args[0].get_shape().lens(); + out_comp_lens[0] = 1; + shape out_comp_s{output_shape.type(), out_comp_lens}; visit_all(result, args[0])([&](auto output, auto input) { - using value_type = typename decltype(input)::value_type; - std::vector batch_max(batch_shape.elements(), - std::numeric_limits::lowest()); - std::vector batch_sum(batch_shape.elements(), value_type(0)); - par_for(batch_shape.elements(), [&](auto i) { - auto idx = batch_shape.multi(i); - for(std::size_t j = 0; j < n_dims; ++j) - { - idx[op.axis] = j; - batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end())); - } - - for(std::size_t j = 0; j < n_dims; ++j) - { - idx[op.axis] = j; - std::size_t index = output_shape.index(idx); - output[index] = std::exp(input[index] - batch_max[i]); - } - - for(std::size_t j = 0; j < n_dims; ++j) - { - idx[op.axis] = j; - batch_sum[i] += output(idx.begin(), idx.end()); - } - - for(std::size_t j = 0; j < n_dims; ++j) - { - idx[op.axis] = j; - output(idx.begin(), idx.end()) = - op.output()(output(idx.begin(), idx.end()), batch_sum[i]); - } + args[1].visit([&](auto seq_lens) { + par_for(output_shape.elements(), [&](auto i) { + auto idx = out_comp_s.multi(i); + auto b = idx[2]; + if(op.direction == op::rnn_direction::reverse or idx[1] == 1) + { + idx[0] = 0; + } + else + { + idx[0] = seq_lens[b] - 1; + } + output[i] = input(idx.begin(), idx.end()); + }); }); }); return result; } }; +MIGRAPHX_REGISTER_OP(cpu_rnn_var_sl_last_output) struct cpu_apply { - program* prog; - std::unordered_map> apply_map{}; + module* modl; + std::unordered_map> apply_map{}; + instruction_ref last{}; - template - auto simple_op() + void extend_op(const std::string& op_name, const std::string& cpu_name, bool allocate = true) { - return [this](instruction_ref ins) { apply_simple_op(ins); }; + apply_map.emplace(op_name, [=](instruction_ref ins) { + auto&& op = ins->get_operator(); + if(allocate) + return replace(ins, make_op(cpu_name, op.to_value())); + return modl->replace_instruction(ins, make_op(cpu_name, op.to_value()), ins->inputs()); + }); + } + + void extend_dnnl_algos(const std::string& dnnl_name, + const std::vector>& algos) + { + for(auto&& pp : algos) + { + std::string op_name = pp.first; + std::string algo = pp.second; + apply_map.emplace(op_name, [=](instruction_ref ins) { + auto v = ins->get_operator().to_value(); + if(not v.is_object()) + return ins; + v["algo"] = algo; + auto op = make_op(dnnl_name, v); + return replace(ins, op); + }); + } } - template - auto extend_op() + template + auto fuse_match(M matcher, const operation& op, const std::vector& bind_inputs) { - return [this](instruction_ref ins) { apply_extend_op(ins); }; + return match::make_match_finder(matcher, [=](auto&, const auto& r) { + auto ins = r.result; + std::vector inputs; + std::transform(bind_inputs.begin(), + bind_inputs.end(), + std::back_inserter(inputs), + [&](const auto& s) { return r.instructions[s]; }); + inputs.push_back(this->insert_allocation(ins, ins->get_shape())); + modl->replace_instruction(ins, op, inputs); + }); } void init() { - apply_map["batch_norm_inference"] = - extend_op(); - apply_map["convolution"] = extend_op, op::convolution>(); - apply_map["dot"] = extend_op(); - apply_map["quant_dot"] = extend_op(); - apply_map["quant_convolution"] = - extend_op, op::quant_convolution>(); - apply_map["elu"] = extend_op, op::elu>(); - apply_map["im2col"] = extend_op(); - apply_map["leaky_relu"] = extend_op, op::leaky_relu>(); - apply_map["logsoftmax"] = extend_op, op::logsoftmax>(); - apply_map["lrn"] = extend_op(); - apply_map["pad"] = extend_op(); - apply_map["softmax"] = extend_op, op::softmax>(); + extend_dnnl_algos("dnnl::binary", + { + {"add", "binary_add"}, + {"div", "binary_div"}, + {"max", "binary_max"}, + {"min", "binary_min"}, + {"mul", "binary_mul"}, + }); + + extend_dnnl_algos("dnnl::eltwise", + { + {"abs", "eltwise_abs"}, + {"elu", "eltwise_elu"}, + {"exp", "eltwise_exp"}, + {"log", "eltwise_log"}, + {"relu", "eltwise_relu"}, + {"sqrt", "eltwise_sqrt"}, + {"tanh", "eltwise_tanh"}, + }); + + extend_dnnl_algos("dnnl::reduction", + { + {"reduce_max", "reduction_max"}, + {"reduce_mean", "reduction_mean"}, + {"reduce_min", "reduction_min"}, + {"reduce_sum", "reduction_sum"}, + }); + + extend_op("concat", "dnnl::concat"); + extend_op("contiguous", "dnnl::reorder"); + extend_op("convolution", "dnnl::convolution"); +#ifndef MIGRAPHX_ENABLE_ZENDNN + extend_op("deconvolution", "dnnl::deconvolution"); + extend_op("dot", "dnnl::dot"); +#endif + extend_op("erf", "cpu::erf"); + extend_op("gather", "cpu::gather"); + extend_op("logsoftmax", "dnnl::logsoftmax"); + extend_op("lrn", "dnnl::lrn"); + extend_op("softmax", "dnnl::softmax"); + extend_op("sub", "cpu::sub"); + + extend_op("im2col", "cpu::im2col", false); + extend_op("leaky_relu", "cpu::leaky_relu", false); + extend_op("pad", "cpu::pad", false); + extend_op("rnn_var_sl_last_output", "cpu::rnn_var_sl_last_output", false); } void apply() { init(); - for(auto it : iterator_for(*prog)) + // Apply fusion matchers first + match::find_matches(*modl, + fuse_match(match::gelu_erf(), + make_op("dnnl::eltwise", {{"algo", "eltwise_gelu_erf"}}), + {"x"}), + fuse_match(match::gelu_tanh(), + make_op("dnnl::eltwise", {{"algo", "eltwise_gelu_tanh"}}), + {"x"}), + fuse_match(match::layernorm(), make_op("dnnl::layernorm"), {"x"})); + // Apply these operators first so the inputs can be const folded + for(auto it : iterator_for(*modl)) + { + if(it->name() == "pow") + { + apply_pow(it); + } + } + for(auto it : iterator_for(*modl)) { if(it->name() == "pooling") { @@ -690,42 +416,62 @@ struct cpu_apply { apply_map.at(it->name())(it); } - else if(is_context_free(it->get_operator())) - { - apply_cpu_op(it); - } } } - void apply_cpu_op(instruction_ref ins) + instruction_ref apply_pow(instruction_ref ins) const + { + auto beta = read_scalar(ins->inputs()[1]); + if(beta.empty()) + return ins; + return replace(ins, + make_op("dnnl::eltwise", + {{"algo", "eltwise_pow"}, {"alpha", 1.0}, {"beta", beta.front()}}), + {ins->inputs().front()}); + } + + instruction_ref apply_pooling(instruction_ref ins) const { - prog->replace_instruction(ins, cpu_op{ins->get_operator()}, ins->inputs()); + auto&& op = ins->get_operator(); + auto v = op.to_value(); + if(has_op("dnnl::pooling") and ins->get_shape().type() == shape::type_t::float_type and + not v["ceil_mode"].to()) + return replace(ins, make_op("dnnl::pooling", op.to_value())); + return ins; } template - void apply_simple_op(instruction_ref ins) + static std::vector read_scalar(instruction_ref ins) + { + if(ins->name() == "contiguous") + return read_scalar(ins->inputs().front()); + if(ins->get_shape().elements() != 1 and not ins->get_shape().scalar()) + return {}; + auto r = ins->eval(); + if(r.empty()) + return {}; + return {r.at()}; + } + + instruction_ref replace(instruction_ref ins, const operation& op) const { - prog->replace_instruction(ins, T{}, ins->inputs()); + return replace(ins, op, ins->inputs()); } - template - void apply_extend_op(instruction_ref ins) + instruction_ref + replace(instruction_ref ins, const operation& op, std::vector inputs) const { - auto&& op = any_cast(ins->get_operator()); - prog->replace_instruction(ins, T{op}, ins->inputs()); + inputs.push_back(insert_allocation(ins, ins->get_shape())); + return modl->replace_instruction(ins, op, inputs); } - void apply_pooling(instruction_ref ins) + instruction_ref insert_allocation(instruction_ref ins, const shape& s) const { - auto&& op = any_cast(ins->get_operator()); - if(op.mode == "max") - prog->replace_instruction(ins, cpu_pooling{op}, ins->inputs()); - else if(op.mode == "average") - prog->replace_instruction(ins, cpu_pooling{op}, ins->inputs()); + return modl->insert_instruction(ins, make_op("allocate", {{"shape", to_value(s)}})); } }; -void lowering::apply(program& p) const { cpu_apply{&p}.apply(); } +void lowering::apply(module& m) const { cpu_apply{&m}.apply(); } } // namespace cpu } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/cpu/lrn.cpp b/src/targets/cpu/lrn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..93dc4f606075a128c0f2a3f861610102eb61fc58 --- /dev/null +++ b/src/targets/cpu/lrn.cpp @@ -0,0 +1,25 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct dnnl_lrn : dnnl_extend_op +{ + dnnl::lrn_forward::desc get_desc(const std::unordered_map& m) const + { + return {dnnl::prop_kind::forward_inference, + dnnl::algorithm::lrn_across_channels, + m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)), + this->op.size, + this->op.alpha, + this->op.beta, + this->op.bias}; + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/pooling.cpp b/src/targets/cpu/pooling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..540c6e50b8885410855bd939b1e8a8e048616a13 --- /dev/null +++ b/src/targets/cpu/pooling.cpp @@ -0,0 +1,39 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct dnnl_pooling : dnnl_extend_op +{ + std::vector arg_map(int) const { return {MIGRAPHX_DNNL_PREFIX(ARG_SRC)}; } + + dnnl::pooling_forward::desc get_desc(const std::unordered_map& m) const + { + auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max + : dnnl::algorithm::pooling_avg; + auto kdims = op.kdims(); + std::vector padding_l(op.padding.begin(), op.padding.begin() + kdims); + std::vector padding_r(op.padding.begin() + kdims, op.padding.end()); + return {dnnl::prop_kind::forward_inference, + algo, + m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)), + m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), + to_dnnl_dims(op.stride), + to_dnnl_dims(op.lengths), + to_dnnl_dims(padding_l), + to_dnnl_dims(padding_r)}; + } +}; + +} // namespace cpu + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/preallocate.cpp b/src/targets/cpu/preallocate.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7566a51d780688007768c35b23bc6d035443a816 --- /dev/null +++ b/src/targets/cpu/preallocate.cpp @@ -0,0 +1,37 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct cpu_preallocate : auto_register_op +{ + shape s; + std::string id = ""; + argument data; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.s, "shape"), f(self.id, "id")); + } + + std::string name() const { return "cpu::preallocate"; } + shape compute_shape(const std::vector& inputs) const + { + check_shapes{inputs, *this}.has(0); + return s; + } + argument compute(context&, const shape&, const std::vector&) const { return data; } + void finalize(context&, const shape&, const std::vector&) { data = argument(s); } + lifetime get_lifetime() const { return lifetime::global; } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/reduction.cpp b/src/targets/cpu/reduction.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ab4b039d0d90a82e34d5a310e9eaa90bd7978c9c --- /dev/null +++ b/src/targets/cpu/reduction.cpp @@ -0,0 +1,50 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct dnnl_reduction : dnnl_op +{ + std::string algo; + std::vector axes{}; + template + static auto reflect(Self& self, F f) + { + return pack_join(self.reflect_base(self, f), + pack(f(self.algo, "algo"), f(self.axes, "axes"))); + } + + std::string name() const { return "dnnl::reduction"; } + + shape compute_shape(std::vector inputs) const + { + // Compensate for allocation + inputs.pop_back(); + check_shapes{this->trim_post_op_inputs(inputs), *this}.has(1).standard(); + auto s = inputs.at(0); + auto lens = s.lens(); + for(auto axis : axes) + { + lens[axis] = 1; + } + auto r = shape{s.type(), lens}; + // Call to get_primitive to make sure an algo is available + this->get_primitive(this->to_memory_desc(r, inputs)); + return r; + } + + dnnl::reduction::desc get_desc(const std::unordered_map& m) const + { + return {to_dnnl_algo(algo), + m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)), + m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), + 0, + 0}; + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/reorder.cpp b/src/targets/cpu/reorder.cpp new file mode 100755 index 0000000000000000000000000000000000000000..f92308004232629e00581fb35cb07821bc46b88a --- /dev/null +++ b/src/targets/cpu/reorder.cpp @@ -0,0 +1,42 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct dnnl_reorder : dnnl_op +{ + std::string name() const { return "dnnl::reorder"; } + + shape adjust_shape(const shape& x, int) const { return x; } + + shape compute_shape(const std::vector& inputs) const + { + check_shapes{inputs, *this}.has(2); + auto r = inputs.back(); + // Call to get_primitive to make sure an algo is available + this->get_primitive(this->to_memory_desc(r, inputs)); + return r; + } + // Custom desc class since its missing in dnnl + struct desc + { + dnnl::memory::desc src; + dnnl::memory::desc dst; + }; + desc get_desc(const std::unordered_map& m) const + { + return {m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)), m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST))}; + } + + auto get_primitive_desc(const desc& d, const dnnl::primitive_attr& attr) const + { + auto& engine = get_dnnl_context().engine; + return dnnl::reorder::primitive_desc(engine, d.src, engine, d.dst, attr); + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/softmax.cpp b/src/targets/cpu/softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..568a6ca39b4d86f5696d765f0c91283cde60ef5e --- /dev/null +++ b/src/targets/cpu/softmax.cpp @@ -0,0 +1,20 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct dnnl_softmax : dnnl_extend_op +{ + dnnl::softmax_forward::desc get_desc(const std::unordered_map& m) const + { + int axis = this->op.axis; + return {dnnl::prop_kind::forward_inference, m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)), axis}; + } +}; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/sub.cpp b/src/targets/cpu/sub.cpp new file mode 100755 index 0000000000000000000000000000000000000000..c42dc5bdf874bdb47042327bf3d994af8769452c --- /dev/null +++ b/src/targets/cpu/sub.cpp @@ -0,0 +1,13 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +template struct cpu_binary; + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/cpu/target.cpp b/src/targets/cpu/target.cpp old mode 100644 new mode 100755 index 92969e2ed48eedf93362c3ab70712ccca6503786..36d19df7698d7068fb2471732b664e14e30572b1 --- a/src/targets/cpu/target.cpp +++ b/src/targets/cpu/target.cpp @@ -1,11 +1,38 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include +#include #include #include -#include -#include -#include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -13,18 +40,55 @@ namespace cpu { std::string target::name() const { return "cpu"; } -std::vector target::get_passes(migraphx::context&, const compile_options&) const +// cppcheck-suppress constParameter +std::vector target::get_passes(migraphx::context& gctx, const compile_options&) const { - return {rewrite_rnn{}, + auto& ctx = any_cast(gctx); + std::set unsupported_types(shape::types().begin(), shape::types().end()); + unsupported_types.erase(shape::type_t::float_type); + return {normalize_ops{}, + rewrite_quantization{}, + dead_code_elimination{}, + eliminate_data_type{unsupported_types, shape::type_t::float_type}, + dead_code_elimination{}, + simplify_reshapes{}, + eliminate_identity{}, + eliminate_pad{}, + dead_code_elimination{}, + rewrite_batchnorm{}, dead_code_elimination{}, + rewrite_rnn{}, + dead_code_elimination{}, + eliminate_common_subexpression{}, + dead_code_elimination{}, + simplify_algebra{}, + simplify_reshapes{}, + simplify_algebra{}, auto_contiguous{}, + simplify_reshapes{}, + propagate_constant{}, dead_code_elimination{}, lowering{}, + eliminate_contiguous{"dnnl::reorder"}, + dead_code_elimination{}, + replace_allocate{cpu_allocation_model{}}, + dead_code_elimination{}, + adjust_allocation{cpu_allocation_model{}}, + dead_code_elimination{}, + fuse_ops{&ctx}, + dead_code_elimination{}, + write_literals{}, + dead_code_elimination{}, + memory_coloring{"cpu::allocate"}, + dead_code_elimination{}, + preallocate_param{"scratch", cpu_allocation_model{}}, dead_code_elimination{}}; } argument target::allocate(const shape& s) const { return fill_argument(s, 0); } +MIGRAPHX_REGISTER_TARGET(target); + } // namespace cpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/cpu/write_literals.cpp b/src/targets/cpu/write_literals.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5c2838ab961e2433c0b0f4d8ec8ae5ec316f9ac7 --- /dev/null +++ b/src/targets/cpu/write_literals.cpp @@ -0,0 +1,45 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace cpu { + +struct cpu_literal +{ + argument data; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.data, "data")); + } + + std::string name() const { return "cpu::literal"; } + + shape compute_shape(const std::vector&) const { return data.get_shape(); } + + argument compute(const shape&, const std::vector&) const { return data; } + + friend std::ostream& operator<<(std::ostream& os, const cpu_literal& x) + { + os << x.name(); + return os; + } +}; + +void write_literals::apply(module& m) const +{ + for(auto ins : iterator_for(m)) + { + if(ins->name() != "@literal") + continue; + m.replace_instruction(ins, cpu_literal{ins->get_literal().get_argument()}); + } +} + +} // namespace cpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt old mode 100644 new mode 100755 index 575c8c3a6191a2fcbec20f2a9da453f6b1fbd9df..822782a661c49124cd4cf715307b069dca477ab3 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -10,8 +10,15 @@ if(NOT TARGET MIOpen) message(SEND_ERROR "Cant find miopen") endif() +include(Embed) +file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp) +message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") +add_embed_library(migraphx_kernels ${KERNEL_FILES}) + add_library(migraphx_device device/acos.cpp + device/acosh.cpp device/add.cpp device/add_clip.cpp device/add_relu.cpp @@ -20,7 +27,9 @@ add_library(migraphx_device device/argmax.cpp device/argmin.cpp device/asin.cpp + device/asinh.cpp device/atan.cpp + device/atanh.cpp device/ceil.cpp device/clip.cpp device/concat.cpp @@ -29,27 +38,45 @@ add_library(migraphx_device device/cos.cpp device/cosh.cpp device/div.cpp + device/equal.cpp device/erf.cpp device/exp.cpp + device/fill.cpp device/floor.cpp device/gather.cpp + device/gelu.cpp + device/greater.cpp device/int8_gemm_pack.cpp + device/layernorm.cpp + device/less.cpp device/log.cpp + device/logical_and.cpp + device/logical_or.cpp + device/logical_xor.cpp device/logsoftmax.cpp device/max.cpp device/min.cpp device/mul.cpp device/mul_add.cpp device/mul_add_relu.cpp + device/multinomial.cpp + device/nonzero.cpp device/pad.cpp device/pow.cpp + device/prelu.cpp + device/prefix_scan_sum.cpp + device/recip.cpp device/reduce_max.cpp device/reduce_mean.cpp device/reduce_min.cpp device/reduce_sum.cpp + device/reduce_prod.cpp device/relu.cpp + device/reverse.cpp + device/rnn_variable_seq_lens.cpp device/round.cpp device/rsqrt.cpp + device/scatter.cpp device/sigmoid.cpp device/sign.cpp device/sin.cpp @@ -60,57 +87,281 @@ add_library(migraphx_device device/sub.cpp device/tan.cpp device/tanh.cpp + device/topk.cpp + device/unary_not.cpp + device/where.cpp ) +add_library(compile_for_gpu INTERFACE) +target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns) +target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored) +check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE) +if(HAS_HIP_LAMBDA_HOST_DEVICE) + message(STATUS "Enable -fhip-lambda-host-device") + target_compile_options(compile_for_gpu INTERFACE -fhip-lambda-host-device) +endif() + set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) rocm_set_soversion(migraphx_device ${MIGRAPHX_SO_VERSION}) rocm_clang_tidy_check(migraphx_device) -target_compile_options(migraphx_device PRIVATE -std=c++17) -target_link_libraries(migraphx_device migraphx hip::device -Wno-invalid-command-line-argument -amdgpu-target=gfx803 -amdgpu-target=gfx900 -amdgpu-target=gfx906) +target_link_libraries(migraphx_device PUBLIC migraphx) +target_link_libraries(migraphx_device PRIVATE compile_for_gpu) target_include_directories(migraphx_device PUBLIC $) target_include_directories(migraphx_device PRIVATE $) +add_library(kernel_file_check EXCLUDE_FROM_ALL) +foreach(KERNEL_FILE ${KERNEL_FILES}) + get_filename_component(KERNEL_BASE_FILE ${KERNEL_FILE} NAME_WE) + file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp "#include \n") + target_sources(kernel_file_check PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp) +endforeach() +target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_NLOCAL=256) +target_include_directories(kernel_file_check PRIVATE $) +target_link_libraries(kernel_file_check compile_for_gpu) + +rocm_clang_tidy_check(kernel_file_check) + +file(GLOB JIT_GPU_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp) add_library(migraphx_gpu + abs.cpp + analyze_streams.cpp + allocation_model.cpp argmax.cpp argmin.cpp + batch_norm_inference.cpp + clip.cpp + code_object_op.cpp + compile_ops.cpp + compile_gen.cpp + compile_hip.cpp + compile_hip_code_object.cpp + compiler.cpp + concat.cpp + convert.cpp + convolution.cpp + deconvolution.cpp + device_name.cpp eliminate_workspace.cpp + elu.cpp fuse_ops.cpp + gather.cpp + gemm_impl.cpp hip.cpp - target.cpp + int8_conv_pack.cpp + int8_gemm_pack.cpp + kernel.cpp lowering.cpp - pooling.cpp - convolution.cpp - quant_convolution.cpp - softmax.cpp logsoftmax.cpp - contiguous.cpp - concat.cpp + loop.cpp + lrn.cpp leaky_relu.cpp - batchnorm.cpp - write_literals.cpp - rocblas.cpp - abs.cpp - elu.cpp + mlir_conv.cpp + multinomial.cpp + nonzero.cpp + pack_args.cpp + pack_int8_args.cpp + prefuse_ops.cpp pad.cpp - gather.cpp - convert.cpp - lrn.cpp + pooling.cpp + quant_convolution.cpp + reverse.cpp + rnn_variable_seq_lens.cpp + rocblas.cpp + scatter.cpp schedule_model.cpp - adjust_allocation.cpp - pack_int8_args.cpp - clip.cpp - int8_gemm_pack.cpp - int8_conv_pack.cpp - gemm_impl.cpp - preallocate_param.cpp + softmax.cpp + sync_device.cpp + target.cpp + topk.cpp + write_literals.cpp + ${JIT_GPU_SRCS} ) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) + +function(register_migraphx_gpu_ops PREFIX) + foreach(OP ${ARGN}) + register_op(migraphx_gpu HEADER migraphx/gpu/${OP}.hpp OPERATORS gpu::${PREFIX}${OP} INCLUDES migraphx/gpu/context.hpp) + endforeach() +endfunction() +register_migraphx_gpu_ops(hip_ + acosh + acos + add + argmax + argmin + asinh + asin + atanh + atan + ceil + clip + concat + convert + cosh + cos + div + equal + erf + exp + floor + gather + greater + less + log + logsoftmax + logical_and + logical_or + logical_xor + loop + max + min + mul + multinomial + nonzero + pad + pow + prelu + prefix_scan_sum + recip + reduce_max + reduce_mean + reduce_min + reduce_prod + reduce_sum + relu + reverse + round + rsqrt + scatter + sigmoid + sign + sinh + sin + softmax + sqdiff + sqrt + sub + tanh + tan + topk + unary_not + where +) +register_migraphx_gpu_ops(miopen_ + abs + batch_norm_inference + contiguous + convolution + deconvolution + elu + int8_conv_pack + leaky_relu + lrn + pooling + quant_convolution +) +register_op(migraphx_gpu + HEADER migraphx/gpu/rnn_variable_seq_lens.hpp + OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output + INCLUDES migraphx/gpu/context.hpp) +register_op(migraphx_gpu + HEADER migraphx/gpu/int8_gemm_pack.hpp + OPERATORS gpu::hip_int8_gemm_pack_a gpu::hip_int8_gemm_pack_b + INCLUDES migraphx/gpu/context.hpp) +register_op(migraphx_gpu + HEADER migraphx/gpu/gemm.hpp + OPERATORS gpu::rocblas_gemm gpu::rocblas_gemm + INCLUDES migraphx/gpu/context.hpp) rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION}) rocm_clang_tidy_check(migraphx_gpu) + +# look for offload bundler +get_filename_component(CMAKE_CXX_COMPILER_PATH "${CMAKE_CXX_COMPILER}" PATH) +if(CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+$") + find_program(MIGRAPHX_OFFLOADBUNDLER_BIN clang-offload-bundler + HINTS ${CMAKE_CXX_COMPILER_PATH} + PATH_SUFFIXES bin + PATHS /opt/rocm/llvm + ) +else() + find_program(MIGRAPHX_EXTRACT_KERNEL extractkernel + PATH_SUFFIXES bin + HINTS ${CMAKE_CXX_COMPILER_PATH} + PATHS + /opt/rocm/hip + /opt/rocm/hcc + /opt/rocm + ) +endif() + +message(STATUS "clang-offload-bundler: ${MIGRAPHX_OFFLOADBUNDLER_BIN}") +message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}") + +set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "") +if(MIGRAPHX_ENABLE_MLIR) + find_library(LIBMLIRMIOPEN MLIRMIOpenThin REQUIRED) + # REQUIRED is not supported before cmake 3.18 + if(NOT LIBMLIRMIOPEN) + message(FATAL_ERROR "libMLIRMIOpenThin not found") + else() + message(STATUS "Build with libMLIRMIOpenThin: " ${LIBMLIRMIOPEN}) + endif() + + target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR_MIOPEN_SUPPORT") + target_link_libraries(migraphx_gpu PUBLIC ${LIBMLIRMIOPEN}) +endif() + +set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "") +if(MIGRAPHX_USE_HIPRTC) +target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1) +else() +# Get flags needed to compile hip +include(TargetFlags) +target_flags(HIP_COMPILER_FLAGS hip::device) +# Remove cuda arch flags +string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") +string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") +# Skip library paths since hip will incorrectly treat it as a source file +string(APPEND HIP_COMPILER_FLAGS " ") +foreach(_unused RANGE 2) + string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") +endforeach() + +message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}") +target_compile_definitions(migraphx_gpu PRIVATE + "-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}" + "-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}" + "-DMIGRAPHX_OFFLOADBUNDLER_BIN=${MIGRAPHX_OFFLOADBUNDLER_BIN}" + "-DMIGRAPHX_EXTRACT_KERNEL=${MIGRAPHX_EXTRACT_KERNEL}" + "-DMIGRAPHX_USE_HIPRTC=0" +) +if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER) +execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER) +string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER) +target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}") +endif() + +endif() + +# Check miopen find mode api +include(CheckLibraryExists) +get_target_property(MIOPEN_LOCATION MIOpen LOCATION) +check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API) +if(HAS_FIND_MODE_API) + target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_MODE_API) + message(STATUS "MIOpen has find mode api") +else() + message(STATUS "MIOpen does not have find mode api") +endif() + +# Workaround broken rocblas headers +target_compile_definitions(migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) -target_link_libraries(migraphx_gpu PRIVATE migraphx_device) +target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) + +add_subdirectory(driver) rocm_install_targets( - TARGETS migraphx_gpu migraphx_device + TARGETS migraphx_gpu migraphx_device compile_for_gpu INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/include ) diff --git a/src/targets/gpu/abs.cpp b/src/targets/gpu/abs.cpp index bab281eb957c3e32d3d0823cba405b6ffbdeefc8..9759925d7b75c2e5381c93d7e625b8c90454b569 100644 --- a/src/targets/gpu/abs.cpp +++ b/src/targets/gpu/abs.cpp @@ -31,6 +31,8 @@ argument miopen_abs::compute(context& ctx, return args[1]; } +void miopen_abs::finalize(context&, const shape&, const std::vector&) { ad = make_abs(); } + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/adjust_allocation.cpp b/src/targets/gpu/adjust_allocation.cpp deleted file mode 100644 index aa2fa96fee30156289adfeec237c7c60bd777663..0000000000000000000000000000000000000000 --- a/src/targets/gpu/adjust_allocation.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include -#include -#include -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { -namespace gpu { - -void adjust_allocation::apply(program& p) const -{ - for(auto ins : iterator_for(p)) - { - // skip instruction with no input - if(ins->inputs().empty()) - continue; - - if(ins->name() == "load") - continue; - - auto alias_ins = instruction::get_output_alias(ins, true); - if(alias_ins->name() == "hip::allocate") - { - // shape allocated is different from actual shape - // of the instruction, reallocate and replace the previous one - if(alias_ins->get_shape() != ins->get_shape()) - { - auto alloc_ins = p.insert_instruction(ins, hip_allocate{ins->get_shape()}); - p.replace_instruction(alias_ins, alloc_ins); - } - } - } -} - -} // namespace gpu -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx diff --git a/src/targets/gpu/allocation_model.cpp b/src/targets/gpu/allocation_model.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e01f09881ce8f82fd7e76ccffa688610664eda21 --- /dev/null +++ b/src/targets/gpu/allocation_model.cpp @@ -0,0 +1,25 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +std::string gpu_allocation_model::name() const { return "hip::allocate"; } +operation gpu_allocation_model::allocate(const shape& s) const +{ + return make_op(name(), {{"shape", to_value(s)}}); +} + +operation gpu_allocation_model::preallocate(const shape& s, const std::string& id) const +{ + return make_op("hip::hip_allocate_memory", {{"shape", to_value(s)}, {"id", id}}); +} + +std::string gpu_allocation_model::copy() const { return "hip::copy"; } + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/analyze_streams.cpp b/src/targets/gpu/analyze_streams.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b52771f6cee8c9fa62cfc67d1b07d30ae7d44c2f --- /dev/null +++ b/src/targets/gpu/analyze_streams.cpp @@ -0,0 +1,59 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct hip_stream_model +{ + std::size_t max_stream = 0; + std::unordered_map ins2stream{}; + std::size_t get_nstream() const { return max_stream + 1; } + std::size_t get_stream(migraphx::instruction_ref ins) const { return ins2stream.at(ins); } + std::size_t get_event_id(migraphx::instruction_ref ins) const + { + auto v = ins->get_operator().to_value(); + return v["event"].to(); + } + bool has_stream(migraphx::instruction_ref ins) const { return ins2stream.count(ins) > 0; } + bool is_record(migraphx::instruction_ref ins) const + { + return ins->name() == "gpu::record_event"; + } + bool is_wait(migraphx::instruction_ref ins) const { return ins->name() == "gpu::wait_event"; } +}; + +stream_model make_stream_model(const module& m) +{ + hip_stream_model hsm; + std::size_t stream = 0; + for(auto ins : iterator_for(m)) + { + if(ins->name() == "gpu::set_stream") + { + auto v = ins->get_operator().to_value(); + stream = v["stream"].to(); + hsm.max_stream = std::max(stream, hsm.max_stream); + } + if(ins->get_operator().is_context_free()) + continue; + if(contains({"hip::hip_allocate_memory", "hip::hip_copy_literal", "@param"}, ins->name())) + continue; + hsm.ins2stream[ins] = stream; + } + return hsm; +} + +std::vector analyze_streams(const module& m) +{ + return migraphx::analyze_streams(m, make_stream_model(m)); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/argmax.cpp b/src/targets/gpu/argmax.cpp index 23165b8ec3dbd7088891c733f2c2b434ff65a40d..0c38a9d5bcce495e05965885c37bb19711513b22 100644 --- a/src/targets/gpu/argmax.cpp +++ b/src/targets/gpu/argmax.cpp @@ -1,6 +1,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -8,13 +9,15 @@ namespace gpu { shape hip_argmax::compute_shape(const std::vector& inputs) const { - check_shapes{inputs, *this}.has(2).standard(); - return op.compute_shape({inputs.at(0)}); + check_shapes{inputs, *this}.has(2); + return op.normalize_compute_shape({inputs.at(0)}); } argument hip_argmax::compute(context& ctx, const shape&, const std::vector& args) const { - device::argmax(ctx.get_stream().get(), args.back(), args.front(), op.axis); + auto n_dim = args.front().get_shape().lens().size(); + int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name()); + device::argmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); return args.back(); } diff --git a/src/targets/gpu/argmin.cpp b/src/targets/gpu/argmin.cpp index 68986d0ba3a00c0e8d650089aa3ce44562ec061d..1a2bf5b65284ef1809a83846e5d2c25bba745c25 100644 --- a/src/targets/gpu/argmin.cpp +++ b/src/targets/gpu/argmin.cpp @@ -1,6 +1,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -8,13 +9,15 @@ namespace gpu { shape hip_argmin::compute_shape(const std::vector& inputs) const { - check_shapes{inputs, *this}.has(2).standard(); - return op.compute_shape({inputs.at(0)}); + check_shapes{inputs, *this}.has(2); + return op.normalize_compute_shape({inputs.at(0)}); } argument hip_argmin::compute(context& ctx, const shape&, const std::vector& args) const { - device::argmin(ctx.get_stream().get(), args.back(), args.front(), op.axis); + auto n_dim = args.front().get_shape().lens().size(); + int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name()); + device::argmin(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); return args.back(); } diff --git a/src/targets/gpu/batchnorm.cpp b/src/targets/gpu/batch_norm_inference.cpp similarity index 68% rename from src/targets/gpu/batchnorm.cpp rename to src/targets/gpu/batch_norm_inference.cpp index 5871e379fd233cb25be7680153ae85b9368974b1..918be99104f6b16fe4a8fbb4a8b422eddcf79abd 100644 --- a/src/targets/gpu/batchnorm.cpp +++ b/src/targets/gpu/batch_norm_inference.cpp @@ -1,4 +1,4 @@ -#include +#include #include namespace migraphx { @@ -8,16 +8,33 @@ namespace gpu { shape miopen_batch_norm_inference::compute_shape(const std::vector& inputs) const { check_shapes{inputs, *this}.has(6); + check_shapes{inputs.data(), inputs.data() + 1, *this}.same_ndims().max_ndims(5); return op.compute_shape({inputs.at(0), inputs.at(1), inputs.at(2), inputs.at(3), inputs.at(4)}); } +inline shape reshape_to_2d(const shape& input) +{ + auto dims = input.lens(); + if(dims.size() >= 4) + return input; + + std::vector new_dims(dims.begin(), dims.end()); + std::size_t num = 4 - dims.size(); + new_dims.insert(new_dims.end(), num, 1); + return {input.type(), new_dims}; +} + argument miopen_batch_norm_inference::compute(context& ctx, const shape& output_shape, const std::vector& args) const { - auto x_desc = make_tensor(args[0].get_shape()); - auto y_desc = make_tensor(output_shape); - auto bn_desc = make_tensor(args[3].get_shape()); + shape x_shape = args[0].get_shape(); + shape y_shape = output_shape; + shape bn_shape = args[3].get_shape(); + + auto x_desc = make_tensor(reshape_to_2d(x_shape)); + auto y_desc = make_tensor(reshape_to_2d(y_shape)); + auto bn_desc = make_tensor(reshape_to_2d(bn_shape)); float alpha = 1.0; float beta = 0.0f; diff --git a/src/targets/gpu/clip.cpp b/src/targets/gpu/clip.cpp index 41519287bb09505486fbb65e367f058f54327257..5212be55e074ff46c116be1306b3dd0cc12c6cbd 100644 --- a/src/targets/gpu/clip.cpp +++ b/src/targets/gpu/clip.cpp @@ -14,7 +14,7 @@ shape hip_clip::compute_shape(std::vector inputs) const argument hip_clip::compute(context& ctx, const shape&, const std::vector& args) const { - device::clip(ctx.get_stream().get(), args.back(), args.front(), op.max_val, op.min_val); + device::clip(ctx.get_stream().get(), args.back(), args.front(), args.at(1), args.at(2)); return args.back(); } diff --git a/src/targets/gpu/code_object_op.cpp b/src/targets/gpu/code_object_op.cpp new file mode 100755 index 0000000000000000000000000000000000000000..568d428c6103ac8f751d717bcc799836911267f3 --- /dev/null +++ b/src/targets/gpu/code_object_op.cpp @@ -0,0 +1,42 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +MIGRAPHX_REGISTER_OP(code_object_op); + +shape code_object_op::compute_shape(std::vector inputs) const +{ + std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](const shape& s) { + return s.normalize_standard(); + }); + auto einputs = expected_inputs; + std::transform(einputs.begin(), einputs.end(), einputs.begin(), [](const shape& s) { + return s.normalize_standard(); + }); + if(einputs != inputs) + MIGRAPHX_THROW("Input shapes have changed: [" + to_string_range(einputs) + "] -> [" + + to_string_range(inputs) + "]"); + return output; +} +argument +code_object_op::compute(context& ctx, const shape&, const std::vector& args) const +{ + std::vector kargs(args.size()); + std::transform( + args.begin(), args.end(), kargs.begin(), [](const argument& a) { return a.data(); }); + k.launch(ctx.get_stream().get(), global, local, std::move(kargs)); + return args.back(); +} +void code_object_op::finalize(context&, const shape&, const std::vector&) +{ + assert(not code_object.empty()); + k = kernel(code_object, symbol_name); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/compile_gen.cpp b/src/targets/gpu/compile_gen.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c990a9786462c1369af6a7dbb0d1124c370878cd --- /dev/null +++ b/src/targets/gpu/compile_gen.cpp @@ -0,0 +1,105 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace gen { + +static std::vector vector_sizes(const std::vector& inputs) +{ + // If all inputs are half then only use half2 + if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) { + return s.type() == shape::half_type; + })) + return {2}; + return {4, 2}; +} + +vectorize vectorize::elements(std::size_t axis, const std::vector& inputs) +{ + auto sizes = vector_sizes(inputs); + std::vector max_vec_size; + std::transform(inputs.begin(), + inputs.end(), + std::back_inserter(max_vec_size), + [&](const auto& input) -> std::size_t { + auto stride = input.strides()[axis]; + auto len = input.lens()[axis]; + if(stride != 0 and stride != 1) + return 1; + if(len == 1 and input.elements() > sizes.front()) + return sizes.front(); + auto it = std::find_if( + sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; }); + if(it != sizes.end()) + return *it; + return 1; + }); + return {*std::min_element(max_vec_size.begin(), max_vec_size.end()), axis}; +} + +std::string vectorize::str() const +{ + return "vectorize<" + to_string(size) + ", " + to_string(axis) + ">()"; +} + +preload preload::broadcasts(std::size_t axis, const std::vector& inputs) +{ + const std::size_t max_lds_bytes = 4096; + std::vector result; + std::transform(inputs.begin(), + inputs.end(), + std::back_inserter(result), + [&](const shape& input) { return input.strides()[axis] == 0; }); + auto bytes = std::inner_product(inputs.begin(), + inputs.end(), + result.begin(), + std::size_t{0}, + std::plus<>{}, + [](const shape& s, bool b) -> std::size_t { + if(b) + return s.bytes(); + return 0; + }); + if(bytes < max_lds_bytes) + return {result}; + // TODO: Try to partially preload items + std::fill(result.begin(), result.end(), false); + return {result}; +} + +std::string preload::str() const +{ + std::vector bool_strs; + std::transform(args.begin(), std::prev(args.end()), std::back_inserter(bool_strs), [](bool b) { + if(b) + return "true"; + return "false"; + }); + return "auto_preload(idx)"; +} + +bool preload::is_preloading() const +{ + return std::accumulate(args.begin(), args.end(), false, std::logical_or<>{}); +} + +std::size_t find_fast_axis(const std::vector& inputs) +{ + auto permutation = find_permutation(inputs); + auto it = std::max_element(permutation.begin(), permutation.end()); + return it - permutation.begin(); +} + +std::string make_transformer_args(std::vector transformers) +{ + return join_strings(std::move(transformers), ", "); +} + +} // namespace gen +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/compile_hip.cpp b/src/targets/gpu/compile_hip.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b243f344512e9d5bc42f592e46fcf8d75f71dfde --- /dev/null +++ b/src/targets/gpu/compile_hip.cpp @@ -0,0 +1,280 @@ +#include +#include +#include +#include +#include +#include +#include + +#if MIGRAPHX_USE_HIPRTC +#include +#include +#include +#else +#include +#include +#endif + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC); + +#if MIGRAPHX_USE_HIPRTC + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC) + +std::string hiprtc_error(hiprtcResult err, const std::string& msg) +{ + return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg)); +} + +void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::string& ctx) +{ + if(err != HIPRTC_SUCCESS) + throw make_exception(ctx, hiprtc_error(err, msg)); +} + +#define MIGRAPHX_HIPRTC(...) \ + hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, MIGRAPHX_MAKE_SOURCE_CTX()) + +#define MIGRAPHX_HIPRTC_THROW(error, msg) MIGRAPHX_THROW(hiprtc_error(error, msg)) + +// Workaround hiprtc's broken API +void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); } +using hiprtc_program_ptr = MIGRAPHX_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy); + +template +hiprtc_program_ptr hiprtc_program_create(Ts... xs) +{ + hiprtcProgram prog = nullptr; + auto result = hiprtcCreateProgram(&prog, xs...); + hiprtc_program_ptr p{prog}; + if(result != HIPRTC_SUCCESS) + MIGRAPHX_HIPRTC_THROW(result, "Create program failed."); + return p; +} + +struct hiprtc_program +{ + struct string_array + { + std::vector strings{}; + std::vector c_strs{}; + + string_array() {} + string_array(const string_array&) = delete; + + std::size_t size() const { return strings.size(); } + + const char** data() { return c_strs.data(); } + + void push_back(std::string s) + { + strings.push_back(std::move(s)); + c_strs.push_back(strings.back().c_str()); + } + }; + + hiprtc_program_ptr prog = nullptr; + string_array headers{}; + string_array include_names{}; + std::string cpp_src = ""; + std::string cpp_name = ""; + + hiprtc_program(const std::vector& srcs) + { + for(auto&& src : srcs) + { + std::string content{src.content.first, src.content.second}; + std::string path = src.path.string(); + if(src.path.extension().string() == ".cpp") + { + cpp_src = std::move(content); + cpp_name = std::move(path); + } + else + { + headers.push_back(std::move(content)); + include_names.push_back(std::move(path)); + } + } + prog = hiprtc_program_create(cpp_src.c_str(), + cpp_name.c_str(), + headers.size(), + headers.data(), + include_names.data()); + } + + void compile(const std::vector& options) + { + if(enabled(MIGRAPHX_TRACE_HIPRTC{})) + std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl; + std::vector c_options; + std::transform(options.begin(), + options.end(), + std::back_inserter(c_options), + [](const std::string& s) { return s.c_str(); }); + auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data()); + std::cerr << log() << std::endl; + if(result != HIPRTC_SUCCESS) + MIGRAPHX_HIPRTC_THROW(result, "Compilation failed."); + } + + std::string log() + { + std::size_t n = 0; + MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n)); + if(n < 2) + return {}; + std::vector buffer(n); + MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data())); + assert(buffer.back() == 0); + return {buffer.begin(), buffer.end() - 1}; + } + + std::vector get_code_obj() + { + std::size_t n = 0; + MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n)); + std::vector buffer(n); + MIGRAPHX_HIPRTC(hiprtcGetCode(prog.get(), buffer.data())); + return buffer; + } +}; + +std::vector> +compile_hip_src(const std::vector& srcs, std::string params, const std::string& arch) +{ + hiprtc_program prog(srcs); + auto options = split_string(params, ' '); + if(enabled(MIGRAPHX_GPU_DEBUG{})) + options.push_back("-DMIGRAPHX_DEBUG"); + if(std::none_of(options.begin(), options.end(), [](const std::string& s) { + return starts_with(s, "--std=") or starts_with(s, "-std="); + })) + options.push_back("-std=c++17"); + options.push_back("-fno-gpu-rdc"); + options.push_back(" -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3")); + options.push_back("-Wno-cuda-compat"); + options.push_back("--cuda-gpu-arch=" + arch); + prog.compile(options); + return {prog.get_code_obj()}; +} + +#else // MIGRAPHX_USE_HIPRTC + +bool is_hcc_compiler() +{ + static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "hcc"); + return result; +} + +bool is_hip_clang_compiler() +{ + static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "clang++"); + return result; +} + +bool has_compiler_launcher() +{ + static const auto result = fs::exists(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER)); + return result; +} + +src_compiler assemble(src_compiler compiler) +{ + compiler.out_ext = ".S"; + compiler.flags = replace_string(compiler.flags, " -c", " -S"); + return compiler; +} + +std::vector> +compile_hip_src(const std::vector& srcs, std::string params, const std::string& arch) +{ + assert(not srcs.empty()); + if(not is_hcc_compiler() and not is_hip_clang_compiler()) + MIGRAPHX_THROW("Unknown hip compiler: " + + std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER))); + + if(params.find("-std=") == std::string::npos) + params += " --std=c++17"; + params += " -fno-gpu-rdc"; + params += " -c"; + if(is_hcc_compiler()) + { + params += " -amdgpu-target=" + arch; + } + else if(is_hip_clang_compiler()) + { + params += " --cuda-gpu-arch=" + arch; + params += " --cuda-device-only"; + params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " "; + } + + if(enabled(MIGRAPHX_GPU_DEBUG{})) + params += " -DMIGRAPHX_DEBUG"; + + params += " -Wno-unused-command-line-argument -Wno-cuda-compat "; + params += MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_FLAGS); + + src_compiler compiler; + compiler.flags = params; + compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER); +#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER + if(has_compiler_launcher()) + compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER); +#endif + + if(is_hcc_compiler()) + compiler.process = [&](const fs::path& obj_path) -> fs::path { + process{MIGRAPHX_STRINGIZE(MIGRAPHX_EXTRACT_KERNEL) + std::string{" -i "} + + obj_path.string()} + .cwd(obj_path.parent_path()); + for(const auto& entry : fs::directory_iterator{obj_path.parent_path()}) + { + const auto& hsaco_path = entry.path(); + if(not fs::is_regular_file(hsaco_path)) + continue; + if(hsaco_path.extension() != ".hsaco") + continue; + return hsaco_path; + } + MIGRAPHX_THROW("Missing hsaco"); + }; + + if(enabled(MIGRAPHX_GPU_DUMP_SRC{})) + { + for(const auto& src : srcs) + { + if(src.path.extension() != ".cpp") + continue; + std::cout << std::string(src.content.first, src.len()) << std::endl; + } + } + + if(enabled(MIGRAPHX_GPU_DUMP_ASM{})) + { + + std::cout << assemble(compiler).compile(srcs).data() << std::endl; + } + + return {compiler.compile(srcs)}; +} + +std::string enum_params(std::size_t count, std::string param) +{ + std::vector items(count); + transform(range(count), items.begin(), [&](auto i) { return param + std::to_string(i); }); + return join_strings(items, ","); +} + +#endif // MIGRAPHX_USE_HIPRTC + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/compile_hip_code_object.cpp b/src/targets/gpu/compile_hip_code_object.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6e0a0775c16a66d75bf4da3e0e15bd02701ececf --- /dev/null +++ b/src/targets/gpu/compile_hip_code_object.cpp @@ -0,0 +1,171 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +template +std::string generate_index_ints(const std::vector& v) +{ + return "index_ints<" + to_string_range(v) + ">{}"; +} + +std::string generate_make_shape(const shape& s) +{ + return "make_shape(" + generate_index_ints(s.lens()) + ", " + generate_index_ints(s.strides()) + + ")"; +} + +static const char* const make_tensor_template = R"__migraphx__( +template<> +struct make_tensor<${n}> +{ + static __device__ auto apply(void* p) + { + return make_tensor_view(reinterpret_cast<${type}*>(p), make_shape(${lens}, ${strides})); + } +}; +)__migraphx__"; + +std::string generate_make_tensor(std::size_t n, const shape& s) +{ + return interpolate_string(make_tensor_template, + {{"n", std::to_string(n)}, + {"type", shape::cpp_type(s.type())}, + {"lens", generate_index_ints(s.lens())}, + {"strides", generate_index_ints(s.strides())}}); +} + +std::string generate_args_hpp(const std::vector& inputs) +{ + std::string inner; + for(std::size_t i = 0; i < inputs.size(); i++) + { + inner += generate_make_tensor(i, inputs[i]); + } + const std::string args_hpp = R"__migraphx__( +#ifndef MIGRAPHX_GUARD_AUTO_ARGS_HPP +#define MIGRAPHX_GUARD_AUTO_ARGS_HPP + +#include +#include + +namespace migraphx { + +__content__ + +} // namespace migraphx +#endif +)__migraphx__"; + return replace_string(args_hpp, "__content__", inner); +} + +const std::vector& compiler_warnings() +{ + static std::vector warnings = {"-Weverything", + "-Wno-c++98-compat", + "-Wno-c++98-compat-pedantic", + "-Wno-conversion", + "-Wno-double-promotion", + "-Wno-exit-time-destructors", + "-Wno-extra-semi", + "-Wno-extra-semi-stmt", + "-Wno-float-conversion", + "-Wno-gnu-anonymous-struct", + "-Wno-gnu-zero-variadic-macro-arguments", + "-Wno-missing-prototypes", + "-Wno-nested-anon-types", + "-Wno-padded", + "-Wno-shorten-64-to-32", + "-Wno-sign-conversion", + "-Wno-sign-compare", + "-Wno-unused-command-line-argument", + "-Wno-weak-vtables", + "-Wno-c99-extensions"}; + return warnings; +} + +void hip_compile_options::set_launch_params( + const value& v, + const std::function& compute_global, + std::size_t default_local) +{ + local = v.get("local", default_local); + if(v.contains("global")) + global = v.at("global").to(); + else + global = compute_global(local); +} + +std::function +compute_global_for(context& ctx, std::size_t n, std::size_t over) +{ + assert(over > 0); + std::size_t max_global = ctx.get_current_device().get_cu_count() * + ctx.get_current_device().get_max_workitems_per_cu(); + return [n, over, max_global](std::size_t local) { + std::size_t groups = (n + local - 1) / local; + std::size_t max_blocks = max_global / local; + std::size_t nglobal = std::min(max_blocks * over, groups) * local; + return nglobal; + }; +} + +std::size_t compute_block_size(std::size_t n, std::size_t max_block_size) +{ + size_t block_size = 128; + while(block_size <= max_block_size and block_size <= n) + block_size *= 2; + return block_size / 2; +} + +operation compile_hip_code_object(const std::string& content, hip_compile_options options) +{ + assert(options.global > 0); + assert(options.local > 0); + assert(not options.inputs.empty()); + assert(options.inputs.size() == options.virtual_inputs.size() or + options.virtual_inputs.empty()); + std::vector srcs; + std::transform(migraphx_kernels().begin(), + migraphx_kernels().end(), + std::back_inserter(srcs), + [](auto&& p) { + auto&& name = p.first; + auto&& c = p.second; + auto path = fs::path{"migraphx"} / "kernels" / name; + return src_file{path, c}; + }); + srcs.push_back(src_file{fs::path{"main.cpp"}, + std::make_pair(content.data(), content.data() + content.size())}); + auto args_hpp = + generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs); + srcs.push_back(src_file{fs::path{"args.hpp"}, + std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())}); + options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global); + options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local); + options.params += " " + join_strings(compiler_warnings(), " "); + options.params += " -ftemplate-backtrace-limit=0"; + options.params += " -Werror"; + auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name()); + if(cos.size() != 1) + MIGRAPHX_THROW("No code object"); + return code_object_op{value::binary{cos.front()}, + options.kernel_name, + options.global, + options.local, + options.inputs, + options.output}; +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c1ec401484391ccc8f5e037b0f61fdbb9b64ccae --- /dev/null +++ b/src/targets/gpu/compile_ops.cpp @@ -0,0 +1,81 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL); + +struct precompile_op +{ + operation op = op::identity{}; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.op, "op")); + } + + std::string name() const { return "gpu::precompile_op"; } + + shape compute_shape(std::vector inputs, const std::vector& mods) const + { + inputs.pop_back(); + return op.compute_shape(inputs, mods); + } + + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +MIGRAPHX_REGISTER_OP(precompile_op); + +struct compiled_result +{ + compiler_replace replace; + instruction_ref ins; +}; + +template +void par_compile(std::size_t n, F f) +{ + if(n == 0) + return; + par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f); +} + +void compile_ops::apply(module& m) const +{ + std::vector> compiles; + + for(auto ins : iterator_for(m)) + { + if(ins->name() != "gpu::precompile_op") + continue; + operation preop = any_cast(ins->get_operator()).op; + compiles.emplace_back([=]() -> compiled_result { + return {compile(*ctx, ins, preop), ins}; + }); + } + std::vector results(compiles.size()); + par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); }); + for(const auto& cr : results) + { + cr.replace(m, cr.ins); + } +} + +} // namespace gpu + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/compiler.cpp b/src/targets/gpu/compiler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4d0a03d677a52d65ed94e703c77fa4486814b25d --- /dev/null +++ b/src/targets/gpu/compiler.cpp @@ -0,0 +1,39 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +auto& compiler_map() +{ + static std::unordered_map m; // NOLINT + return m; +} + +auto& compiler_op_map() +{ + static std::unordered_map m; // NOLINT + return m; +} + +void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop) +{ + compiler_map()[name] = std::move(c); + compiler_op_map()[name] = std::move(cop); +} + +bool has_compiler_for(const std::string& name) { return compiler_map().count(name) > 0; } +compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) +{ + return compiler_map().at(op.name())(ctx, ins, op); +} +operation +compile_op(const std::string& name, context& ctx, const std::vector& inputs, const value& v) +{ + return compiler_op_map().at(name)(ctx, inputs, v); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/concat.cpp b/src/targets/gpu/concat.cpp index a2a2276365a9e66ea1258fe3ec8228332d1b8e26..11da025b2d353d64f6d3cb5701e0d083222616d3 100644 --- a/src/targets/gpu/concat.cpp +++ b/src/targets/gpu/concat.cpp @@ -9,7 +9,7 @@ namespace gpu { shape hip_concat::compute_shape(std::vector inputs) const { inputs.pop_back(); - return op.compute_shape(inputs); + return op.normalize_compute_shape(inputs); } argument hip_concat::compute(context& ctx, diff --git a/src/targets/gpu/contiguous.cpp b/src/targets/gpu/contiguous.cpp deleted file mode 100644 index 72dfe9cf9333084d839280372a57dd129c27ce7f..0000000000000000000000000000000000000000 --- a/src/targets/gpu/contiguous.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include -#include -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { -namespace gpu { - -shape miopen_contiguous::compute_shape(const std::vector& inputs) const -{ - check_shapes{inputs, *this}.has(2); - return op.compute_shape({inputs.at(0)}); -} -argument miopen_contiguous::compute(context& ctx, - shape output_shape, - const std::vector& args) const -{ - assert(output_shape == args[1].get_shape()); - assert(output_shape.standard()); - (void)output_shape; - device::contiguous(ctx.get_stream().get(), args.at(1), args.at(0)); - return args.at(1); -} - -} // namespace gpu -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx diff --git a/src/targets/gpu/convert.cpp b/src/targets/gpu/convert.cpp index 3787d2dcde941c5f8d5c5d35a1bceec41888931c..28f8081a2a6b3d279bebe1ed022f3893fdc5ef09 100644 --- a/src/targets/gpu/convert.cpp +++ b/src/targets/gpu/convert.cpp @@ -9,7 +9,7 @@ namespace gpu { shape hip_convert::compute_shape(std::vector inputs) const { inputs.pop_back(); - check_shapes{inputs}.packed(); + check_shapes{inputs, *this}.packed(); return op.compute_shape(inputs); } diff --git a/src/targets/gpu/convolution.cpp b/src/targets/gpu/convolution.cpp index d0143053afc609a02feda00d6af8e77b3dd19c2c..d9fe8860552a72727f12b7b69cea2bab3e483f29 100644 --- a/src/targets/gpu/convolution.cpp +++ b/src/targets/gpu/convolution.cpp @@ -9,44 +9,60 @@ namespace gpu { shape miopen_convolution::compute_shape(const std::vector& inputs) const { check_shapes{inputs, *this}.has(4).standard(); - return op.compute_shape({inputs.at(0), inputs.at(1)}); + std::vector conv_inputs(inputs.begin(), inputs.begin() + 2); + check_shapes{conv_inputs, *this}.max_ndims(5); + return op.normalize_compute_shape(conv_inputs); } + +inline shape reshape_if_1d(const shape& input) +{ + shape new_shape{input}; + auto dims = new_shape.lens(); + + if(dims.size() == 3) + { + std::vector new_dims = dims; + new_dims.insert(new_dims.begin() + 2, 1); + new_shape = shape{input.type(), new_dims}; + } + return new_shape; +} + argument miopen_convolution::compute(context& ctx, const shape& output_shape, const std::vector& args) const { - auto x_desc = make_tensor(args[0].get_shape()); - auto w_desc = make_tensor(args[1].get_shape()); - auto y_desc = make_tensor(output_shape); - - float alpha = 1; - float beta = 0; - auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(), - &alpha, - x_desc.get(), - args[0].implicit(), - w_desc.get(), - args[1].implicit(), - cd.get(), - algo, - &beta, - y_desc.get(), - args[3].implicit(), - args[2].implicit(), - args[2].get_shape().bytes()); + auto x_desc = make_tensor(reshape_if_1d(args[0].get_shape())); + auto w_desc = make_tensor(reshape_if_1d(args[1].get_shape())); + auto y_desc = make_tensor(reshape_if_1d(output_shape)); + + if(solution_id == 0) + MIGRAPHX_THROW("MIOpen Convolution: invalid solution ID"); + + auto status = miopenConvolutionForwardImmediate(ctx.get_stream().get_miopen(), + w_desc.get(), + args[1].implicit(), + x_desc.get(), + args[0].implicit(), + cd.get(), + y_desc.get(), + args[3].implicit(), + args[2].implicit(), + args[2].get_shape().bytes(), + solution_id); + if(status != miopenStatusSuccess) - MIGRAPHX_THROW("Running convolution failed"); + MIGRAPHX_THROW("MIOpen Convolution: running convolution failed"); return args[3]; } -shape miopen_convolution::compile(context& ctx, - const shape& output_shape, - std::vector inputs) +shape miopen_convolution::find(context& ctx, const shape& output_shape, std::vector inputs) { shape workspace_shape{}; - auto x_desc = make_tensor(inputs[0]); - auto w_desc = make_tensor(inputs[1]); - auto y_desc = make_tensor(output_shape); + + auto x_desc = make_tensor(reshape_if_1d(inputs[0])); + auto w_desc = make_tensor(reshape_if_1d(inputs[1])); + auto y_desc = make_tensor(reshape_if_1d(output_shape)); std::size_t workspace_size = 0; miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(), @@ -79,9 +95,35 @@ shape miopen_convolution::compile(context& ctx, workspace_size, false); if(status != miopenStatusSuccess) - MIGRAPHX_THROW("Find convolution failed"); - handle = ctx.get_stream().get_miopen(); - algo = perf.fwd_algo; + MIGRAPHX_THROW("MIOpen Convolution: find convolution failed"); + algo = perf.fwd_algo; + + size_t solution_count; + + status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(), + w_desc.get(), + x_desc.get(), + cd.get(), + y_desc.get(), + &solution_count); + if(status != miopenStatusSuccess) + MIGRAPHX_THROW("MIOpen Convolution: get solution count failed"); + + std::vector solutions(solution_count); + + status = miopenConvolutionForwardGetSolution(ctx.get_stream().get_miopen(), + w_desc.get(), + x_desc.get(), + cd.get(), + y_desc.get(), + solution_count, + &solution_count, + solutions.data()); + if(status != miopenStatusSuccess) + MIGRAPHX_THROW("MIOpen Convolution: get solution failed"); + + solution_id = solutions.front().solution_id; + return shape{shape::int8_type, {perf.memory}}; } @@ -89,13 +131,29 @@ void miopen_convolution::finalize(context& ctx, const shape& output_shape, std::vector inputs) { - if(handle == ctx.get_stream().get_miopen()) - return; - // Check that workspace hasn't changed - auto size = inputs.at(2).bytes(); - auto ws = compile(ctx, output_shape, std::move(inputs)); - if(ws.bytes() > size) - MIGRAPHX_THROW("Workspace has changed during finalization."); + if(cd == nullptr) + cd = make_conv(op); + if(solution_id == 0) + { + // Check that workspace hasn't changed + auto size = inputs.at(2).bytes(); + auto ws = find(ctx, output_shape, inputs); + if(ws.bytes() > size) + MIGRAPHX_THROW("MIOpen Convolution: workspace has changed during finalization."); + } + + auto x_desc = make_tensor(reshape_if_1d(inputs[0])); + auto w_desc = make_tensor(reshape_if_1d(inputs[1])); + auto y_desc = make_tensor(reshape_if_1d(output_shape)); + + auto status = miopenConvolutionForwardCompileSolution(ctx.get_stream().get_miopen(), + w_desc.get(), + x_desc.get(), + cd.get(), + y_desc.get(), + solution_id); + if(status != miopenStatusSuccess) + MIGRAPHX_THROW("MIOpen Convolution: compile solution failed"); } } // namespace gpu diff --git a/src/targets/gpu/deconvolution.cpp b/src/targets/gpu/deconvolution.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3671fe5d13abe842c9f1fe60c2d84f0bf515fd89 --- /dev/null +++ b/src/targets/gpu/deconvolution.cpp @@ -0,0 +1,120 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +shape miopen_deconvolution::compute_shape(const std::vector& inputs) const +{ + check_shapes{inputs, *this}.has(4).standard(); + std::vector conv_inputs(inputs.begin(), inputs.begin() + 2); + check_shapes{conv_inputs, *this}.max_ndims(5); + return op.compute_shape(conv_inputs); +} + +inline shape reshape_if_1d(const shape& input) +{ + shape new_shape{input}; + auto dims = new_shape.lens(); + + if(dims.size() == 3) + { + std::vector new_dims = dims; + new_dims.insert(new_dims.begin() + 2, 1); + new_shape = shape{input.type(), new_dims}; + } + return new_shape; +} + +argument miopen_deconvolution::compute(context& ctx, + const shape& output_shape, + const std::vector& args) const +{ + auto x_desc = make_tensor(reshape_if_1d(args[0].get_shape())); + auto w_desc = make_tensor(reshape_if_1d(args[1].get_shape())); + auto y_desc = make_tensor(reshape_if_1d(output_shape)); + + float alpha = 1; + float beta = 0; + auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(), + &alpha, + x_desc.get(), + args[0].implicit(), + w_desc.get(), + args[1].implicit(), + cd.get(), + algo, + &beta, + y_desc.get(), + args[3].implicit(), + args[2].implicit(), + args[2].get_shape().bytes()); + if(status != miopenStatusSuccess) + MIGRAPHX_THROW("Running deconvolution failed"); + return args[3]; +} + +shape miopen_deconvolution::compile(context& ctx, + const shape& output_shape, + std::vector inputs) +{ + shape workspace_shape{}; + auto x_desc = make_tensor(reshape_if_1d(inputs[0])); + auto w_desc = make_tensor(reshape_if_1d(inputs[1])); + auto y_desc = make_tensor(reshape_if_1d(output_shape)); + + std::size_t workspace_size = 0; + miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(), + w_desc.get(), + x_desc.get(), + cd.get(), + y_desc.get(), + &workspace_size); + workspace_shape = shape{shape::int8_type, {workspace_size}}; + + auto x = to_gpu(generate_argument(inputs[0])); + auto w = to_gpu(generate_argument(inputs[1])); + auto y = allocate_gpu(output_shape); + auto workspace = allocate_gpu(workspace_shape); + + int algo_count = 1; + miopenConvAlgoPerf_t perf; + auto status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(), + x_desc.get(), + x.implicit(), + w_desc.get(), + w.implicit(), + cd.get(), + y_desc.get(), + y.implicit(), + 1, + &algo_count, + &perf, + workspace.implicit(), + workspace_size, + false); + if(status != miopenStatusSuccess) + MIGRAPHX_THROW("Find deconvolution failed"); + handle = ctx.get_stream().get_miopen(); + algo = perf.fwd_algo; + return shape{shape::int8_type, {perf.memory}}; +} + +void miopen_deconvolution::finalize(context& ctx, + const shape& output_shape, + std::vector inputs) +{ + if(handle == ctx.get_stream().get_miopen()) + return; + // Check that workspace hasn't changed + auto size = inputs.at(2).bytes(); + auto ws = compile(ctx, output_shape, std::move(inputs)); + if(ws.bytes() > size) + MIGRAPHX_THROW("Workspace has changed during finalization."); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/acos.cpp b/src/targets/gpu/device/acos.cpp index edaa972470c584c640e42a91d09f26a903525f35..1b49e054afc6f7856bc90760a2e49a1d490bdc65 100644 --- a/src/targets/gpu/device/acos.cpp +++ b/src/targets/gpu/device/acos.cpp @@ -9,7 +9,7 @@ namespace device { void acos(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::acos(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::acos(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/acosh.cpp b/src/targets/gpu/device/acosh.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c6be266bfd6930e04f7c4ec6501e09f31a9961f --- /dev/null +++ b/src/targets/gpu/device/acosh.cpp @@ -0,0 +1,18 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void acosh(hipStream_t stream, const argument& result, const argument& arg) +{ + nary(stream, result, arg)([](auto x) { return ::acosh(to_hip_type(x)); }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/add.cpp b/src/targets/gpu/device/add.cpp index 45fe71181376aad6ba729faeed6a3fc817f27e76..3848e0cc346b257860dd0786ea7dd674ee97278d 100644 --- a/src/targets/gpu/device/add.cpp +++ b/src/targets/gpu/device/add.cpp @@ -8,7 +8,7 @@ namespace device { void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) { - nary(stream, result, arg1, arg2)([](auto x, auto y) { return x + y; }); + nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x + y; }); } void add(hipStream_t stream, @@ -17,7 +17,8 @@ void add(hipStream_t stream, const argument& arg2, const argument& arg3) { - nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z) { return x + y + z; }); + nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z) + __device__ { return x + y + z; }); } } // namespace device diff --git a/src/targets/gpu/device/add_clip.cpp b/src/targets/gpu/device/add_clip.cpp index 5e3a8ea6e500a689071ea211350325e38201176a..970ee34ddf396a316d34f90bb99290c64e384695 100644 --- a/src/targets/gpu/device/add_clip.cpp +++ b/src/targets/gpu/device/add_clip.cpp @@ -10,12 +10,12 @@ void add_clip(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2, - const float max, - const float min) + const argument& min_arg, + const argument& max_arg) { - nary(stream, result, arg1, arg2)([max, min](auto x, auto y) { - return std::min(std::max(min, x + y), max); - }); + nary(stream, result, arg1, arg2, min_arg, max_arg)( + [](auto x, auto y, auto min, auto max) + __device__ { return ::min(::max(min, x + y), max); }); } void add_clip(hipStream_t stream, @@ -23,12 +23,13 @@ void add_clip(hipStream_t stream, const argument& arg1, const argument& arg2, const argument& arg3, - const float max, - const float min) + const argument& min_arg, + const argument& max_arg) { - nary(stream, result, arg1, arg2, arg3)([max, min](auto x, auto y, auto z) { - return std::min(std::max(min, x + y + z), max); - }); + nary(stream, result, arg1, arg2, arg3, min_arg, max_arg)( + [](auto x, auto y, auto z, auto min, auto max) __device__ { + return ::min(::max(min, x + y + z), max); + }); } } // namespace device diff --git a/src/targets/gpu/device/add_relu.cpp b/src/targets/gpu/device/add_relu.cpp index 13b97a1a1d45a8b3a8a3ab136224c12a7214cb76..e9d60ae742b09747a980cecd3765b4f83f4f9188 100644 --- a/src/targets/gpu/device/add_relu.cpp +++ b/src/targets/gpu/device/add_relu.cpp @@ -11,8 +11,8 @@ void add_relu(hipStream_t stream, const argument& arg1, const argument& arg2) { - nary(stream, result, arg1, arg2)( - [](auto x, auto y) { return std::max(0, x + y); }); + nary(stream, result, arg1, arg2)([](auto x, auto y) + __device__ { return ::max(0, x + y); }); } void add_relu(hipStream_t stream, @@ -22,7 +22,7 @@ void add_relu(hipStream_t stream, const argument& arg3) { nary(stream, result, arg1, arg2, arg3)( - [](auto x, auto y, auto z) { return std::max(0, x + y + z); }); + [](auto x, auto y, auto z) __device__ { return ::max(0, x + y + z); }); } } // namespace device diff --git a/src/targets/gpu/device/add_sigmoid.cpp b/src/targets/gpu/device/add_sigmoid.cpp index 7dd7743761f9fae2642b5d952333f05287f047ce..180740ce1750caaf5f5445a2e4a23359f48a3e90 100644 --- a/src/targets/gpu/device/add_sigmoid.cpp +++ b/src/targets/gpu/device/add_sigmoid.cpp @@ -12,7 +12,7 @@ void add_sigmoid(hipStream_t stream, const argument& arg2) { nary(stream, result, arg1, arg2)( - [](auto x, auto y) { return 1.f / (1.f + ::exp(to_hip_type(-(x + y)))); }); + [](auto x, auto y) __device__ { return 1.f / (1.f + ::exp(to_hip_type(-(x + y)))); }); } void add_sigmoid(hipStream_t stream, @@ -21,8 +21,9 @@ void add_sigmoid(hipStream_t stream, const argument& arg2, const argument& arg3) { - nary(stream, result, arg1, arg2, arg3)( - [](auto x, auto y, auto z) { return 1.f / (1.f + ::exp(to_hip_type(-(x + y + z)))); }); + nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z) __device__ { + return 1.f / (1.f + ::exp(to_hip_type(-(x + y + z)))); + }); } } // namespace device diff --git a/src/targets/gpu/device/add_tanh.cpp b/src/targets/gpu/device/add_tanh.cpp index 47a79b303830d69382a16c5fb16725db6c4cfb6e..574badffa92ba3ef5a71f61a35f476c5cb687a47 100644 --- a/src/targets/gpu/device/add_tanh.cpp +++ b/src/targets/gpu/device/add_tanh.cpp @@ -11,7 +11,8 @@ void add_tanh(hipStream_t stream, const argument& arg1, const argument& arg2) { - nary(stream, result, arg1, arg2)([](auto x, auto y) { return ::tanh(to_hip_type(x + y)); }); + nary(stream, result, arg1, arg2)([](auto x, auto y) + __device__ { return ::tanh(to_hip_type(x + y)); }); } void add_tanh(hipStream_t stream, @@ -21,7 +22,7 @@ void add_tanh(hipStream_t stream, const argument& arg3) { nary(stream, result, arg1, arg2, arg3)( - [](auto x, auto y, auto z) { return ::tanh(to_hip_type(x + y + z)); }); + [](auto x, auto y, auto z) __device__ { return ::tanh(to_hip_type(x + y + z)); }); } } // namespace device diff --git a/src/targets/gpu/device/asin.cpp b/src/targets/gpu/device/asin.cpp index 2986f11526afa7f02b5c37f0f1ac814438c12a7b..cfc863e5834690b9d219d007350665f25557843a 100644 --- a/src/targets/gpu/device/asin.cpp +++ b/src/targets/gpu/device/asin.cpp @@ -9,7 +9,7 @@ namespace device { void asin(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::asin(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::asin(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/asinh.cpp b/src/targets/gpu/device/asinh.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0ddb9446f1e61dd0df91a63188fe6d887e0851ca --- /dev/null +++ b/src/targets/gpu/device/asinh.cpp @@ -0,0 +1,18 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void asinh(hipStream_t stream, const argument& result, const argument& arg) +{ + nary(stream, result, arg)([](auto x) { return ::asinh(to_hip_type(x)); }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/atan.cpp b/src/targets/gpu/device/atan.cpp index c0443a79629d12c77c70c2358ee316905385e0a3..f1be1bf2417f8139cafe75a82e2e1321b112e49c 100644 --- a/src/targets/gpu/device/atan.cpp +++ b/src/targets/gpu/device/atan.cpp @@ -9,7 +9,7 @@ namespace device { void atan(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::atan(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::atan(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/atanh.cpp b/src/targets/gpu/device/atanh.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6df334cf844cc949a525ba8e9ccac48bec228d27 --- /dev/null +++ b/src/targets/gpu/device/atanh.cpp @@ -0,0 +1,18 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void atanh(hipStream_t stream, const argument& result, const argument& arg) +{ + nary(stream, result, arg)([](auto x) { return ::atanh(to_hip_type(x)); }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/ceil.cpp b/src/targets/gpu/device/ceil.cpp index 66449d0abe5c5e9b98743a892c743206f7beda43..1714715d25df21d2cdfad74d57b4df7cfb2732b9 100644 --- a/src/targets/gpu/device/ceil.cpp +++ b/src/targets/gpu/device/ceil.cpp @@ -9,7 +9,7 @@ namespace device { void ceil(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::ceil(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::ceil(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/clip.cpp b/src/targets/gpu/device/clip.cpp index e972fbfc0eae01eb8c60ba39893bc137fed5e472..d20638c94c7a392badc82116c4c6769fcc4cfa95 100644 --- a/src/targets/gpu/device/clip.cpp +++ b/src/targets/gpu/device/clip.cpp @@ -9,11 +9,13 @@ namespace device { void clip(hipStream_t stream, const argument& result, const argument& arg1, - const float max, - const float min) + const argument& min_val, + const argument& max_val) { - nary(stream, result, arg1)( - [max, min](auto x) { return std::min(std::max(min, x), max); }); + + nary(stream, result, arg1, min_val, max_val)([](auto x, auto min, auto max) __device__ { + return ::min(::max(min, x), max); + }); } } // namespace device diff --git a/src/targets/gpu/device/concat.cpp b/src/targets/gpu/device/concat.cpp index a3ecbde61a0c64aaf316f8431306ab07e1409d83..8e7db2c756c69b484caee0b67ae46e7a9ddff203 100644 --- a/src/targets/gpu/device/concat.cpp +++ b/src/targets/gpu/device/concat.cpp @@ -22,7 +22,7 @@ argument concat(hipStream_t stream, auto output_shape = shape{ arg.get_shape().type(), arg.get_shape().lens(), args.back().get_shape().strides()}; auto output = argument{output_shape, args.back().data() + byte_offset}; - contiguous(stream, std::move(output), arg); + contiguous(stream, output, arg); } return args.back(); } diff --git a/src/targets/gpu/device/contiguous.cpp b/src/targets/gpu/device/contiguous.cpp index 84b46f7564d6ba1ef68560886cac9f3bae01adca..9852c70faba2fdd45a206773c965e1dcd60ad693 100644 --- a/src/targets/gpu/device/contiguous.cpp +++ b/src/targets/gpu/device/contiguous.cpp @@ -7,9 +7,33 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { namespace device { -void contiguous(hipStream_t stream, argument result, argument arg) +void contiguous_nonstandard(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, std::move(result), std::move(arg))([](auto x) { return x; }); + shape s{result.get_shape().type(), result.get_shape().lens()}; + visit_all(result, arg)([&](auto output_v, auto input_v) { + hip_visit_views(output_v, input_v, s)([&](auto output, auto input, auto standard_shape) { + mi_gs_launch(stream, + standard_shape)([=](auto idx) __device__ { output[idx] = input[idx]; }); + }); + }); +} + +void contiguous_packed(hipStream_t stream, const argument& result, const argument& arg) +{ + index_int nelements = result.get_shape().elements(); + visit_all(result, arg)([&](auto output_v, auto input_v) { + const auto* input = device_cast(input_v.data()); + auto* output = device_cast(output_v.data()); + gs_launch(stream, nelements)([=](auto i) __device__ { output[i] = input[i]; }); + }); +} + +void contiguous(hipStream_t stream, const argument& result, const argument& arg) +{ + if(result.get_shape() == arg.get_shape() and result.get_shape().packed()) + contiguous_packed(stream, result, arg); + else + contiguous_nonstandard(stream, result, arg); } } // namespace device diff --git a/src/targets/gpu/device/convert.cpp b/src/targets/gpu/device/convert.cpp index afa6d085e1cc725cbeaff618722078398192b6bc..74db167744b876efed308326fd376ee3997f0150 100644 --- a/src/targets/gpu/device/convert.cpp +++ b/src/targets/gpu/device/convert.cpp @@ -12,8 +12,8 @@ void convert(hipStream_t stream, const argument& result, const argument& arg) arg.visit([&](auto input) { const auto* input_ptr = device_cast(input.data()); auto* output_ptr = device_cast(output.data()); - gs_launch(stream, - result.get_shape().elements())([=](auto i) { output_ptr[i] = input_ptr[i]; }); + gs_launch(stream, result.get_shape().elements())( + [=](auto i) __device__ { output_ptr[i] = input_ptr[i]; }); }); }); } diff --git a/src/targets/gpu/device/cos.cpp b/src/targets/gpu/device/cos.cpp index d0aeacfa94ae6c337825b8062aafa3735882a452..68372c9cbfdc501711ca0b270ed9399194755749 100644 --- a/src/targets/gpu/device/cos.cpp +++ b/src/targets/gpu/device/cos.cpp @@ -9,7 +9,7 @@ namespace device { void cos(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::cos(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::cos(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/cosh.cpp b/src/targets/gpu/device/cosh.cpp index aedc2f84d444372659aad394c1f0a8b441e39e09..d701f6a702239b774fe6aa175f9241b516699b19 100644 --- a/src/targets/gpu/device/cosh.cpp +++ b/src/targets/gpu/device/cosh.cpp @@ -9,7 +9,7 @@ namespace device { void cosh(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::cosh(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::cosh(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/div.cpp b/src/targets/gpu/device/div.cpp index cc02bb33386ef00b76702b0c672c8f3656efdaf3..4be5c59197a22ac6a06d87622a76af5e2a2d6358 100644 --- a/src/targets/gpu/device/div.cpp +++ b/src/targets/gpu/device/div.cpp @@ -8,7 +8,7 @@ namespace device { void div(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) { - nary(stream, result, arg1, arg2)([](auto x, auto y) { return x / y; }); + nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x / y; }); } } // namespace device diff --git a/src/targets/gpu/device/equal.cpp b/src/targets/gpu/device/equal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..81e27af86aa31a508dc8b50ccea60a46dbbb2be1 --- /dev/null +++ b/src/targets/gpu/device/equal.cpp @@ -0,0 +1,26 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +template +__device__ bool equal(T x, T y) +{ + auto eps = std::numeric_limits::epsilon(); + auto diff = x - y; + return (diff <= eps) and (diff >= -eps); +} + +void equal(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) +{ + nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return equal(x, y); }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/erf.cpp b/src/targets/gpu/device/erf.cpp index 0ce69ee2dd7c2784e6b021b26833fdf26182d2eb..c424948390980d7f439d7670f40eee36e87764f7 100644 --- a/src/targets/gpu/device/erf.cpp +++ b/src/targets/gpu/device/erf.cpp @@ -9,7 +9,7 @@ namespace device { void erf(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::erf(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::erf(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/exp.cpp b/src/targets/gpu/device/exp.cpp index c02d7e79969fd015b19d4374836eda6590cd40eb..2a67ba8fb6d2898ff15d8af4d8c59dc32e4a1aaf 100644 --- a/src/targets/gpu/device/exp.cpp +++ b/src/targets/gpu/device/exp.cpp @@ -9,7 +9,7 @@ namespace device { void exp(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::exp(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::exp(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/fill.cpp b/src/targets/gpu/device/fill.cpp new file mode 100644 index 0000000000000000000000000000000000000000..febae466fd66a53eee6d869379e6dff0f0e2994c --- /dev/null +++ b/src/targets/gpu/device/fill.cpp @@ -0,0 +1,17 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void fill(hipStream_t stream, const argument& result, unsigned long val) +{ + nary(stream, result)([=]() __device__ { return val; }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/floor.cpp b/src/targets/gpu/device/floor.cpp index 499c45bb5c3f53b4ab3c3543cf964841497310b0..6854ecb257546f76f2e9fa94cccaff5da98d8926 100644 --- a/src/targets/gpu/device/floor.cpp +++ b/src/targets/gpu/device/floor.cpp @@ -9,7 +9,7 @@ namespace device { void floor(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::floor(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::floor(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/gather.cpp b/src/targets/gpu/device/gather.cpp index ed4464a77d0018bb3540adc14432d9364fb8f5d1..3687f504d14a1ae22e0a3fadf3d67af05e13b372 100644 --- a/src/targets/gpu/device/gather.cpp +++ b/src/targets/gpu/device/gather.cpp @@ -10,13 +10,12 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { namespace device { -argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int axis) +argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int64_t axis) { - auto axis_index = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis; - auto& input_shape = arg1.get_shape(); - auto lens = input_shape.lens(); - auto axis_dim_size = lens[axis_index]; - lens[axis_index] = arg2.get_shape().elements(); + const auto& input_shape = arg1.get_shape(); + auto lens = input_shape.lens(); + auto axis_dim_size = lens[axis]; + lens[axis] = arg2.get_shape().elements(); shape out_comp_shape{result.get_shape().type(), lens}; std::size_t nelements = result.get_shape().elements(); @@ -25,12 +24,12 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg arg2.visit([&](auto indices) { const auto* indices_ptr = device_cast(indices.data()); auto* output_ptr = device_cast(output.data()); - gs_launch(stream, nelements, 256)([=](auto i) { - auto idx = out_comp.multi(i); - auto in_index = indices_ptr[idx[axis_index]]; - in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; - idx[axis_index] = in_index; - output_ptr[i] = input[idx]; + gs_launch(stream, nelements, 256)([=](auto i) __device__ { + auto idx = out_comp.multi(i); + auto in_index = indices_ptr[idx[axis]]; + in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; + idx[axis] = in_index; + output_ptr[i] = input[idx]; }); }); }); diff --git a/src/targets/gpu/device/gelu.cpp b/src/targets/gpu/device/gelu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8ca2cfd4a772ccb5bbb79e7378253b5a6ee2aca --- /dev/null +++ b/src/targets/gpu/device/gelu.cpp @@ -0,0 +1,60 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +// x * 0.5 * (1.0 + erf(x / sqrt(2.0))) +template +auto gelu_fn(T x) __device__ +{ + return x * 0.5 * (1 + ::erf(x * M_SQRT1_2)); +} + +// 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * pow(x, 3)))) +template +auto gelu_fn_new(T x) __device__ +{ + return 0.5 * x * (1 + tanh(sqrt(M_2_PI) * (x + 0.044715 * x * x * x))); +} + +void gelu(hipStream_t stream, const argument& result, const argument& arg) +{ + nary(stream, result, arg)([](auto x) __device__ { return gelu_fn(to_hip_type(x)); }); +} + +void gelu_new(hipStream_t stream, const argument& result, const argument& arg) +{ + nary(stream, result, arg)([](auto x) __device__ { return gelu_fn_new(to_hip_type(x)); }); +} + +void add_gelu(hipStream_t stream, + const argument& result, + const argument& arg1, + const argument& arg2) +{ + nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { + auto sum = to_hip_type(x + y); + return gelu_fn(sum); + }); +} + +void add_gelu_new(hipStream_t stream, + const argument& result, + const argument& arg1, + const argument& arg2) +{ + nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { + auto sum = to_hip_type(x + y); + return gelu_fn(sum); + }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/greater.cpp b/src/targets/gpu/device/greater.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8fcaf86aa5ffb9d90eb935f4e2cf26525e8d7ded --- /dev/null +++ b/src/targets/gpu/device/greater.cpp @@ -0,0 +1,18 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void greater(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) +{ + nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x > y; }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/float_equal.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/float_equal.hpp new file mode 100644 index 0000000000000000000000000000000000000000..86e804b87e7046f56488c410c224d07ac703be60 --- /dev/null +++ b/src/targets/gpu/device/include/migraphx/gpu/device/float_equal.hpp @@ -0,0 +1,41 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_DEVICE_FLOAT_EQUAL_HPP +#define MIGRAPHX_GUARD_RTGLIB_GPU_DEVICE_FLOAT_EQUAL_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +template +using common_type = typename std::common_type::type; + +template {})> +__device__ bool float_equal_device(T x, T y) +{ + return std::isfinite(x) and std::isfinite(y) and + std::nextafter(x, std::numeric_limits::lowest()) <= y and + std::nextafter(x, std::numeric_limits::max()) >= y; +} + +template {})> +__device__ bool float_equal_device(T x, T y) +{ + return x == y; +} + +template +__device__ bool float_equal(T x, U y) +{ + return float_equal_device>(x, y); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp index 826b615603df06da9a71d1c13dd2aa4f2b7c2a7d..188b19f3c79ba822831b5eb901b8d3989e272b2c 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp @@ -44,7 +44,7 @@ struct index template __global__ void launcher(F f) { - index idx{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; + index idx{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT f(idx); } @@ -56,6 +56,7 @@ inline auto launch(hipStream_t stream, index_int global, index_int local) using f_type = decltype(f); dim3 nblocks(global / local); dim3 nthreads(local); + // cppcheck-suppress UseDeviceLaunch hipLaunchKernelGGL((launcher), nblocks, nthreads, 0, stream, f); }; } @@ -74,15 +75,20 @@ MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype( inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024) { - index_int groups = (n + local - 1) / local; - index_int nglobal = std::min(256, groups) * local; + index_int groups = (n + local - 1) / local; + // max possible number of blocks is set to 1B (1,073,741,824) + index_int nglobal = std::min(1073741824, groups) * local; return [=](auto f) { - launch(stream, nglobal, local)( - [=](auto idx) { idx.global_stride(n, [&](auto i) { gs_invoke(f, i, idx); }); }); + launch(stream, nglobal, local)([=](auto idx) __device__ { + idx.global_stride(n, [&](auto i) { gs_invoke(f, i, idx); }); + }); }; } +#ifdef MIGRAPHX_USE_CLANG_TIDY +#define MIGRAPHX_DEVICE_SHARED +#else // Workaround hcc's broken tile_static macro #ifdef tile_static #undef tile_static @@ -90,6 +96,7 @@ inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024) #else #define MIGRAPHX_DEVICE_SHARED __shared__ #endif +#endif } // namespace device } // namespace gpu diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/multi_index.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/multi_index.hpp index 2819287bf8c24f24786b81fc8ea45b6693ae7d51..1b06d7da6ca8daae8891ac6c4a3d0d5f6cbbdcc4 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/multi_index.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/multi_index.hpp @@ -31,7 +31,7 @@ struct multi_index }; template -auto deduce_for_stride(ForStride fs) -> decltype(fs(id{})); +__device__ __host__ auto deduce_for_stride(ForStride fs) -> decltype(fs(id{})); MIGRAPHX_DEVICE_CONSTEXPR multi_index<1> make_multi_index(index_int i, index_int n) { @@ -57,9 +57,10 @@ inline auto mi_nglobal(const hip_shape& s, index_int nlocal) { assert(s.standard); assert(s.elements() > 0); - index_int n = s.elements(); - index_int groups = (n + nlocal - 1) / nlocal; - index_int nglobal = std::min(128, groups) * nlocal; + index_int n = s.elements(); + index_int groups = (n + nlocal - 1) / nlocal; + // max possible number of blocks is set to 1B (1,073,741,824) + index_int nglobal = std::min(1073741824, groups) * nlocal; assert(groups > 0); assert(nglobal > 0); @@ -95,7 +96,7 @@ inline auto mi_launch(hipStream_t stream, const hip_shape& global, index_int auto nglobal = global.index(nglobal_multi); return [=](auto f) { - launch(stream, nglobal, nlocal)([=](auto idx) { + launch(stream, nglobal, nlocal)([=](auto idx) __device__ { auto midx = make_multi_index(global, idx.global, nglobal_multi); f(idx, midx.for_stride(global.lens)); }); diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp index f8e9740f678695fcbdcf9b2b2279080f21fc3aeb..6d55f1572da11e49175d2b0182851355f77078c6 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp @@ -36,7 +36,8 @@ auto nary_nonstandard_nonpacked_impl(hipStream_t stream, F f, argument result, A MIGRAPHX_TRACE_NARY_FUNCTION shape s{result.get_shape().type(), result.get_shape().lens()}; hip_visit_all(s, result, args...)([&](auto standard_shape, auto output, auto... inputs) { - mi_gs_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); }); + mi_gs_launch(stream, + standard_shape)([=](auto idx) __device__ { output[idx] = f(inputs[idx]...); }); }); } @@ -45,7 +46,7 @@ inline auto create_broadcast_index(index_int len, index_int stride) auto next_stride = stride * len; auto e_next_stride = encode_divisor(next_stride); auto e_stride = encode_divisor(stride); - return [=](auto i) { + return [=](auto i) __device__ { // ( i % next_stride) / stride return fast_div(i, e_stride) - len * fast_div(i, e_next_stride); }; @@ -61,11 +62,11 @@ auto nary_nonstandard_packed_impl(hipStream_t stream, auto arg_shape = make_array(args...).front().get_shape(); auto perm = find_permutation(arg_shape); auto s = reorder_shape(arg_shape, perm); - hip_visit_all(s, - result.reshape(reorder_shape(result.get_shape(), perm)), - args.reshape(s)...)([&](auto standard_shape, auto output, auto... inputs) { - mi_gs_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); }); - }); + hip_visit_all(s, result.reshape(reorder_shape(result.get_shape(), perm)), args.reshape(s)...)( + [&](auto standard_shape, auto output, auto... inputs) { + mi_gs_launch(stream, standard_shape)( + [=](auto idx) __device__ { output[idx] = f(inputs[idx]...); }); + }); } template @@ -93,7 +94,6 @@ void nary_broadcast_vec_impl( using type = typename decltype(output)::value_type; const index_int nelements = output.size() / vec_size; launch(stream, nglobal, nlocal)([=](auto idx) __device__ { - MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size]; // Load bias into LDS for(size_t i = idx.local; i < bdim_vec_len; i += nlocal) @@ -185,7 +185,6 @@ void nary_double_broadcast_vec_impl( using type = typename decltype(output)::value_type; const index_int nelements = output.size() / vec_size; launch(stream, nglobal, nlocal)([=](auto idx) __device__ { - MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size]; // Load bias into LDS for(size_t i = idx.local; i < bdim_vec_len; i += nlocal) @@ -274,7 +273,7 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments. const index_int vec_size = 4; auto data = pack_vec<4>(device_cast(inputs.data())...); auto* outp = as_vec<4>(device_cast(output.data())); - gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) { + gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) __device__ { vec out = outp[i]; data( [&](auto... xs) { @@ -295,7 +294,7 @@ void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... a MIGRAPHX_TRACE_NARY_FUNCTION index_int nelements = result.get_shape().elements(); hip_pointer_visit_all(result, args...)([&](auto output, auto... inputs) { - gs_launch(stream, nelements)([=](auto i) { output[i] = f(inputs[i]...); }); + gs_launch(stream, nelements)([=](auto i) __device__ { output[i] = f(inputs[i]...); }); }); } @@ -353,7 +352,8 @@ bool broadcastable(bool& divisible_by_4, auto b_len = result.get_shape().lens()[b_idx]; auto b_stride = result.get_shape().strides()[b_idx]; assert(bshape.lens()[b_idx] == b_len); - if(b_len <= max_size and std::none_of(std::next(b_it), strides.end(), not_zero)) + if(b_len <= max_size and std::none_of(std::next(b_it), strides.end(), not_zero) and + is_divisor_encodable(b_stride * b_len)) { divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp index bf335f5127de4e69855e3001f02b8325f9a322bd..bbef1578d1a86bd26c93500ecda1da200f77cd86 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp @@ -5,84 +5,26 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { namespace device { -struct sum -{ - template - MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const - { - return x + y; - } -}; - -struct id -{ - template - MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const - { - return x; - } -}; - -struct mean -{ - size_t item_num = 1; - template - MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const - { - return static_cast(x / item_num); - } -}; - -struct max -{ - template - MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const - { - return x > y ? x : y; - } -}; - -struct min -{ - template - MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const - { - return x < y ? x : y; - } -}; - -struct lowest -{ - template - operator T() const - { - return device_cast(std::numeric_limits>::lowest()); - } -}; - -struct highest -{ - template - operator T() const - { - return device_cast(std::numeric_limits>::max()); - } -}; - #ifdef MIGRAPHX_NO_DPP -template -__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) +template {})> +__device__ auto block_reduce(index idx, Op op, T init, ForStride fs, F f) { - using type = decltype(f(idx.local)); + using type = decltype(f(deduce_for_stride(fs))); MIGRAPHX_DEVICE_SHARED type buffer[N]; type x = init; - idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); + fs([&](auto i) { x = op(x, f(i)); }); buffer[idx.local] = x; __syncthreads(); @@ -131,7 +73,11 @@ __device__ T dpp_mov(T& x) input.data = x; for(index_int i = 0; i < n; i++) { +#if defined(__HCC__) output.reg[i] = __llvm_amdgcn_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl); +#else + output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl); +#endif } return output.data; } @@ -148,10 +94,12 @@ __device__ void dpp_reduce(T& in, Op op) in = op(in, out); out = dpp_mov(in); in = op(in, out); +#if __AMDGCN_WAVEFRONT_SIZE == 64 out = dpp_mov(in); in = op(in, out); out = dpp_mov(in); in = op(in, out); +#endif } __device__ inline void dpp_reduce(float& x, sum) @@ -168,9 +116,11 @@ __device__ inline void dpp_reduce(float& x, sum) "s_nop 1\n" "v_add_f32 %0 %0 %0 row_shr:8 bank_mask:0xc\n" "s_nop 1\n" +#if __AMDGCN_WAVEFRONT_SIZE == 64 "v_add_f32 %0 %0 %0 row_bcast:15 row_mask:0xa\n" "s_nop 1\n" "v_add_f32 %0 %0 %0 row_bcast:31 row_mask:0xc\n" +#endif "s_nop 1\n" : "=v"(x) : "0"(x)); @@ -185,27 +135,33 @@ template {})> __device__ auto block_reduce(index idx, Op op, T init, ForStride fs, F f) { - using type = decltype(f(deduce_for_stride(fs))); - MIGRAPHX_DEVICE_SHARED type buffer[N / 64]; + +#if __AMDGCN_WAVEFRONT_SIZE == 32 + constexpr index_int nthreads = 16; +#else + constexpr index_int nthreads = 64; +#endif + using type = decltype(f(deduce_for_stride(fs))); + MIGRAPHX_DEVICE_SHARED type buffer[N / nthreads]; type x = init; fs([&](auto i) { x = op(x, f(i)); }); dpp_reduce(x, op); - const auto ldsidx = idx.local / 64; - if((idx.local % 64) == 63) + const auto ldsidx = idx.local / nthreads; + if((idx.local % nthreads) == nthreads - 1) { buffer[ldsidx] = x; } __syncthreads(); type y = init; - for(index_int i = 0; i < idx.nlocal() / 64; i++) + for(index_int i = 0; i < idx.nlocal() / nthreads; i++) { y = op(y, buffer[i]); } return y; } - +#endif template __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) { @@ -216,8 +172,6 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) return block_reduce( idx, op, init, midx.for_stride(fs), [&](auto mi) __device__ { return f(mi[0]); }); } - -#endif constexpr index_int compute_block_size(index_int n, index_int max_block_size) { size_t block_size = 64; @@ -226,6 +180,23 @@ constexpr index_int compute_block_size(index_int n, index_int max_block_size) return block_size; } +inline std::vector get_reduce_lens(const std::vector& input_lens, + const std::vector& output_lens) +{ + std::vector reduce_lens; + std::transform(output_lens.begin(), + output_lens.end(), + input_lens.begin(), + std::back_inserter(reduce_lens), + [](auto x, auto y) -> index_int { + if(x == y) + return 1; + else + return y; + }); + return reduce_lens; +} + template void reduce_multi_impl(hipStream_t stream, const argument& result, @@ -293,29 +264,19 @@ void reduce(hipStream_t stream, { auto&& output_shape = result.get_shape(); auto&& input_shape = arg.get_shape(); - assert(output_shape.lens().size() == input_shape.lens().size()); + auto input_lens = input_shape.lens(); + auto output_lens = output_shape.lens(); + assert(output_lens.size() == input_lens.size()); if(input_shape.standard() and output_shape.standard() and - output_shape.lens().back() != input_shape.lens().back() and - std::equal(output_shape.lens().begin(), - std::prev(output_shape.lens().end()), - input_shape.lens().begin())) + output_lens.back() != input_lens.back() and + std::equal(output_lens.begin(), std::prev(output_lens.end()), input_lens.begin())) { reduce_standard_impl( - stream, result, arg, op, init, read_input, read_output, input_shape.lens().back()); + stream, result, arg, op, init, read_input, read_output, input_lens.back()); } else { - std::vector reduce_lens; - std::transform(output_shape.lens().begin(), - output_shape.lens().end(), - input_shape.lens().begin(), - std::back_inserter(reduce_lens), - [](auto x, auto y) -> index_int { - if(x == y) - return 1; - else - return y; - }); + std::vector reduce_lens = get_reduce_lens(input_lens, output_lens); shape reduce_slice{output_shape.type(), reduce_lens}; reduce_multi_impl(stream, result, arg, op, init, read_input, read_output, reduce_slice); } diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/reduce_ops.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/reduce_ops.hpp new file mode 100755 index 0000000000000000000000000000000000000000..8e5a0cde1cbda568cd5fcbf16a792618195e715f --- /dev/null +++ b/src/targets/gpu/device/include/migraphx/gpu/device/reduce_ops.hpp @@ -0,0 +1,88 @@ +#ifndef MIGRAPHX_GUARD_DEVICE_REDUCE_OPS_HPP +#define MIGRAPHX_GUARD_DEVICE_REDUCE_OPS_HPP + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +struct sum +{ + template + MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const + { + return x + y; + } +}; + +struct product +{ + template + MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const + { + return x * y; + } +}; + +struct id +{ + template + MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const + { + return x; + } +}; + +struct mean +{ + size_t item_num = 1; + template + MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const + { + return x / static_cast(item_num); + } +}; + +struct max +{ + template + MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const + { + return (x > y) ? x : y; + } +}; + +struct min +{ + template + MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const + { + return (x < y) ? x : y; + } +}; + +struct lowest +{ + template + __device__ __host__ operator T() const + { + return device_cast(std::numeric_limits>::lowest()); + } +}; + +struct highest +{ + template + __device__ __host__ operator T() const + { + return device_cast(std::numeric_limits>::max()); + } +}; + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_DEVICE_REDUCE_OPS_HPP diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3b3b404d9bcc9f110f73790a3b60378ac1a573ff --- /dev/null +++ b/src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp @@ -0,0 +1,66 @@ +#ifndef MIGRAPHX_GUARD_DEVICE_SCAN_HPP +#define MIGRAPHX_GUARD_DEVICE_SCAN_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +template {})> +__device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, Output output) +{ + using type = decltype(input(deduce_for_stride(fs))); + MIGRAPHX_DEVICE_SHARED type buffer[N]; + type x = init; + fs([&](auto i) { + if(idx.local == 0) + buffer[idx.local] = op(input(i), x); + else + buffer[idx.local] = input(i); + __syncthreads(); + for(index_int s = 1; s < idx.nlocal(); s *= 2) + { + if(idx.local + s < idx.nlocal()) + { + buffer[idx.local + s] = op(buffer[idx.local], buffer[idx.local + s]); + } + __syncthreads(); + } + x = buffer[idx.nlocal() - 1]; + output(i, buffer[idx.local]); + }); +} + +template +__device__ void block_scan(index idx, Op op, T init, index_int n, Input input, Output output) +{ + block_scan( + idx, + op, + init, + [&](auto f) -> decltype(f(index_int{})) { return idx.local_stride(n, f); }, + input, + output); +} + +template +constexpr auto reverse_scan(index_int n, F f) +{ + return [=](auto i, auto&&... xs) { return f(n - i - 1, xs...); }; +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_DEVICE_SCAN_HPP diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp index 03d81fdc148052c091c8a59be849d403cdb944e7..6e73168336d1c4f5daf407fcd4f68d129f47f5a0 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp @@ -103,19 +103,19 @@ host_type* host_cast(T* x) } template -device_type device_cast(const T& x) +__device__ __host__ device_type device_cast(const T& x) { return reinterpret_cast&>(x); } template -device_type* device_cast(T* x) +__device__ __host__ device_type* device_cast(T* x) { return reinterpret_cast*>(x); } template -tensor_view> device_cast(tensor_view x) +__device__ __host__ tensor_view> device_cast(tensor_view x) { return {x.get_shape(), reinterpret_cast*>(x.data())}; } @@ -129,6 +129,21 @@ __device__ __host__ T to_hip_type(T x) // Hip doens't support __fp16 inline __device__ __host__ float to_hip_type(gpu_half x) { return x; } +#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ + template \ + struct trait : std::trait \ + { \ + }; \ + \ + template <> \ + struct trait : std::true_type \ + { \ + }; + +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16) + } // namespace device } // namespace gpu } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp index 1ee52fbd17a928d303476f1236fe0a48cfd488a1..4c82bd6c186c32aa4c5dfd191dd462c19b514dc5 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp @@ -14,32 +14,27 @@ constexpr void visit_tensor_size(index_int n, F f) { switch(n) { - case 1: - { + case 1: { f(std::integral_constant{}); break; } - case 2: - { + case 2: { f(std::integral_constant{}); break; } - case 3: - { + case 3: { f(std::integral_constant{}); break; } - case 4: - { + case 4: { f(std::integral_constant{}); break; } - case 5: - { + case 5: { f(std::integral_constant{}); break; } - default: throw std::runtime_error("Unknown tensor size"); + default: throw std::runtime_error("Tensor dims " + std::to_string(n) + " out of range"); } } @@ -51,6 +46,50 @@ auto get_shape(const T& x) -> decltype(x.get_shape()) return x.get_shape(); } +template +struct is_hip_type : std::false_type +{ +}; + +template <> +struct is_hip_type : std::true_type +{ +}; +template <> +struct is_hip_type : std::true_type +{ +}; +template <> +struct is_hip_type : std::true_type +{ +}; +template <> +struct is_hip_type : std::true_type +{ +}; +template <> +struct is_hip_type : std::true_type +{ +}; + +template {})> +void hip_visitor_invoke(T as, V&& v) +{ + v(as); +} + +template {})> +void hip_visitor_invoke(T, V&&) +{ + MIGRAPHX_THROW(std::string("Unsupported data type on GPU: ") + __PRETTY_FUNCTION__); +} + +template +auto hip_visitor(V v) +{ + return [=](auto as) { hip_visitor_invoke(as, v); }; +} + template void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) { @@ -62,8 +101,9 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) static_cast(get_shape(xs).lens().size())...}; if(!std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); })) MIGRAPHX_THROW("Ranks must be the same"); - visit_tensor_size(s.lens().size(), - [&](auto ndim) { s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); }); + visit_tensor_size(s.lens().size(), [&](auto ndim) { + s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); })); + }); } template @@ -136,7 +176,13 @@ template auto hip_vec_visit_all(T&& x, Ts&&... xs) { return [&](auto f) { - hip_visit_all_impl(get_shape(x), + auto sx = get_shape(x); + auto lens = sx.lens(); + assert(lens.back() % N == 0); + assert(sx.strides().back() == 1); + lens.back() /= N; + shape vec_sx{sx.type(), lens}; + hip_visit_all_impl(vec_sx, make_hip_convert([](auto* p) { return as_vec(device_cast(p)); }), f, x, diff --git a/src/targets/gpu/device/int8_gemm_pack.cpp b/src/targets/gpu/device/int8_gemm_pack.cpp index 090d9d7f05325c2ae1b129ab0330cceae5645e1b..b9a182986619666a473075414053b82bd62782fc 100644 --- a/src/targets/gpu/device/int8_gemm_pack.cpp +++ b/src/targets/gpu/device/int8_gemm_pack.cpp @@ -24,7 +24,7 @@ void int8_gemm_pack_a(hipStream_t stream, const argument& result, const argument auto* in_ptr = device_cast(input.data()); visit_tensor_size(out_lens.size(), [&](auto out_dim) { hip_tensor_descriptor desc(comp_shape); - gs_launch(stream, nelements, 256)([=](auto ii) { + gs_launch(stream, nelements, 256)([=](auto ii) __device__ { const size_t nb = 4; auto idx = desc.multi(ii); std::size_t i_m = idx[dim_1]; @@ -55,7 +55,7 @@ void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument auto* in_ptr = device_cast(input.data()); visit_tensor_size(out_lens.size(), [&](auto out_dim) { hip_tensor_descriptor desc(comp_shape); - gs_launch(stream, nelements, 256)([=](auto ii) { + gs_launch(stream, nelements, 256)([=](auto ii) __device__ { const size_t nb = 4; auto idx = desc.multi(ii); std::size_t i_n = idx[dim_1]; diff --git a/src/targets/gpu/device/layernorm.cpp b/src/targets/gpu/device/layernorm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3c7f9d0a19e0cd569aa61ffdcadc90f0382e881f --- /dev/null +++ b/src/targets/gpu/device/layernorm.cpp @@ -0,0 +1,222 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +#ifndef MIGRAPHX_WORKAROUND_NAVI_DPP_SYNC +#if __AMDGCN_WAVEFRONT_SIZE == 32 +#define MIGRAPHX_WORKAROUND_NAVI_DPP_SYNC 1 +#else +#define MIGRAPHX_WORKAROUND_NAVI_DPP_SYNC 0 +#endif +#endif + +template +struct vector_type +{ +}; + +template +struct vector_type> +{ + using type = T; +}; + +template +using vector_type_t = typename vector_type::type; + +template +struct vector_size : std::integral_constant +{ +}; + +template +struct vector_size> : std::integral_constant +{ +}; + +template +__device__ auto vec_transform(T x, F f) +{ + return f(x); +} + +template +__device__ auto vec_transform(vec x, F f) +{ + vec y = x; + // cppcheck-suppress useStlAlgorithm + for(index_int k = 0; k < N; k++) + y[k] = f(x[k]); + return y; +} + +template +__device__ auto vec_reduce(T x, U, Op) +{ + return x; +} + +template +__device__ auto vec_reduce(vec x, U init, Op op) +{ + T r = init; + for(index_int k = 0; k < N; k++) + r = op(r, x[k]); + return r; +} + +template +__device__ auto auto_block_reduce(index idx, Op op, T init, index_int n, F f) +{ + auto r = block_reduce(idx, op, init, n, f); + return vec_reduce(r, 0, op); +} + +template +__device__ void layernorm(index_int i, + index idx, + std::size_t block_size_div, + index_int relements, + Input input, + Output output) +{ + using value_type = decltype(input(idx.local)); + const auto relements_v = relements / vector_size{}; + const auto out_idx = fast_div(i, block_size_div); + const auto base_idx = out_idx * relements_v; + const auto input_idx = base_idx + idx.local; + const bool in_range = idx.local < relements_v; + + auto mean = [&](auto z) { + auto m = auto_block_reduce( + idx, sum{}, value_type(0), relements_v, [=](auto) { return z; }) / + value_type(relements); +#if MIGRAPHX_WORKAROUND_NAVI_DPP_SYNC + __builtin_amdgcn_s_barrier(); +#endif + return m; + }; + + // m = x - mean(x) + value_type x = in_range ? input(input_idx) : 0; + value_type m = x - mean(x); + + // mean(m ^ 2) + 1e-12 + value_type r = mean(m * m) + value_type(1e-12); + + // m * rsqrt(mean(m ^ 2) + 1e-12) + if(in_range) + output(input_idx, m * vec_transform(r, &rsqrt)); +} + +// m = x - mean(x) +// m / sqrt(mean(m ^ 2) + 1e-12) + +template +void layernorm_vec_impl(hipStream_t stream, + index_int nelements, + index_int relements, + Input in, + Output out, + const argument& result, + const Arguments&... args) +{ + hip_vec_visit_all(result, args...)([&](auto output, auto... inputs) { + const auto relements_v = relements / N; + const std::size_t max_block_size = 256; + const std::size_t block_size = compute_block_size(relements_v, max_block_size); + const std::size_t block_size_div = encode_divisor(block_size); + assert(relements_v <= block_size); + + gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { + layernorm( + i, + idx, + block_size_div, + relements, + [&](auto input_idx) { return in(inputs.data()[input_idx]...); }, + [&](auto input_idx, auto x) { + out(x, output.data()[input_idx], inputs.data()[input_idx]...); + }); + }); + }); +} + +template +void layernorm_impl(hipStream_t stream, + index_int nelements, + index_int relements, + Input in, + Output out, + const argument& result, + const Arguments&... args) +{ + hip_visit_all(result, args...)([&](auto output, auto... inputs) { + const std::size_t max_block_size = 256; + const std::size_t block_size = compute_block_size(relements, max_block_size); + const std::size_t block_size_div = encode_divisor(block_size); + assert(relements <= block_size); + + gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { + layernorm( + i, + idx, + block_size_div, + relements, + [&](auto input_idx) { return in(inputs.data()[input_idx]...); }, + [&](auto input_idx, auto x) { + out(x, output.data()[input_idx], inputs.data()[input_idx]...); + }); + }); + }); +} + +template +auto layernorm_fusion(hipStream_t stream, + const argument& result, + const argument& arg1, + const Arguments&... args) +{ + return [=](auto input, auto output) { + auto relements = arg1.get_shape().lens().back(); + auto nelements = result.get_shape().elements() / relements; + auto output_shape = result.get_shape(); + auto reduce_output_lens(output_shape.lens()); + reduce_output_lens.back() = 1; + + if((relements % 4) == 0) + layernorm_vec_impl<4>( + stream, nelements, relements, input, output, result, arg1, args...); + else if(relements < 256) + layernorm_impl(stream, nelements, relements, input, output, result, arg1, args...); + else + MIGRAPHX_THROW("No kernel for layernorm"); + }; +} + +void triadd_layernorm(hipStream_t stream, + const argument& result, + const argument& arg1, + const argument& arg2, + const argument& arg3) +{ + layernorm_fusion(stream, result, arg1, arg2, arg3)( + [](auto x, auto y, auto z) { return x + y + z; }, [](auto x, auto& y, auto...) { y = x; }); +} + +void layernorm(hipStream_t stream, const argument& result, const argument& arg1) +{ + layernorm_fusion(stream, result, arg1)([](auto x) { return x; }, + [](auto x, auto& y, auto) { y = x; }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/less.cpp b/src/targets/gpu/device/less.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2b16a89bea34211f9b1d2b76f43a96165f970895 --- /dev/null +++ b/src/targets/gpu/device/less.cpp @@ -0,0 +1,18 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void less(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) +{ + nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x < y; }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/log.cpp b/src/targets/gpu/device/log.cpp index 8732f17894fb3039f99e8d2cd4a76c9f8e80fbf0..cc65d66a1187b90c1586fae49cbbf83a83d11a41 100644 --- a/src/targets/gpu/device/log.cpp +++ b/src/targets/gpu/device/log.cpp @@ -9,7 +9,7 @@ namespace device { void log(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::log(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::log(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/logical_and.cpp b/src/targets/gpu/device/logical_and.cpp new file mode 100644 index 0000000000000000000000000000000000000000..243d6b237338d75928316a28cbadae22e92dca64 --- /dev/null +++ b/src/targets/gpu/device/logical_and.cpp @@ -0,0 +1,22 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void logical_and(hipStream_t stream, + const argument& result, + const argument& arg1, + const argument& arg2) +{ + nary(stream, result, arg1, arg2)( + [](auto x, auto y) __device__ { return static_cast(x) and static_cast(y); }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/logical_or.cpp b/src/targets/gpu/device/logical_or.cpp new file mode 100644 index 0000000000000000000000000000000000000000..af58c71703f1dcb547acf45c3b98a8c39cf82dfd --- /dev/null +++ b/src/targets/gpu/device/logical_or.cpp @@ -0,0 +1,22 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void logical_or(hipStream_t stream, + const argument& result, + const argument& arg1, + const argument& arg2) +{ + nary(stream, result, arg1, arg2)( + [](auto x, auto y) __device__ { return static_cast(x) or static_cast(y); }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/logical_xor.cpp b/src/targets/gpu/device/logical_xor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a251135d62dae36010c56e69284d6f083aac9c06 --- /dev/null +++ b/src/targets/gpu/device/logical_xor.cpp @@ -0,0 +1,22 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void logical_xor(hipStream_t stream, + const argument& result, + const argument& arg1, + const argument& arg2) +{ + nary(stream, result, arg1, arg2)( + [](auto x, auto y) __device__ { return static_cast(x) xor static_cast(y); }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/logsoftmax.cpp b/src/targets/gpu/device/logsoftmax.cpp index 596d3e6089acc8b9f3866d70688e0bf6424ad716..1e6e85a8fc8834230f4f9d4a62b761d6fadc15fd 100644 --- a/src/targets/gpu/device/logsoftmax.cpp +++ b/src/targets/gpu/device/logsoftmax.cpp @@ -11,11 +11,10 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { namespace device { -void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int axis) +void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis) { - auto lens = result.get_shape().lens(); - auto batch_lens = lens; - index_int batch_item_num = lens[axis]; + auto batch_lens = result.get_shape().lens(); + index_int batch_item_num = batch_lens[axis]; batch_lens[axis] = 1; migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; @@ -44,7 +43,7 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, auto log_batch_sum = ::log(to_hip_type(batch_sum)) + batch_max; - idx.local_stride(batch_item_num, [&](auto j) { + idx.local_stride(batch_item_num, [&](auto j) __device__ { data_idx[axis] = j; output[data_idx] = input[data_idx] - log_batch_sum; }); diff --git a/src/targets/gpu/device/max.cpp b/src/targets/gpu/device/max.cpp index 3a21cd165c14e54d8555848f54ed8f86e9dfbbbf..cd47e16cade1e99c439a82c517f1103dfbdd5608 100644 --- a/src/targets/gpu/device/max.cpp +++ b/src/targets/gpu/device/max.cpp @@ -10,7 +10,7 @@ namespace device { void max(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) { nary(stream, result, arg1, arg2)( - [](auto x, auto y) { return std::max(to_hip_type(x), to_hip_type(y)); }); + [](auto x, auto y) __device__ { return ::max(to_hip_type(x), to_hip_type(y)); }); } } // namespace device diff --git a/src/targets/gpu/device/min.cpp b/src/targets/gpu/device/min.cpp index 4fb00f452f3ec4368f2e9be87bdb69eb48ede66d..dee081ed8b4bdb3159fb1c69f9208d94cfa7676e 100644 --- a/src/targets/gpu/device/min.cpp +++ b/src/targets/gpu/device/min.cpp @@ -10,7 +10,7 @@ namespace device { void min(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) { nary(stream, result, arg1, arg2)( - [](auto x, auto y) { return std::min(to_hip_type(x), to_hip_type(y)); }); + [](auto x, auto y) __device__ { return ::min(to_hip_type(x), to_hip_type(y)); }); } } // namespace device diff --git a/src/targets/gpu/device/mul.cpp b/src/targets/gpu/device/mul.cpp index 66610022d84a80a0ef2e26e237130245bed0a0e2..38ea36d87743264020d3df5589399fab70e7e4d6 100644 --- a/src/targets/gpu/device/mul.cpp +++ b/src/targets/gpu/device/mul.cpp @@ -8,7 +8,7 @@ namespace device { void mul(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) { - nary(stream, result, arg1, arg2)([](auto x, auto y) { return x * y; }); + nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x * y; }); } void mul(hipStream_t stream, @@ -17,7 +17,8 @@ void mul(hipStream_t stream, const argument& arg2, const argument& arg3) { - nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z) { return x * y * z; }); + nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z) + __device__ { return x * y * z; }); } } // namespace device diff --git a/src/targets/gpu/device/mul_add.cpp b/src/targets/gpu/device/mul_add.cpp index d8deaefc3dddaff632c612b165826e544163276c..0f7acb688d73a16eb9922272467a35d8cfc1a6a7 100644 --- a/src/targets/gpu/device/mul_add.cpp +++ b/src/targets/gpu/device/mul_add.cpp @@ -12,7 +12,8 @@ void mul_add(hipStream_t stream, const argument& arg2, const argument& arg3) { - nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b) { return a * x + b; }); + nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b) + __device__ { return a * x + b; }); } } // namespace device diff --git a/src/targets/gpu/device/mul_add_relu.cpp b/src/targets/gpu/device/mul_add_relu.cpp index d673f8f07db88f7aa0b690a4de9be798ee4a478e..8edcf5575500d25564753921863901dfb8d3e9bc 100644 --- a/src/targets/gpu/device/mul_add_relu.cpp +++ b/src/targets/gpu/device/mul_add_relu.cpp @@ -13,7 +13,7 @@ void mul_add_relu(hipStream_t stream, const argument& arg3) { nary(stream, result, arg1, arg2, arg3)( - [](auto x, auto a, auto b) { return std::max(0, a * x + b); }); + [](auto x, auto a, auto b) __device__ { return ::max(0, a * x + b); }); } } // namespace device diff --git a/src/targets/gpu/device/multinomial.cpp b/src/targets/gpu/device/multinomial.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc4a2cd805315dd7bb1d6f5b48aac7797b821a68 --- /dev/null +++ b/src/targets/gpu/device/multinomial.cpp @@ -0,0 +1,66 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +template +constexpr Iterator upper_bound(Iterator first, Iterator last, const T& value) +{ + Iterator it; + typename std::iterator_traits::difference_type count; + typename std::iterator_traits::difference_type step; + count = std::distance(first, last); + + while(count > 0) + { + it = first; + step = count / 2; + std::advance(it, step); + if(!(value < *it)) + { + first = ++it; + count -= step + 1; + } + else + count = step; + } + return first; +} + +void multinomial(hipStream_t stream, + const argument& result, + const argument& arg0, + const argument& arg1) +{ + size_t batch_size = arg0.get_shape().lens().front(); + size_t class_size = arg0.get_shape().lens().back(); + size_t sample_size = result.get_shape().lens().back(); + + hip_visit_all(arg0, arg1)([&](auto cdf, auto dist) { + result.visit([&](auto out) { + hip_visit_views(out)([&](auto output) { + gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ { + auto idx = output.get_shape().multi(i); + auto cdf_begin = cdf.begin() + (idx.front() * class_size); + auto cdf_end = cdf_begin + class_size; + auto sample_iter = + upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end))); + output[i] = std::distance(cdf_begin, sample_iter); + }); + }); + }); + }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/nonzero.cpp b/src/targets/gpu/device/nonzero.cpp new file mode 100644 index 0000000000000000000000000000000000000000..60f72fc2fe7ade4a57b4a3501fcf25daaa8d38e1 --- /dev/null +++ b/src/targets/gpu/device/nonzero.cpp @@ -0,0 +1,54 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +argument nonzero(hipStream_t stream, const argument& result, const argument& arg_data) +{ + auto s = arg_data.get_shape(); + auto elem_num = s.elements(); + auto out_elem_num = result.get_shape().elements(); + + // call the prefix_sum function to do a prefix_sum to compute + // index in the output. Only 1 block can be used since we have + // only one prefix sum + const index_int block_size = 256; + hip_visit_all(arg_data, s)([&](auto input, auto si) { + const auto* in_ptr = device_cast(input.data()); + auto* ptr = result.cast(); + gs_launch(stream, block_size, block_size)([=](auto, auto idx) __device__ { + // fill all output to 0 first + idx.local_stride(out_elem_num, [&](auto j) { ptr[j] = 0; }); + + block_scan( + idx, + sum{}, + 0, + elem_num, + [&](auto j) { return (float_equal(in_ptr[j], 0)) ? 0 : 1; }, + [&](auto j, auto x) { + auto out_loc = x - 1; + if(float_equal(in_ptr[j], 0)) + return; + + auto index = si.multi(j); + for(size_t k = 0; k < index.size(); ++k) + { + ptr[k * elem_num + out_loc] = index[k]; + } + }); + }); + }); + + return result; +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/pad.cpp b/src/targets/gpu/device/pad.cpp index 8ed27424f0692337a46ff6bc1d8a1e96c3fd47af..0f66d86bd92703fd6a596c317e9b59a516b79f0e 100644 --- a/src/targets/gpu/device/pad.cpp +++ b/src/targets/gpu/device/pad.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -18,17 +19,13 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector hip_visit_all(result, arg1)([&](auto output, auto input) { using type = typename decltype(output)::value_type; using hip_index = typename decltype(output)::hip_index; - type device_val = value; - if(float_equal(value, std::numeric_limits::lowest())) - { - device_val = device_cast(std::numeric_limits::lowest()); - } - gs_launch(stream, - result.get_shape().elements())([=](auto i) { output.data()[i] = device_val; }); + type device_val = pad_clamp>(value); + gs_launch(stream, result.get_shape().elements())( + [=](auto i) __device__ { output.data()[i] = device_val; }); hip_index offsets; std::copy(pads.begin(), pads.begin() + offsets.size(), offsets.begin()); - gs_launch(stream, nelements)([=](auto i) { + gs_launch(stream, nelements)([=](auto i) __device__ { auto idx = input.get_shape().multi(i); for(std::size_t j = 0; j < offsets.size(); j++) { diff --git a/src/targets/gpu/device/pow.cpp b/src/targets/gpu/device/pow.cpp index 5578ee8ad7cceab63f0e71b33dd22c50b82c9836..e6a79d5aef79f8b0aa750ddcb08b8c4006aa0ea6 100644 --- a/src/targets/gpu/device/pow.cpp +++ b/src/targets/gpu/device/pow.cpp @@ -9,7 +9,7 @@ namespace device { void pow(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) { nary(stream, result, arg1, arg2)( - [](auto b, auto e) { return ::pow(to_hip_type(b), to_hip_type(e)); }); + [](auto b, auto e) __device__ { return ::pow(to_hip_type(b), to_hip_type(e)); }); } } // namespace device diff --git a/src/targets/gpu/device/prefix_scan_sum.cpp b/src/targets/gpu/device/prefix_scan_sum.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2f8d2c569f337dc00d3bfc1bf2f876db79c0c1a6 --- /dev/null +++ b/src/targets/gpu/device/prefix_scan_sum.cpp @@ -0,0 +1,120 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void prefix_scan_sum(hipStream_t stream, + const argument& result, + const argument& arg, + int32_t axis, + bool exclusive, + bool reverse) +{ + const index_int max_block_size = 256; + const index_int n = arg.get_shape().lens()[axis]; + auto rlens = result.get_shape().lens(); + rlens[axis] = 1; + + hip_visit_all(result, arg, result.get_shape().with_lens(rlens))( + [=](auto output, auto input, auto rshape) { + const index_int block_size = compute_block_size(rshape.elements(), max_block_size); + if(reverse and exclusive) + { + gs_launch(stream, rshape.elements() * block_size, block_size)( + [=](auto i, auto idx) __device__ { + const auto ridx = rshape.multi(i / block_size); + auto compute_idx = [&](auto j) { + auto k = ridx; + k[axis] = j; + return k; + }; + block_scan( + idx, + sum{}, + 0, + n, + reverse_scan(n, [&](auto j) { return input[compute_idx(j)]; }), + reverse_scan(n, [&](auto j, auto x) { + if(j == n - 1) + output[compute_idx(j)] = 0; + if(j > 0) + output[compute_idx(j - 1)] = x; + })); + }); + } + else if(reverse) + { + gs_launch(stream, rshape.elements() * block_size, block_size)( + [=](auto i, auto idx) __device__ { + const auto ridx = rshape.multi(i / block_size); + auto compute_idx = [&](auto j) { + auto k = ridx; + k[axis] = j; + return k; + }; + block_scan( + idx, + sum{}, + 0, + n, + reverse_scan(n, [&](auto j) { return input[compute_idx(j)]; }), + reverse_scan(n, [&](auto j, auto x) { output[compute_idx(j)] = x; })); + }); + } + else if(exclusive) + { + gs_launch(stream, rshape.elements() * block_size, block_size)( + [=](auto i, auto idx) __device__ { + const auto ridx = rshape.multi(i / block_size); + auto compute_idx = [&](auto j) { + auto k = ridx; + k[axis] = j; + return k; + }; + block_scan( + idx, + sum{}, + 0, + n, + [&](auto j) { return input[compute_idx(j)]; }, + [&](auto j, auto x) { + auto k = j + 1; + if(j == 0) + output[compute_idx(0)] = 0; + if(k < n) + output[compute_idx(k)] = x; + }); + }); + } + else + { + gs_launch(stream, rshape.elements() * block_size, block_size)( + [=](auto i, auto idx) __device__ { + const auto ridx = rshape.multi(i / block_size); + auto compute_idx = [&](auto j) { + auto k = ridx; + k[axis] = j; + return k; + }; + block_scan( + idx, + sum{}, + 0, + n, + [&](auto j) { return input[compute_idx(j)]; }, + [&](auto j, auto x) { output[compute_idx(j)] = x; }); + }); + } + }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/prelu.cpp b/src/targets/gpu/device/prelu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e4be841f6f4461294de6499a70f11148f9ca0a90 --- /dev/null +++ b/src/targets/gpu/device/prelu.cpp @@ -0,0 +1,18 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void prelu(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) +{ + nary(stream, result, arg1, arg2)([](auto x, auto slope) + __device__ { return ((x < 0) ? (x * slope) : x); }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/recip.cpp b/src/targets/gpu/device/recip.cpp new file mode 100644 index 0000000000000000000000000000000000000000..548a32da6e074b32c422d2eb0e1a0f3791582c5b --- /dev/null +++ b/src/targets/gpu/device/recip.cpp @@ -0,0 +1,18 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void recip(hipStream_t stream, const argument& result, const argument& arg) +{ + nary(stream, result, arg)([](auto x) __device__ { return 1 / x; }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/reduce_prod.cpp b/src/targets/gpu/device/reduce_prod.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c4fcf08593386b22c2b5ed8500e548990fc46ef8 --- /dev/null +++ b/src/targets/gpu/device/reduce_prod.cpp @@ -0,0 +1,18 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void reduce_prod(hipStream_t stream, const argument& result, const argument& arg) +{ + + reduce(stream, result, arg, product{}, 1, id{}, id{}); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/relu.cpp b/src/targets/gpu/device/relu.cpp index 2a438e600b6b9c208d44c1cb21c8eafdfe80b48a..c2db092f662e936ff28dbe790efa0347196362ca 100644 --- a/src/targets/gpu/device/relu.cpp +++ b/src/targets/gpu/device/relu.cpp @@ -8,7 +8,7 @@ namespace device { void relu(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return std::max(0, x); }); + nary(stream, result, arg)([](auto x) __device__ { return ::max(0, x); }); } } // namespace device diff --git a/src/targets/gpu/device/reverse.cpp b/src/targets/gpu/device/reverse.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d97f5a19ce3794e93551d82f3d28c589a3cd3d07 --- /dev/null +++ b/src/targets/gpu/device/reverse.cpp @@ -0,0 +1,43 @@ +#include "migraphx/gpu/device/visit.hpp" +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +argument +reverse(hipStream_t stream, argument result, argument arg1, const std::vector& axes) +{ + auto s = arg1.get_shape(); + // auto lens = s.lens(); + std::vector axis_len(axes.begin(), axes.end()); + shape sa{shape::float_type, axis_len}; + std::size_t nelements = s.elements(); + visit_all(result, arg1)([&](auto output1, auto input1) { + hip_visit_views(output1, input1, s)([&](auto output, auto input, auto hs) { + hip_visit_views(sa)([&](auto daxes) { + auto lens = hs.lens; + gs_launch(stream, nelements)([=](auto i) __device__ { + auto idx = hs.multi(i); + auto in_idx = idx; + for(auto axis : daxes.lens) + in_idx[axis] = lens[axis] - 1 - idx[axis]; + output[idx] = input[in_idx]; + }); + }); + }); + }); + + return result; +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/rnn_variable_seq_lens.cpp b/src/targets/gpu/device/rnn_variable_seq_lens.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8540fd3822da59f0b517c04bdbaf0681e6ac891b --- /dev/null +++ b/src/targets/gpu/device/rnn_variable_seq_lens.cpp @@ -0,0 +1,117 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void rnn_var_sl_shift_sequence(hipStream_t stream, + const argument& result, + const argument& arg_hs, + const argument& arg_sl) +{ + auto output_shape = result.get_shape(); + int64_t max_len = output_shape.lens()[0]; + visit_all(result, arg_hs)([&](auto output, auto input) { + const auto* in_data = device_cast(input.data()); + auto* out_data = device_cast(output.data()); + auto out_s = make_hip_shape<3>(output_shape); + arg_sl.visit([&](auto sl) { + const auto* sl_data = device_cast(sl.data()); + gs_launch(stream, output_shape.elements(), 256)([=](auto i) __device__ { + auto idx = out_s.multi(i); + auto t = idx[0]; + auto b = idx[1]; + auto l = sl_data[b]; + auto val = in_data[0]; + val = 0; + if(t >= max_len - l) + { + auto in_idx = idx; + in_idx[0] -= (max_len - l); + val = in_data[out_s.index(in_idx)]; + } + out_data[i] = val; + }); + }); + }); +} + +void rnn_var_sl_shift_output(hipStream_t stream, + const argument& result, + const argument& arg_hs, + const argument& arg_sl, + bool is_reverse) +{ + auto output_shape = result.get_shape(); + int64_t max_len = output_shape.lens()[0]; + visit_all(result, arg_hs)([&](auto output, auto input) { + const auto* in_data = device_cast(input.data()); + auto* out_data = device_cast(output.data()); + auto out_s = make_hip_shape<4>(output_shape); + arg_sl.visit([&](auto sl) { + const auto* sl_data = device_cast(sl.data()); + gs_launch(stream, output_shape.elements(), 256)([=](auto i) __device__ { + auto idx = out_s.multi(i); + auto t = idx[0]; + auto d = idx[1]; + auto b = idx[2]; + auto l = sl_data[b]; + auto val = in_data[0]; + val = 0; + if(t < l) + { + int offset = (d == 1 or is_reverse) ? 1 : 0; + auto in_idx = idx; + in_idx[0] += offset * (max_len - l); + val = in_data[out_s.index(in_idx)]; + } + out_data[i] = val; + }); + }); + }); +} + +void rnn_var_sl_last_output(hipStream_t stream, + const argument& result, + const argument& arg_hs, + const argument& arg_sl, + bool is_reverse) +{ + auto input_shape = arg_hs.get_shape(); + auto out_comp_lens = input_shape.lens(); + out_comp_lens[0] = 1; + shape out_comp_shape{input_shape.type(), out_comp_lens}; + + visit_all(result, arg_hs)([&](auto output, auto input) { + const auto* in_data = device_cast(input.data()); + auto* out_data = device_cast(output.data()); + arg_sl.visit([&](auto sl) { + const auto* sl_data = device_cast(sl.data()); + auto in_s = make_hip_shape<4>(input_shape); + auto out_s = make_hip_shape<4>(out_comp_shape); + gs_launch(stream, result.get_shape().elements(), 256)([=](auto i) __device__ { + auto idx = out_s.multi(i); + auto d = idx[1]; + auto b = idx[2]; + auto l = sl_data[b]; + if(is_reverse or d == 1) + { + idx[0] = 0; + } + else + { + idx[0] = l - 1; + } + out_data[i] = in_data[in_s.index(idx)]; + }); + }); + }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/round.cpp b/src/targets/gpu/device/round.cpp index 295bcf46f45b4e04cc523a767481e1f37c69c400..e742d0e82c5702a7d452da19a968cb5e4277d08c 100644 --- a/src/targets/gpu/device/round.cpp +++ b/src/targets/gpu/device/round.cpp @@ -9,7 +9,7 @@ namespace device { void round(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::round(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::round(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/scatter.cpp b/src/targets/gpu/device/scatter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eecadce7dab9a3e76bc66766f09b0eee1fe7dacf --- /dev/null +++ b/src/targets/gpu/device/scatter.cpp @@ -0,0 +1,42 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +argument scatter( + hipStream_t stream, argument result, argument arg0, argument arg1, argument arg2, int64_t axis) +{ + auto ds = arg0.get_shape(); + auto inds = arg1.get_shape(); + auto axis_dim_size = ds.lens()[axis]; + hip_visit_all(result, arg0, inds)([&](auto output, auto data, auto s1) { + auto* output_ptr = device_cast(output.data()); + const auto* data_ptr = device_cast(data.data()); + gs_launch(stream, ds.elements())([=](auto i) __device__ { output_ptr[i] = data_ptr[i]; }); + hip_visit_all(arg1, arg2)([&](auto indices, auto update) { + const auto* upd_ptr = device_cast(update.data()); + const auto* indices_ptr = device_cast(indices.data()); + gs_launch(stream, inds.elements())([=](auto i) __device__ { + auto out_idx = s1.multi(i); + auto index = indices_ptr[i]; + index = index < 0 ? index + axis_dim_size : index; + out_idx[axis] = index; + output[out_idx] = upd_ptr[i]; + }); + }); + }); + + return result; +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/sigmoid.cpp b/src/targets/gpu/device/sigmoid.cpp index b3296a4045017ca8dbe97f8b8cf8d3bdac50dd75..12744442b6d49012b06cf0af0c332c0090b91379 100644 --- a/src/targets/gpu/device/sigmoid.cpp +++ b/src/targets/gpu/device/sigmoid.cpp @@ -9,7 +9,8 @@ namespace device { void sigmoid(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return 1.f / (1.f + ::exp(to_hip_type(-x))); }); + nary(stream, result, arg)([](auto x) + __device__ { return 1.f / (1.f + ::exp(to_hip_type(-x))); }); } } // namespace device diff --git a/src/targets/gpu/device/sign.cpp b/src/targets/gpu/device/sign.cpp index d336f3c30eab50530c4e12b6c90f54b2cd792106..cc272572049ab0b80e6138494f36955b9568545b 100644 --- a/src/targets/gpu/device/sign.cpp +++ b/src/targets/gpu/device/sign.cpp @@ -9,7 +9,7 @@ namespace device { void sign(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return (x > 0 ? 1 : ((x < 0) ? -1 : 0)); }); + nary(stream, result, arg)([](auto x) __device__ { return (x > 0 ? 1 : ((x < 0) ? -1 : 0)); }); } } // namespace device diff --git a/src/targets/gpu/device/sin.cpp b/src/targets/gpu/device/sin.cpp index 17372ce142211ed3a3c1727215dbd64d2d4df37e..a9c4b5fa9ddbf1e8fc7e4e48ea5a89c49d342b2e 100644 --- a/src/targets/gpu/device/sin.cpp +++ b/src/targets/gpu/device/sin.cpp @@ -9,7 +9,7 @@ namespace device { void sin(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::sin(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::sin(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/sinh.cpp b/src/targets/gpu/device/sinh.cpp index 1e69e63502aa1d82e2b08e6814548c3c2d7002ae..b05bce5e18b6f675ffcf5eef7aa507e8bae11b0a 100644 --- a/src/targets/gpu/device/sinh.cpp +++ b/src/targets/gpu/device/sinh.cpp @@ -9,7 +9,7 @@ namespace device { void sinh(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::sinh(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::sinh(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/softmax.cpp b/src/targets/gpu/device/softmax.cpp index c0674e9410be17719266a94139f49702d3c3aef4..eab6f7b2ab927eb48f4050dc86ea95c33a3390d9 100644 --- a/src/targets/gpu/device/softmax.cpp +++ b/src/targets/gpu/device/softmax.cpp @@ -12,43 +12,66 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { namespace device { -void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis) +void softmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis) { - auto lens = result.get_shape().lens(); - auto batch_lens = lens; - index_int batch_item_num = lens[axis]; + auto batch_lens = result.get_shape().lens(); + index_int batch_item_num = batch_lens[axis]; batch_lens[axis] = 1; migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { - const index_int max_block_size = 256; + const index_int max_block_size = 128; const index_int block_size = compute_block_size(batch_item_num, max_block_size); - gs_launch(stream, - batch_shape.elements() * block_size, - block_size)([=](auto i, auto idx) __device__ { - auto data_idx = batch.multi(i / block_size); - using type = device_type>; - type init = lowest(); - - auto batch_max = block_reduce( - idx, max{}, init, batch_item_num, [&](auto j) __device__ { - data_idx[axis] = j; - return input[data_idx]; - }); + using type = device_type>; + type init = lowest(); + + if(axis == batch_lens.size() - 1) + { + gs_launch(stream, batch_shape.elements() * block_size, block_size)( + [=](auto i, auto idx) __device__ { + auto start_loc = i / block_size * batch_item_num; + auto batch_max = block_reduce( + idx, max{}, init, batch_item_num, [&](auto j) __device__ { + return input[start_loc + j]; + }); + + auto batch_sum = block_reduce( + idx, sum{}, 0, batch_item_num, [&](auto j) __device__ { + auto val = input[start_loc + j] - batch_max; + return ::exp(to_hip_type(val)); + }); - auto batch_sum = - block_reduce(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ { - data_idx[axis] = j; - auto val = input[data_idx] - batch_max; - return ::exp(to_hip_type(val)); + idx.local_stride(batch_item_num, [&](auto j) __device__ { + auto val = input[start_loc + j] - batch_max; + output[start_loc + j] = ::exp(to_hip_type(val)) / batch_sum; + }); }); + } + else + { + gs_launch(stream, batch_shape.elements() * block_size, block_size)( + [=](auto i, auto idx) __device__ { + auto data_idx = batch.multi(i / block_size); + auto batch_max = block_reduce( + idx, max{}, init, batch_item_num, [&](auto j) __device__ { + data_idx[axis] = j; + return input[data_idx]; + }); - idx.local_stride(batch_item_num, [&](auto j) { - data_idx[axis] = j; - auto val = input[data_idx] - batch_max; - output[data_idx] = ::exp(to_hip_type(val)) / batch_sum; - }); - }); + auto batch_sum = block_reduce( + idx, sum{}, 0, batch_item_num, [&](auto j) __device__ { + data_idx[axis] = j; + auto val = input[data_idx] - batch_max; + return ::exp(to_hip_type(val)); + }); + + idx.local_stride(batch_item_num, [&](auto j) __device__ { + data_idx[axis] = j; + auto val = input[data_idx] - batch_max; + output[data_idx] = ::exp(to_hip_type(val)) / batch_sum; + }); + }); + } }); } diff --git a/src/targets/gpu/device/sqdiff.cpp b/src/targets/gpu/device/sqdiff.cpp index 100dfa1af6badc5f64e50e43ccfe1bf19f53cff1..7a6a595a352caa8dc6510009b296ab221eaa68ed 100644 --- a/src/targets/gpu/device/sqdiff.cpp +++ b/src/targets/gpu/device/sqdiff.cpp @@ -8,7 +8,7 @@ namespace device { void sqdiff(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) { - nary(stream, result, arg1, arg2)([](auto x, auto y) { return (x - y) * (x - y); }); + nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return (x - y) * (x - y); }); } } // namespace device diff --git a/src/targets/gpu/device/sqrt.cpp b/src/targets/gpu/device/sqrt.cpp index c6f436090a13df1037fee9986e45b15c1cfbca77..b888f8ca2a1d6bd6f3df6ddf76589170bebbc8c4 100644 --- a/src/targets/gpu/device/sqrt.cpp +++ b/src/targets/gpu/device/sqrt.cpp @@ -9,7 +9,7 @@ namespace device { void sqrt(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::sqrt(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::sqrt(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/sub.cpp b/src/targets/gpu/device/sub.cpp index d7fa635ddba9da2435c45708d3e832935d661b57..3f2f2d26796147b8c8cbbabe100c942219503c3c 100644 --- a/src/targets/gpu/device/sub.cpp +++ b/src/targets/gpu/device/sub.cpp @@ -8,7 +8,7 @@ namespace device { void sub(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) { - nary(stream, result, arg1, arg2)([](auto x, auto y) { return x - y; }); + nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x - y; }); } } // namespace device diff --git a/src/targets/gpu/device/tan.cpp b/src/targets/gpu/device/tan.cpp index 6e9de40b11f6a171a32ae79eb10c3346ba114f07..bed9545567c8667c683f4f921056804ad5c5367b 100644 --- a/src/targets/gpu/device/tan.cpp +++ b/src/targets/gpu/device/tan.cpp @@ -9,7 +9,7 @@ namespace device { void tan(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::tan(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::tan(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/tanh.cpp b/src/targets/gpu/device/tanh.cpp index eb5be8adf322aff7043042db41fe7ff0a5e8264f..9550ae14a3e19892deaaed37854d4c95fce3a5bd 100644 --- a/src/targets/gpu/device/tanh.cpp +++ b/src/targets/gpu/device/tanh.cpp @@ -9,7 +9,7 @@ namespace device { void tanh(hipStream_t stream, const argument& result, const argument& arg) { - nary(stream, result, arg)([](auto x) { return ::tanh(to_hip_type(x)); }); + nary(stream, result, arg)([](auto x) __device__ { return ::tanh(to_hip_type(x)); }); } } // namespace device diff --git a/src/targets/gpu/device/topk.cpp b/src/targets/gpu/device/topk.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b30c7886e8eefa381a96d2c290c3727a85d59320 --- /dev/null +++ b/src/targets/gpu/device/topk.cpp @@ -0,0 +1,216 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +template +struct hip_heap_vector +{ + MIGRAPHX_DEVICE_CONSTEXPR hip_heap_vector(T* val, index_int n, Index v_idx, Compare comp) + : data(val), size(n), data_index(v_idx), compare(comp) + { + make_heap(size); + } + + MIGRAPHX_DEVICE_CONSTEXPR void try_push(const T val) + { + if(compare(val, data[data_index(0)])) + return; + + pop_heap(size - 1); + data[data_index(size - 1)] = val; + push_heap(size - 1); + } + + MIGRAPHX_DEVICE_CONSTEXPR void sort() { sort_heap(size); } + + private: + MIGRAPHX_DEVICE_CONSTEXPR inline static void swap(T& v1, T& v2) + { + T v = v1; + v1 = v2; + v2 = v; + } + + MIGRAPHX_DEVICE_CONSTEXPR inline void heapify_down(index_int n, index_int index) + { + while(index < n) + { + auto pre_index = index; + index_int l = 2 * index + 1; + index_int r = 2 * index + 2; + + if(l < n && compare(data[data_index(l)], data[data_index(index)])) + { + index = l; + } + + if(r < n && compare(data[data_index(r)], data[data_index(index)])) + { + index = r; + if(compare(data[data_index(l)], data[data_index(r)])) + { + index = l; + } + } + + if(index == pre_index) + { + break; + } + + swap(data[data_index(index)], data[data_index(pre_index)]); + } + } + + MIGRAPHX_DEVICE_CONSTEXPR inline void heapify_up(index_int index) + { + while(index > 0) + { + auto parent_idx = (index - 1) / 2; + + if(not compare(data[data_index(index)], data[data_index(parent_idx)])) + { + break; + } + + swap(data[data_index(index)], data[data_index(parent_idx)]); + index = parent_idx; + } + } + + MIGRAPHX_DEVICE_CONSTEXPR inline void make_heap(index_int n) + { + for(int j = n / 2 - 1; j >= 0; --j) + { + heapify_down(n, j); + } + } + + MIGRAPHX_DEVICE_CONSTEXPR inline void push_heap(index_int loc) { heapify_up(loc); } + + MIGRAPHX_DEVICE_CONSTEXPR inline void pop_heap(index_int loc) + { + swap(data[data_index(0)], data[data_index(loc)]); + heapify_down(loc, 0); + } + + MIGRAPHX_DEVICE_CONSTEXPR inline void sort_heap(index_int n) + { + for(int j = n - 1; j > 0; --j) + { + swap(data[data_index(0)], data[data_index(j)]); + heapify_down(j, 0); + } + } + + T* data = nullptr; + index_int size; + Index data_index; + Compare compare; +}; + +template +__device__ hip_heap_vector +make_heap(T* data, index_int n, Index idx, Compare compare) +{ + return {data, n, idx, compare}; +} + +template +std::vector topk(hipStream_t stream, + const argument& val_res, + const argument& ind_res, + const argument& arg, + int64_t k, + int64_t axis, + Compare compare) +{ + auto in_s = arg.get_shape(); + auto in_lens = in_s.lens(); + auto out_s = val_res.get_shape(); + auto axis_dim = in_s.lens()[axis]; + auto comp_lens = in_lens; + comp_lens[axis] = 1; + shape comp_s{in_s.type(), comp_lens}; + std::size_t elem_num = comp_s.elements(); + + hip_visit_all(val_res, arg, out_s, in_s, comp_s)( + [&](auto out_val, auto input, auto oss, auto iss, auto css) { + auto* data = device_cast(input.data()); + auto* out = device_cast(out_val.data()); + auto* const ind = ind_res.cast(); + gs_launch(stream, elem_num)([=](auto i) __device__ { + auto idx = css.multi(i); + + auto in_idx = [&](int ii) { + auto iidx = idx; + iidx[axis] = ii; + return iss.index(iidx); + }; + + auto out_idx = [&](int ii) { + auto iidx = idx; + iidx[axis] = ii; + return oss.index(iidx); + }; + + auto data_compare = [=](auto ii, auto jj) { + return compare(data[in_idx(ii)], data[in_idx(jj)]); + }; + + for(int j = 0; j < k; ++j) + { + ind[out_idx(j)] = j; + } + + auto hp = make_heap(ind, k, out_idx, data_compare); + for(int j = k; j < axis_dim; ++j) + { + hp.try_push(j); + } + hp.sort(); + + for(int j = 0; j < k; ++j) + { + out[out_idx(j)] = data[in_idx(ind[out_idx(j)])]; + } + }); + }); + + return {val_res, ind_res}; +} + +argument topk_largest(hipStream_t stream, + const argument& val_res, + const argument& ind_res, + const argument& arg, + int64_t k, + int64_t axis) +{ + return {topk(stream, val_res, ind_res, arg, k, axis, std::less<>{})}; +} + +argument topk_smallest(hipStream_t stream, + const argument& val_res, + const argument& ind_res, + const argument& arg, + int64_t k, + int64_t axis) +{ + return {topk(stream, val_res, ind_res, arg, k, axis, std::greater<>{})}; +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/unary_not.cpp b/src/targets/gpu/device/unary_not.cpp new file mode 100644 index 0000000000000000000000000000000000000000..97b3e34b75d4d09d8db0fc9d8f761edfa8c0c899 --- /dev/null +++ b/src/targets/gpu/device/unary_not.cpp @@ -0,0 +1,18 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void unary_not(hipStream_t stream, const argument& result, const argument& arg) +{ + nary(stream, result, arg)([](auto x) __device__ { return not x; }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device/where.cpp b/src/targets/gpu/device/where.cpp new file mode 100644 index 0000000000000000000000000000000000000000..befd2fbcd08362a6c312205e7fdf47430e7bfcaf --- /dev/null +++ b/src/targets/gpu/device/where.cpp @@ -0,0 +1,39 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +template +constexpr auto get_rank(const Shape&) +{ + return decltype(typename Shape::hip_index{}.size()){}; +} + +void where(hipStream_t stream, + const argument& result, + const argument& arg0, + const argument& arg1, + const argument& arg2) +{ + hip_visit_all(result, arg1, arg2)([&](auto output, auto x, auto y) { + hip_visit_all(arg0)([&](auto cond) { + if constexpr(get_rank(cond.get_shape()) == get_rank(output.get_shape())) + { + gs_launch(stream, arg1.get_shape().elements())([=](auto idx) __device__ { + auto i = output.get_shape().multi(idx); + output[i] = cond[i] ? x[i] : y[i]; + }); + } + }); + }); +} + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/device_name.cpp b/src/targets/gpu/device_name.cpp new file mode 100755 index 0000000000000000000000000000000000000000..bdb651ed0c7c0c0b39cdad21c11a59b64b8d02c5 --- /dev/null +++ b/src/targets/gpu/device_name.cpp @@ -0,0 +1,43 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +template +std::string get_arch_name(rank<0>, const HipDeviceProp& props) +{ + return "gfx" + std::to_string(props.gcnArch); +} + +template +auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(props.gcnArchName)) +{ + return std::string(props.gcnArchName); +} + +int get_device_id() +{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + MIGRAPHX_THROW("No device"); + return device; +} + +std::string get_device_name() +{ + hipDeviceProp_t props{}; + auto status = hipGetDeviceProperties(&props, get_device_id()); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to get device properties"); + return get_arch_name(rank<1>{}, props); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/driver/CMakeLists.txt b/src/targets/gpu/driver/CMakeLists.txt new file mode 100755 index 0000000000000000000000000000000000000000..4bf5cce544557014dfce191c299d1fb4b7fdfbeb --- /dev/null +++ b/src/targets/gpu/driver/CMakeLists.txt @@ -0,0 +1,7 @@ + +file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) +add_executable(gpu-driver + ${GPU_DRIVER_SRCS} +) +target_include_directories(gpu-driver PRIVATE include) +target_link_libraries(gpu-driver PRIVATE migraphx_gpu) diff --git a/src/targets/gpu/driver/action.cpp b/src/targets/gpu/driver/action.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ba5480fbeb0b6743213924b3f913298d1ee7c1a9 --- /dev/null +++ b/src/targets/gpu/driver/action.cpp @@ -0,0 +1,27 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace driver { + +auto& action_map() +{ + static std::unordered_map m; + return m; +} + +action_function get_action(const std::string& name) +{ + if(action_map().count(name) == 0) + MIGRAPHX_THROW("Missing action: " + name); + return action_map().at(name); +} + +void register_action(const std::string& name, const action_function& a) { action_map()[name] = a; } + +} // namespace driver +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/driver/compile_op.cpp b/src/targets/gpu/driver/compile_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..db5354823f384e469a5dec8cd5b6c15895c799b7 --- /dev/null +++ b/src/targets/gpu/driver/compile_op.cpp @@ -0,0 +1,26 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace driver { + +struct compile_op : action +{ + static void apply(const parser& p, const value& v) + { + context ctx; + auto inputs = p.parse_shapes(v.at("inputs")); + auto op = gpu::compile_op(v.at("name").to(), ctx, inputs, v); + double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); + std::cout << op << ": " << t << "ms" << std::endl; + } +}; + +} // namespace driver +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/driver/include/migraphx/gpu/driver/action.hpp b/src/targets/gpu/driver/include/migraphx/gpu/driver/action.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cd6abc601a6a41d5b805ecb14f4ad0e9f336ca75 --- /dev/null +++ b/src/targets/gpu/driver/include/migraphx/gpu/driver/action.hpp @@ -0,0 +1,37 @@ +#ifndef MIGRAPHX_GUARD_GPU_DRIVER_ACTION_HPP +#define MIGRAPHX_GUARD_GPU_DRIVER_ACTION_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace driver { + +using action_function = std::function; + +action_function get_action(const std::string& name); +void register_action(const std::string& name, const action_function& a); + +struct auto_register_action +{ + template + static void apply() + { + auto name = get_type_name(); + register_action(name.substr(name.rfind("::") + 2), + [](auto&&... xs) { T::apply(std::forward(xs)...); }); + } +}; + +template +using action = auto_register; + +} // namespace driver +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_DRIVER_ACTION_HPP diff --git a/src/targets/gpu/driver/include/migraphx/gpu/driver/parser.hpp b/src/targets/gpu/driver/include/migraphx/gpu/driver/parser.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3d6a43d9f45bc184f7a5c9a527df0454df53b372 --- /dev/null +++ b/src/targets/gpu/driver/include/migraphx/gpu/driver/parser.hpp @@ -0,0 +1,45 @@ +#ifndef MIGRAPHX_GUARD_GPU_DRIVER_PARSER_HPP +#define MIGRAPHX_GUARD_GPU_DRIVER_PARSER_HPP + +#include +#include + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace driver { + +[[noreturn]] void error(const std::string& msg); + +struct parser +{ + parser() = default; + + template + T get(const value& v, const std::string& key, const T& default_value) const + { + return v.get(key, settings.get(key, default_value)); + } + + shape parse_shape(const value& v) const; + + std::vector parse_shapes(const value& v) const; + + void load_settings(const value& v); + + static void process(const value& v); + + private: + value settings = value::object{}; +}; + +} // namespace driver +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_DRIVER_PARSER_HPP diff --git a/src/targets/gpu/driver/include/migraphx/gpu/driver/perf.hpp b/src/targets/gpu/driver/include/migraphx/gpu/driver/perf.hpp new file mode 100755 index 0000000000000000000000000000000000000000..004af8bd448561a17a4b359c5ab23401b68409c4 --- /dev/null +++ b/src/targets/gpu/driver/include/migraphx/gpu/driver/perf.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_GPU_DRIVER_PERF_HPP +#define MIGRAPHX_GUARD_GPU_DRIVER_PERF_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace driver { + +double time_op(context& ctx, operation op, const std::vector& inputs, int n = 100); + +} // namespace driver +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_DRIVER_PERF_HPP diff --git a/src/targets/gpu/driver/main.cpp b/src/targets/gpu/driver/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0696355c2687914ca833401e92483e431f5ceaeb --- /dev/null +++ b/src/targets/gpu/driver/main.cpp @@ -0,0 +1,21 @@ +#include +#include +#include +#include +#include + +using namespace migraphx; // NOLINT +using namespace migraphx::gpu; // NOLINT +using namespace migraphx::gpu::driver; // NOLINT + +int main(int argc, char const* argv[]) +{ + std::vector args(argv, argv + argc); + if(args.size() < 2) + { + std::cout << "Usage: gpu-driver " << std::endl; + std::abort(); + } + auto v = from_json_string(convert_to_json(read_string(args[1]))); + parser::process(v); +} diff --git a/src/targets/gpu/driver/parser.cpp b/src/targets/gpu/driver/parser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fec9c7b1138258c6f92bc9c96606501c1e20e977 --- /dev/null +++ b/src/targets/gpu/driver/parser.cpp @@ -0,0 +1,58 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace driver { + +[[noreturn]] void error(const std::string& msg) +{ + std::cout << msg << std::endl; + std::abort(); +} + +shape parser::parse_shape(const value& v) const +{ + auto lens = get(v, "lens", std::vector{}); + auto strides = get(v, "strides", std::vector{}); + auto type = shape::parse_type(get(v, "type", "float")); + if(strides.empty()) + return shape{type, lens}; + else + return shape{type, lens, strides}; +} + +std::vector parser::parse_shapes(const value& v) const +{ + std::vector result; + std::transform( + v.begin(), v.end(), std::back_inserter(result), [&](auto&& x) { return parse_shape(x); }); + return result; +} + +void parser::load_settings(const value& v) +{ + if(v.contains("settings")) + settings = v.at("settings"); +} + +void parser::process(const value& v) +{ + if(not v.is_object()) + error("Input is not an object"); + parser p{}; + p.load_settings(v); + for(auto&& pp : v) + { + if(pp.get_key() == "settings") + continue; + get_action(pp.get_key())(p, pp.without_key()); + } +} + +} // namespace driver +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/driver/perf.cpp b/src/targets/gpu/driver/perf.cpp new file mode 100755 index 0000000000000000000000000000000000000000..f4f498100b9d1a604b1f6c4085883baf00ca6cd4 --- /dev/null +++ b/src/targets/gpu/driver/perf.cpp @@ -0,0 +1,43 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace driver { + +std::vector generate_arguments(const std::vector& shapes, unsigned long seed = 0) +{ + std::vector args; + std::transform(shapes.begin(), shapes.end(), std::back_inserter(args), [&](auto& s) { + return to_gpu(generate_argument(s, seed++)); + }); + return args; +} + +using milliseconds = std::chrono::duration; +double time_op(context& ctx, operation op, const std::vector& inputs, int n) +{ + // TODO: Use std::ref + migraphx::context gctx = ctx; + auto output = op.compute_shape(inputs); + op.finalize(gctx, output, inputs); + auto args = generate_arguments(inputs); + auto run = [&] { + op.compute(gctx, output, args); + gctx.finish(); + }; + run(); + auto r = range(n); + double t = std::accumulate( + r.begin(), r.end(), double{0.0}, [&](auto x, auto) { return x + time(run); }); + return t / n; +} + +} // namespace driver +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/driver/run_op.cpp b/src/targets/gpu/driver/run_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9d523d34054ccd9eba9d19f098410b987f569411 --- /dev/null +++ b/src/targets/gpu/driver/run_op.cpp @@ -0,0 +1,31 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace driver { + +struct run_op : action +{ + static void apply(const parser& p, const value& v) + { + context ctx; + auto inputs = p.parse_shapes(v.at("inputs")); + auto name = v.at("name").to(); + if(not contains(name, "::")) + name = "gpu::" + name; + auto op = make_op(name); + if(v.contains("fields")) + op.from_value(v.at("fields")); + double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); + std::cout << op << ": " << t << "ms" << std::endl; + } +}; + +} // namespace driver +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/eliminate_workspace.cpp b/src/targets/gpu/eliminate_workspace.cpp index 665361a259d72c4766fb95c8475ec1c5bed7b831..9f09aec186e8945c3cf3d1d435350e355f1530c0 100644 --- a/src/targets/gpu/eliminate_workspace.cpp +++ b/src/targets/gpu/eliminate_workspace.cpp @@ -11,11 +11,11 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { -void eliminate_workspace::apply(program& p) const +void eliminate_workspace::apply(module& m) const { std::size_t n = 0; std::vector allocs; - for(auto ins : iterator_for(p)) + for(auto ins : iterator_for(m)) { if(ins->outputs().size() != 1) continue; @@ -30,11 +30,11 @@ void eliminate_workspace::apply(program& p) const } if(n > 0) { - auto ws = p.add_parameter("workspace", shape{shape::int8_type, {n}}); + auto ws = m.add_parameter("workspace", shape{shape::int8_type, {n}}); for(auto&& a : allocs) { - p.replace_instruction(a, ws); - p.remove_instruction(a); + m.replace_instruction(a, ws); + m.remove_instruction(a); } } } diff --git a/src/targets/gpu/elu.cpp b/src/targets/gpu/elu.cpp index 804577fdcd1f9f63550d4e28462044f137bcb45a..b205bdd5a11efe4b798396f491bea051a2744a75 100644 --- a/src/targets/gpu/elu.cpp +++ b/src/targets/gpu/elu.cpp @@ -31,6 +31,11 @@ argument miopen_elu::compute(context& ctx, return args[1]; } +void miopen_elu::finalize(context&, const shape&, const std::vector&) +{ + ad = make_elu(op.alpha); +} + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/fuse_ops.cpp b/src/targets/gpu/fuse_ops.cpp index c16d5dfa9e7c791238e0c9a82fcfdb996e62e9f4..f74a71ecc766b0d36274f4c13f2c9dc8cd484569 100644 --- a/src/targets/gpu/fuse_ops.cpp +++ b/src/targets/gpu/fuse_ops.cpp @@ -1,9 +1,17 @@ +#include +#include #include #include #include #include #include +#include #include +#include +#include +#include +#include +#include #include #include #include @@ -11,9 +19,15 @@ #include #include #include +#include +#include +#include #include +#include #include #include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -37,16 +51,22 @@ struct fusion return result; } + fusion() = default; + fusion(const shape& input) - // : fp(make_fusion_plan(input)) { + assert(input.standard()); auto t = make_tensor(input); fp = make_fusion_plan(t); + assert(fp); keep_alive(std::move(t)); } + bool empty() const { return fp == nullptr; } + op_t operator[](std::size_t i) const { + assert(fp); op_t result; auto status = miopenFusionPlanGetOp(fp.get(), i, &result); if(status != miopenStatusSuccess) @@ -54,10 +74,15 @@ struct fusion return result; } - auto get() const { return fp.get(); } + auto get() const + { + assert(fp); + return fp.get(); + } op_t create_bias(const shape& bias) { + assert(fp); op_t result; auto b = shape{bias.type(), {1, bias.lens().at(1), 1, 1}}; auto t = keep_alive(make_tensor(b)); @@ -69,6 +94,7 @@ struct fusion op_t create_relu() { + assert(fp); op_t result; auto status = miopenCreateOpActivationForward(fp.get(), &result, miopenActivationRELU); if(status != miopenStatusSuccess) @@ -78,6 +104,7 @@ struct fusion op_t create_conv(const op::convolution& op, const shape& weights) { + assert(fp); op_t result; auto cd = keep_alive(make_conv(op)); auto t = keep_alive(make_tensor(weights)); @@ -89,6 +116,7 @@ struct fusion shape get_workspace(context&) { + // assert(fp); // TODO: Use zero workspace for now std::size_t ws_size = 0; // int algo_count = 1; @@ -99,11 +127,11 @@ struct fusion return shape{shape::int8_type, {ws_size}}; } - void compile(context& ctx) + bool compile(context& ctx) { - auto status = miopenCompileFusionPlan(ctx.get_stream().get_miopen(), fp.get()); - if(status != miopenStatusSuccess) - MIGRAPHX_THROW("Compiling fusion plan failed"); + assert(fp); + return miopenCompileFusionPlan(ctx.get_stream().get_miopen(), fp.get()) == + miopenStatusSuccess; } argument execute(context& ctx, @@ -111,6 +139,7 @@ struct fusion const argument& x, const argument& y) const { + assert(fp); auto x_td = make_tensor(x.get_shape()); auto y_td = make_tensor(y.get_shape()); auto status = miopenExecuteFusionPlan(ctx.get_stream().get_miopen(), @@ -126,6 +155,12 @@ struct fusion } }; +const std::unordered_set& get_supported_archs() +{ + static std::unordered_set supported_archs{"gfx900", "gfx906", "gfx908", "gfx1030"}; + return supported_archs; +} + MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins) { auto&& s = ins->get_shape(); @@ -135,6 +170,9 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins) MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) { + const auto device_name = trim(split_string(get_device_name(), ':').front()); + if(not contains(get_supported_archs(), device_name)) + return false; if(enabled(MIGRAPHX_DISABLE_MIOPEN_FUSION{})) return false; if(ins->name() != "gpu::convolution") @@ -148,155 +186,109 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) return false; if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd) return false; + + // Do not fuse non-symmetric input + auto input_lens = ins->inputs().at(0)->get_shape().lens(); + if(input_lens[2] != input_lens[3] or wei.lens()[2] != wei.lens()[3]) + return false; + auto op = conv.op; // Dont fuse winograd for non-3x3s since there is no fused windograd for those configs if(conv.algo == miopenConvolutionFwdAlgoWinograd and wei.lens()[2] != 3 and - wei.lens()[3] != 3 and op.stride == make_array(1, 1)) + wei.lens()[3] != 3 and contains({{1, 1}}, op.stride)) return false; - return contains({{0, 0}, {1, 1}, {2, 2}}, op.padding) and - contains({{0, 0}, {1, 1}}, op.stride) and op.dilation == make_array(1, 1); + return contains({{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}}, op.padding) and + contains({{0, 0}, {1, 1}}, op.stride) and contains({{1, 1}}, op.dilation); } -struct hip_triadd +struct hip_triadd : ternary_device { - std::string name() const { return "hip::triadd"; } - shape compute_shape(const std::vector& inputs) const - { - check_shapes{inputs, *this}.has(4); - return inputs.front(); - } - argument compute(context& ctx, const shape&, const std::vector& args) const - { - device::add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2)); - return args.at(3); - } - std::ptrdiff_t output_alias(const std::vector& shapes) const - { - return shapes.size() - 1; - } }; +MIGRAPHX_REGISTER_OP(hip_triadd) -struct hip_triadd_clip +struct hip_triadd_clip : quinary_device { - op::clip op; - - template - static auto reflect(Self& self, F f) - { - return op::clip::reflect(self.op, f); - } - std::string name() const { return "hip::triadd_clip"; } - shape compute_shape(const std::vector& inputs) const - { - check_shapes{inputs, *this}.has(4); - return inputs.front(); - } - argument compute(context& ctx, const shape&, const std::vector& args) const - { - device::add_clip(ctx.get_stream().get(), - args.at(3), - args.at(0), - args.at(1), - args.at(2), - op.max_val, - op.min_val); - return args.at(3); - } - std::ptrdiff_t output_alias(const std::vector& shapes) const - { - return shapes.size() - 1; - } }; +MIGRAPHX_REGISTER_OP(hip_triadd_clip) -struct hip_add_clip +struct hip_add_clip : quaternary_device { - op::clip op; - - template - static auto reflect(Self& self, F f) - { - return op::clip::reflect(self.op, f); - } - std::string name() const { return "hip::add_clip"; } - shape compute_shape(const std::vector& inputs) const - { - check_shapes{inputs, *this}.has(3); - return inputs.front(); - } - argument compute(context& ctx, const shape&, const std::vector& args) const - { - device::add_clip( - ctx.get_stream().get(), args.at(2), args.at(0), args.at(1), op.max_val, op.min_val); - return args.at(2); - } - std::ptrdiff_t output_alias(const std::vector& shapes) const - { - return shapes.size() - 1; - } }; +MIGRAPHX_REGISTER_OP(hip_add_clip) struct hip_triadd_relu : ternary_device { }; +MIGRAPHX_REGISTER_OP(hip_triadd_relu) struct hip_triadd_sigmoid : ternary_device { }; +MIGRAPHX_REGISTER_OP(hip_triadd_sigmoid) struct hip_triadd_tanh : ternary_device { }; +MIGRAPHX_REGISTER_OP(hip_triadd_tanh) struct hip_add_relu : binary_device { }; +MIGRAPHX_REGISTER_OP(hip_add_relu) struct hip_add_sigmoid : binary_device { }; +MIGRAPHX_REGISTER_OP(hip_add_sigmoid) struct hip_add_tanh : binary_device { }; +MIGRAPHX_REGISTER_OP(hip_add_tanh) -struct hip_mul_add +struct hip_layernorm : unary_device { - std::string name() const { return "hip::mul_add"; } - shape compute_shape(const std::vector& inputs) const - { - check_shapes{inputs, *this}.has(4); - return inputs.front(); - } - argument compute(context& ctx, const shape&, const std::vector& args) const - { - device::mul_add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2)); - return args.at(3); - } - std::ptrdiff_t output_alias(const std::vector& shapes) const - { - return shapes.size() - 1; - } + // Empty finalize to skip dimension reduction + void finalize(context&, const shape&, const std::vector&) {} }; +MIGRAPHX_REGISTER_OP(hip_layernorm) -struct hip_mul_add_relu +struct hip_triadd_layernorm : ternary_device +{ + // Empty finalize to skip dimension reduction + void finalize(context&, const shape&, const std::vector&) {} +}; +MIGRAPHX_REGISTER_OP(hip_triadd_layernorm) + +struct hip_gelu : unary_device +{ +}; +MIGRAPHX_REGISTER_OP(hip_gelu) + +struct hip_add_gelu : binary_device +{ +}; +MIGRAPHX_REGISTER_OP(hip_add_gelu) + +struct hip_gelu_new : unary_device +{ +}; +MIGRAPHX_REGISTER_OP(hip_gelu_new) + +struct hip_add_gelu_new : binary_device { - std::string name() const { return "hip::mul_add_relu"; } - shape compute_shape(const std::vector& inputs) const - { - check_shapes{inputs, *this}.has(4); - return inputs.front(); - } - argument compute(context& ctx, const shape&, const std::vector& args) const - { - device::mul_add_relu( - ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2)); - return args.at(3); - } - std::ptrdiff_t output_alias(const std::vector& shapes) const - { - return shapes.size() - 1; - } }; +MIGRAPHX_REGISTER_OP(hip_add_gelu_new) + +struct hip_mul_add : ternary_device +{ +}; +MIGRAPHX_REGISTER_OP(hip_mul_add) + +struct hip_mul_add_relu : ternary_device +{ +}; +MIGRAPHX_REGISTER_OP(hip_mul_add_relu) void move_broadcasted_back(std::vector& args) { @@ -318,32 +310,147 @@ void move_standard_front(std::vector& args) std::swap(*it, args.front()); } -struct find_add_clip +auto gpu_name(const std::string& s) { return match::name("gpu::" + s); } + +struct find_layernorm +{ + auto matcher() const { return match::layernorm(&gpu_name); } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto x_ins = r.instructions["x"]; + auto args = ins->inputs(); + + // We dont fuse for non-standard layouts + if(not x_ins->get_shape().standard()) + return; + + auto relements = x_ins->get_shape().lens().back(); + + if(relements > 1024 or (relements % 4 != 0 and relements > 256)) + return; + + m.replace_instruction(ins, hip_layernorm{}, x_ins, args.back()); + } +}; + +struct find_triadd_layernorm { auto matcher() const { - return match::name(std::unordered_set{"gpu::clip", "gpu::clipped_relu"})( - match::arg(0)(match::any_of(match::name("gpu::add"), - match::name("hip::triadd"), - match::any_of[match::inputs()](match::standard_shape())) - .bind("add"))); + return match::name("gpu::layernorm")(match::arg(0)(match::name("gpu::triadd")( + match::used_once(), match::all_of[match::inputs()](match::standard_shape())))); } - void apply(program& p, match::matcher_result r) const + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto triadd = ins->inputs().front(); + m.replace_instruction(ins, hip_triadd_layernorm{}, triadd->inputs()); + } +}; + +struct find_gelu +{ + auto matcher() const { return match::gelu_erf(&gpu_name); } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto x_ins = r.instructions["x"]; + auto args = ins->inputs(); + + m.replace_instruction(ins, hip_gelu{}, x_ins, args.back()); + } +}; + +struct find_add_gelu +{ + auto matcher() const + { + return match::name("gpu::gelu")(match::arg(0)(match::name("gpu::add").bind("add"))); + } + + void apply(module& m, const match::matcher_result& r) const { auto add_ins = r.instructions["add"]; auto ins = r.result; - auto&& op = any_cast(ins->get_operator()).op; auto args = add_ins->inputs(); move_standard_front(args); move_broadcasted_back(args); - // Use the allocation from the relu operator args.back() = ins->inputs().back(); + m.replace_instruction(ins, hip_add_gelu{}, args); + } +}; + +struct find_gelu_new +{ + bool fast_math = true; + + auto matcher() const { return match::gelu_tanh(&gpu_name); } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto x_ins = r.instructions["x"]; + auto args = ins->inputs(); + + if(fast_math) + m.replace_instruction(ins, hip_gelu{}, x_ins, args.back()); + else + m.replace_instruction(ins, hip_gelu_new{}, x_ins, args.back()); + } +}; + +struct find_add_gelu_new +{ + auto matcher() const + { + return match::name("gpu::gelu_new")(match::arg(0)(match::name("gpu::add").bind("add"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto add_ins = r.instructions["add"]; + auto ins = r.result; + auto args = add_ins->inputs(); + move_standard_front(args); + move_broadcasted_back(args); + + args.back() = ins->inputs().back(); + m.replace_instruction(ins, hip_add_gelu_new{}, args); + } +}; + +struct find_add_clip +{ + auto matcher() const + { + return match::name(std::unordered_set{"gpu::clip", "gpu::clipped_relu"})( + match::arg(0)(match::any_of(match::name("gpu::add"), + match::name("gpu::triadd"), + match::any_of[match::inputs()](match::standard_shape())) + .bind("add"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto add_ins = r.instructions["add"]; + auto ins = r.result; + auto ins_args = ins->inputs(); + auto add_args = add_ins->inputs(); + move_standard_front(add_args); + move_broadcasted_back(add_args); + + // Use the allocation from the clip operator + add_args.pop_back(); + add_args.insert(add_args.end(), std::next(ins_args.begin()), ins_args.end()); if(add_ins->name() == "gpu::add") - p.replace_instruction(ins, hip_add_clip{op}, args); - else if(add_ins->name() == "hip::triadd") - p.replace_instruction(ins, hip_triadd_clip{op}, args); + m.replace_instruction(ins, hip_add_clip{}, add_args); + else if(add_ins->name() == "gpu::triadd") + m.replace_instruction(ins, hip_triadd_clip{}, add_args); } }; @@ -357,13 +464,13 @@ struct find_add_unary return match::name(op_name)(match::arg(0)( match::used_once(), match::any_of(match::name("gpu::add"), - match::name("hip::triadd"), + match::name("gpu::triadd"), match::any_of(match::name("@literal"), match::any_of[match::inputs()](match::standard_shape()))) .bind("add"))); } - void apply(program& p, match::matcher_result r) const + void apply(module& m, const match::matcher_result& r) const { auto add_ins = r.instructions["add"]; auto ins = r.result; @@ -374,9 +481,9 @@ struct find_add_unary // Use the allocation from the relu operator args.back() = ins->inputs().back(); if(add_ins->name() == "gpu::add") - p.replace_instruction(ins, binary_add_op, args); - else if(add_ins->name() == "hip::triadd") - p.replace_instruction(ins, ternary_add_op, args); + m.replace_instruction(ins, binary_add_op, args); + else if(add_ins->name() == "gpu::triadd") + m.replace_instruction(ins, ternary_add_op, args); } }; @@ -391,23 +498,22 @@ struct find_triadd .bind("input"))); } - void apply(program& p, match::matcher_result r) const + void apply(module& m, const match::matcher_result& r) const { auto add_ins = r.instructions["add"]; auto input_ins = r.instructions["input"]; auto ins = r.result; auto args = add_ins->inputs(); - assert(add_ins != input_ins); auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); }; - if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1) + if(std::count_if(args.begin(), args.end(), is_broadcasted) > 2) return; args.insert(args.begin(), input_ins); move_standard_front(args); move_broadcasted_back(args); args.back() = ins->inputs().back(); - p.replace_instruction(ins, hip_triadd{}, args); + m.replace_instruction(ins, hip_triadd{}, args); } }; @@ -419,7 +525,7 @@ struct find_mul_add match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b"))); } - void apply(program& p, match::matcher_result r) const + void apply(module& m, const match::matcher_result& r) const { auto mul_ins = r.instructions["mul"]; auto b_ins = r.instructions["b"]; @@ -432,7 +538,7 @@ struct find_mul_add args.insert(std::prev(args.end()), b_ins); args.back() = ins->inputs().back(); - p.replace_instruction(ins, hip_mul_add{}, args); + m.replace_instruction(ins, hip_mul_add{}, args); } }; @@ -441,10 +547,10 @@ struct find_mul_add_relu auto matcher() const { return match::name("gpu::relu")( - match::arg(0)(match::name("hip::mul_add")(match::used_once()).bind("mul_add"))); + match::arg(0)(match::name("gpu::mul_add")(match::used_once()).bind("mul_add"))); } - void apply(program& p, match::matcher_result r) const + void apply(module& m, const match::matcher_result& r) const { auto mul_add_ins = r.instructions["mul_add"]; auto ins = r.result; @@ -452,16 +558,132 @@ struct find_mul_add_relu // Use the allocation from the relu operator args.back() = ins->inputs().back(); - p.replace_instruction(ins, hip_mul_add_relu{}, args); + m.replace_instruction(ins, hip_mul_add_relu{}, args); + } +}; + +struct miopen_fusion +{ + struct fuse_op_data + { + operation op; + float alpha = 1; + float beta = 0; + }; + struct fuse_op : fuse_op_data, reflect_equality, reflect_stream + { + template + static auto reflect(Self& self, F f) + { + return pack(f(self.op, "op"), f(self.alpha, "alpha"), f(self.beta, "beta")); + } + }; + std::vector ops = {}; + fusion f = {}; + std::function&)> execute; + template + static auto reflect(Self& self, F f) + { + return pack(f(self.ops, "ops")); + } + + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } + + value compile(context& ctx, const shape&, std::vector inputs) + { + // Compensate for allocation + inputs.pop_back(); + std::size_t i = 0; + f = fusion(inputs[i]); + i++; + std::vector&)>> + invokers; + for(auto&& fop : ops) + { + if(i > inputs.size()) + { + f = {}; + return {}; + } + if(fop.op.name() == "convolution") + { + auto* mop = f.create_conv(any_cast(fop.op), inputs[i]); + invokers.push_back( + [=](const fused_operator_args& fargs, const std::vector& args) { + miopenSetOpArgsConvForward( + fargs.get(), mop, &fop.alpha, &fop.beta, args[i].implicit()); + }); + i++; + } + else if(fop.op.name() == "add") + { + auto* mop = f.create_bias(inputs[i]); + invokers.push_back( + [=](const fused_operator_args& fargs, const std::vector& args) { + miopenSetOpArgsBiasForward( + fargs.get(), mop, &fop.alpha, &fop.beta, args[i].implicit()); + }); + i++; + } + else if(fop.op.name() == "relu") + { + auto* mop = f.create_relu(); + invokers.push_back([=](const fused_operator_args& fargs, + const std::vector&) { + miopenSetOpArgsActivForward(fargs.get(), mop, &fop.alpha, &fop.beta, 0, 0, 0); + }); + } + else + { + f = {}; + return {}; + } + } + if(not f.compile(ctx)) + { + f = {}; + return {}; + } + execute = [invokers](context& c, const fusion& ff, const std::vector& args) { + auto fargs = make_fused_args(); + for(auto&& invoker : invokers) + invoker(fargs, args); + ff.execute(c, fargs, args.front(), args.back()); + }; + return {{"workspace", f.get_workspace(ctx).bytes()}}; + } + void finalize(context& ctx, const shape& output_shape, const std::vector& inputs) + { + if(not f.empty()) + return; + auto v = compile(ctx, output_shape, inputs); + if(not v.is_object()) + MIGRAPHX_THROW("Failed to compile fusion plan"); + } + std::string name() const { return "gpu::miopen_fusion"; } + shape compute_shape(const std::vector& inputs) const + { + if(ops.empty()) + return {}; + // TODO: Check number of arguments + return ops.front().op.compute_shape({inputs[0], inputs[1]}); + } + argument compute(context& ctx, const shape&, const std::vector& args) const + { + execute(ctx, f, args); + return args.back(); } }; struct miopen_conv_bias { op::convolution op; - fusion f; - fusion::op_t conv; - fusion::op_t bias; + fusion fp = {}; + fusion::op_t conv = {}; + fusion::op_t bias = {}; template static auto reflect(Self& self, F f) @@ -469,19 +691,12 @@ struct miopen_conv_bias return op::convolution::reflect(self.op, f); } - miopen_conv_bias(op::convolution c, const shape& input, const shape& weights, const shape& b) - : op(c), f(input) - { - conv = f.create_conv(op, weights); - bias = f.create_bias(b); - } - std::string name() const { return "gpu::conv_bias"; } shape compute_shape(const std::vector& inputs) const { check_shapes{inputs, *this}.has(5); // TODO: Check slices - return op.compute_shape({inputs.at(0), inputs.at(1)}); + return op.normalize_compute_shape({inputs.at(0), inputs.at(1)}); } argument compute(context& ctx, const shape&, const std::vector& args) const { @@ -490,24 +705,33 @@ struct miopen_conv_bias float beta = 0; miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit()); miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit()); - return f.execute(ctx, fargs, args[0], args[4]); + return fp.execute(ctx, fargs, args[0], args[4]); + } + + void finalize(context& ctx, const shape&, const std::vector& inputs) + { + fp = fusion(inputs[0]); + conv = fp.create_conv(op, inputs[1]); + bias = fp.create_bias(inputs[3]); + if(not fp.compile(ctx)) + MIGRAPHX_THROW("Failed to compile fusion plan"); } - void finalize(context& ctx, const shape&, const std::vector&) { f.compile(ctx); } - shape get_workspace(context& ctx) { return f.get_workspace(ctx); } + shape get_workspace(context& ctx) { return fp.get_workspace(ctx); } std::ptrdiff_t output_alias(const std::vector& shapes) const { return shapes.size() - 1; } }; +MIGRAPHX_REGISTER_OP(miopen_conv_bias) struct miopen_conv_bias_relu { op::convolution op; - fusion f; - fusion::op_t conv; - fusion::op_t bias; - fusion::op_t relu; + fusion fp = {}; + fusion::op_t conv = {}; + fusion::op_t bias = {}; + fusion::op_t relu = {}; template static auto reflect(Self& self, F f) @@ -515,23 +739,12 @@ struct miopen_conv_bias_relu return op::convolution::reflect(self.op, f); } - miopen_conv_bias_relu(op::convolution c, - const shape& input, - const shape& weights, - const shape& b) - : op(c), f(input) - { - conv = f.create_conv(op, weights); - bias = f.create_bias(b); - relu = f.create_relu(); - } - std::string name() const { return "gpu::conv_bias_relu"; } shape compute_shape(const std::vector& inputs) const { check_shapes{inputs, *this}.has(5); // TODO: Check slices - return op.compute_shape({inputs.at(0), inputs.at(1)}); + return op.normalize_compute_shape({inputs.at(0), inputs.at(1)}); } argument compute(context& ctx, const shape&, const std::vector& args) const { @@ -541,15 +754,24 @@ struct miopen_conv_bias_relu miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit()); miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit()); miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0); - return f.execute(ctx, fargs, args[0], args[4]); + return fp.execute(ctx, fargs, args[0], args[4]); } - void finalize(context& ctx, const shape&, const std::vector&) { f.compile(ctx); } - shape get_workspace(context& ctx) { return f.get_workspace(ctx); } + void finalize(context& ctx, const shape&, const std::vector& inputs) + { + fp = fusion(inputs[0]); + conv = fp.create_conv(op, inputs[1]); + bias = fp.create_bias(inputs[3]); + relu = fp.create_relu(); + fp.compile(ctx); + } + + shape get_workspace(context& ctx) { return fp.get_workspace(ctx); } std::ptrdiff_t output_alias(const std::vector& shapes) const { return shapes.size() - 1; } }; +MIGRAPHX_REGISTER_OP(miopen_conv_bias_relu) template auto conv_bias(Ms... ms) @@ -561,7 +783,7 @@ auto conv_bias(Ms... ms) } template -void apply_conv_bias(context& ctx, program& p, match::matcher_result r) +void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r) { auto conv_ins = r.instructions["conv"]; auto bias_ins = r.instructions["bias"]; @@ -572,11 +794,30 @@ void apply_conv_bias(context& ctx, program& p, match::matcher_result r) auto alloc_ins = ins->inputs().back(); auto old_ws_ins = conv_ins->inputs().at(2); - Op cb{conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()}; + Op cb{conv_op}; // TODO: Insert ws allocation auto ws = cb.get_workspace(ctx); (void)ws; - p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); + m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); +} + +inline auto precompile_name(std::string s) // NOLINT +{ + return match::make_basic_pred_matcher([=](instruction_ref ins) { + if(ins->name() != "gpu::precompile_op") + return false; + auto op = from_value(ins->get_operator().to_value().at("op")); + return (op.name() == s); + }); +} + +template +auto conv_bias_pointwise(Ms... ms) +{ + return precompile_name("pointwise")( + match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"), + fusable_conv(match::used_once()).bind("conv")), + ms...); } struct find_conv_bias @@ -588,9 +829,9 @@ struct find_conv_bias match::output(match::name(std::unordered_set{"gpu::relu"})))); } - void apply(program& p, match::matcher_result r) const + void apply(module& m, const match::matcher_result& r) const { - apply_conv_bias(*ctx, p, std::move(r)); + apply_conv_bias(*ctx, m, r); } }; @@ -599,27 +840,179 @@ struct find_conv_bias_relu context* ctx = nullptr; auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); } - void apply(program& p, match::matcher_result r) const + void apply(module& m, const match::matcher_result& r) const + { + apply_conv_bias(*ctx, m, r); + } +}; + +struct find_conv_pointwise +{ + context* ctx = nullptr; + auto matcher() const + { + return precompile_name("pointwise")( + match::nargs(3), + match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"), + fusable_conv(match::used_once()).bind("conv"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto conv_ins = r.instructions["conv"]; + auto bias_ins = r.instructions["bias"]; + auto ins = r.result; + auto input_ins = conv_ins->inputs().at(0); + auto weights_ins = conv_ins->inputs().at(1); + auto conv_op = any_cast(conv_ins->get_operator()).op; + auto alloc_ins = ins->inputs().back(); + + module_ref pm = ins->module_inputs().front(); + + miopen_fusion op{}; + op.ops.push_back({{conv_op}}); + for(auto&& i : *pm) + { + if(i.name()[0] == '@') + continue; + op.ops.push_back({{i.get_operator()}}); + } + std::vector inputs = {input_ins, weights_ins, bias_ins, alloc_ins}; + auto v = op.compile(*ctx, ins->get_shape(), to_shapes(inputs)); + if(not v.is_object()) + return; + m.replace_instruction(ins, op, inputs); + } +}; + +struct find_gemm_add +{ + auto matcher() const + { + return match::name("gpu::add")( + match::all_of[match::inputs()](match::standard_shape()), + match::either_arg(0, 1)(match::used_once().bind("c"), + match::name("gpu::gemm")(match::nargs(3)).bind("gemm"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto gemm_ins = r.instructions["gemm"]; + auto c_ins = r.instructions["c"]; + + auto gemm = any_cast>(gemm_ins->get_operator()); + + // Already fused gemm + if(not float_equal(gemm.beta, 0)) + return; + + auto inputs = gemm_ins->inputs(); + inputs.pop_back(); + + auto copy_ins = c_ins; + + // Insert copy + if(ins == m.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty()) + { + copy_ins = m.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back()); + } + inputs.push_back(copy_ins); + inputs.push_back(copy_ins); + + gemm.beta = 1; + m.replace_instruction(ins, gemm, inputs); + } +}; + +auto pointwise_name(const std::string& s) +{ + return precompile_name("pointwise")(match::make_basic_pred_matcher([=](auto ins) { + module_ref pm = ins->module_inputs().front(); + auto n = std::count_if(pm->begin(), pm->end(), [&](auto& i) { return i.name() == s; }); + if(n != 1) + return false; + return std::all_of(pm->begin(), pm->end(), [&](auto& i) { + return starts_with(i.name(), "@") or i.name() == s; + }); + })); +} + +struct find_gemm_pointwise +{ + auto matcher() const { - apply_conv_bias(*ctx, p, std::move(r)); + return pointwise_name("add")( + match::nargs(3), + match::all_of[match::inputs()](match::standard_shape()), + match::either_arg(0, 1)(match::used_once().bind("c"), + match::name("gpu::gemm")(match::nargs(3)).bind("gemm"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto gemm_ins = r.instructions["gemm"]; + auto c_ins = r.instructions["c"]; + + auto gemm = any_cast>(gemm_ins->get_operator()); + + // Already fused gemm + if(not float_equal(gemm.beta, 0)) + return; + + auto inputs = gemm_ins->inputs(); + inputs.pop_back(); + + inputs.push_back(c_ins); + inputs.push_back(ins->inputs().back()); + + gemm.beta = 1; + m.replace_instruction(ins, gemm, inputs); + } +}; + +struct find_commutative_broadcast +{ + auto matcher() const + { + return match::name("gpu::add", "gpu::mul")(match::arg(1)(match::broadcast_shape())); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto args = ins->inputs(); + move_broadcasted_back(args); + + m.replace_instruction(ins, ins->get_operator(), args); } }; -void fuse_ops::apply(program& p) const +void fuse_ops::apply(module& m) const { - // clang-format off - match::find_matches(p, find_triadd{}); - match::find_matches(p, - find_conv_bias_relu{ctx}, - find_conv_bias{ctx}, - find_mul_add{}, - find_mul_add_relu{}, - find_add_unary{"gpu::relu", hip_add_relu{}, hip_triadd_relu{}}, - find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}}, - find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}}, - find_add_clip{} - ); - // clang-format on + match::find_matches(m, find_gelu{}, find_gelu_new{fast_math}); + run_passes(m, {dead_code_elimination{}}); + match::find_matches(m, find_triadd{}); + match::find_matches(m, + find_layernorm{}, + find_conv_pointwise{ctx}, + find_conv_bias_relu{ctx}, + find_conv_bias{ctx}, + find_add_gelu{}, + find_add_gelu_new{}, + find_mul_add{}, + find_mul_add_relu{}, + find_add_unary{"gpu::relu", hip_add_relu{}, hip_triadd_relu{}}, + find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}}, + find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}}, + find_add_clip{}); + run_passes(m, {dead_code_elimination{}}); + match::find_matches(m, + find_triadd_layernorm{}, + find_gemm_add{}, + find_gemm_pointwise{}, + find_commutative_broadcast{}); } } // namespace gpu diff --git a/src/targets/gpu/gather.cpp b/src/targets/gpu/gather.cpp index aa6cd22121b1857206e58f740e4f053bda23981a..59d6ec8642ace00f0d0e241fc45966e3f1db22a7 100644 --- a/src/targets/gpu/gather.cpp +++ b/src/targets/gpu/gather.cpp @@ -9,7 +9,7 @@ namespace gpu { shape hip_gather::compute_shape(std::vector inputs) const { inputs.pop_back(); - return op.compute_shape(inputs); + return op.normalize_compute_shape(inputs); } argument hip_gather::compute(context& ctx, const shape&, const std::vector& args) const diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index 9ee24487cb79438aec2a635ba0ff92f74ab90a6d..c2d7dabe2903901f1718af89ec9b356320380772 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -1,5 +1,6 @@ -#include +#include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -16,6 +17,8 @@ rocblas_datatype get_type(shape::type_t type) case shape::uint8_type: return rocblas_datatype_u8_r; case shape::int32_type: return rocblas_datatype_i32_r; case shape::uint32_type: return rocblas_datatype_u32_r; + case shape::tuple_type: + case shape::bool_type: case shape::uint16_type: case shape::int16_type: case shape::int64_type: @@ -25,12 +28,54 @@ rocblas_datatype get_type(shape::type_t type) MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!"); } +void blas_shape(const shape& s) +{ + if(s.lens().size() < 2) + return; + if(std::none_of(s.strides().end() - 2, s.strides().end(), [&](auto i) { return i == 1; })) + MIGRAPHX_THROW("GPU_GEMM: needs to have one matrix stride as 1"); + if(s.lens().size() < 3) + return; + shape batch_shape{s.type(), + {s.lens().begin(), s.lens().end() - 2}, + {s.strides().begin(), s.strides().end() - 2}}; + auto batch_shapes = reduce_dims({batch_shape}); + if(batch_shapes.front().lens().size() != 1) + MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); +} + +template +R rocblas_invoke(R (*f)(Ts...), Us... xs) +{ + if constexpr(sizeof...(Ts) == sizeof...(Us)) + return f(xs...); + else + return f(xs..., nullptr, nullptr); +} + +static bool is_transposed(const shape& s) +{ + if(not s.transposed()) + return false; + return s.strides().back() != 1; +} + +static rocblas_int get_batch_stride(const argument& a) +{ + return a.get_shape().strides()[a.get_shape().strides().size() - 3]; +} + template -void gemm_impl( - context& ctx, const shape& output_shape, const std::vector& args, T alpha, T beta) +void gemm_impl(context& ctx, + const shape& output_shape, + const std::vector& args, + T alpha, + T beta, + bool int8_x4_format, + bool compute_fp32) { - bool transa = args[0].get_shape().transposed(); - bool transb = args[1].get_shape().transposed(); + bool transa = is_transposed(args[0].get_shape()); + bool transb = is_transposed(args[1].get_shape()); auto n_dim = output_shape.lens().size(); auto dim_1 = n_dim - 1; auto dim_0 = n_dim - 2; @@ -50,18 +95,42 @@ void gemm_impl( output_type = rocblas_datatype_i32_r; } auto compute_type = output_type; + if(compute_fp32) + { + if(arg_type == rocblas_datatype_f16_r) + compute_type = rocblas_datatype_f32_r; + } + +#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 + rocblas_gemm_flags flag = + int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none; +#else + (void)int8_x4_format; + int flag = 0; +#endif auto a_lens = args[0].get_shape().lens(); auto b_lens = args[1].get_shape().lens(); output_shape.visit_type([&](auto as) { - auto alpha_r = as(alpha); - auto beta_r = as(beta); + auto alpha_r = as(alpha); + auto beta_r = as(beta); + + // use void pointer to select different data type if using fp32 mode + void* alpha_v = &alpha_r; + void* beta_v = &beta_r; + + if(compute_fp32) + { + alpha_v = α + beta_v = β + } + auto out_lens = output_shape.lens(); rocblas_int m = out_lens[dim_0]; rocblas_int n = out_lens[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1]; auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); }; - if(args[0].get_shape().type() == shape::int8_type and (k % 4) != 0) + if(args[0].get_shape().type() == shape::int8_type and (k % 4) != 0 and int8_x4_format) { MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!"); } @@ -74,67 +143,67 @@ void gemm_impl( // column-major format. When doing a C = A * B, we actually do // C^T = (B^T) * (A^T). That is the reason we input args[1] as // A and args[0] as B in calling the rocblas_gemm. - rocblas_gemm_ex(ctx.get_stream().get_rocblas(), - transb ? rocblas_operation_transpose : rocblas_operation_none, - transa ? rocblas_operation_transpose : rocblas_operation_none, - n, - m, - k, - &alpha_r, - to_pointer(args.at(1)), - arg_type, - ldb, - to_pointer(args.at(0)), - arg_type, - lda, - &beta_r, - to_pointer(args[2]), - output_type, - ldc, - is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), - output_type, - ldc, - compute_type, - rocblas_gemm_algo_standard, - 0, - 0, - nullptr, - nullptr); + rocblas_invoke(&rocblas_gemm_ex, + ctx.get_stream().get_rocblas(), + transb ? rocblas_operation_transpose : rocblas_operation_none, + transa ? rocblas_operation_transpose : rocblas_operation_none, + n, + m, + k, + alpha_v, + to_pointer(args.at(1)), + arg_type, + ldb, + to_pointer(args.at(0)), + arg_type, + lda, + beta_v, + to_pointer(args[2]), + output_type, + ldc, + is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), + output_type, + ldc, + compute_type, + rocblas_gemm_algo_standard, + 0, + flag); } else { - rocblas_gemm_strided_batched_ex( - ctx.get_stream().get_rocblas(), - transb ? rocblas_operation_transpose : rocblas_operation_none, - transa ? rocblas_operation_transpose : rocblas_operation_none, - n, - m, - k, - &alpha_r, - to_pointer(args.at(1)), - arg_type, - ldb, - k * n, - to_pointer(args.at(0)), - arg_type, - lda, - m * k, - &beta_r, - to_pointer(args[2]), - output_type, - ldc, - m * n, - is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), - output_type, - ldc, - m * n, - num_matrices, - compute_type, - rocblas_gemm_algo_standard, - 0, - 0, - nullptr, - nullptr); + auto a_stride = get_batch_stride(args[0]); + auto b_stride = get_batch_stride(args[1]); + auto c_stride = get_batch_stride(args[2]); + rocblas_invoke(&rocblas_gemm_strided_batched_ex, + ctx.get_stream().get_rocblas(), + transb ? rocblas_operation_transpose : rocblas_operation_none, + transa ? rocblas_operation_transpose : rocblas_operation_none, + n, + m, + k, + alpha_v, + to_pointer(args.at(1)), + arg_type, + ldb, + b_stride, + to_pointer(args.at(0)), + arg_type, + lda, + a_stride, + beta_v, + to_pointer(args[2]), + output_type, + ldc, + c_stride, + is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), + output_type, + ldc, + c_stride, + num_matrices, + compute_type, + rocblas_gemm_algo_standard, + 0, + flag); } }); } @@ -143,18 +212,22 @@ void gemm(context& ctx, const shape& output_shape, const std::vector& args, float alpha, - float beta) + float beta, + bool int8_x4_format, + bool compute_fp32) { - gemm_impl(ctx, output_shape, args, alpha, beta); + gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32); } void gemm(context& ctx, const shape& output_shape, const std::vector& args, int32_t alpha, - int32_t beta) + int32_t beta, + bool int8_x4_format, + bool compute_fp32) { - gemm_impl(ctx, output_shape, args, alpha, beta); + gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32); } } // namespace gpu diff --git a/src/targets/gpu/hip.cpp b/src/targets/gpu/hip.cpp index 8e667c84cabe4e44d92050a60f345ce20bb00534..6cb830712b9b078f680c4662ce8f0fa49ce82399 100644 --- a/src/targets/gpu/hip.cpp +++ b/src/targets/gpu/hip.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -12,11 +13,29 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { +MIGRAPHX_REGISTER_OP(hip_allocate) +MIGRAPHX_REGISTER_OP(hip_sync_device) +MIGRAPHX_REGISTER_OP(hip_sync_stream) +MIGRAPHX_REGISTER_OP(hip_copy_to_gpu) +MIGRAPHX_REGISTER_OP(hip_copy_from_gpu) +MIGRAPHX_REGISTER_OP(hip_copy) +MIGRAPHX_REGISTER_OP(hip_allocate_memory) +MIGRAPHX_REGISTER_OP(hip_copy_literal) + using hip_ptr = MIGRAPHX_MANAGE_PTR(void, hipFree); using hip_host_ptr = MIGRAPHX_MANAGE_PTR(void, hipHostUnregister); std::string hip_error(int error) { return hipGetErrorString(static_cast(error)); } +bool is_device_ptr(const void* ptr) +{ + hipPointerAttribute_t attr; + auto status = hipPointerGetAttributes(&attr, ptr); + if(status != hipSuccess) + return false; + return attr.memoryType == hipMemoryTypeDevice; +} + std::size_t get_available_gpu_memory() { size_t free; @@ -40,15 +59,16 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false) { if(sz > get_available_gpu_memory()) MIGRAPHX_THROW("Memory not available to allocate buffer: " + std::to_string(sz)); - void* result; - auto status = host ? hipHostMalloc(&result, sz) : hipMalloc(&result, sz); + void* result = nullptr; + auto status = host ? hipHostMalloc(&result, sz) : hipMalloc(&result, sz); if(status != hipSuccess) { if(host) MIGRAPHX_THROW("Gpu allocation failed: " + hip_error(status)); else - allocate_gpu(sz, true); + return allocate_gpu(sz, true); } + assert(result != nullptr); return hip_ptr{result}; } @@ -65,6 +85,8 @@ std::vector read_from_gpu(const void* x, std::size_t sz) { gpu_sync(); std::vector result(sz); + assert(not is_device_ptr(result.data())); + assert(is_device_ptr(x)); auto status = hipMemcpy(result.data(), x, sz * sizeof(T), hipMemcpyDeviceToHost); if(status != hipSuccess) MIGRAPHX_THROW("Copy from gpu failed: " + hip_error(status)); // NOLINT @@ -75,6 +97,8 @@ hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false) { gpu_sync(); auto result = allocate_gpu(sz, host); + assert(is_device_ptr(result.get())); + assert(not is_device_ptr(x)); auto status = hipMemcpy(result.get(), x, sz, hipMemcpyHostToDevice); if(status != hipSuccess) MIGRAPHX_THROW("Copy to gpu failed: " + hip_error(status)); @@ -99,10 +123,9 @@ argument register_on_gpu(const argument& arg) { auto arg_shared = arg.share(); auto p = share(register_on_gpu(arg_shared.data(), arg_shared.get_shape().bytes())); - return {arg_shared.get_shape(), - [ p, a = std::move(arg_shared) ]() mutable {return get_device_ptr(p.get()); -} -}; // namespace gpu + return {arg_shared.get_shape(), [p, a = std::move(arg_shared)]() mutable { + return get_device_ptr(p.get()); + }}; // namespace gpu } // namespace MIGRAPHX_INLINE_NS argument to_gpu(const argument& arg, bool host) @@ -130,7 +153,14 @@ void set_device(std::size_t id) MIGRAPHX_THROW("Error setting device"); } -void gpu_sync() { hipDeviceSynchronize(); } +void gpu_sync() +{ + auto status = hipDeviceSynchronize(); + if(status != hipSuccess) + MIGRAPHX_THROW("hip device synchronization failed: " + hip_error(status)); +} + +void gpu_sync(const context& ctx) { ctx.finish(); } void hip_async_copy(context& ctx, const argument& src, const argument& dst, hipMemcpyKind kind) { @@ -152,12 +182,26 @@ void gpu_copy(context& ctx, const argument& src, const argument& dst) void copy_to_gpu(context& ctx, const argument& src, const argument& dst) { - gpu_copy(ctx, register_on_gpu(src), dst); + if(src.get_shape() == dst.get_shape() and dst.get_shape().packed()) + { + hip_async_copy(ctx, src, dst, hipMemcpyHostToDevice); + } + else + { + gpu_copy(ctx, register_on_gpu(src), dst); + } } void copy_from_gpu(context& ctx, const argument& src, const argument& dst) { - gpu_copy(ctx, src, register_on_gpu(dst)); + if(src.get_shape() == dst.get_shape() and dst.get_shape().packed()) + { + hip_async_copy(ctx, src, dst, hipMemcpyDeviceToHost); + } + else + { + gpu_copy(ctx, src, register_on_gpu(dst)); + } } argument get_preallocation(context& ctx, const std::string& id) @@ -165,6 +209,11 @@ argument get_preallocation(context& ctx, const std::string& id) return ctx.get_current_device().preallocations.at(id); } +void store_preallocated_param(context& ctx, const std::string& id, const argument& a) +{ + ctx.get_current_device().preallocations[id] = a; +} + // clang-format off } // namespace gpu } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/gpu/include/migraphx/gpu/abs.hpp b/src/targets/gpu/include/migraphx/gpu/abs.hpp index ffa3252689c5981c759fd1b574d05542210e3341..230e6e4e54d2cb0159a02b89be8e938f817ab68a 100644 --- a/src/targets/gpu/include/migraphx/gpu/abs.hpp +++ b/src/targets/gpu/include/migraphx/gpu/abs.hpp @@ -2,7 +2,9 @@ #define MIGRAPHX_GUARD_RTGLIB_ABS_HPP #include +#include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -12,18 +14,20 @@ struct context; struct miopen_abs { + op::abs op; shared ad; template static auto reflect(Self& self, F f) { - return gpu::reflect(self.ad.get(), f); + return migraphx::reflect(self.op, f); } std::string name() const { return "gpu::abs"; } shape compute_shape(const std::vector& inputs) const; argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; + void finalize(context&, const shape&, const std::vector&); std::ptrdiff_t output_alias(const std::vector& shapes) const { return shapes.size() - 1; diff --git a/src/targets/gpu/include/migraphx/gpu/acosh.hpp b/src/targets/gpu/include/migraphx/gpu/acosh.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a17813bfafb2ef30ef74d9350df136b470d3c716 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/acosh.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_ACOSH_HPP +#define MIGRAPHX_GUARD_RTGLIB_ACOSH_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct hip_acosh : unary_device +{ +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/allocation_model.hpp b/src/targets/gpu/include/migraphx/gpu/allocation_model.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e9e1e3c55dce26db4d829be7e81eaf485c10ba41 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/allocation_model.hpp @@ -0,0 +1,26 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_GPU_ALLOCATION_MODEL_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_GPU_ALLOCATION_MODEL_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct gpu_allocation_model +{ + std::string name() const; + std::string copy() const; + operation allocate(const shape& s) const; + operation preallocate(const shape& s, const std::string& id) const; + bool needs_out_params() const { return true; } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/analyze_streams.hpp b/src/targets/gpu/include/migraphx/gpu/analyze_streams.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dd7fa840ac584387d883bda6295d2866f889ccd5 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/analyze_streams.hpp @@ -0,0 +1,20 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_ANALYZE_STREAMS_HPP +#define MIGRAPHX_GUARD_RTGLIB_GPU_ANALYZE_STREAMS_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +namespace gpu { + +std::vector analyze_streams(const module& m); + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/argmax.hpp b/src/targets/gpu/include/migraphx/gpu/argmax.hpp index 7692f33578cafe6eee2cab6c6f485ae1e56e6ddb..92814fa8c0f169f2bb5340af5a5bd03a26a27323 100644 --- a/src/targets/gpu/include/migraphx/gpu/argmax.hpp +++ b/src/targets/gpu/include/migraphx/gpu/argmax.hpp @@ -1,7 +1,8 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_ARGMAX_HPP #define MIGRAPHX_GUARD_RTGLIB_ARGMAX_HPP -#include +#include +#include #include #include diff --git a/src/targets/gpu/include/migraphx/gpu/argmin.hpp b/src/targets/gpu/include/migraphx/gpu/argmin.hpp index 90dff5f6b20e60b883d8930ecb5b80f83d770d03..f43c2bd71369b76966937fa7ed447b8c64e449c9 100644 --- a/src/targets/gpu/include/migraphx/gpu/argmin.hpp +++ b/src/targets/gpu/include/migraphx/gpu/argmin.hpp @@ -1,7 +1,8 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_ARGMIN_HPP #define MIGRAPHX_GUARD_RTGLIB_ARGMIN_HPP -#include +#include +#include #include #include diff --git a/src/targets/gpu/include/migraphx/gpu/asinh.hpp b/src/targets/gpu/include/migraphx/gpu/asinh.hpp new file mode 100644 index 0000000000000000000000000000000000000000..099460fc66a4fbefb1efa5e73d558c37efec2be3 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/asinh.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_ASINH_HPP +#define MIGRAPHX_GUARD_RTGLIB_ASINH_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct hip_asinh : unary_device +{ +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/atanh.hpp b/src/targets/gpu/include/migraphx/gpu/atanh.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6b2f434970b1c37608fa24ecb19009260cf2f231 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/atanh.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_ATANH_HPP +#define MIGRAPHX_GUARD_RTGLIB_ATANH_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct hip_atanh : unary_device +{ +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/batchnorm.hpp b/src/targets/gpu/include/migraphx/gpu/batch_norm_inference.hpp similarity index 88% rename from src/targets/gpu/include/migraphx/gpu/batchnorm.hpp rename to src/targets/gpu/include/migraphx/gpu/batch_norm_inference.hpp index bb040398ddeab7d15ea7493c3a90fc3b12c26229..9d682131f0e126ea836508bdb3c43362b792767f 100644 --- a/src/targets/gpu/include/migraphx/gpu/batchnorm.hpp +++ b/src/targets/gpu/include/migraphx/gpu/batch_norm_inference.hpp @@ -1,8 +1,9 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_BATCHNORM_HPP #define MIGRAPHX_GUARD_RTGLIB_BATCHNORM_HPP -#include -#include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { diff --git a/src/targets/gpu/include/migraphx/gpu/clip.hpp b/src/targets/gpu/include/migraphx/gpu/clip.hpp index 5c123a825ecc9f2ea53e7a49cb6c19a63548d335..a3bb7b916d0dbce6e3689d301e2d35f3bb1147e9 100644 --- a/src/targets/gpu/include/migraphx/gpu/clip.hpp +++ b/src/targets/gpu/include/migraphx/gpu/clip.hpp @@ -1,7 +1,8 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_CLIP_HPP #define MIGRAPHX_GUARD_RTGLIB_CLIP_HPP -#include +#include +#include #include namespace migraphx { diff --git a/src/targets/gpu/include/migraphx/gpu/code_object_op.hpp b/src/targets/gpu/include/migraphx/gpu/code_object_op.hpp new file mode 100755 index 0000000000000000000000000000000000000000..7e014a45179c144ac42729b40e3f954119dca724 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/code_object_op.hpp @@ -0,0 +1,67 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CODE_OBJECT_OP_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_CODE_OBJECT_OP_HPP + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct context; + +struct code_object_op +{ + value::binary code_object; + std::string symbol_name; + std::size_t global; + std::size_t local; + std::vector expected_inputs; + shape output; + kernel k{}; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.code_object, "code_object"), + f(self.symbol_name, "symbol_name"), + f(self.global, "global"), + f(self.local, "local"), + f(self.expected_inputs, "expected_inputs"), + f(self.output, "output")); + } + + value attributes() const { return {{"group", group()}}; } + + std::string group() const { return "gpu::code_object::" + symbol_name; } + + std::string name() const { return "gpu::code_object"; } + shape compute_shape(std::vector inputs) const; + argument + compute(context& ctx, const shape& output_shape, const std::vector& args) const; + void finalize(context&, const shape&, const std::vector&); + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } + + friend std::ostream& operator<<(std::ostream& os, const code_object_op& op) + { + os << op.name() << "["; + os << "code_object=" << op.code_object.size() << ","; + os << "symbol_name=" << op.symbol_name << ","; + os << "global=" << op.global << ","; + os << "local=" << op.local << ","; + os << "]"; + return os; + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/compile_gen.hpp b/src/targets/gpu/include/migraphx/gpu/compile_gen.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b28f922dacc30465ffc5643c8e42a3c2ad7368b0 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/compile_gen.hpp @@ -0,0 +1,46 @@ +#ifndef MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP +#define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct shape; + +namespace gpu { +namespace gen { + +struct vectorize +{ + std::size_t size = 1; + std::size_t axis = 0; + static vectorize elements(std::size_t axis, const std::vector& inputs); + std::string str() const; +}; +struct preload +{ + std::vector args = {}; + static preload broadcasts(std::size_t axis, const std::vector& inputs); + bool is_preloading() const; + std::string str() const; +}; + +std::size_t find_fast_axis(const std::vector& inputs); + +std::string make_transformer_args(std::vector transformers); + +template +std::string make_transformer_args(Ts... xs) +{ + return make_transformer_args({xs.str()...}); +} + +} // namespace gen +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP diff --git a/src/targets/gpu/include/migraphx/gpu/compile_hip.hpp b/src/targets/gpu/include/migraphx/gpu/compile_hip.hpp new file mode 100755 index 0000000000000000000000000000000000000000..556f7006595a685861a054c13f813aa874f0d776 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/compile_hip.hpp @@ -0,0 +1,24 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_COMPILE_HIP_HPP +#define MIGRAPHX_GUARD_RTGLIB_COMPILE_HIP_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +std::vector> +compile_hip_src(const std::vector& srcs, std::string params, const std::string& arch); + +std::string enum_params(std::size_t count, std::string param); + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp b/src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp new file mode 100644 index 0000000000000000000000000000000000000000..882baadc6ecc30ae2c71a54c32fea8937076d8e0 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp @@ -0,0 +1,54 @@ +#ifndef MIGRAPHX_GUARD_GPU_COMPILE_HIP_CODE_OBJECT_HPP +#define MIGRAPHX_GUARD_GPU_COMPILE_HIP_CODE_OBJECT_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct context; + +struct hip_compile_options +{ + std::size_t global; + std::size_t local; + std::vector inputs; + shape output; + std::string kernel_name = "kernel"; + std::string params = ""; + std::vector virtual_inputs = {}; + + /** + * @brief Set the launch parameters but allow v to override the values + * + * @param v A value class which can have a "global" and/or "local" keys to override the default + * global and local + * @param compute_global A function used to compute the global based on the local + * @param default_local The defaul local to use if its missing from the v parameter + */ + void set_launch_params(const value& v, + const std::function& compute_global, + std::size_t default_local = 1024); + + void + set_launch_params(const value& v, std::size_t default_global, std::size_t default_local = 1024) + { + set_launch_params( + v, [=](auto) { return default_global; }, default_local); + } +}; + +/// Compute global for n elements, but max out on target-specific upper limit +std::function +compute_global_for(context& ctx, std::size_t n, std::size_t over = 1); + +operation compile_hip_code_object(const std::string& content, hip_compile_options options); + +std::size_t compute_block_size(std::size_t n, std::size_t max_block_size = 1024); + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_COMPILE_HIP_CODE_OBJECT_HPP diff --git a/src/targets/gpu/include/migraphx/gpu/compile_ops.hpp b/src/targets/gpu/include/migraphx/gpu/compile_ops.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6861c96b4365b31f8b08631fdb47422db9bc9fbd --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/compile_ops.hpp @@ -0,0 +1,27 @@ +#ifndef MIGRAPHX_GUARD_GPU_COMPILE_OPS_HPP +#define MIGRAPHX_GUARD_GPU_COMPILE_OPS_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +namespace gpu { + +struct context; + +struct compile_ops +{ + context* ctx = nullptr; + std::string name() const { return "gpu::compile_ops"; } + void apply(module& m) const; +}; + +} // namespace gpu + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_COMPILE_OPS_HPP diff --git a/src/targets/gpu/include/migraphx/gpu/compiler.hpp b/src/targets/gpu/include/migraphx/gpu/compiler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d641415d319c3b0f1c219b32c3b404144f2bbd90 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/compiler.hpp @@ -0,0 +1,70 @@ +#ifndef MIGRAPHX_GUARD_GPU_COMPILER_HPP +#define MIGRAPHX_GUARD_GPU_COMPILER_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct context; + +using compiler_replace = std::function; +using compiler_compile = std::function; +using compiler_compile_op = + std::function& inputs, const value&)>; + +void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop); + +bool has_compiler_for(const std::string& name); +compiler_replace compile(context& ctx, instruction_ref ins, const operation& op); +operation +compile_op(const std::string& name, context& ctx, const std::vector& inputs, const value& v); + +template +void register_compiler() +{ + T c; + for(auto&& name : c.names()) + { + register_compiler( + name, + [=](auto&&... xs) { return c.compile(std::forward(xs)...); }, + [=](auto&&... xs) { return c.compile_op(std::forward(xs)...); }); + } +} + +struct register_compiler_action +{ + template + static void apply() + { + register_compiler(); + } +}; + +template +using auto_register_compiler = auto_register; + +template +struct compiler : auto_register_compiler +{ + auto replace(const operation& op) const + { + return + [=](module& m, instruction_ref ins) { m.replace_instruction(ins, op, ins->inputs()); }; + } + operation compile_op(context&, const std::vector&, const value&) const { return {}; } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif // MIGRAPHX_GUARD_GPU_COMPILER_HPP diff --git a/src/targets/gpu/include/migraphx/gpu/concat.hpp b/src/targets/gpu/include/migraphx/gpu/concat.hpp index 08c776bd05dd090e41f9bcba250e269688c03602..0292cb06c33aa3f447a8e40ebf86bb0c8d27b921 100644 --- a/src/targets/gpu/include/migraphx/gpu/concat.hpp +++ b/src/targets/gpu/include/migraphx/gpu/concat.hpp @@ -1,7 +1,8 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_CONCAT_HPP #define MIGRAPHX_GUARD_RTGLIB_CONCAT_HPP -#include +#include +#include #include namespace migraphx { diff --git a/src/targets/gpu/include/migraphx/gpu/concat_gpu_opt.hpp b/src/targets/gpu/include/migraphx/gpu/concat_gpu_opt.hpp index 3931e19ffb021f61be19f82bf8f192310c06d505..6cb772ba82f638978e16e4dc4d24340f74f4418a 100644 --- a/src/targets/gpu/include/migraphx/gpu/concat_gpu_opt.hpp +++ b/src/targets/gpu/include/migraphx/gpu/concat_gpu_opt.hpp @@ -2,6 +2,7 @@ #define MIGRAPHX_GUARD_RTGLIB_CONCAT_GPU_OPT_HPP #include +#include namespace migraphx { namespace gpu { diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index 64cc24c9ebb20934ce629d7951453a3b61d691c7..f9da7a90d8701baea678a8fa27e30f6aa463b2db 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -1,12 +1,14 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP #define MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP +#include #include #include #include #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -19,10 +21,20 @@ using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy); struct hip_device { - hip_device() { add_stream(); } + hip_device() + { + device_props.gcnArchName[0] = '\0'; + device_props.gcnArch = 0; + device_props.multiProcessorCount = 0; + add_stream(); + } hip_device(std::size_t id, std::size_t n) : device_id(id) { + auto status = hipGetDeviceProperties(&device_props, device_id); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to allocate stream"); + for(std::size_t i = 0; i < n; i++) add_stream(); } @@ -35,7 +47,7 @@ struct hip_device stream(std::size_t device_number) : id(device_number) {} - void setup() { set_device(id); } + void setup() const { set_device(id); } static hip_stream_ptr create_stream() { @@ -85,6 +97,16 @@ struct hip_device return rbhandle.get(); } + void wait() const + { + if(s == nullptr) + return; + setup(); + auto status = hipStreamSynchronize(s.get()); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to wait."); + } + void wait(hipEvent_t event) { setup(); @@ -114,16 +136,36 @@ struct hip_device stream& get_stream(std::size_t n) { return streams.at(n); } + const stream& get_stream() const { return streams.at(current_stream); } + + const stream& get_stream(std::size_t n) const { return streams.at(n); } + void set_stream(std::size_t n) { current_stream = n; } std::size_t nstreams() const { return streams.size(); } std::size_t stream_id() const { return current_stream; } + std::string get_device_name() const { return device_props.gcnArchName; } + + std::size_t get_device_major() const { return device_props.major; } + + std::size_t get_device_minor() const { return device_props.minor; } + + std::size_t get_cu_count() const { return device_props.multiProcessorCount; } + + std::size_t get_max_workitems_per_cu() const + { + return device_props.maxThreadsPerMultiProcessor; + } + + std::size_t get_max_workitems_per_block() const { return device_props.maxThreadsPerBlock; } + private: std::size_t device_id = 0; std::size_t current_stream = 0; std::vector streams; + hipDeviceProp_t device_props; public: std::unordered_map preallocations{}; @@ -142,9 +184,21 @@ struct context return *current_device; } + const hip_device& get_current_device() const + { + assert(current_device != nullptr); + return *current_device; + } + hip_device::stream& get_stream() { return get_current_device().get_stream(); } hip_device::stream& get_stream(std::size_t n) { return get_current_device().get_stream(n); } + const hip_device::stream& get_stream() const { return get_current_device().get_stream(); } + const hip_device::stream& get_stream(std::size_t n) const + { + return get_current_device().get_stream(n); + } + void set_stream(std::size_t n) { get_current_device().set_stream(n); } void create_events(std::size_t num_of_events) @@ -156,7 +210,7 @@ struct context hipEvent_t get_event(std::size_t i) const { return events.at(i).get(); } std::vector literals{}; - void finish() const { gpu_sync(); } + void finish() const { get_stream().wait(); } static hip_event_ptr create_event() { @@ -167,11 +221,38 @@ struct context return hip_event_ptr{event}; } + value to_value() const + { + value result; + result["events"] = events.size(); + result["streams"] = current_device->nstreams(); + + return result; + } + + void from_value(const value& v) + { + auto v_events = v.at("events"); + std::size_t n_events = v_events.without_key().to(); + this->create_events(n_events - 1); + + auto v_streams = v.at("streams"); + std::size_t n_streams = v_streams.without_key().to(); + + this->current_device = std::make_shared(0, n_streams); + } + + any_ptr get_queue() { return get_stream().get(); } + private: // TODO: Make this a vector to support multiple devices std::shared_ptr current_device; std::vector> events; }; + +inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); } +inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); } + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/include/migraphx/gpu/contiguous.hpp b/src/targets/gpu/include/migraphx/gpu/contiguous.hpp index 5e471a3f1ba788d301cd35ded4996663f2527216..d8294ebde5a4cd0242f23ee1eeeb229721343b4a 100644 --- a/src/targets/gpu/include/migraphx/gpu/contiguous.hpp +++ b/src/targets/gpu/include/migraphx/gpu/contiguous.hpp @@ -3,6 +3,8 @@ #include #include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -10,22 +12,17 @@ namespace gpu { struct context; -struct miopen_contiguous +struct miopen_contiguous : unary_device { - op::contiguous op; - - template - static auto reflect(Self& self, F f) - { - return migraphx::reflect(self.op, f); - } - std::string name() const { return "gpu::contiguous"; } - shape compute_shape(const std::vector& inputs) const; - argument compute(context&, shape output_shape, const std::vector& args) const; - std::ptrdiff_t output_alias(const std::vector& shapes) const + shape compute_shape(const std::vector& inputs) const { - return shapes.size() - 1; + check_shapes{inputs, *this}.has(2); + if(inputs.front().standard()) + return inputs.front(); + auto lens = inputs.at(0).lens(); + auto t = inputs.at(0).type(); + return {t, lens}; } }; diff --git a/src/targets/gpu/include/migraphx/gpu/convert.hpp b/src/targets/gpu/include/migraphx/gpu/convert.hpp index d119b85ec6172008a710592e6ff55fa39d503d22..d276c5e2fca6c83cf963541d11f6dd136db937b5 100644 --- a/src/targets/gpu/include/migraphx/gpu/convert.hpp +++ b/src/targets/gpu/include/migraphx/gpu/convert.hpp @@ -1,7 +1,8 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_CONVERT_HPP #define MIGRAPHX_GUARD_RTGLIB_CONVERT_HPP -#include +#include +#include #include namespace migraphx { diff --git a/src/targets/gpu/include/migraphx/gpu/convolution.hpp b/src/targets/gpu/include/migraphx/gpu/convolution.hpp index 111f0a58a8e66f2ec59ea97f05c5b6f3e761dc2d..54c2590d8a51c4725411f4c909ec9a223c98f524 100644 --- a/src/targets/gpu/include/migraphx/gpu/convolution.hpp +++ b/src/targets/gpu/include/migraphx/gpu/convolution.hpp @@ -14,22 +14,26 @@ struct context; struct miopen_convolution { op::convolution op; - shared cd; + shared cd = nullptr; miopenConvFwdAlgorithm_t algo{}; - miopenHandle_t handle = nullptr; + uint64_t solution_id = 0; template static auto reflect(Self& self, F f) { - // TODO: Add algo - return op::convolution::reflect(self.op, f); + return pack(f(self.op.padding, "padding"), + f(self.op.stride, "stride"), + f(self.op.dilation, "dilation"), + f(self.op.group, "group"), + f(self.op.padding_mode, "padding_mode"), + f(self.solution_id, "solution_id")); } std::string name() const { return "gpu::convolution"; } shape compute_shape(const std::vector& inputs) const; argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; - shape compile(context& ctx, const shape& output_shape, std::vector inputs); + shape find(context& ctx, const shape& output_shape, std::vector inputs); void finalize(context& ctx, const shape& output_shape, std::vector inputs); std::ptrdiff_t output_alias(const std::vector& shapes) const { diff --git a/src/targets/gpu/include/migraphx/gpu/deconvolution.hpp b/src/targets/gpu/include/migraphx/gpu/deconvolution.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a45d4ecce8b6a9f2137773c34e62029291e8583e --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/deconvolution.hpp @@ -0,0 +1,44 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DECONVOLUTION_HPP +#define MIGRAPHX_GUARD_RTGLIB_DECONVOLUTION_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct context; + +struct miopen_deconvolution +{ + op::deconvolution op; + shared cd; + miopenConvFwdAlgorithm_t algo{}; + miopenHandle_t handle = nullptr; + + template + static auto reflect(Self& self, F f) + { + // TODO: Add algo + return op::convolution::reflect(self.op, f); + } + + std::string name() const { return "gpu::deconv"; } + shape compute_shape(const std::vector& inputs) const; + argument + compute(context& ctx, const shape& output_shape, const std::vector& args) const; + shape compile(context& ctx, const shape& output_shape, std::vector inputs); + void finalize(context& ctx, const shape& output_shape, std::vector inputs); + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/acosh.hpp b/src/targets/gpu/include/migraphx/gpu/device/acosh.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4bf210453aaad8d7e860dd29597f5c3c688a263e --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/acosh.hpp @@ -0,0 +1,20 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ACOSH_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ACOSH_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void acosh(hipStream_t stream, const argument& result, const argument& arg); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/add_clip.hpp b/src/targets/gpu/include/migraphx/gpu/device/add_clip.hpp index 9e056dbf13c3b9c4db1a3d3da1cfd02221373fca..0c6431d48929378adc119693ec3dbd4e813459c9 100644 --- a/src/targets/gpu/include/migraphx/gpu/device/add_clip.hpp +++ b/src/targets/gpu/include/migraphx/gpu/device/add_clip.hpp @@ -15,16 +15,16 @@ void add_clip(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2, - float max, - float min); + const argument& min_arg, + const argument& max_arg); void add_clip(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3, - float max, - float min); + const argument& min_arg, + const argument& max_arg); } // namespace device } // namespace gpu diff --git a/src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp b/src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp index cbcd06642def0dfbeff8fb0ce13128dca6daf2c7..28117e8f532792c940f69e928569deb86f0fb355 100644 --- a/src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp +++ b/src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp @@ -72,15 +72,15 @@ template void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis) { auto arg_shape = arg.get_shape(); - auto lens = arg_shape.lens(); - auto batch_lens = lens; - size_t batch_item_num = lens[axis]; + auto batch_lens = arg_shape.lens(); + size_t batch_item_num = batch_lens[axis]; batch_lens[axis] = 1; migraphx::shape batch_shape{arg_shape.type(), batch_lens}; + migraphx::shape std_arg_shape{arg_shape.type(), arg_shape.lens()}; - hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) { - auto output = device_cast(result.get().data()); - using type = device_type>; + hip_visit_all(arg, std_arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) { + auto* output = device_cast(result.get().data()); + using type = device_type>; // use one block for items in one batch. const size_t max_block_size = 256; const std::size_t block_size = compute_block_size(batch_item_num, max_block_size); diff --git a/src/targets/gpu/include/migraphx/gpu/device/asinh.hpp b/src/targets/gpu/include/migraphx/gpu/device/asinh.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fdcbac7c9010011286255cc89ea6e6cc76b271ac --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/asinh.hpp @@ -0,0 +1,20 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ASINH_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ASINH_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void asinh(hipStream_t stream, const argument& result, const argument& arg); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/atanh.hpp b/src/targets/gpu/include/migraphx/gpu/device/atanh.hpp new file mode 100644 index 0000000000000000000000000000000000000000..014c40fdab39e233c629075b9846c09f06fd5142 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/atanh.hpp @@ -0,0 +1,20 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ATANH_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ATANH_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void atanh(hipStream_t stream, const argument& result, const argument& arg); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/clip.hpp b/src/targets/gpu/include/migraphx/gpu/device/clip.hpp index 8f9a5c440915815301e9cce15fd35d4d3f17a394..3a028b692f8bbff39cbfedd8fbb288955059a34e 100644 --- a/src/targets/gpu/include/migraphx/gpu/device/clip.hpp +++ b/src/targets/gpu/include/migraphx/gpu/device/clip.hpp @@ -10,7 +10,11 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { namespace device { -void clip(hipStream_t stream, const argument& result, const argument& arg1, float max, float min); +void clip(hipStream_t stream, + const argument& result, + const argument& arg1, + const argument& min_val, + const argument& max_val); } // namespace device } // namespace gpu diff --git a/src/targets/gpu/include/migraphx/gpu/device/contiguous.hpp b/src/targets/gpu/include/migraphx/gpu/device/contiguous.hpp index 668b271be623a9a4b3d598dcd45cdfa29da8d45a..7d2c4b2dab017ca6e8afd426235a067da4f0f40e 100644 --- a/src/targets/gpu/include/migraphx/gpu/device/contiguous.hpp +++ b/src/targets/gpu/include/migraphx/gpu/device/contiguous.hpp @@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { namespace device { -void contiguous(hipStream_t stream, argument result, argument arg); +void contiguous(hipStream_t stream, const argument& result, const argument& arg); } // namespace device } // namespace gpu diff --git a/src/targets/gpu/include/migraphx/gpu/device/equal.hpp b/src/targets/gpu/include/migraphx/gpu/device/equal.hpp new file mode 100644 index 0000000000000000000000000000000000000000..559678dac56b4587fa89517480c61e3897da5a25 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/equal.hpp @@ -0,0 +1,21 @@ + +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_EQUAL_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_EQUAL_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void equal(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/fill.hpp b/src/targets/gpu/include/migraphx/gpu/device/fill.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6d7d4d5ae01d0ac19369ef3778bddc05afecd8ce --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/fill.hpp @@ -0,0 +1,20 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_FILL_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_FILL_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void fill(hipStream_t stream, const argument& result, unsigned long val); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/gather.hpp b/src/targets/gpu/include/migraphx/gpu/device/gather.hpp index b10bb6909f4f7fff27252ffd6cb05a4b1f60f65f..02fb6bdf5e4dfcdbb92b30a91cde745a631cbb88 100644 --- a/src/targets/gpu/include/migraphx/gpu/device/gather.hpp +++ b/src/targets/gpu/include/migraphx/gpu/device/gather.hpp @@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { namespace device { -argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int axis); +argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int64_t axis); } // namespace device } // namespace gpu diff --git a/src/targets/gpu/include/migraphx/gpu/device/gelu.hpp b/src/targets/gpu/include/migraphx/gpu/device/gelu.hpp new file mode 100644 index 0000000000000000000000000000000000000000..469e63c38c1a8473a6c80a19fe45c1b199553206 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/gelu.hpp @@ -0,0 +1,32 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_GELU_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_GELU_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void gelu(hipStream_t stream, const argument& result, const argument& arg); + +void gelu_new(hipStream_t stream, const argument& result, const argument& arg); + +void add_gelu(hipStream_t stream, + const argument& result, + const argument& arg1, + const argument& arg2); + +void add_gelu_new(hipStream_t stream, + const argument& result, + const argument& arg1, + const argument& arg2); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/greater.hpp b/src/targets/gpu/include/migraphx/gpu/device/greater.hpp new file mode 100644 index 0000000000000000000000000000000000000000..03588ddcbaa5ba2c1a66b0a49d4d4ea9311b07ba --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/greater.hpp @@ -0,0 +1,24 @@ + +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_GREATER_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_GREATER_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void greater(hipStream_t stream, + const argument& result, + const argument& arg1, + const argument& arg2); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/layernorm.hpp b/src/targets/gpu/include/migraphx/gpu/device/layernorm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6508b0031e3f78149698901ba01b71bf193af004 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/layernorm.hpp @@ -0,0 +1,26 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_LAYERNORM_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_LAYERNORM_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void layernorm(hipStream_t stream, const argument& result, const argument& arg1); + +void triadd_layernorm(hipStream_t stream, + const argument& result, + const argument& arg1, + const argument& arg2, + const argument& arg3); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/less.hpp b/src/targets/gpu/include/migraphx/gpu/device/less.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8e7c99d51adaeb187d98f6881891dc79e0fac099 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/less.hpp @@ -0,0 +1,21 @@ + +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_LESS_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_LESS_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void less(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/logical_and.hpp b/src/targets/gpu/include/migraphx/gpu/device/logical_and.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f7182ac879cc86f553f9d827f73872095fb865e4 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/logical_and.hpp @@ -0,0 +1,23 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_LOGICAL_AND_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_LOGICAL_AND_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void logical_and(hipStream_t stream, + const argument& result, + const argument& arg1, + const argument& arg2); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/logical_or.hpp b/src/targets/gpu/include/migraphx/gpu/device/logical_or.hpp new file mode 100644 index 0000000000000000000000000000000000000000..df10071712c69789bc12619c1260e69ec9d8a2d9 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/logical_or.hpp @@ -0,0 +1,23 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_LOGICAL_OR_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_LOGICAL_OR_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void logical_or(hipStream_t stream, + const argument& result, + const argument& arg1, + const argument& arg2); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/logical_xor.hpp b/src/targets/gpu/include/migraphx/gpu/device/logical_xor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..698ddd890948a8901cfca90d63a74955e71ca7b1 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/logical_xor.hpp @@ -0,0 +1,23 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_LOGICAL_XOR_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_LOGICAL_XOR_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void logical_xor(hipStream_t stream, + const argument& result, + const argument& arg1, + const argument& arg2); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/logsoftmax.hpp b/src/targets/gpu/include/migraphx/gpu/device/logsoftmax.hpp index ab558c50fdd377c3ae6d5e11fbd0ab9404df967e..e1217bce90c2757361c988a17bce893b58eb3db4 100644 --- a/src/targets/gpu/include/migraphx/gpu/device/logsoftmax.hpp +++ b/src/targets/gpu/include/migraphx/gpu/device/logsoftmax.hpp @@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { namespace device { -void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int axis); +void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis); } // namespace device } // namespace gpu diff --git a/src/targets/gpu/include/migraphx/gpu/device/multinomial.hpp b/src/targets/gpu/include/migraphx/gpu/device/multinomial.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fbd956a29169fc555e7ef7a2786a60fae1a0ae81 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/multinomial.hpp @@ -0,0 +1,23 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_MULTINOMIAL_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_MULTINOMIAL_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void multinomial(hipStream_t stream, + const argument& result, + const argument& arg0, + const argument& arg1); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/nonzero.hpp b/src/targets/gpu/include/migraphx/gpu/device/nonzero.hpp new file mode 100644 index 0000000000000000000000000000000000000000..172dcde85019fba485eac888087c8a01de88da7c --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/nonzero.hpp @@ -0,0 +1,20 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_NONZERO_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_NONZERO_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +argument nonzero(hipStream_t stream, const argument& result, const argument& arg_data); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/prefix_scan_sum.hpp b/src/targets/gpu/include/migraphx/gpu/device/prefix_scan_sum.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a336e459ec6627e177c4acbf40383eb65d63f5f5 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/prefix_scan_sum.hpp @@ -0,0 +1,24 @@ +#ifndef MIGRAPHX_GUARD_DEVICE_PREFIX_SCAN_SUM_HPP +#define MIGRAPHX_GUARD_DEVICE_PREFIX_SCAN_SUM_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void prefix_scan_sum(hipStream_t stream, + const argument& result, + const argument& arg, + int32_t axis, + bool exclusive, + bool reverse); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_DEVICE_PREFIX_SCAN_SUM_HPP diff --git a/src/targets/gpu/include/migraphx/gpu/device/prelu.hpp b/src/targets/gpu/include/migraphx/gpu/device/prelu.hpp new file mode 100644 index 0000000000000000000000000000000000000000..671a5f16be177fb2c3b3b1fb3b8f733207de87cb --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/prelu.hpp @@ -0,0 +1,21 @@ + +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_PRELU_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_PRELU_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void prelu(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/recip.hpp b/src/targets/gpu/include/migraphx/gpu/device/recip.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b6b9ea961abe690e129e7dfb906123f9002dac27 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/recip.hpp @@ -0,0 +1,20 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_RECIP_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_RECIP_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void recip(hipStream_t stream, const argument& result, const argument& arg); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/reduce_prod.hpp b/src/targets/gpu/include/migraphx/gpu/device/reduce_prod.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cd40dd045f48bb28fbffaee3cdd7499650904116 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/reduce_prod.hpp @@ -0,0 +1,20 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_PROD_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_PROD_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void reduce_prod(hipStream_t stream, const argument& result, const argument& arg); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/reverse.hpp b/src/targets/gpu/include/migraphx/gpu/device/reverse.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4802d0494d7f005e8ae888d46738731b268945e4 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/reverse.hpp @@ -0,0 +1,21 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_REVERSE_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_REVERSE_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +argument +reverse(hipStream_t stream, argument result, argument arg1, const std::vector& axes); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/rnn_variable_seq_lens.hpp b/src/targets/gpu/include/migraphx/gpu/device/rnn_variable_seq_lens.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1259e31365524c5cb760fce764138740c8bb17cf --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/rnn_variable_seq_lens.hpp @@ -0,0 +1,35 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_RNN_VARIABLE_SEQ_LENS_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_RNN_VARIABLE_SEQ_LENS_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void rnn_var_sl_shift_sequence(hipStream_t stream, + const argument& result, + const argument& arg_hs, + const argument& arg_sl); + +void rnn_var_sl_shift_output(hipStream_t stream, + const argument& result, + const argument& arg_hs, + const argument& arg_sl, + bool is_reverse); + +void rnn_var_sl_last_output(hipStream_t stream, + const argument& result, + const argument& arg_hs, + const argument& arg_sl, + bool is_reverse); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/scatter.hpp b/src/targets/gpu/include/migraphx/gpu/device/scatter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..61373698fbd2051af5658d8e51926b2f6a215249 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/scatter.hpp @@ -0,0 +1,21 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SCATTER_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SCATTER_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +argument scatter( + hipStream_t stream, argument result, argument arg0, argument arg1, argument arg2, int64_t axis); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/softmax.hpp b/src/targets/gpu/include/migraphx/gpu/device/softmax.hpp index e5795b4c4d53e9d8543e6d3d9dedf3f45bde4337..1bc1d980bbba91e6aca709768af74d6cfca21f33 100644 --- a/src/targets/gpu/include/migraphx/gpu/device/softmax.hpp +++ b/src/targets/gpu/include/migraphx/gpu/device/softmax.hpp @@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { namespace gpu { namespace device { -void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis); +void softmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis); } // namespace device } // namespace gpu diff --git a/src/targets/gpu/include/migraphx/gpu/device/topk.hpp b/src/targets/gpu/include/migraphx/gpu/device/topk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..eac7cad265a3a338f1dff700112e4ef8ab9d80cd --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/topk.hpp @@ -0,0 +1,32 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_TOPK_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_TOPK_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +argument topk_smallest(hipStream_t stream, + const argument& val_res, + const argument& ind_res, + const argument& arg, + int64_t k, + int64_t axis); + +argument topk_largest(hipStream_t stream, + const argument& val_res, + const argument& ind_res, + const argument& arg, + int64_t k, + int64_t axis); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/unary_not.hpp b/src/targets/gpu/include/migraphx/gpu/device/unary_not.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1ed690285fe28878f55e996c03c2f2bb13191107 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/unary_not.hpp @@ -0,0 +1,20 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_UNARY_NOT_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_UNARY_NOT_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void unary_not(hipStream_t stream, const argument& result, const argument& arg); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device/where.hpp b/src/targets/gpu/include/migraphx/gpu/device/where.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3726380b816be3beb6a6513c5e477bdeae0f3c0f --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device/where.hpp @@ -0,0 +1,24 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_WHERE_HPP +#define MIGRAPHX_GUARD_RTGLIB_DEVICE_WHERE_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { +namespace device { + +void where(hipStream_t stream, + const argument& result, + const argument& arg0, + const argument& arg1, + const argument& arg2); + +} // namespace device +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/device_name.hpp b/src/targets/gpu/include/migraphx/gpu/device_name.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0d4a543889ce72838b949a956aa8bce2e9dc85f1 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/device_name.hpp @@ -0,0 +1,16 @@ +#ifndef MIGRAPHX_GUARD_GPU_DEVICE_NAME_HPP +#define MIGRAPHX_GUARD_GPU_DEVICE_NAME_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +std::string get_device_name(); + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_DEVICE_NAME_HPP diff --git a/src/targets/gpu/include/migraphx/gpu/eliminate_workspace.hpp b/src/targets/gpu/include/migraphx/gpu/eliminate_workspace.hpp index 1f144a475199dffcd33b9887a4c162e39feac885..be1a47d220f89992ed1bdb76636d52723fc51bc6 100644 --- a/src/targets/gpu/include/migraphx/gpu/eliminate_workspace.hpp +++ b/src/targets/gpu/include/migraphx/gpu/eliminate_workspace.hpp @@ -7,14 +7,14 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; namespace gpu { struct eliminate_workspace { std::string name() const { return "eliminate_workspace"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace gpu } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/gpu/include/migraphx/gpu/elu.hpp b/src/targets/gpu/include/migraphx/gpu/elu.hpp index f97459756a8f448006607630d31e88c17c85825d..cc114d55c5f30ff67905c3fcddc63c3900f30a72 100644 --- a/src/targets/gpu/include/migraphx/gpu/elu.hpp +++ b/src/targets/gpu/include/migraphx/gpu/elu.hpp @@ -1,7 +1,9 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_ELU_HPP #define MIGRAPHX_GUARD_RTGLIB_ELU_HPP +#include #include +#include #include namespace migraphx { @@ -12,18 +14,20 @@ struct context; struct miopen_elu { + op::elu op; shared ad; template static auto reflect(Self& self, F f) { - return gpu::reflect(self.ad.get(), f); + return migraphx::reflect(self.op, f); } std::string name() const { return "gpu::elu"; } shape compute_shape(const std::vector& inputs) const; argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; + void finalize(context&, const shape&, const std::vector&); std::ptrdiff_t output_alias(const std::vector& shapes) const { return shapes.size() - 1; diff --git a/src/targets/gpu/include/migraphx/gpu/equal.hpp b/src/targets/gpu/include/migraphx/gpu/equal.hpp new file mode 100644 index 0000000000000000000000000000000000000000..90d7eb4f85374515b9e4ae3641b6781bb8c758d2 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/equal.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_EQUAL_HPP +#define MIGRAPHX_GUARD_RTGLIB_EQUAL_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct hip_equal : binary_device +{ +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/fuse_ops.hpp b/src/targets/gpu/include/migraphx/gpu/fuse_ops.hpp index 6f5f2d062f8e7be94e43751d9f89a737902fe9af..4ea17a851c12e7cdfdbff0944eb4bae707d8f781 100644 --- a/src/targets/gpu/include/migraphx/gpu/fuse_ops.hpp +++ b/src/targets/gpu/include/migraphx/gpu/fuse_ops.hpp @@ -1,20 +1,22 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_FUSE_OPS_HPP #define MIGRAPHX_GUARD_RTGLIB_FUSE_OPS_HPP -#include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +struct module; + namespace gpu { struct fuse_ops { - context* ctx = nullptr; + context* ctx = nullptr; + bool fast_math = true; std::string name() const { return "gpu::fuse_ops"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace gpu diff --git a/src/targets/gpu/include/migraphx/gpu/gather.hpp b/src/targets/gpu/include/migraphx/gpu/gather.hpp index 7abc6740e10e6a66585f4dec4b0db0e7271232d9..023bae0085c058c40850c154d125badc43a1f59d 100644 --- a/src/targets/gpu/include/migraphx/gpu/gather.hpp +++ b/src/targets/gpu/include/migraphx/gpu/gather.hpp @@ -1,7 +1,8 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_GATHER_HPP #define MIGRAPHX_GUARD_RTGLIB_GATHER_HPP -#include +#include +#include #include #include diff --git a/src/targets/gpu/include/migraphx/gpu/gemm.hpp b/src/targets/gpu/include/migraphx/gpu/gemm.hpp index 5f69ed3d302aeb325c4b55a74da4b8aa4d9c6fa9..5ce3b20f7acb6b673fc8ea9eac50f3595f23d8a0 100644 --- a/src/targets/gpu/include/migraphx/gpu/gemm.hpp +++ b/src/targets/gpu/include/migraphx/gpu/gemm.hpp @@ -1,7 +1,11 @@ -#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_HPP -#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP +#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP +#define MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP +#include +#include +#include #include +#include #include #include #include @@ -14,15 +18,24 @@ namespace gpu { struct context; +void blas_shape(const shape& s); + template struct rocblas_gemm { Op op; + float alpha = 1; + float beta = 0; + bool int8_x4_format = true; + bool compute_fp32 = false; template static auto reflect(Self& self, F f) { - return migraphx::reflect(self.op, f); + return pack_join(migraphx::reflect(self.op, f), + pack(f(self.alpha, "alpha"), + f(self.beta, "beta"), + f(self.int8_x4_format, "int8_x4_format"))); } std::string name() const @@ -38,9 +51,31 @@ struct rocblas_gemm { std::vector in_shapes(inputs); in_shapes.pop_back(); - check_shapes{in_shapes}.not_broadcasted(); - batch_not_transposed(inputs[0].strides()); - batch_not_transposed(inputs[1].strides()); + check_shapes{in_shapes, *this}.not_broadcasted(); + blas_shape(inputs[0]); + blas_shape(inputs[1]); + // if gemm and add are fused + if(in_shapes.size() > 2) + { + auto cmat_shape = in_shapes.back(); + in_shapes.pop_back(); + blas_shape(cmat_shape); + auto op_out_shape = op.compute_shape(in_shapes); + if(cmat_shape.lens() != op_out_shape.lens()) + { + MIGRAPHX_THROW(this->name() + " : dimension mismatch, operand C: {" + + to_string_range(cmat_shape.lens()) + + "}, cannot add to operand A * B: {" + + to_string_range(op_out_shape.lens()) + "}"); + } + if(cmat_shape.type() != op_out_shape.type()) + { + MIGRAPHX_THROW(this->name() + " : operand C type mismatch, operand C is of type: " + + to_string(cmat_shape.type()) + + ", it must be: " + to_string(op_out_shape.type())); + } + return op_out_shape; + } return op.compute_shape(in_shapes); } @@ -48,24 +83,21 @@ struct rocblas_gemm argument compute(context& ctx, const shape& output_shape, const std::vector& args) const { - gemm(ctx, output_shape, args, op.alpha, op.beta); - return args.back(); - } - - void batch_not_transposed(const std::vector& strides) const - { - if(strides.size() <= 2) - return; - auto dim_0 = strides.size() - 2; - auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]); - std::vector batch(strides.begin(), strides.begin() + dim_0); - if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) { - return (i < j or i < matrix_size or j < matrix_size); - }) != batch.end()) + if(this->name() == "gpu::gemm") { - MIGRAPHX_THROW("GPU_GEMM: batch size {" + to_string_range(strides) + - "} is transposed!"); + gemm(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32); } + else + { + gemm(ctx, + output_shape, + args, + int32_t(alpha), + int32_t(beta), + int8_x4_format, + compute_fp32); + } + return args.back(); } std::ptrdiff_t output_alias(const std::vector& shapes) const diff --git a/src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp b/src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp index e28716a84cd57ec3b1b5186783f27f15601bfe97..bd314224e30ecef04afcafef9face1a289f456cb 100644 --- a/src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp +++ b/src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp @@ -13,12 +13,16 @@ void gemm(context& ctx, const shape& output_shape, const std::vector& args, float alpha, - float beta); + float beta, + bool int8_x4_format, + bool compute_fp32); void gemm(context& ctx, const shape& output_shape, const std::vector& args, int32_t alpha, - int32_t beta); + int32_t beta, + bool int8_x4_format, + bool compute_fp32); } // namespace gpu } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/gpu/include/migraphx/gpu/greater.hpp b/src/targets/gpu/include/migraphx/gpu/greater.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c88b5cf720817551e334334854a4d0a7f8021f59 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/greater.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_GREATER_HPP +#define MIGRAPHX_GUARD_RTGLIB_GREATER_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct hip_greater : binary_device +{ +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/hip.hpp b/src/targets/gpu/include/migraphx/gpu/hip.hpp index e295588d04aedff5743a8ee1a2b57c2d035aeb6b..d150373033cb678f6b911d3fb8d89d81cdcfe539 100644 --- a/src/targets/gpu/include/migraphx/gpu/hip.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hip.hpp @@ -3,7 +3,9 @@ #include #include +#include #include +#include #include namespace migraphx { @@ -23,6 +25,7 @@ argument from_gpu(const argument& arg); void set_device(std::size_t id); void gpu_sync(); +void gpu_sync(const context& ctx); void gpu_copy(context& ctx, const argument& src, const argument& dst); void copy_to_gpu(context& ctx, const argument& src, const argument& dst); @@ -44,7 +47,7 @@ struct hip_allocate std::string name() const { return "hip::allocate"; } shape compute_shape(const std::vector& inputs) const { - check_shapes{inputs}.has(0); + check_shapes{inputs, *this}.has(0); return s; } argument compute(context&, const shape& output_shape, const std::vector&) const @@ -53,7 +56,7 @@ struct hip_allocate } }; -struct hip_sync +struct hip_sync_device { std::string tag{}; @@ -63,21 +66,47 @@ struct hip_sync return pack(f(self.tag, "tag")); } - std::string name() const { return "hip::sync"; } + std::string name() const { return "hip::sync_device"; } shape compute_shape(const std::vector& inputs) const { if(inputs.empty()) return {}; - else - return inputs.front(); + return inputs.front(); } + argument compute(context&, const shape&, const std::vector& args) const { gpu_sync(); if(args.empty()) return {}; - else - return args.front(); + return args.front(); + } +}; + +struct hip_sync_stream +{ + std::string tag{}; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.tag, "tag")); + } + + std::string name() const { return "hip::sync_stream"; } + shape compute_shape(const std::vector& inputs) const + { + if(inputs.empty()) + return {}; + return inputs.front(); + } + + argument compute(context& ctx, const shape&, const std::vector& args) const + { + gpu_sync(ctx); + if(args.empty()) + return {}; + return args.front(); } }; @@ -86,7 +115,7 @@ struct hip_copy_to_gpu std::string name() const { return "hip::copy_to_gpu"; } shape compute_shape(std::vector inputs) const { - check_shapes{inputs}.has(1, 2); + check_shapes{inputs, *this}.has(1, 2); return inputs.at(0); } argument compute(context& ctx, const shape&, const std::vector& args) const @@ -112,7 +141,7 @@ struct hip_copy_from_gpu std::string name() const { return "hip::copy_from_gpu"; } shape compute_shape(std::vector inputs) const { - check_shapes{inputs}.has(1, 2); + check_shapes{inputs, *this}.has(1, 2); return inputs.at(0); } argument @@ -140,7 +169,7 @@ struct hip_copy std::string name() const { return "hip::copy"; } shape compute_shape(std::vector inputs) const { - check_shapes{inputs}.has(2).standard(); + check_shapes{inputs, *this}.has(2); return inputs.at(1); } argument compute(context& ctx, const shape&, std::vector args) const @@ -151,7 +180,9 @@ struct hip_copy std::ptrdiff_t output_alias(const std::vector&) const { return 1; } }; -struct hip_load_memory +void store_preallocated_param(context& ctx, const std::string& id, const argument& a); + +struct hip_allocate_memory { shape s; std::string id{}; @@ -162,16 +193,58 @@ struct hip_load_memory return pack(f(self.s, "shape"), f(self.id, "id")); } - std::string name() const { return "hip::hip_load_memory"; } + std::string name() const { return "hip::hip_allocate_memory"; } shape compute_shape(const std::vector& inputs) const { - check_shapes{inputs}.has(0); + check_shapes{inputs, *this}.has(0); return s; } + argument compute(context& ctx, const shape&, const std::vector&) const { return get_preallocation(ctx, id); } + + void finalize(context& ctx, const shape&, const std::vector&) const + { + argument a = allocate_gpu(s); + store_preallocated_param(ctx, id, a); + } +}; + +struct hip_copy_literal +{ + literal l; + std::string id{}; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.l, "literal"), f(self.id, "id")); + } + + std::string name() const { return "hip::hip_copy_literal"; } + shape compute_shape(const std::vector& inputs) const + { + check_shapes{inputs, *this}.has(0); + return l.get_shape(); + } + + argument compute(context& ctx, const shape&, const std::vector&) const + { + return get_preallocation(ctx, id); + } + + void finalize(context& ctx, const shape&, const std::vector&) const + { + argument a = to_gpu(l.get_argument()); + store_preallocated_param(ctx, id, a); + } + friend std::ostream& operator<<(std::ostream& os, const hip_copy_literal& x) + { + os << x.name() << "[id=" << x.id << "]"; + return os; + } }; } // namespace gpu diff --git a/src/targets/gpu/include/migraphx/gpu/kernel.hpp b/src/targets/gpu/include/migraphx/gpu/kernel.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a7298751e9c7ed7d44d98a62a8067ae06417a283 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/kernel.hpp @@ -0,0 +1,52 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_KERNEL_HPP +#define MIGRAPHX_GUARD_RTGLIB_KERNEL_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct kernel_impl; + +struct kernel +{ + kernel() = default; + kernel(const char* image, const std::string& name); + template + kernel(const std::vector& image, const std::string& name) + : kernel(reinterpret_cast(image.data()), name) + { + } + + void launch(hipStream_t stream, + std::size_t global, + std::size_t local, + const std::vector& args) const; + + void launch(hipStream_t stream, + std::size_t global, + std::size_t local, + std::vector args) const; + + auto launch(hipStream_t stream, std::size_t global, std::size_t local) const + { + return [=](auto&&... xs) { + launch(stream, global, local, std::vector{xs...}); + }; + } + + private: + std::shared_ptr impl; +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/leaky_relu.hpp b/src/targets/gpu/include/migraphx/gpu/leaky_relu.hpp index 622245ab3237545d35895469c5daa9709a4c2deb..3e9d4838f82c2e094641652d9dcb026c2584d75a 100644 --- a/src/targets/gpu/include/migraphx/gpu/leaky_relu.hpp +++ b/src/targets/gpu/include/migraphx/gpu/leaky_relu.hpp @@ -1,7 +1,9 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_LEAKY_RELU_HPP #define MIGRAPHX_GUARD_RTGLIB_LEAKY_RELU_HPP +#include #include +#include #include namespace migraphx { @@ -12,18 +14,20 @@ struct context; struct miopen_leaky_relu { + op::leaky_relu op; shared ad; template static auto reflect(Self& self, F f) { - return gpu::reflect(self.ad.get(), f); + return migraphx::reflect(self.op, f); } std::string name() const { return "gpu::leaky_relu"; } shape compute_shape(const std::vector& inputs) const; argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; + void finalize(context&, const shape&, const std::vector&); std::ptrdiff_t output_alias(const std::vector& shapes) const { return shapes.size() - 1; diff --git a/src/targets/gpu/include/migraphx/gpu/less.hpp b/src/targets/gpu/include/migraphx/gpu/less.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cab4b2fc28f61482cf4e961c2816966de3b3e223 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/less.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_LESS_HPP +#define MIGRAPHX_GUARD_RTGLIB_LESS_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct hip_less : binary_device +{ +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/logical_and.hpp b/src/targets/gpu/include/migraphx/gpu/logical_and.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1e603f2f1ad064a18469c16c75b2322d5b9ee3ab --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/logical_and.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_LOGICLA_AND_HPP +#define MIGRAPHX_GUARD_RTGLIB_LOGICLA_AND_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct hip_logical_and : binary_device +{ +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/logical_or.hpp b/src/targets/gpu/include/migraphx/gpu/logical_or.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b51a68177327365057b630b4f38eb0058559f0b2 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/logical_or.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_LOGICAL_OR_HPP +#define MIGRAPHX_GUARD_RTGLIB_LOGICAL_OR_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct hip_logical_or : binary_device +{ +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/logical_xor.hpp b/src/targets/gpu/include/migraphx/gpu/logical_xor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..56290c62321ce8a67b86b1f407527ff05df24b70 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/logical_xor.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_LOGICAL_XOR_HPP +#define MIGRAPHX_GUARD_RTGLIB_LOGICAL_XOR_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct hip_logical_xor : binary_device +{ +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/loop.hpp b/src/targets/gpu/include/migraphx/gpu/loop.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6f31daf3f16e78c3cf6da914b6919b238b09c26a --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/loop.hpp @@ -0,0 +1,43 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_LOOP_HPP +#define MIGRAPHX_GUARD_RTGLIB_LOOP_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct context; + +struct hip_loop +{ + op::loop op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "gpu::loop"; } + shape compute_shape(std::vector inputs, std::vector mods) const; + argument + compute(context& ctx, + const shape& output_shape, + const std::vector& args, + const std::vector& mods, + const std::function( + module_ref&, const std::unordered_map&)>& run) const; + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/lowering.hpp b/src/targets/gpu/include/migraphx/gpu/lowering.hpp index 8a4fd1267e9c557913392ba4731ff250598a9872..04cd19af8f9134f5ac01e56c145eee7cc6bfc30b 100644 --- a/src/targets/gpu/include/migraphx/gpu/lowering.hpp +++ b/src/targets/gpu/include/migraphx/gpu/lowering.hpp @@ -1,20 +1,21 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_MIOPEN_LOWERING_HPP #define MIGRAPHX_GUARD_RTGLIB_MIOPEN_LOWERING_HPP -#include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -namespace gpu { +struct module; + +namespace gpu { struct lowering { context* ctx; bool offload_copy; std::string name() const { return "gpu::lowering"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace gpu diff --git a/src/targets/gpu/include/migraphx/gpu/lrn.hpp b/src/targets/gpu/include/migraphx/gpu/lrn.hpp index be6a67f786f5da78393da9e9d514b66facd40eb7..db8111c39577eeb50f09c2f636c0b8e8abf7ee08 100644 --- a/src/targets/gpu/include/migraphx/gpu/lrn.hpp +++ b/src/targets/gpu/include/migraphx/gpu/lrn.hpp @@ -2,6 +2,7 @@ #define MIGRAPHX_GUARD_RTGLIB_LRN_HPP #include +#include #include namespace migraphx { @@ -12,18 +13,20 @@ struct context; struct miopen_lrn { + op::lrn op; shared ldesc; template static auto reflect(Self& self, F f) { - return gpu::reflect(self.ldesc.get(), f); + return migraphx::reflect(self.op, f); } std::string name() const { return "gpu::lrn"; } shape compute_shape(const std::vector& inputs) const; argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; + void finalize(context&, const shape&, const std::vector&); std::ptrdiff_t output_alias(const std::vector& shapes) const { return shapes.size() - 1; diff --git a/src/targets/gpu/include/migraphx/gpu/miopen.hpp b/src/targets/gpu/include/migraphx/gpu/miopen.hpp index 2c341003f661302de3f1146bedf58b7acfa92df6..e505ca5b3513c7494d75acd90914b304c0dc4af0 100644 --- a/src/targets/gpu/include/migraphx/gpu/miopen.hpp +++ b/src/targets/gpu/include/migraphx/gpu/miopen.hpp @@ -2,12 +2,21 @@ #define MIGRAPHX_GUARD_MIGRAPHLIB_MIOPEN_HPP #include +#include #include #include #include #include #include +#include + +#ifdef MIGRAPHX_HAS_FIND_MODE_API +extern "C" miopenStatus_t +miopenHiddenSetConvolutionFindMode(miopenConvolutionDescriptor_t convDesc, // NOLINT + int findMode); // NOLINT +#endif + namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { @@ -38,8 +47,9 @@ Result make_obj(F f, Ts... xs) return r; } -inline tensor_descriptor make_tensor(const migraphx::shape& s, bool pack = false) +inline tensor_descriptor make_tensor(const migraphx::shape& os, bool pack = false) { + auto s = os.normalize_standard(); auto t = make_obj(&miopenCreateTensorDescriptor); // Convert to ints std::vector lens(s.lens().begin(), s.lens().end()); @@ -81,14 +91,42 @@ inline convolution_descriptor make_conv(const T& op) miopenConvolutionMode_t c_mode = miopenConvolution; if(op.group > 1) c_mode = miopenGroupConv; - miopenInitConvolutionDescriptor(c.get(), - c_mode, - op.padding[0], - op.padding[1], - op.stride[0], - op.stride[1], - op.dilation[0], - op.dilation[1]); + + int kdims = op.kdims(); + std::vector padding(std::max(2, kdims), 0); + std::vector stride(std::max(2, kdims), 1); + std::vector dilation(std::max(2, kdims), 1); + + std::copy_backward(op.padding.begin(), op.padding.begin() + kdims, padding.end()); + std::copy_backward(op.stride.begin(), op.stride.end(), stride.end()); + std::copy_backward(op.dilation.begin(), op.dilation.end(), dilation.end()); + + miopenInitConvolutionNdDescriptor( + c.get(), padding.size(), padding.data(), stride.data(), dilation.data(), c_mode); + if(op.group > 1) + miopenSetConvolutionGroupCount(c.get(), op.group); +#ifdef MIGRAPHX_HAS_FIND_MODE_API + miopenHiddenSetConvolutionFindMode(c.get(), 1); // Normal mode +#endif + return c; +} + +template +inline convolution_descriptor make_deconv(const T& op) +{ + auto c = make_obj(&miopenCreateConvolutionDescriptor); + miopenConvolutionMode_t c_mode = miopenTranspose; + int kdims = op.kdims(); + std::vector padding(std::max(2, kdims), 0); + std::vector stride(std::max(2, kdims), 1); + std::vector dilation(std::max(2, kdims), 1); + + std::copy_backward(op.padding.begin(), op.padding.end(), padding.end()); + std::copy_backward(op.stride.begin(), op.stride.end(), stride.end()); + std::copy_backward(op.dilation.begin(), op.dilation.end(), dilation.end()); + + miopenInitConvolutionNdDescriptor( + c.get(), padding.size(), padding.data(), stride.data(), dilation.data(), c_mode); if(op.group > 1) miopenSetConvolutionGroupCount(c.get(), op.group); return c; @@ -97,19 +135,29 @@ inline convolution_descriptor make_conv(const T& op) inline pooling_descriptor make_pooling(const migraphx::op::pooling& op) { miopenPoolingMode_t mode; - if(op.mode == "max") + if(op.mode == op::pooling_mode::max) mode = miopenPoolingMax; - else + else if(op.mode == op::pooling_mode::average) mode = miopenPoolingAverage; + else + { + std::stringstream ss("Unknown mode for pooling: "); + ss << op.mode; + MIGRAPHX_THROW(ss.str()); + } auto p = make_obj(&miopenCreatePoolingDescriptor); - miopenSet2dPoolingDescriptor(p.get(), - mode, - op.lengths[0], - op.lengths[1], - op.padding[0], - op.padding[1], - op.stride[0], - op.stride[1]); + + int kdims = op.kdims(); + std::vector padding(std::max(2, kdims), 0); + std::vector stride(std::max(2, kdims), 1); + std::vector lengths(std::max(2, kdims), 1); + + std::copy_backward(op.padding.begin(), op.padding.begin() + kdims, padding.end()); + std::copy_backward(op.stride.begin(), op.stride.end(), stride.end()); + std::copy_backward(op.lengths.begin(), op.lengths.end(), lengths.end()); + + miopenSetNdPoolingDescriptor( + p.get(), mode, padding.size(), lengths.data(), padding.data(), stride.data()); return p; } diff --git a/src/targets/gpu/include/migraphx/gpu/mlir_conv.hpp b/src/targets/gpu/include/migraphx/gpu/mlir_conv.hpp new file mode 100644 index 0000000000000000000000000000000000000000..78543f1de1757444957e11b06692749900cde72a --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/mlir_conv.hpp @@ -0,0 +1,24 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_MIOPEN_MLIR_CONV_HPP +#define MIGRAPHX_GUARD_RTGLIB_MIOPEN_MLIR_CONV_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +namespace gpu { +struct mlir_conv +{ + context* ctx; + std::string name() const { return "mlir::convolution"; } + void apply(module& m) const; +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/multinomial.hpp b/src/targets/gpu/include/migraphx/gpu/multinomial.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ee23603cc0b8ccd66f00d1741e60139d5ed6b067 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/multinomial.hpp @@ -0,0 +1,36 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_MULTINOMIAL_HPP +#define MIGRAPHX_GUARD_RTGLIB_MULTINOMIAL_HPP + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct context; + +struct hip_multinomial +{ + op::multinomial op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "gpu::multinomial"; } + shape compute_shape(std::vector inputs) const; + argument + compute(context& ctx, const shape& output_shape, const std::vector& args) const; + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/nonzero.hpp b/src/targets/gpu/include/migraphx/gpu/nonzero.hpp new file mode 100644 index 0000000000000000000000000000000000000000..aee57f6f0df5fb92c8ce2d834d791cbcab425265 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/nonzero.hpp @@ -0,0 +1,39 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_NONZERO_HPP +#define MIGRAPHX_GUARD_RTGLIB_NONZERO_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct context; + +struct hip_nonzero +{ + op::nonzero op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "gpu::nonzero"; } + shape compute_shape(std::vector inputs) const; + argument + compute(context& ctx, const shape& output_shape, const std::vector& args) const; + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/oper.hpp b/src/targets/gpu/include/migraphx/gpu/oper.hpp index c920251a1478c5fba2b060737d30b60a285621d5..f1584d9812cf4e69cd15eb6c24b96e7e0608060a 100644 --- a/src/targets/gpu/include/migraphx/gpu/oper.hpp +++ b/src/targets/gpu/include/migraphx/gpu/oper.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -15,95 +16,126 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { -template -struct unary_device : oper +template +struct device_base : oper { - shape compute_shape(const std::vector& inputs) const + template + static auto reflect(Self&, F) { - check_shapes{inputs, *this}.has(2); - auto s = inputs.at(0); - if(s.packed()) - { - return s; - } - else - { - return {s.type(), s.lens()}; - } + return pack(); } - argument compute(context& ctx, const shape&, const std::vector& args) const + std::vector reduce_shapes; + + void finalize(context&, const shape&, const std::vector& inputs) { - F(ctx.get_stream().get(), args[1], args[0]); - return args[1]; + reduce_shapes = reduce_dims(inputs); } - std::ptrdiff_t output_alias(const std::vector& shapes) const + argument get_arg(const std::vector& args, std::size_t i) const { - return shapes.size() - 1; + if(reduce_shapes.empty()) + return args[i]; + return args.at(i).reshape(reduce_shapes.at(i)); } -}; -template -struct binary_device : oper -{ shape compute_shape(const std::vector& inputs) const { - check_shapes{inputs, *this}.has(3); + check_shapes{inputs, *this}.has(N + 1); auto s0 = inputs.at(0); - auto s1 = inputs.at(1); - if(s0 == s1 and s0.packed()) - { + if(std::all_of(inputs.begin(), inputs.end() - 1, [&](auto s) { return s == s0; }) and + s0.packed()) return s0; - } else - { return {s0.type(), s0.lens()}; - } } + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +template +struct unary_device : device_base +{ argument compute(context& ctx, const shape&, const std::vector& args) const { - F(ctx.get_stream().get(), args[2], args[0], args[1]); - return args[2]; + F(ctx.get_stream().get(), this->get_arg(args, 1), this->get_arg(args, 0)); + return args[1]; } +}; - std::ptrdiff_t output_alias(const std::vector& shapes) const +template +struct binary_device : device_base +{ + argument compute(context& ctx, const shape&, const std::vector& args) const { - return shapes.size() - 1; + F(ctx.get_stream().get(), + this->get_arg(args, 2), + this->get_arg(args, 0), + this->get_arg(args, 1)); + return args[2]; } }; template -struct ternary_device : oper +struct ternary_device : device_base { - shape compute_shape(const std::vector& inputs) const + argument compute(context& ctx, const shape&, const std::vector& args) const { - check_shapes{inputs, *this}.has(4); - auto s0 = inputs.at(0); - auto s1 = inputs.at(1); - auto s2 = inputs.at(2); - if(s0 == s1 and s1 == s2 and s0.packed()) - { - return s0; - } - else - { - return {s0.type(), s0.lens()}; - } + F(ctx.get_stream().get(), + this->get_arg(args, 3), + this->get_arg(args, 0), + this->get_arg(args, 1), + this->get_arg(args, 2)); + return args[3]; } +}; +template +struct quaternary_device : device_base +{ argument compute(context& ctx, const shape&, const std::vector& args) const { - F(ctx.get_stream().get(), args[3], args[0], args[1], args[2]); - return args[3]; + F(ctx.get_stream().get(), + this->get_arg(args, 4), + this->get_arg(args, 0), + this->get_arg(args, 1), + this->get_arg(args, 2), + this->get_arg(args, 3)); + return args[4]; } +}; - std::ptrdiff_t output_alias(const std::vector& shapes) const +template +struct quinary_device : device_base +{ + argument compute(context& ctx, const shape&, const std::vector& args) const { - return shapes.size() - 1; + F(ctx.get_stream().get(), + this->get_arg(args, 5), + this->get_arg(args, 0), + this->get_arg(args, 1), + this->get_arg(args, 2), + this->get_arg(args, 3), + this->get_arg(args, 4)); + return args[5]; } }; diff --git a/src/targets/gpu/include/migraphx/gpu/pack_args.hpp b/src/targets/gpu/include/migraphx/gpu/pack_args.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e3f4a24b9ec49212cde949fdea7bc236f4f26e1e --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/pack_args.hpp @@ -0,0 +1,32 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_PACK_ARGS_HPP +#define MIGRAPHX_GUARD_RTGLIB_PACK_ARGS_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct kernel_argument +{ + template , + MIGRAPHX_REQUIRES(not std::is_base_of{})> + kernel_argument(T&& x) : size(sizeof(U)), align(alignof(U)), data(&x) // NOLINT + { + } + std::size_t size; + std::size_t align; + void* data; +}; + +std::vector pack_args(const std::vector& args); + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/pack_int8_args.hpp b/src/targets/gpu/include/migraphx/gpu/pack_int8_args.hpp index 69dd3f83660f9b5e461bc00ffd9a31951c10c866..68e611bc736d7e077c619e6278fa3ea48c7b3c4e 100644 --- a/src/targets/gpu/include/migraphx/gpu/pack_int8_args.hpp +++ b/src/targets/gpu/include/migraphx/gpu/pack_int8_args.hpp @@ -13,7 +13,7 @@ namespace gpu { struct pack_int8_args { std::string name() const { return "gpu::pack_int8_args"; } - void apply(program& p) const; + void apply(module& m) const; shape pack_int8_shape(const shape& s) const; }; diff --git a/src/targets/gpu/include/migraphx/gpu/pad.hpp b/src/targets/gpu/include/migraphx/gpu/pad.hpp index e54532c38a4ff309e1cec808ffa510d1f86c2696..b5b3362cca1881009f9f11ede328f43006aafd7b 100644 --- a/src/targets/gpu/include/migraphx/gpu/pad.hpp +++ b/src/targets/gpu/include/migraphx/gpu/pad.hpp @@ -1,7 +1,8 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_PAD_HPP #define MIGRAPHX_GUARD_RTGLIB_PAD_HPP -#include +#include +#include #include namespace migraphx { diff --git a/src/targets/gpu/include/migraphx/gpu/pooling.hpp b/src/targets/gpu/include/migraphx/gpu/pooling.hpp index a6f4541ad69b1febca48e26579fed651a7a958c2..8d409f7efb8d8c8ef49265d37eddfad50edba2ea 100644 --- a/src/targets/gpu/include/migraphx/gpu/pooling.hpp +++ b/src/targets/gpu/include/migraphx/gpu/pooling.hpp @@ -1,7 +1,8 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_POOLING_HPP #define MIGRAPHX_GUARD_RTGLIB_POOLING_HPP -#include +#include +#include #include #include @@ -24,6 +25,7 @@ struct miopen_pooling std::string name() const { return "gpu::pooling"; } shape compute_shape(const std::vector& inputs) const; + void finalize(context&, const shape&, const std::vector&); argument compute(context& ctx, const shape& output_shape, const std::vector& args) const; std::ptrdiff_t output_alias(const std::vector& shapes) const diff --git a/src/targets/gpu/include/migraphx/gpu/preallocate_param.hpp b/src/targets/gpu/include/migraphx/gpu/preallocate_param.hpp deleted file mode 100644 index 3e1fc64a8999b8c33b627cf0ae039827946b7236..0000000000000000000000000000000000000000 --- a/src/targets/gpu/include/migraphx/gpu/preallocate_param.hpp +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_PREALLOCATE_PARAM_HPP -#define MIGRAPHX_GUARD_RTGLIB_GPU_PREALLOCATE_PARAM_HPP - -#include -#include -#include -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { -struct program; - -namespace gpu { - -struct preallocate_param -{ - std::string param{}; - context* ctx = nullptr; - std::string name() const { return "preallocate_param"; } - void apply(program& p) const; -}; -} // namespace gpu -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx - -#endif diff --git a/src/targets/gpu/include/migraphx/gpu/prefix_scan_sum.hpp b/src/targets/gpu/include/migraphx/gpu/prefix_scan_sum.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dbde086b08b02425541745a8f072d683d069fdc0 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/prefix_scan_sum.hpp @@ -0,0 +1,57 @@ +#ifndef MIGRAPHX_GUARD_GPU_PREFIX_SCAN_SUM_HPP +#define MIGRAPHX_GUARD_GPU_PREFIX_SCAN_SUM_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct context; + +struct hip_prefix_scan_sum : oper +{ + op::prefix_scan_sum op; + + template + static auto reflect(Self& self, T f) + { + return migraphx::reflect(self.op, f); + } + + shape compute_shape(const std::vector& inputs) const + { + std::vector in_shapes{inputs}; + in_shapes.pop_back(); + check_shapes{in_shapes, *this}.standard(); + return op.normalize_compute_shape(in_shapes); + } + + argument compute(context& ctx, const shape&, const std::vector& args) const + { + device::prefix_scan_sum( + ctx.get_stream().get(), args[1], args[0], op.axis, op.exclusive, op.reverse); + return args[1]; + } + + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_PREFIX_SCAN_SUM_HPP diff --git a/src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp b/src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a24e9e095af5799df95a453a39e1c8fc748c6ef5 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp @@ -0,0 +1,24 @@ +#ifndef MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP +#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +namespace gpu { + +struct prefuse_ops +{ + std::string name() const { return "gpu::prefuse_ops"; } + void apply(module& m) const; +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif // MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP diff --git a/src/targets/gpu/include/migraphx/gpu/prelu.hpp b/src/targets/gpu/include/migraphx/gpu/prelu.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9fbf75ad24a7faf0b16a90f7bad422baec0c0d1f --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/prelu.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_PRELU_HPP +#define MIGRAPHX_GUARD_RTGLIB_PRELU_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct hip_prelu : binary_device +{ +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/quant_convolution.hpp b/src/targets/gpu/include/migraphx/gpu/quant_convolution.hpp index 58ece093ab1d7046afefab9685519f664d719d4e..fb2dd542c2fdc4f1ca03f8b8ebd496ed97d67fe2 100644 --- a/src/targets/gpu/include/migraphx/gpu/quant_convolution.hpp +++ b/src/targets/gpu/include/migraphx/gpu/quant_convolution.hpp @@ -2,6 +2,7 @@ #define MIGRAPHX_GUARD_RTGLIB_QUANT_CONVOLUTION_HPP #include +#include #include #include @@ -14,6 +15,7 @@ struct context; struct miopen_quant_convolution { op::quant_convolution op; + bool int8_x4_format = false; shared cd; miopenConvFwdAlgorithm_t algo{}; miopenHandle_t handle = nullptr; @@ -22,7 +24,8 @@ struct miopen_quant_convolution static auto reflect(Self& self, F f) { // TODO: Add algo - return op::quant_convolution::reflect(self.op, f); + return pack_join(migraphx::reflect(self.op, f), + pack(f(self.int8_x4_format, "int8_x4_format"))); } std::string name() const { return "gpu::quant_convolution"; } diff --git a/src/targets/gpu/include/migraphx/gpu/recip.hpp b/src/targets/gpu/include/migraphx/gpu/recip.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2be4882fa452ac19b258ccdf7667e409e224d793 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/recip.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_RECIP_HPP +#define MIGRAPHX_GUARD_RTGLIB_RECIP_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct hip_recip : unary_device +{ +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/reduce_op.hpp b/src/targets/gpu/include/migraphx/gpu/reduce_op.hpp index 682503bbe2ff6ef7134f4ec4b02b5678c8168399..e212a770e4d56eeedc8d3aa31b31dce35d93c49b 100644 --- a/src/targets/gpu/include/migraphx/gpu/reduce_op.hpp +++ b/src/targets/gpu/include/migraphx/gpu/reduce_op.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -32,7 +33,8 @@ struct reduce_op : oper { std::vector in_shapes{inputs}; in_shapes.pop_back(); - return op.compute_shape(in_shapes); + check_shapes{in_shapes, *this}.standard(); + return op.normalize_compute_shape(in_shapes); } argument compute(context& ctx, const shape&, const std::vector& args) const diff --git a/src/targets/gpu/include/migraphx/gpu/reduce_prod.hpp b/src/targets/gpu/include/migraphx/gpu/reduce_prod.hpp new file mode 100644 index 0000000000000000000000000000000000000000..448e13db3dfb91af90a1a9fd27b13090fd746de5 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/reduce_prod.hpp @@ -0,0 +1,24 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_REDUCE_PROD_HPP +#define MIGRAPHX_GUARD_RTGLIB_REDUCE_PROD_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct context; + +struct hip_reduce_prod : reduce_op +{ + hip_reduce_prod() {} + hip_reduce_prod(const op::reduce_prod& op_ref) : reduce_op(op_ref) {} +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/reverse.hpp b/src/targets/gpu/include/migraphx/gpu/reverse.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5c0cdee4d516cba7e84a104057582951b2a3a49f --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/reverse.hpp @@ -0,0 +1,39 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_REVERSE_HPP +#define MIGRAPHX_GUARD_RTGLIB_REVERSE_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct context; + +struct hip_reverse +{ + op::reverse op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "gpu::reverse"; } + shape compute_shape(std::vector inputs) const; + argument + compute(context& ctx, const shape& output_shape, const std::vector& args) const; + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/rnn_variable_seq_lens.hpp b/src/targets/gpu/include/migraphx/gpu/rnn_variable_seq_lens.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ce8115fc55fd95a21def2b06f887ca25bb74c538 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/rnn_variable_seq_lens.hpp @@ -0,0 +1,78 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_RNN_VARIABLE_SEQ_LENS_HPP +#define MIGRAPHX_GUARD_RTGLIB_RNN_VARIABLE_SEQ_LENS_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct hip_rnn_var_sl_shift_sequence +{ + op::rnn_var_sl_shift_sequence op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "gpu::rnn_var_sl_shift_sequence"; } + shape compute_shape(std::vector inputs) const; + argument + compute(context& ctx, const shape& output_shape, const std::vector& args) const; + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +struct hip_rnn_var_sl_shift_output +{ + op::rnn_var_sl_shift_output op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "gpu::rnn_var_sl_shift_output"; } + shape compute_shape(std::vector inputs) const; + argument + compute(context& ctx, const shape& output_shape, const std::vector& args) const; + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +struct hip_rnn_var_sl_last_output +{ + op::rnn_var_sl_last_output op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "gpu::" + op.name(); } + shape compute_shape(std::vector inputs) const; + argument compute(context& ctx, const shape&, const std::vector& args) const; + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/scatter.hpp b/src/targets/gpu/include/migraphx/gpu/scatter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..379d4ebfd4895f57d462ddde0022965491088c8e --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/scatter.hpp @@ -0,0 +1,41 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_SCATTER_HPP +#define MIGRAPHX_GUARD_RTGLIB_SCATTER_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct context; + +struct hip_scatter +{ + // scatter_none is an exact replacement for previous op::scatter, + // renamed to match an Onnx option. Don't use base class op::scatter + op::scatter_none op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "gpu::scatter_none"; } + shape compute_shape(std::vector inputs) const; + argument + compute(context& ctx, const shape& output_shape, const std::vector& args) const; + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/schedule_model.hpp b/src/targets/gpu/include/migraphx/gpu/schedule_model.hpp index 3dac47c9ba0cca83aa8792562bded1f66c8ecfb3..0540e6c918cc8b07c758000ec24483af87ba68fc 100644 --- a/src/targets/gpu/include/migraphx/gpu/schedule_model.hpp +++ b/src/targets/gpu/include/migraphx/gpu/schedule_model.hpp @@ -8,7 +8,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; struct operation; namespace gpu { @@ -17,9 +17,9 @@ struct schedule_model { std::size_t streams = 0; std::size_t concurrency() const; - void sched(program& p, instruction_ref ins, std::size_t n) const; - void wait(program& p, instruction_ref ins, std::size_t wait_id) const; - void record(program& p, instruction_ref ins, std::size_t wait_id) const; + void sched(module& m, instruction_ref ins, std::size_t n) const; + void wait(module& m, instruction_ref ins, std::size_t wait_id) const; + void record(module& m, instruction_ref ins, std::size_t wait_id) const; std::size_t weight(const operation& op) const; }; diff --git a/src/targets/gpu/include/migraphx/gpu/sync_device.hpp b/src/targets/gpu/include/migraphx/gpu/sync_device.hpp new file mode 100644 index 0000000000000000000000000000000000000000..54ead79d159464967ded0bfc51f4cd5dd4923887 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/sync_device.hpp @@ -0,0 +1,24 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_SYNC_DEVICE_HPP +#define MIGRAPHX_GUARD_RTGLIB_GPU_SYNC_DEVICE_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +struct module; + +namespace gpu { + +struct sync_device +{ + std::string name() const { return "sync_device"; } + void apply(module& m) const; +}; +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/topk.hpp b/src/targets/gpu/include/migraphx/gpu/topk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4782ec8b813934641a9356b549079e82c6448dc7 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/topk.hpp @@ -0,0 +1,39 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_TOPK_HPP +#define MIGRAPHX_GUARD_RTGLIB_TOPK_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct context; + +struct hip_topk +{ + op::topk op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "gpu::topk"; } + shape compute_shape(std::vector inputs) const; + argument + compute(context& ctx, const shape& output_shape, const std::vector& args) const; + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/unary_not.hpp b/src/targets/gpu/include/migraphx/gpu/unary_not.hpp new file mode 100755 index 0000000000000000000000000000000000000000..ca31385d993eec158f0bfd2cdccd9723c085e902 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/unary_not.hpp @@ -0,0 +1,20 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_UNARY_NOT_HPP +#define MIGRAPHX_GUARD_RTGLIB_UNARY_NOT_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct hip_unary_not : unary_device +{ + std::string name() const { return "gpu::not"; } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/where.hpp b/src/targets/gpu/include/migraphx/gpu/where.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1847aa34b2868a48f99980ba281b5789c781aeb6 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/where.hpp @@ -0,0 +1,41 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_WHERE_HPP +#define MIGRAPHX_GUARD_RTGLIB_WHERE_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct hip_where : ternary_device +{ + shape compute_shape(const std::vector& inputs) const + { + check_shapes{inputs, *this}.has(4).same_dims(); + auto s1 = inputs.at(1); + auto s2 = inputs.at(2); + if(s1 == s2 and s1.packed()) + { + return s1; + } + else if(s1.packed() != s2.packed()) + { + return s1.packed() ? s1 : s2; + } + else if(s1.broadcasted() != s2.broadcasted()) + { + return s1.broadcasted() ? s2.with_lens(s1.lens()) : s1.with_lens(s1.lens()); + } + else + { + return {s1.type(), s1.lens()}; + } + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/include/migraphx/gpu/write_literals.hpp b/src/targets/gpu/include/migraphx/gpu/write_literals.hpp index 79f71350abb9578f1ea00d51bc680f52dc11dad5..0c7444e49a5e1944180b85f4daeb793e2ecd2a85 100644 --- a/src/targets/gpu/include/migraphx/gpu/write_literals.hpp +++ b/src/targets/gpu/include/migraphx/gpu/write_literals.hpp @@ -1,11 +1,11 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_MIOPEN_WRITE_LITERALS_HPP #define MIGRAPHX_GUARD_RTGLIB_MIOPEN_WRITE_LITERALS_HPP -#include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +struct module; namespace gpu { @@ -14,7 +14,7 @@ struct write_literals context* ctx = nullptr; std::string name() const { return "gpu::write_literals"; } - void apply(program& p) const; + void apply(module& m) const; }; } // namespace gpu diff --git a/src/targets/gpu/int8_conv_pack.cpp b/src/targets/gpu/int8_conv_pack.cpp old mode 100644 new mode 100755 index 4e9a933463c0fa932f20a130a3ce9b41008ec3fd..4c54ab0c745b1e8fde788de448fddecc9f20ff10 --- a/src/targets/gpu/int8_conv_pack.cpp +++ b/src/targets/gpu/int8_conv_pack.cpp @@ -5,10 +5,25 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { +shape pack_int8_shape(const shape& s) +{ + if(s.type() != shape::int8_type) + { + MIGRAPHX_THROW("PACK_INT8_ARGS: only process int8_type"); + } + + auto lens = s.lens(); + auto strides = s.strides(); + lens[1] = (lens[1] + 3) / 4 * 4; + strides[0] = strides[1] * lens[1]; + + return {s.type(), lens, strides}; +} + shape miopen_int8_conv_pack::compute_shape(const std::vector& inputs) const { check_shapes{{inputs.at(0)}, *this}.has(1).standard(); - return inputs.at(0); + return pack_int8_shape(inputs.at(0)); } argument diff --git a/src/targets/gpu/jit/gathernd.cpp b/src/targets/gpu/jit/gathernd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..68ceeed133471b887cd189364b00e5a430f9ea3e --- /dev/null +++ b/src/targets/gpu/jit/gathernd.cpp @@ -0,0 +1,75 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +// NOLINTNEXTLINE +static const char* const gathernd_kernel = R"__migraphx__( +#include +#include +#include +#include +#include + +namespace migraphx { + +extern "C" { + +__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output) +{ + make_tensors()(in_data, in_indices, output)([](auto&&... xs) { + auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS})); + gathernd(xs..., settings); + }); +} + +} + +} // namespace migraphx + +)__migraphx__"; + +struct gathernd_compiler : compiler +{ + std::vector names() const { return {"gathernd"}; } + + operation compile_op(context& ctx, const std::vector& inputs, const value& v) const + { + hip_compile_options options; + auto out_s = inputs.back(); + options.set_launch_params(v, compute_global_for(ctx, out_s.elements())); + options.inputs = inputs; + options.output = out_s; + options.kernel_name = "gathernd_kernel"; + options.virtual_inputs = inputs; + + // batch_dims + assert(v.contains("batch_dims")); + auto batch_dims = v.at("batch_dims").to(); + options.params += " -DBATCH_DIMS=" + std::to_string(batch_dims); + + return compile_hip_code_object(gathernd_kernel, options); + } + + compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const + { + return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/jit/pointwise.cpp b/src/targets/gpu/jit/pointwise.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ffe4453ccdd1dea4eb09f84530cdb80c87213b1a --- /dev/null +++ b/src/targets/gpu/jit/pointwise.cpp @@ -0,0 +1,126 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +using namespace migraphx::gpu::gen; // NOLINT + +static const char* const pointwise_kernel = R"__migraphx__( +#include +#include +#include + +namespace migraphx { + +${preamble} + +extern "C" { +__global__ void ${kernel}(${params}) +{ + auto idx = make_index(); + pointwise(idx, ${transformers})(${lambda}, ${args}); +} + +} + +} // namespace migraphx + +)__migraphx__"; + +static std::vector get_op_names(const module& m) +{ + std::vector result; + for(auto& ins : m) + { + if(starts_with(ins.name(), "@")) + continue; + result.push_back(ins.name()); + } + return result; +} + +struct pointwise_compiler : compiler +{ + std::vector names() const { return {"pointwise"}; } + + static std::size_t oversubscribe_if(bool b) + { + if(b) + return 256; + else + return 1; + } + operation compile_op(context& ctx, const std::vector& inputs, const value& v) const + { + hip_compile_options options; + options.inputs = inputs; + options.output = inputs.back(); + options.virtual_inputs = reduce_dims(inputs); + options.params = "-Wno-float-equal"; + auto axis = find_fast_axis(options.virtual_inputs); + auto vec = vectorize::elements(axis, options.virtual_inputs); + auto preloads = preload::broadcasts(axis, options.virtual_inputs); + options.kernel_name = v.get("kernel", "kernel"); + options.set_launch_params( + v, + compute_global_for(ctx, + options.output.elements() / vec.size, + oversubscribe_if(not preloads.is_preloading()))); + auto src = interpolate_string(pointwise_kernel, + {{"kernel", options.kernel_name}, + {"params", enum_params(inputs.size(), "void * private_p")}, + {"args", enum_params(inputs.size(), "private_p")}, + {"lambda", v.at("lambda").to()}, + {"transformers", make_transformer_args(preloads, vec)}, + {"preamble", v.get("preamble", std::string{})}}); + return compile_hip_code_object(src, options); + } + + compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const + { + assert(not ins->module_inputs().empty()); + auto* pm = ins->module_inputs().front(); + run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}}); + cpp_generator g; + g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); + g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); + g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})"); + g.add_point_op("sign", + "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"); + g.add_point_op("equal", "migraphx::abs(${0} == ${1})"); + g.add_point_op("less", "migraphx::abs(${0} < ${1})"); + g.add_point_op("greater", "migraphx::abs(${0} > ${1})"); + g.add_point_op("not", "migraphx::abs(not ${0})"); + // Add explict conversions + g.fresult( + [](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; }); + auto name = g.create_function( + g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm)); + std::string lambda = "MIGRAPHX_LIFT(" + name + ")"; + auto op_names = get_op_names(*pm); + op_names.push_back("kernel"); + auto op_name_string = join_strings(op_names, "_"); + return replace( + compile_op(ctx, + to_shapes(ins->inputs()), + {{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}})); + } +}; +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3c628500f1b5932e7eab6f3e0d7d7f6c8935f7c6 --- /dev/null +++ b/src/targets/gpu/jit/reduce.cpp @@ -0,0 +1,179 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +using namespace migraphx::gpu::gen; // NOLINT + +static const char* const simple_reduce_kernel = R"__migraphx__( +#include +#include +#include +#include + +namespace migraphx { + +${preamble} + +extern "C" { +__global__ void reduce_kernel(void* input_p, void* output_p) +{ + + transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) { + + simple_reduce(${reduction}, ${init}, input, output, ${read}, ${write}); + }); +} + +} + +} // namespace migraphx + +)__migraphx__"; + +static std::size_t get_reduce_elements(const std::vector& inputs) +{ + return inputs.front().elements() / inputs.back().elements(); +} +static std::size_t get_reduce_elements(const std::vector& inputs) +{ + return get_reduce_elements(to_shapes(inputs)); +} + +static std::vector get_reduce_lens(const std::vector& input_lens, + const std::vector& output_lens) +{ + std::vector reduce_lens; + std::transform(output_lens.begin(), + output_lens.end(), + input_lens.begin(), + std::back_inserter(reduce_lens), + [](auto x, auto y) -> std::size_t { + if(x == y) + return 1; + else + return y; + }); + return reduce_lens; +} + +static std::string get_reduce_algo(const std::vector& inputs) +{ + auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens()); + const auto init = std::numeric_limits::max(); + // The minimum stride + auto min_stride = std::inner_product( + rlens.begin(), + rlens.end(), + inputs.front().strides().begin(), + init, + [](auto x, auto y) { return std::min(x, y); }, + [](auto len, auto stride) { return len == 1 ? init : stride; }); + if(min_stride > 2) + return "lane"; + return "block"; +} + +struct reduce_compiler : compiler +{ + std::vector names() const + { + return {"reduce", "reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_prod"}; + } + + operation compile_op(context& ctx, const std::vector& inputs, const value& v) const + { + hip_compile_options options; + options.inputs = inputs; + options.output = inputs.back(); + options.virtual_inputs = reduce_dims(inputs); + auto faxis = find_fast_axis({options.virtual_inputs.front()}); + vectorize vec{}; + // Vectorize if the axis is a reduction axis + if(options.virtual_inputs.back().lens()[faxis] == 1) + { + vec = vectorize::elements(faxis, options.virtual_inputs); + } + auto relements = get_reduce_elements(options.virtual_inputs) / vec.size; + auto nelements = options.virtual_inputs.back().elements(); + auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); + if(algo == "block") + { + auto block_size = compute_block_size(relements, 256); + options.set_launch_params( + v, compute_global_for(ctx, nelements * block_size, 256), block_size); + } + else if(algo == "lane") + { + options.set_launch_params(v, compute_global_for(ctx, nelements, 256)); + } + else + { + MIGRAPHX_THROW("Unknown reduce algo: " + algo); + } + options.kernel_name = "reduce_kernel"; + std::string identity = "[](auto x) { return x; }"; + auto src = interpolate_string(simple_reduce_kernel, + {{"reduction", v.at("reduction").to()}, + {"init", v.get("init", std::string{"0"})}, + {"read", v.get("read", identity)}, + {"write", v.get("write", identity)}, + {"algo", algo}, + {"transformers", make_transformer_args(vec)}, + {"preamble", v.get("preamble", std::string{})}}); + options.params += "-Wno-float-equal"; + return compile_hip_code_object(src, options); + } + + compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const + { + value v = value::object{}; + auto reduce_elements = get_reduce_elements(ins->inputs()); + if(op.name() == "reduce_sum") + { + v["reduction"] = "op::sum{}"; + } + else if(op.name() == "reduce_mean") + { + v["reduction"] = "op::sum{}"; + v["write"] = "op::mean{" + std::to_string(reduce_elements) + "}"; + } + else if(op.name() == "reduce_max") + { + v["reduction"] = "op::max{}"; + v["init"] = "lowest{}"; + } + else if(op.name() == "reduce_min") + { + v["reduction"] = "op::min{}"; + v["init"] = "highest{}"; + } + else if(op.name() == "reduce_prod") + { + v["reduction"] = "op::product{}"; + v["init"] = "1"; + } + else + { + MIGRAPHX_THROW("Unsupported reduce"); + } + return replace(compile_op(ctx, to_shapes(ins->inputs()), v)); + } +}; +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/jit/roialign.cpp b/src/targets/gpu/jit/roialign.cpp new file mode 100644 index 0000000000000000000000000000000000000000..db60465b7caeb01ba12f33457976d4cd32035dff --- /dev/null +++ b/src/targets/gpu/jit/roialign.cpp @@ -0,0 +1,87 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +// NOLINTNEXTLINE +static const char* const roialign_kernel = R"__migraphx__( +#include +#include +#include +#include + +namespace migraphx { + +extern "C" { + +__global__ void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y) +{ + make_tensors()(in_x, in_rois, in_ind, y)([](auto&&... xs) { + auto settings = make_roalign_settings(MIGRAPHX_MAKE_CONSTANT(float{ROIS_OFFSET}), + _c, + _c, + MIGRAPHX_MAKE_CONSTANT(float{SPATIAL_SCALE})); + roialign(xs..., settings); + }); +} + +} + +} // namespace migraphx + +)__migraphx__"; + +struct roialign_compiler : compiler +{ + std::vector names() const { return {"roialign"}; } + + operation compile_op(context& ctx, const std::vector& inputs, const value& v) const + { + hip_compile_options options; + options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()), 128); + options.output = inputs.back(); + options.inputs = inputs; + options.kernel_name = "roialign_kernel"; + + // sampling_ratio + options.params += " -DSAMPLING_RATIO=" + v.at("sampling_ratio").to(); + + // pooling_mode + auto mode = v.at("mode").to(); + std::string is_avg_pooling = + (mode == migraphx::op::pooling_mode::average) ? "true" : "false"; + options.params += " -DIS_AVG_POOLING=" + is_avg_pooling; + + // coord_trans_mode + auto ctm = v.at("coordinate_transformation_mode").to(); + float rois_offset = (ctm == "output_half_pixel") ? -0.5f : 0.0f; + options.params += " -DROIS_OFFSET=" + std::to_string(rois_offset); + + // spatial_scale + options.params += " -DSPATIAL_SCALE=" + v.at("spatial_scale").to(); + + return compile_hip_code_object(roialign_kernel, options); + } + + compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const + { + return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/jit/scatternd.cpp b/src/targets/gpu/jit/scatternd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fc34656ca84afaef7a67c7ea92a5ca9aaa43d1b5 --- /dev/null +++ b/src/targets/gpu/jit/scatternd.cpp @@ -0,0 +1,86 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +// NOLINTNEXTLINE +static const char* const scatternd_kernel = R"__migraphx__( +#include +#include +#include +#include + +namespace migraphx { + +extern "C" { + +__global__ void scatternd_kernel(void* in_indices, void* in_updates, void* output) +{ + make_tensors()(in_indices, in_updates, output)([](auto&&... xs) { + scatternd(xs..., ${reduction}{}); + }); +} + +} + +} // namespace migraphx + +)__migraphx__"; + +struct scatternd_compiler : compiler +{ + std::vector names() const + { + return {"scatternd_none", "scatternd_add", "scatternd_mul"}; + } + + operation compile_op(context& ctx, const std::vector& inputs, const value& v) const + { + hip_compile_options options; + options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements())); + options.inputs = inputs; + options.output = inputs.back(); + options.kernel_name = "scatternd_kernel"; + options.virtual_inputs = inputs; + auto reduction = "assign_" + v.get("reduction", std::string{"none"}); + auto src = interpolate_string(scatternd_kernel, {{"reduction", reduction}}); + return compile_hip_code_object(src, options); + } + + compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const + { + assert(starts_with(op.name(), "scatternd_")); + auto reduction = op.name().substr(10); + return insert(compile_op(ctx, + to_shapes({ins->inputs().begin() + 1, ins->inputs().end()}), + {{"reduction", reduction}})); + } + + compiler_replace insert(const operation& op) const + { + return [=](module& m, instruction_ref ins) { + auto args = ins->inputs(); + args.back() = + m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back()); + args.erase(args.begin()); + return m.replace_instruction(ins, op, args); + }; + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/kernel.cpp b/src/targets/gpu/kernel.cpp new file mode 100755 index 0000000000000000000000000000000000000000..582bd19b93e2f71fd8811da5164de4bcf772f1e8 --- /dev/null +++ b/src/targets/gpu/kernel.cpp @@ -0,0 +1,109 @@ +#include +#include +#include +#include +#include + +// extern declare the function since hip/hip_ext.h header is broken +extern hipError_t hipExtModuleLaunchKernel(hipFunction_t, // NOLINT + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + size_t, + hipStream_t, + void**, + void**, + hipEvent_t = nullptr, + hipEvent_t = nullptr, + uint32_t = 0); + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +extern std::string hip_error(int error); + +using hip_module_ptr = MIGRAPHX_MANAGE_PTR(hipModule_t, hipModuleUnload); + +struct kernel_impl +{ + hip_module_ptr module = nullptr; + hipFunction_t fun = nullptr; +}; + +hip_module_ptr load_module(const char* image) +{ + hipModule_t raw_m; + auto status = hipModuleLoadData(&raw_m, image); + hip_module_ptr m{raw_m}; + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to load module: " + hip_error(status)); + return m; +} + +kernel::kernel(const char* image, const std::string& name) : impl(std::make_shared()) +{ + impl->module = load_module(image); + auto status = hipModuleGetFunction(&impl->fun, impl->module.get(), name.c_str()); + if(hipSuccess != status) + MIGRAPHX_THROW("Failed to get function: " + name + ": " + hip_error(status)); +} + +void launch_kernel(hipFunction_t fun, + hipStream_t stream, + std::size_t global, + std::size_t local, + void* kernargs, + std::size_t size) +{ + assert(global > 0); + assert(local > 0); + void* config[] = { +// HIP_LAUNCH_PARAM_* are macros that do horrible things +#ifdef MIGRAPHX_USE_CLANG_TIDY + nullptr, kernargs, nullptr, &size, nullptr +#else + HIP_LAUNCH_PARAM_BUFFER_POINTER, + kernargs, + HIP_LAUNCH_PARAM_BUFFER_SIZE, + &size, + HIP_LAUNCH_PARAM_END +#endif + }; + + auto status = hipExtModuleLaunchKernel( + fun, global, 1, 1, local, 1, 1, 0, stream, nullptr, reinterpret_cast(&config)); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to launch kernel: " + hip_error(status)); +} + +void kernel::launch(hipStream_t stream, + std::size_t global, + std::size_t local, + std::vector args) const +{ + assert(impl != nullptr); + void* kernargs = args.data(); + std::size_t size = args.size() * sizeof(void*); + + launch_kernel(impl->fun, stream, global, local, kernargs, size); +} + +void kernel::launch(hipStream_t stream, + std::size_t global, + std::size_t local, + const std::vector& args) const +{ + assert(impl != nullptr); + std::vector kernargs = pack_args(args); + std::size_t size = kernargs.size(); + + launch_kernel(impl->fun, stream, global, local, kernargs.data(), size); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2b702c05612a77a3f8fea0df89bcfe0ed34d1965 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp @@ -0,0 +1,150 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_ALGORITHM_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_ALGORITHM_HPP + +namespace migraphx { + +struct less +{ + template + constexpr auto operator()(T x, U y) const + { + return x < y; + } +}; + +struct greater +{ + template + constexpr auto operator()(T x, U y) const + { + return x > y; + } +}; + +template +constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op) +{ + for(; first != last; ++first) + { + init = op(std::move(init), *first); + } + return init; +} + +template +constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first) +{ + while(first != last) + { + *d_first++ = *first++; + } + return d_first; +} + +template +constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp) +{ + if(first != last) + { + Iterator next = first; + while(++next != last) + { + if(comp(*next, *first)) + return next; + first = next; + } + } + return last; +} + +template +constexpr bool is_sorted(Iterator first, Iterator last, Compare comp) +{ + return is_sorted_until(first, last, comp) == last; +} + +template +constexpr F for_each(Iterator first, Iterator last, F f) +{ + for(; first != last; ++first) + { + f(*first); + } + return f; +} + +template +constexpr Iterator find_if(Iterator first, Iterator last, Predicate p) +{ + for(; first != last; ++first) + { + if(p(*first)) + { + return first; + } + } + return last; +} + +template +constexpr Iterator find(Iterator first, Iterator last, const T& value) +{ + return find_if(first, last, [&](const auto& x) { return x == value; }); +} + +template +constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, Iterator2 s_last) +{ + for(;; ++first) + { + Iterator1 it = first; + for(Iterator2 s_it = s_first;; ++it, ++s_it) + { + if(s_it == s_last) + { + return first; + } + if(it == last) + { + return last; + } + if(!(*it == *s_it)) + { + break; + } + } + } +} + +template +constexpr T inner_product(InputIt1 first1, + InputIt1 last1, + InputIt2 first2, + T init, + BinaryOperation1 op1, + BinaryOperation2 op2) +{ + while(first1 != last1) + { + init = op1(init, op2(*first1, *first2)); + ++first1; + ++first2; + } + return init; +} + +template +constexpr T inner_product(InputIt1 first1, InputIt1 last1, InputIt2 first2, T init) +{ + return inner_product( + first1, + last1, + first2, + init, + [](auto x, auto y) { return x + y; }, + [](auto x, auto y) { return x * y; }); +} + +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/args.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/args.hpp new file mode 100755 index 0000000000000000000000000000000000000000..7586425a14369934ec5e423e0410d66135c5b0d2 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/args.hpp @@ -0,0 +1,27 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_ARGS_HPP +#define MIGRAPHX_GUARD_KERNELS_ARGS_HPP + +#include +#include + +namespace migraphx { + +// Use template specialization since ADL is broken on hcc +template +struct make_tensor; + +template +__device__ auto make_tensors_impl(F f, detail::seq, Ts*... xs) +{ + return f(make_tensor::apply(xs)...); +} + +inline __device__ auto make_tensors() +{ + return [](auto*... xs) { + return [=](auto f) { return make_tensors_impl(f, detail::gens{}, xs...); }; + }; +} + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_ARGS_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/array.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/array.hpp new file mode 100644 index 0000000000000000000000000000000000000000..39861988791624342202db529b3efe7abb502237 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/array.hpp @@ -0,0 +1,204 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_ARRAY_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_ARRAY_HPP + +#include +#include +#include +#include + +namespace migraphx { + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \ + template \ + constexpr array& operator op(const array& x) \ + { \ + for(index_int i = 0; i < N; i++) \ + d[i] op x[i]; \ + return *this; \ + } \ + template {})> \ + constexpr array& operator op(const U& x) \ + { \ + for(index_int i = 0; i < N; i++) \ + d[i] op x; \ + return *this; \ + } \ + template \ + friend constexpr auto operator binary_op(const array& x, const array& y) \ + { \ + array z{}; \ + for(index_int i = 0; i < N; i++) \ + z[i] = x[i] binary_op y[i]; \ + return z; \ + } \ + template {})> \ + friend constexpr auto operator binary_op(const array& x, const U& y) \ + { \ + array z{}; \ + for(index_int i = 0; i < N; i++) \ + z[i] = x[i] binary_op y; \ + return z; \ + } \ + template {})> \ + friend constexpr auto operator binary_op(const U& x, const array& y) \ + { \ + array z{}; \ + for(index_int i = 0; i < N; i++) \ + z[i] = x binary_op y[i]; \ + return z; \ + } + +template +struct array +{ + T d[N]; + constexpr T& operator[](index_int i) + { + MIGRAPHX_ASSERT(i < N); + return d[i]; + } + constexpr const T& operator[](index_int i) const + { + MIGRAPHX_ASSERT(i < N); + return d[i]; + } + + constexpr T& front() { return d[0]; } + constexpr const T& front() const { return d[0]; } + + constexpr T& back() { return d[N - 1]; } + constexpr const T& back() const { return d[N - 1]; } + + constexpr T* data() { return d; } + constexpr const T* data() const { return d; } + + constexpr index_constant size() const { return {}; } + constexpr auto empty() const { return size() == _c<0>; } + + constexpr T* begin() { return d; } + constexpr const T* begin() const { return d; } + + constexpr T* end() { return d + size(); } + constexpr const T* end() const { return d + size(); } + + constexpr T dot(const array& x) const + { + T result = 0; + for(index_int i = 0; i < N; i++) + result += x[i] * d[i]; + return result; + } + + constexpr T product() const + { + T result = 1; + for(index_int i = 0; i < N; i++) + result *= d[i]; + return result; + } + + constexpr T single(index_int width = 100) const + { + T result = 0; + T a = 1; + for(index_int i = 0; i < N; i++) + { + result += d[N - i - 1] * a; + a *= width; + } + return result; + } + + MIGRAPHX_DEVICE_ARRAY_OP(+=, +) + MIGRAPHX_DEVICE_ARRAY_OP(-=, -) + MIGRAPHX_DEVICE_ARRAY_OP(*=, *) + MIGRAPHX_DEVICE_ARRAY_OP(/=, /) + MIGRAPHX_DEVICE_ARRAY_OP(%=, %) + MIGRAPHX_DEVICE_ARRAY_OP(&=, &) + MIGRAPHX_DEVICE_ARRAY_OP(|=, |) + MIGRAPHX_DEVICE_ARRAY_OP(^=, ^) + + friend constexpr bool operator==(const array& x, const array& y) + { + for(index_int i = 0; i < N; i++) + { + if(x[i] != y[i]) + return false; + } + return true; + } + + friend constexpr bool operator!=(const array& x, const array& y) { return !(x == y); } + // This uses the product order rather than lexical order + friend constexpr bool operator<(const array& x, const array& y) + { + for(index_int i = 0; i < N; i++) + { + if(not(x[i] < y[i])) + return false; + } + return true; + } + friend constexpr bool operator>(const array& x, const array& y) { return y < x; } + friend constexpr bool operator<=(const array& x, const array& y) { return (x < y) or (x == y); } + friend constexpr bool operator>=(const array& x, const array& y) { return (y < x) or (x == y); } + + constexpr array carry(array result) const + { + index_int overflow = 0; + for(diff_int i = result.size() - 1; i > 0; i--) + { + auto z = result[i] + overflow; + // Reset overflow + overflow = 0; + // Compute overflow using while loop instead of mod + while(z >= d[i]) + { + z -= d[i]; + overflow += 1; + } + result[i] = z; + } + result[0] += overflow; + return result; + } + + template + friend constexpr const Stream& operator<<(const Stream& ss, const array& a) + { + for(index_int i = 0; i < N; i++) + { + if(i > 0) + ss << ", "; + ss << a[i]; + } + return ss; + } +}; + +template +struct integral_const_array : array +{ + using base_array = array; + MIGRAPHX_DEVICE_CONSTEXPR integral_const_array() : base_array({Xs...}) {} +}; + +template +constexpr auto transform(integral_const_array, F f) +{ + return integral_const_array{}; +} + +template +constexpr auto transform(integral_const_array, integral_const_array, F f) +{ + return integral_const_array{}; +} + +template +using index_ints = integral_const_array; + +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3efc070ff338ace592353ce911539bb2349a1e7f --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp @@ -0,0 +1,159 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_DEBUG_HPP +#define MIGRAPHX_GUARD_KERNELS_DEBUG_HPP + +#include + +namespace migraphx { + +#define MIGRAPHX_STRINGIZE_1(...) #__VA_ARGS__ +#define MIGRAPHX_STRINGIZE(...) MIGRAPHX_STRINGIZE_1(__VA_ARGS__) + +// Workaround hip's broken abort on device code +#ifdef __HIP_DEVICE_COMPILE__ +// NOLINTNEXTLINE +#define MIGRAPHX_HIP_NORETURN +#else +// NOLINTNEXTLINE +#define MIGRAPHX_HIP_NORETURN [[noreturn]] +#endif + +namespace debug { +struct swallow +{ + template + constexpr swallow(Ts&&...) + { + } +}; + +template +struct print_buffer +{ + char buffer[N + 1] = {0}; + char* pos = buffer; + + constexpr void append(char c) + { + if(c == 0) + return; + if(pos < buffer + N) + { + *pos = c; + pos++; + } + } + template + constexpr void append(T i) + { + if(i < 0) + { + append('-'); + i = -i; + } + char c = (i % 10) + '0'; + if(i > 9) + append(i / 10); + append(c); + } + + constexpr void append(const char* str) + { + if(str == nullptr) + return; + int i = 512; + while(*str != 0 and i > 0) + { + append(*str); + str++; + i--; + } + } + + template + constexpr void append(const char (&array)[M]) + { + for(int i = 0; i < M; i++) + append(array[i]); + } +}; + +template +__host__ __device__ void print(const Ts&... xs) +{ + print_buffer<1024> buffer; + swallow{(buffer.append(xs), 0)...}; + printf("%s", buffer.buffer); +} + +} // namespace debug + +struct source_location +{ + int line = __builtin_LINE(); + const char* file = __builtin_FILE(); + const char* function = __builtin_FUNCTION(); +}; + +template +struct source_location_capture +{ + T x; + source_location loc; + template + constexpr source_location_capture(U px, source_location ploc = source_location{}) + : x(px), loc(ploc) + { + } + + constexpr operator source_location() const { return loc; } + + constexpr operator T() const { return x; } +}; + +// noreturn cannot be used on this function because abort in hip is broken +template +MIGRAPHX_HIP_NORETURN inline __host__ __device__ void +assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& function) +{ + // printf is broken on hip with more than one argument, so use a simple print functions instead + debug::print(file, ":", line, ": ", function, ": assertion '", assertion, "' failed.\n"); + // printf("%s:%s: %s: assertion '%s' failed.\n", file, line, function, assertion); + abort(); +} + +template +MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_location& loc, + Ts... xs) +{ + debug::print(loc.file, ":", loc.line, ": ", loc.function, ": error: ", xs..., "\n"); + abort(); +} + +// NOLINTNEXTLINE +#define MIGRAPHX_ASSERT_FAIL(cond, ...) \ + ((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \ + assert_fail(private_migraphx_xs...); \ + }(__VA_ARGS__)) + +// NOLINTNEXTLINE +#define MIGRAPHX_CHECK(cond) \ + MIGRAPHX_ASSERT_FAIL(cond, #cond, __FILE__, __LINE__, __PRETTY_FUNCTION__) + +#ifdef MIGRAPHX_DEBUG +// NOLINTNEXTLINE +#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) source_location_capture +#define MIGRAPHX_WARN(cond, loc, ...) MIGRAPHX_ASSERT_FAIL(cond, loc, __VA_ARGS__) +#define MIGRAPHX_ASSERT MIGRAPHX_CHECK +#define MIGRAPHX_ASSUME MIGRAPHX_CHECK +#define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false) +#else +// NOLINTNEXTLINE +#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) T +#define MIGRAPHX_ASSUME __builtin_assume +#define MIGRAPHX_UNREACHABLE __builtin_unreachable +#define MIGRAPHX_ASSERT(cond) +#define MIGRAPHX_WARN(...) +#endif + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/dfor.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/dfor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..58c6121bc636f46ec086fea29e3c85ef96004ca2 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/dfor.hpp @@ -0,0 +1,25 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_DFOR_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_DFOR_HPP + +namespace migraphx { + +// Multidimensional for loop +inline constexpr auto dfor() +{ + return [](auto f) { f(); }; +} + +template +constexpr auto dfor(T x, Ts... xs) +{ + return [=](auto f) { + for(T i = 0; i < x; i++) + { + dfor(xs...)([&](Ts... is) { f(i, is...); }); + } + }; +} + +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d3d9ea7a29f9a6052e9a3ed5cec79f7bae31a4af --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp @@ -0,0 +1,55 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_DPP_HPP +#define MIGRAPHX_GUARD_KERNELS_DPP_HPP + +#include +#include +#include + +namespace migraphx { + +#ifndef MIGRAPHX_HAS_DPP +#define MIGRAPHX_HAS_DPP 1 +#endif + +#if MIGRAPHX_HAS_DPP +constexpr unsigned int dpp_row_shr(unsigned int x) { return 0x110u | x; } + +constexpr unsigned int dpp_row_bcast(unsigned int x) +{ + unsigned int y = 0; + switch(x) + { + case 15: y = 0x142; break; + case 31: y = 0x143; break; + default: MIGRAPHX_UNREACHABLE(); + } + return y; +} + +template +__device__ T dpp_mov(T& x) +{ + static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4; + union type + { + uint32_t reg[n]; + T data; + }; + type output{}; + type input{}; + // cppcheck-suppress unreadVariable + input.data = x; + for(index_int i = 0; i < n; i++) + { + output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl); + } + return output.data; +} +#endif + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_DPP_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5722ab5c8ad040faec97599d3d31ad55d05f8242 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp @@ -0,0 +1,255 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP +#define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP + +#include + +// NOLINTNEXTLINE +#define MIGRAPHX_RETURNS(...) \ + ->decltype(__VA_ARGS__) { return __VA_ARGS__; } + +// NOLINTNEXTLINE +#define MIGRAPHX_LIFT(...) \ + [](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast(xs)...)) + +namespace migraphx { + +struct swallow +{ + template + constexpr swallow(Ts&&...) + { + } +}; + +template +using ignore = swallow; + +template +struct overloaded : Fs... +{ + using Fs::operator()...; + overloaded(Fs... fs) : Fs(fs)... {} +}; + +template +overloaded overload(Fs... fs) +{ + return {fs...}; +} + +namespace detail { + +template +struct eval_helper +{ + R result; + + template + constexpr eval_helper(const F& f, Ts&&... xs) : result(f(static_cast(xs)...)) + { + } +}; + +template <> +struct eval_helper +{ + int result; + template + constexpr eval_helper(const F& f, Ts&&... xs) : result((f(static_cast(xs)...), 0)) + { + } +}; + +template +struct seq +{ + using type = seq; +}; + +template +struct merge_seq; + +template +struct merge_seq, seq> : seq +{ +}; + +template +struct gens : merge_seq::type, typename gens::type> +{ +}; + +template <> +struct gens<0> : seq<> +{ +}; +template <> +struct gens<1> : seq<0> +{ +}; + +template +constexpr auto sequence_c_impl(F&& f, seq) +{ + return f(index_constant{}...); +} + +template +constexpr auto args_at(seq) +{ + return [](ignore..., auto x, auto...) { return x; }; +} + +} // namespace detail + +template +constexpr auto always(T x) +{ + return [=](auto&&...) { return x; }; +} + +template +constexpr auto sequence_c(F&& f) +{ + return detail::sequence_c_impl(f, detail::gens{}); +} + +template +constexpr auto sequence(IntegerConstant ic, F&& f) +{ + return sequence_c(f); +} + +template +constexpr auto by(F f, G g) +{ + return [=](auto... xs) { + return detail::eval_helper{g, f(xs)...}.result; + }; +} + +template +constexpr auto by(F f) +{ + return by([=](auto x) { return (f(x), 0); }, always(0)); +} + +template +constexpr void each_args(F f, Ts&&... xs) +{ + swallow{(f(static_cast(xs)), 0)...}; +} + +template +constexpr void each_args(F) +{ +} + +template +constexpr auto fold_impl(F&&, T&& x) +{ + return static_cast(x); +} + +template +constexpr auto fold_impl(F&& f, T&& x, U&& y, Ts&&... xs) +{ + return fold_impl(f, f(static_cast(x), static_cast(y)), static_cast(xs)...); +} + +template +constexpr auto fold(F f) +{ + return [=](auto&&... xs) { return fold_impl(f, static_cast(xs)...); }; +} + +template +constexpr auto pack(Ts... xs) +{ + return [=](auto f) { return f(xs...); }; +} + +template +constexpr auto join(G g, F f) +{ + return f([=](auto... xs) { return g(xs...); }); +} + +template +constexpr auto join(G g, F f, Fs... fs) +{ + return f([=](auto... xs) { return join([=](auto... ys) { return g(xs..., ys...); }, fs...); }); +} + +template +constexpr auto pack_compare(Compare compare, P1 p1, P2 p2) +{ + return p1([&](auto... xs) { + return p2([&](auto... ys) { + auto c = [&](auto x, auto y) -> int { + if(compare(x, y)) + return 1; + else if(compare(y, x)) + return -1; + else + return 0; + }; + return fold([](auto x, auto y) { return x ? x : y; })(c(xs, ys)..., 0); + }); + }); +} + +template +constexpr auto arg_c() +{ + return [](auto... xs) { return detail::args_at(detail::gens{})(xs...); }; +} + +template +constexpr auto arg(IntegralConstant ic) +{ + return arg_c(); +} + +template +constexpr auto make_transform(F f) +{ + return [=](auto... xs) { return [=](auto g) { return f(g, xs...); }; }; +} + +// An arg transformation takes the arguments and then a function to take the new arguments: +// transform(xs...)([](auto... ys) { ... }) +// The transform_args function takes a list of transformations and continually applies them +template +constexpr auto transform_args(F f) +{ + return f; +} + +template +constexpr auto transform_args(F f, Fs... fs) +{ + return make_transform([=](auto g, auto... xs) { + return f(xs...)([=](auto... ys) { return transform_args(fs...)(ys...)(g); }); + }); +} + +// identity transform +inline constexpr auto transform_args() +{ + return make_transform([](auto f, auto... xs) { return f(xs...); }); +} + +// Rotate the first argument to the last argument +inline constexpr auto rotate_last() +{ + return make_transform([](auto f, auto... xs) { + return sequence_c([&](auto... is) { + constexpr auto size = sizeof...(is); + return f(arg_c<(is + size - 1) % size>()(xs...)...); + }); + }); +} + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..22d49ac381178bb92888693defcf52bd0d7b03bb --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp @@ -0,0 +1,81 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_GATHERND_HPP +#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP + +#include +#include + +namespace migraphx { + +template +struct gathernd_settings +{ + T batch_dims{}; +}; + +template +constexpr gathernd_settings make_gathernd_settings(Ts... xs) +{ + return {xs...}; +} + +template +__device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, Settings s) +{ + auto ind = make_index(); + auto batch_dims = s.batch_dims; + auto output_shape = output_t.get_shape(); + auto indices_shape = indices_t.get_shape(); + auto data_shape = data_t.get_shape(); + + auto indices_shape_lens = indices_shape.lens; + auto data_shape_lens = data_shape.lens; + auto num_slice_dims = indices_shape_lens.back(); + std::size_t num_slices = accumulate(indices_shape_lens.begin(), + indices_shape_lens.end() - 1, + 1, + std::multiplies()); + std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims, + data_shape_lens.end(), + 1, + std::multiplies()); + const std::size_t num_batches = accumulate(data_shape_lens.begin(), + data_shape_lens.begin() + batch_dims, + 1, + std::multiplies()); + const std::size_t data_batch_stride = accumulate(data_shape_lens.begin() + batch_dims, + data_shape_lens.end(), + 1, + std::multiplies()); + const auto num_slices_per_batch = num_slices / num_batches; + + ind.global_stride(output_shape.elements(), [&](auto i) { + const auto* indices_ptr = indices_t.data(); + const std::size_t j = i / slice_size; + const std::size_t batch_idx = j / num_slices_per_batch; + + auto* slice_indices = indices_ptr + (j * num_slice_dims); + std::size_t relative_slice_offset = 0; + for(std::size_t idx = 0; idx < num_slice_dims; ++idx) + { + int64_t index = slice_indices[idx]; + const std::size_t input_dim_idx = batch_dims + idx; + const auto input_dim = data_shape_lens[input_dim_idx]; + assert(index >= -static_cast(input_dim) and + index < static_cast(input_dim)); + if(index < 0) + index += input_dim; + std::size_t size_from_slice_dims = + accumulate(data_shape_lens.begin() + batch_dims + idx + 1, + data_shape_lens.begin() + batch_dims + num_slice_dims, + slice_size, + std::multiplies()); + relative_slice_offset += index * size_from_slice_dims; + } + + auto slice_offset = (batch_idx * data_batch_stride) + relative_slice_offset; + output_t[i] = data_t[slice_offset + i % slice_size]; + }); +} + +} // namespace migraphx +#endif diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/generic_constant.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/generic_constant.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b9b226c936add9906c71a6593dc9a24c5f17f4de --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/generic_constant.hpp @@ -0,0 +1,33 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP +#define MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP + +namespace migraphx { + +template +struct generic_constant +{ + static constexpr auto value = F{}(); + using value_type = decltype(value); + using type = generic_constant; + constexpr operator value_type() const noexcept { return value; } + constexpr value_type operator()() const noexcept { return value; } +}; + +template +constexpr generic_constant make_generic_constant(F) +{ + return {}; +} + +// NOLINTNEXTLINE +#define MIGRAPHX_MAKE_CONSTANT(x) \ + make_generic_constant([] { \ + struct fun \ + { \ + constexpr auto operator()() const { return x; } \ + }; \ + return fun{}; \ + }()) + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2017faeb0075dea02bfa525e747cc71c8f8af794 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp @@ -0,0 +1,11 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP +#define MIGRAPHX_GUARD_KERNELS_HIP_HPP + +// Workaround macro redefinition issue with clang tidy +#if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY) +#undef __HIP_PLATFORM_HCC__ // NOLINT +#endif + +#include + +#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cef0c35ef08f24d7516072ea40eced1ec7c2dec3 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp @@ -0,0 +1,61 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_INDEX_HPP +#define MIGRAPHX_GUARD_KERNELS_INDEX_HPP + +#include +#include +#include + +namespace migraphx { + +struct index +{ + index_int global = 0; + index_int local = 0; + index_int group = 0; + +#ifdef MIGRAPHX_NGLOBAL + constexpr index_constant nglobal() const { return {}; } +#else + __device__ index_int nglobal() const + { + return blockDim.x * gridDim.x; // NOLINT + } +#endif + +#ifdef MIGRAPHX_NLOCAL + constexpr index_constant nlocal() const { return {}; } +#else + __device__ index_int nlocal() const + { + return blockDim.x; // NOLINT + } +#endif + + template + __device__ void global_stride(index_int n, F f) const + { + const auto stride = nglobal(); + for(index_int i = global; i < n; i += stride) + { + f(i); + } + } + + template + __device__ void local_stride(index_int n, F f) const + { + const auto stride = nlocal(); + for(index_int i = local; i < n; i += stride) + { + f(i); + } + } +}; + +inline __device__ index make_index() +{ + return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT +} + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_INDEX_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/integral_constant.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/integral_constant.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e0e32ca4447fc86c3ad7d9fd3a347719f765c1e4 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/integral_constant.hpp @@ -0,0 +1,80 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP +#define MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP + +#include + +namespace migraphx { + +template +struct integral_constant +{ + static constexpr T value = V; + using value_type = T; + using type = integral_constant; + constexpr operator value_type() const noexcept { return value; } + constexpr value_type operator()() const noexcept { return value; } + static constexpr type to() { return {}; } +}; + +// NOLINTNEXTLINE +#define MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(op) \ + template \ + constexpr inline integral_constant operator op( \ + integral_constant, integral_constant) noexcept \ + { \ + return {}; \ + } + +// NOLINTNEXTLINE +#define MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(op) \ + template \ + constexpr inline integral_constant operator op( \ + integral_constant) noexcept \ + { \ + return {}; \ + } + +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(+) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(-) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(*) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(/) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(%) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>>) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(<<) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(<) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(<=) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>=) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(==) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(^) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(|) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&&) +MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(||) + +MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(!) +MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(~) +MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(+) +MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(-) + +template +using bool_constant = integral_constant; + +using true_type = bool_constant; +using false_type = bool_constant; + +template +using index_constant = integral_constant; + +template +static constexpr auto _c = integral_constant{}; // NOLINT + +template +constexpr auto return_c(F f) +{ + return _c; +} + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/iota_iterator.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/iota_iterator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9f7dd46a987e8a7786cc28dddd5d2e4276ae5332 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/iota_iterator.hpp @@ -0,0 +1,145 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP +#define MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP + +#include +#include + +namespace migraphx { + +template +struct basic_iota_iterator +{ + Iterator index; + F f; + + using difference_type = diff_int; + using reference = decltype(f(declval())); + using value_type = remove_reference_t; + using pointer = add_pointer_t; + + constexpr basic_iota_iterator& operator+=(diff_int n) + { + index += n; + return *this; + } + + constexpr basic_iota_iterator& operator-=(diff_int n) + { + index -= n; + return *this; + } + + constexpr basic_iota_iterator& operator++() + { + index++; + return *this; + } + + constexpr basic_iota_iterator& operator--() + { + index--; + return *this; + } + + constexpr basic_iota_iterator operator++(int) // NOLINT + { + basic_iota_iterator it = *this; + index++; + return it; + } + + constexpr basic_iota_iterator operator--(int) // NOLINT + { + basic_iota_iterator it = *this; + index--; + return it; + } + // TODO: operator-> + constexpr reference operator*() const { return f(index); } + + template + constexpr reference operator[](T x) const + { + return f(index + x); + } +}; + +template +constexpr basic_iota_iterator make_basic_iota_iterator(T x, F f) +{ + return basic_iota_iterator{x, f}; +} + +template +constexpr basic_iota_iterator operator+(basic_iota_iterator x, diff_int y) +{ + return x += y; +} + +template +constexpr basic_iota_iterator operator+(diff_int x, basic_iota_iterator y) +{ + return y + x; +} + +template +constexpr diff_int operator-(basic_iota_iterator x, basic_iota_iterator y) +{ + return x.index - y.index; +} + +template +constexpr basic_iota_iterator operator-(basic_iota_iterator x, diff_int y) +{ + return x -= y; +} + +template +constexpr bool operator==(basic_iota_iterator x, basic_iota_iterator y) +{ + return x.index == y.index; +} + +template +constexpr bool operator!=(basic_iota_iterator x, basic_iota_iterator y) +{ + return x.index != y.index; +} + +template +constexpr bool operator<(basic_iota_iterator x, basic_iota_iterator y) +{ + return x.index < y.index; +} + +template +constexpr bool operator>(basic_iota_iterator x, basic_iota_iterator y) +{ + return x.index > y.index; +} + +template +constexpr bool operator>=(basic_iota_iterator x, basic_iota_iterator y) +{ + return x.index >= y.index; +} + +template +constexpr bool operator<=(basic_iota_iterator x, basic_iota_iterator y) +{ + return x.index <= y.index; +} + +struct defaul_iota_iterator +{ + template + constexpr auto operator()(T x) const + { + return x; + } +}; + +using iota_iterator = basic_iota_iterator; + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3c1562baf47e18f5eb8194c50f2670e2bcae9ced --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp @@ -0,0 +1,227 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_MATH_HPP +#define MIGRAPHX_GUARD_KERNELS_MATH_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { + +namespace math { +constexpr float as_float(migraphx::half x) { return x; } +template +constexpr T as_float(T x) +{ + return x; +} +} // namespace math + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH(name, fname) \ + template ())> \ + auto __device__ name(Ts... xs) MIGRAPHX_RETURNS(fname(xs...)) + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_VEC(name) \ + template ())> \ + auto __device__ name(Ts... xs) \ + { \ + return vec_transform(xs...)([](auto... ys) { return name(ys...); }); \ + } + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \ + template ())> \ + auto __device__ name(type x, Ts... xs)->type \ + { \ + return fname(x, xs...); \ + } + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \ + inline auto __device__ name(type x, type y)->type { return fname(x, y); } + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \ + template ())> \ + auto __device__ name(migraphx::half x, Ts... xs) \ + MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...)) + +// Template with two overloads for math functions, one for half2 type and one for more generic +// vectorization where N is 4 or another even number. + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_HALF2(name, fname) \ + template \ + auto __device__ name(migraphx::vec x, Ts... xs) \ + MIGRAPHX_RETURNS(migraphx::vec{fname(x, xs...)}); \ + template 2))> \ + auto __device__ name(migraphx::vec x, Ts... xs) \ + { \ + return vec_packed_transform<2>(x, xs...)( \ + [](auto... ys) -> migraphx::vec { return fname(ys...); }); \ + } + +MIGRAPHX_DEVICE_MATH(abs, ::abs) +MIGRAPHX_DEVICE_MATH(acos, ::acos) +MIGRAPHX_DEVICE_MATH(acosh, ::acosh) +MIGRAPHX_DEVICE_MATH(asin, ::asin) +MIGRAPHX_DEVICE_MATH(asinh, ::asinh) +MIGRAPHX_DEVICE_MATH(atan, ::atan) +MIGRAPHX_DEVICE_MATH(atanh, ::atanh) +MIGRAPHX_DEVICE_MATH(ceil, ::ceil) +MIGRAPHX_DEVICE_MATH(cos, ::cos) +MIGRAPHX_DEVICE_MATH(cosh, ::cosh) +MIGRAPHX_DEVICE_MATH(erf, ::erf) +MIGRAPHX_DEVICE_MATH(exp, ::exp) +MIGRAPHX_DEVICE_MATH(floor, ::floor) +MIGRAPHX_DEVICE_MATH(isnan, ::isnan) +MIGRAPHX_DEVICE_MATH(log, ::log) +MIGRAPHX_DEVICE_MATH(pow, ::pow) +MIGRAPHX_DEVICE_MATH(round, ::round) +MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt) +MIGRAPHX_DEVICE_MATH(sin, ::sin) +MIGRAPHX_DEVICE_MATH(sinh, ::sinh) +MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt) +MIGRAPHX_DEVICE_MATH(tan, ::tan) +MIGRAPHX_DEVICE_MATH(tanh, ::tanh) + +// Float overloads +MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf) +MIGRAPHX_DEVICE_MATH_FOR(float, acosh, ::acoshf) +MIGRAPHX_DEVICE_MATH_FOR(float, asin, ::asinf) +MIGRAPHX_DEVICE_MATH_FOR(float, asinh, ::asinhf) +MIGRAPHX_DEVICE_MATH_FOR(float, atan, ::atanf) +MIGRAPHX_DEVICE_MATH_FOR(float, atanh, ::atanhf) +MIGRAPHX_DEVICE_MATH_FOR(float, cos, ::cosf) +MIGRAPHX_DEVICE_MATH_FOR(float, cosh, ::coshf) +MIGRAPHX_DEVICE_MATH_FOR(float, rsqrt, ::rsqrtf) +MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf) +MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf) +MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf) +MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf) + +// Builtin half functions +MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs) +MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, exp, ::hexp) +MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog) +MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt) +MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt) + +// Use float to compute half overload +MIGRAPHX_DEVICE_MATH_HALF(acos, ::acos) +MIGRAPHX_DEVICE_MATH_HALF(acosh, ::acosh) +MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin) +MIGRAPHX_DEVICE_MATH_HALF(asinh, ::asinh) +MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan) +MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh) +MIGRAPHX_DEVICE_MATH_HALF(ceil, ::ceil) +MIGRAPHX_DEVICE_MATH_HALF(cos, ::cos) +MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh) +MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf) +MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor) +MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan) +MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow) +MIGRAPHX_DEVICE_MATH_HALF(round, ::round) +MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin) +MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh) +MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan) +MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh) + +// Map math functions to hip half2 functions +// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats +// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names +// Most but not all of these math ops have operators of the same names. Ones not yet implemented +// at this time are: exp2, exp10, log2, log10, isinf +MIGRAPHX_DEVICE_MATH_HALF2(abs, ::__habs2) +MIGRAPHX_DEVICE_MATH_HALF2(ceil, ::h2ceil) +MIGRAPHX_DEVICE_MATH_HALF2(floor, ::h2floor) +MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin) +MIGRAPHX_DEVICE_MATH_HALF2(cos, ::h2cos) +MIGRAPHX_DEVICE_MATH_HALF2(exp, ::h2exp) +MIGRAPHX_DEVICE_MATH_HALF2(exp2, ::h2exp2) +MIGRAPHX_DEVICE_MATH_HALF2(exp10, ::h2exp10) +MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2) +MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log) +MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10) +MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt) +MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt) +MIGRAPHX_DEVICE_MATH_HALF2(isinf, ::__hisinf2) +MIGRAPHX_DEVICE_MATH_HALF2(isnan, ::__hisnan2) + +template +constexpr auto where(bool cond, const T& a, const U& b) +{ + return cond ? a : b; +} + +MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max) +MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min) +MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max) +MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min) +// Add overloads for half that calls the float version +MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::fmaxf) +MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::fminf) + +template ())> +constexpr auto max(const T& a, const T& b) +{ + return where(a < b, b, a); +} + +template ())> +constexpr auto min(const T& a, const T& b) +{ + return where(a < b, a, b); +} + +template {} and not is_any_vec())> +constexpr auto max(const T& a, const U& b) +{ + return max>(a, b); +} + +template {} and not is_any_vec())> +constexpr auto min(const T& a, const U& b) +{ + return min>(a, b); +} + +MIGRAPHX_DEVICE_MATH_VEC(abs) +MIGRAPHX_DEVICE_MATH_VEC(acos) +MIGRAPHX_DEVICE_MATH_VEC(acosh) +MIGRAPHX_DEVICE_MATH_VEC(asin) +MIGRAPHX_DEVICE_MATH_VEC(asinh) +MIGRAPHX_DEVICE_MATH_VEC(atan) +MIGRAPHX_DEVICE_MATH_VEC(atanh) +MIGRAPHX_DEVICE_MATH_VEC(ceil) +MIGRAPHX_DEVICE_MATH_VEC(cos) +MIGRAPHX_DEVICE_MATH_VEC(cosh) +MIGRAPHX_DEVICE_MATH_VEC(erf) +MIGRAPHX_DEVICE_MATH_VEC(exp) +MIGRAPHX_DEVICE_MATH_VEC(floor) +MIGRAPHX_DEVICE_MATH_VEC(isnan) +MIGRAPHX_DEVICE_MATH_VEC(log) +MIGRAPHX_DEVICE_MATH_VEC(max) +MIGRAPHX_DEVICE_MATH_VEC(min) +MIGRAPHX_DEVICE_MATH_VEC(pow) +MIGRAPHX_DEVICE_MATH_VEC(round) +MIGRAPHX_DEVICE_MATH_VEC(rsqrt) +MIGRAPHX_DEVICE_MATH_VEC(sin) +MIGRAPHX_DEVICE_MATH_VEC(sinh) +MIGRAPHX_DEVICE_MATH_VEC(sqrt) +MIGRAPHX_DEVICE_MATH_VEC(tan) +MIGRAPHX_DEVICE_MATH_VEC(tanh) +MIGRAPHX_DEVICE_MATH_VEC(where) + +template +constexpr auto convert(U v) +{ + return vec_transform(v)([](auto x) -> T { return x; }); +} + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_MATH_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f8f382c82dc5e5d5a66dfc98e908059008c1a3be --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp @@ -0,0 +1,83 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_OPS_HPP +#define MIGRAPHX_GUARD_KERNELS_OPS_HPP + +#include + +namespace migraphx { +namespace op { + +struct sum +{ + template + MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const + { + return x + y; + } +}; + +struct product +{ + template + MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const + { + return x * y; + } +}; + +struct id +{ + template + MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const + { + return x; + } +}; + +struct mean +{ + index_int item_num = 1; + template + MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const + { + return x / static_cast(item_num); + } +}; + +struct max +{ + template + MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const + { + return migraphx::max(x, y); + } +}; + +struct min +{ + template + MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const + { + return migraphx::min(x, y); + } +}; +} // namespace op + +struct lowest +{ + template + constexpr operator T() const + { + return numeric_lowest(); + } +}; + +struct highest +{ + template + constexpr operator T() const + { + return numeric_max(); + } +}; +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_OPS_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..60a3ec23fe56db3bfb84d200ba650462d173fc12 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp @@ -0,0 +1,55 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_POINTWISE_HPP +#define MIGRAPHX_GUARD_KERNELS_POINTWISE_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { + +template +struct implicit_conversion_op +{ + T x; + + template + constexpr operator vec() const + { + static_assert(vec_size() == N, "Vector mismatch size"); + return __builtin_convertvector(x, vec); + } + + template + constexpr operator U() const + { + return x; + } +}; + +template +constexpr implicit_conversion_op implicit_conversion(T x) +{ + return {x}; +} + +template +__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) +{ + idx.global_stride(out.get_shape().elements(), + [&](auto i) { out[i] = implicit_conversion(f(xs[i]...)); }); +} + +template +__device__ auto pointwise(index idx, Transforms... transforms) +{ + return [=](auto f, auto*... ps) { + auto t = transform_args(make_tensors(), rotate_last(), transforms...); + t(ps...)([&](auto... xs) { pointwise_tensor(idx, f, xs...); }); + }; +} + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_POINTWISE_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp new file mode 100644 index 0000000000000000000000000000000000000000..14ab71d47ef33cc2e55d97f5da2c351cebf6e930 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp @@ -0,0 +1,174 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP +#define MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP + +#include +#include +#include +#include + +namespace migraphx { + +template +struct remove_vec_impl +{ + using type = T; +}; + +template +struct remove_vec_impl> +{ + using type = T; +}; + +template +using remove_vec = typename remove_vec_impl::type; + +template +constexpr auto traverse_preload(Shapes... ss) +{ + return [=](auto f, auto... g) { + index_int offset = 0; + auto each = [&](auto x) { + using type = remove_vec; + constexpr auto s = decltype(x.get_shape()){}; + constexpr auto size = s.element_space(); + if constexpr(not s.broadcasted() or (s.elements() - size) < 64 or + not is_same{}) + return f(x, offset, false_type{}); + else + { + auto pre_offset = offset; + offset += size; + offset += offset % 4; + return f(x, pre_offset, true_type{}); + } + }; + return by(each, g...)(ss...); + }; +} + +template +constexpr index_int compute_preload_size_c(Shapes...) +{ + index_int size = 0; + traverse_preload(Shapes{}...)( + [&](auto s, auto offset, auto) { size = offset + s.element_space(); }); + return size; +} + +template +constexpr auto compute_preload_size(Shapes...) +{ + return _c(Shapes{}...)>; +} + +template +__device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs) +{ + auto invoke = [&](auto... ys) { + __syncthreads(); + f(ys...); + }; + traverse_preload(xs...)( + [&](auto x, auto offset, auto copy) { + if constexpr(copy) + { + if constexpr(decltype(tensor_vec_size(x)){} == 0) + { + auto v = auto_vectorize(x); + auto b = as_vec(tensor_vec_size(v), buffer + offset); + idx.local_stride(v.get_shape().element_space(), + [&](auto i) { b[i] = v.data()[i]; }); + return x.with(buffer + offset); + } + else + { + auto b = as_vec(tensor_vec_size(x), buffer + offset); + idx.local_stride(x.get_shape().element_space(), + [&](auto i) { b[i] = x.data()[i]; }); + return x.with(b); + } + } + else + { + return x; + } + }, + invoke); +} + +template +struct shape_type : Shape +{ + using type = T; +}; + +template +constexpr auto make_shape_type(T) +{ + return shape_type{}; +} + +template +__device__ auto preload(index idx, Ts... xs) +{ + using type = remove_vec; + constexpr auto size = decltype(compute_preload_size(make_shape_type(xs)...)){}; + const index_int max_size = 512 * sizeof(type); + return [=](auto f) { + if constexpr(size > 0 and size < max_size) + { + __shared__ type buffer[size]; + preload_copy(idx, f, buffer, xs...); + } + else + { + f(xs...); + } + }; +} + +inline __device__ auto auto_preload(index idx) +{ + return make_transform([=](auto f, auto out, auto... xs) { + preload(idx, xs...)([&](auto... ys) { f(out, ys...); }); + }); +} + +template +__device__ auto preload_copy(index idx, T x) +{ + return [=](auto f) { + if constexpr(B) + { + using type = typename T::type; + constexpr auto size = get_shape_c{}.element_space(); + __shared__ type buffer[size]; + // TODO: Always vecotrize when size > 4, and then use a second loop for remainder + constexpr auto n = find_vectorize_size([&](auto i) { return (size % i) == 0; }); + auto input = as_vec(remove_bool(x.data())); + auto b = as_vec(remove_bool(buffer)); + idx.local_stride(size / n, [&](auto i) { b[i] = input[i]; }); + return f(x.with(buffer)); + } + else + { + return f(x); + } + }; +} + +template +__device__ auto auto_preload(index idx) +{ + return make_transform([=](auto f, auto... xs) { + auto invoke = [=](auto... ys) { + __syncthreads(); + f(ys...); + }; + join(invoke, preload_copy(idx, xs)...); + }); +} + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/print.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/print.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c4a22688050a09a63b0256b25193c842812e24f0 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/print.hpp @@ -0,0 +1,234 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_PRINT_HPP +#define MIGRAPHX_GUARD_KERNELS_PRINT_HPP + +#include +#include +#include +#include + +namespace migraphx { + +template +struct on_exit +{ + F f; + G g; + template + __host__ __device__ auto operator()(T x) const + { + return f(x); + } + + __host__ __device__ ~on_exit() { f(g); } +}; + +template +constexpr auto print_type_name_probe() +{ + constexpr auto name = __PRETTY_FUNCTION__; + constexpr auto size = sizeof(__PRETTY_FUNCTION__); + constexpr auto parameter_name = "PrivateMIGraphXTypeNameProbe = "; + constexpr auto parameter_name_size = sizeof("PrivateMIGraphXTypeNameProbe = ") - 1; + constexpr auto begin = + search(name, name + size, parameter_name, parameter_name + parameter_name_size); + static_assert(begin < name + size, "Type probe not found."); + constexpr auto start = begin + parameter_name_size; + constexpr auto last = find_if(start, name + size, [](auto c) { return c == ']' or c == ';'; }); + return [=](const auto& s) { s.print_string(start, last - start); }; +} + +template +struct type_printer +{ + template + friend constexpr const Stream& operator<<(const Stream& s, type_printer) + { + print_type_name_probe()(s); + return s; + } +}; + +template +constexpr type_printer type_of() +{ + return {}; +} + +template +constexpr type_printer type_of(T) +{ + return {}; +} + +template +constexpr type_printer sub_type_of() +{ + return {}; +} + +template +constexpr type_printer sub_type_of(T) +{ + return {}; +} + +template +struct basic_printer +{ + F f; + __host__ __device__ const basic_printer& print_long(long value) const + { + f([&] { printf("%li", value); }); + return *this; + } + __host__ __device__ const basic_printer& print_ulong(unsigned long value) const + { + f([&] { printf("%lu", value); }); + return *this; + } + __host__ __device__ const basic_printer& print_char(char value) const + { + f([&] { printf("%c", value); }); + return *this; + } + __host__ __device__ const basic_printer& print_string(const char* value) const + { + f([&] { printf("%s", value); }); + return *this; + } + __host__ __device__ const basic_printer& print_string(const char* value, int size) const + { + f([&] { printf("%.*s", size, value); }); + return *this; + } + __host__ __device__ const basic_printer& print_double(double value) const + { + f([&] { printf("%f", value); }); + return *this; + } + __host__ __device__ const basic_printer& print_bool(bool value) const + { + f([&] { + if(value) + printf("true"); + else + printf("false"); + }); + return *this; + } + __host__ __device__ const basic_printer& operator<<(short value) const + { + return print_long(value); + } + __host__ __device__ const basic_printer& operator<<(unsigned short value) const + { + return print_ulong(value); + } + __host__ __device__ const basic_printer& operator<<(int value) const + { + return print_long(value); + } + __host__ __device__ const basic_printer& operator<<(unsigned int value) const + { + return print_ulong(value); + } + __host__ __device__ const basic_printer& operator<<(long value) const + { + return print_long(value); + } + __host__ __device__ const basic_printer& operator<<(unsigned long value) const + { + return print_ulong(value); + } + __host__ __device__ const basic_printer& operator<<(migraphx::half value) const + { + return print_double(value); + } + __host__ __device__ const basic_printer& operator<<(float value) const + { + return print_double(value); + } + __host__ __device__ const basic_printer& operator<<(double value) const + { + return print_double(value); + } + __host__ __device__ const basic_printer& operator<<(bool value) const + { + return print_bool(value); + } + __host__ __device__ const basic_printer& operator<<(char value) const + { + return print_char(value); + } + __host__ __device__ const basic_printer& operator<<(unsigned char value) const + { + return print_char(value); + } + __host__ __device__ const basic_printer& operator<<(const char* value) const + { + return print_string(value); + } +}; + +template +constexpr basic_printer make_printer(F f) +{ + return {f}; +} + +template +constexpr basic_printer> make_printer(F f, G g) +{ + return {{f, g}}; +} + +inline __device__ auto cout() +{ + return make_printer([](auto f) { f(); }); +} + +inline __device__ auto coutln() +{ + return make_printer([](auto f) { f(); }, [] { printf("\n"); }); +} + +template +__device__ void print_each(F f, Ts... xs) +{ + each_args([&](auto x) { f() << x; }, xs...); +} + +template +__device__ void print_each_once(F f, Ts... xs) +{ + auto idx = make_index(); + if(idx.global == 0) + print_each(f, xs...); +} + +template +__device__ void print(Ts... xs) +{ + print_each(&cout, xs...); +} + +template +__device__ void print_once(Ts... xs) +{ + print_each_once(&cout, xs...); +} + +template +__device__ void println(Ts... xs) +{ + print_each(&coutln, xs...); +} + +template +__device__ void println_once(Ts... xs) +{ + print_each_once(&coutln, xs...); +} + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_PRINT_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d5684624e096fe28b97fd0c747e74ec8ce27e6a0 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp @@ -0,0 +1,266 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_REDUCE_HPP +#define MIGRAPHX_GUARD_KERNELS_REDUCE_HPP + +#include +#include +#include +#include + +namespace migraphx { + +#if MIGRAPHX_HAS_DPP + +template +__device__ void dpp_reduce(T& in, Op op) +{ + T out{}; + out = dpp_mov(in); + in = op(in, out); + out = dpp_mov(in); + in = op(in, out); + out = dpp_mov(in); + in = op(in, out); + out = dpp_mov(in); + in = op(in, out); +#if __AMDGCN_WAVEFRONT_SIZE == 64 + out = dpp_mov(in); + in = op(in, out); + out = dpp_mov(in); + in = op(in, out); +#endif +} +#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK) +// NOLINTNEXTLINE +#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1 +#elif __AMDGCN_WAVEFRONT_SIZE == 64 +#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \ + __asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \ + "s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \ + "s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \ + "s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \ + "s_nop 1\n" #ins " %0 %0 %0 row_bcast:15 row_mask:0xa\n" \ + "s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \ + "s_nop 1\n" \ + : "=v"(x) \ + : "0"(x)) +#else +#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \ + __asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \ + "s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \ + "s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \ + "s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \ + "s_nop 1\n" \ + "s_nop 1\n" \ + : "=v"(x) \ + : "0"(x)) +#endif + +// NOLINTNEXTLINE +#define MIGRAPHX_DPP_REDUCE(op, prefix) \ + __device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \ + __device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \ + __device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \ + __device__ inline void dpp_reduce(int32_t& x, op) \ + { \ + MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); \ + } \ + __device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); } + +MIGRAPHX_DPP_REDUCE(op::sum, v_add) +MIGRAPHX_DPP_REDUCE(op::max, v_max) +MIGRAPHX_DPP_REDUCE(op::min, v_min) +MIGRAPHX_DPP_REDUCE(op::product, v_mul) + +template +__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) +{ +#if __AMDGCN_WAVEFRONT_SIZE == 32 + constexpr index_int lanes_per_thread = 16; +#else + constexpr index_int lanes_per_thread = 64; +#endif + using type = decltype(f(0)); + __shared__ type buffer[idx.nlocal() / lanes_per_thread]; + type x = init; + idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); + dpp_reduce(x, op); + + const auto ldsidx = idx.local / lanes_per_thread; + if((idx.local % lanes_per_thread) == lanes_per_thread - 1) + { + buffer[ldsidx] = x; + } + __syncthreads(); + + type y = init; + for(index_int i = 0; i < idx.nlocal() / lanes_per_thread; i++) + { + y = op(y, buffer[i]); + } + return y; +} +#else +template +__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) +{ + + using type = decltype(f(0)); + __shared__ type buffer[idx.nlocal()]; + type x = init; + idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); + buffer[idx.local] = x; + __syncthreads(); + + for(index_int s = 1; s < idx.nlocal(); s *= 2) + { + const index_int index = 2 * s * idx.local; + if(index + s < idx.nlocal()) + { + buffer[index] = op(buffer[index], buffer[index + s]); + } + __syncthreads(); + } + return buffer[0]; +} +#endif + +template +constexpr auto reduce_slice(Input input, T i) +{ + constexpr auto lens = transform(get_shape_c{}.lens, + get_shape_c{}.lens, + [](index_int x, index_int y) -> index_int { + if(x == y) + return 1; + return x; + }); + ; + constexpr auto s = make_shape(lens, get_shape_c{}.strides); + MIGRAPHX_ASSERT((input.get_shape().index(i) + s.element_space()) <= + input.get_shape().element_space()); + return make_tensor_view(&input[i], s); +} + +namespace reduce { + +template +constexpr auto sliced(Slicer slicer, F f) +{ + return [=](auto x, auto... xs) { + // TODO: assert all elements are the same + return f(slicer(x), slicer(xs)...); + }; +} + +struct block +{ + template + struct reducer + { + index idx; + Slicer slicer; + template + __device__ auto reduce(Op op, T init, Read read) const + { + return sliced(slicer, [=](auto x, auto... xs) { + return vec_reduce(block_reduce(idx, + op, + init, + x.get_shape().elements(), + [&](auto j) { return read(x[j], xs[j]...); }), + op); + }); + } + + template + __device__ void outer(F f) const + { + if(idx.local == 0) + f(); + } + }; + + template + static __device__ auto make(index idx, Slicer slicer) + { + return reducer{idx, slicer}; + } + + template + static __device__ void run(F f) + { + auto idx = make_index(); + constexpr auto nelements = get_shape_c{}.elements(); + idx.global_stride(nelements * idx.nlocal(), [&](auto i) { + const auto out_idx = get_shape_c{}.multi(i / idx.nlocal()); + f(out_idx, make(idx, [&](auto input) { return reduce_slice(input, out_idx); })); + }); + } +}; + +struct lane +{ + template + struct reducer + { + index idx; + Slicer slicer; + template + __device__ auto reduce(Op op, T init, Read read) const + { + return sliced(slicer, [=](auto x, auto... xs) { + using type = typename decltype(x)::type; + type r = init; + for(index_int j = 0; j < x.get_shape().elements(); j++) + { + r = op(r, read(x[j], xs[j]...)); + } + return r; + }); + } + + template + __device__ void outer(F f) const + { + f(); + } + }; + + template + static __device__ auto make(index idx, Slicer slicer) + { + return reducer{idx, slicer}; + } + + template + static __device__ void run(F f) + { + auto idx = make_index(); + constexpr auto nelements = get_shape_c{}.elements(); + idx.global_stride(nelements, [&](auto i) { + const auto out_idx = get_shape_c{}.multi(i); + f(out_idx, make(idx, [&](auto input) { return reduce_slice(input, out_idx); })); + }); + } +}; + +} // namespace reduce + +template +__device__ void +simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOuput write) +{ + Algo::template run([&](auto out_idx, auto r) { + auto x = r.reduce(op, init, read)(input); + r.outer([&] { output[out_idx] = write(x); }); + }); +} + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_REDUCE_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp new file mode 100644 index 0000000000000000000000000000000000000000..79533d1b147802cc235adac19a8141ccfa784729 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp @@ -0,0 +1,203 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_ROIALIGN_HPP +#define MIGRAPHX_GUARD_KERNELS_ROIALIGN_HPP + +#include +#include +#include +#include +#include + +namespace migraphx { + +struct max_pool +{ + MIGRAPHX_DEVICE_CONSTEXPR auto init() { return lowest{}; } + + template + MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y) + { + return max(x, y); + } + + template + MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int) + { + return (x); + } +}; + +struct avg_pool +{ + MIGRAPHX_DEVICE_CONSTEXPR auto init() { return 0.0; } + + template + MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y) + { + return x + y; + } + + template + MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y) + { + return (y == 0) ? 0.0 : (x / y); + } +}; + +template +MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( + const Iterator data, const array& dims, array xy, Op pooling) +{ + array low{}; + array high{}; + for(index_int ii = 0; ii < xy.size(); ++ii) + { + if(xy[ii] < -1.0f or xy[ii] > dims[ii]) + { + return 0; + } + + xy[ii] = migraphx::max(xy[ii], 0.0f); + low[ii] = xy[ii]; + high[ii] = low[ii] + 1; + if(low[ii] >= dims[ii] - 1) + { + xy[ii] = high[ii] = low[ii] = dims[ii] - 1; + } + } + array locs = {low[0] * dims[1] + low[1], + low[0] * dims[1] + high[1], + high[0] * dims[1] + low[1], + high[0] * dims[1] + high[1]}; + + float ly = xy[0] - low[0]; + float lx = xy[1] - low[1]; + float hy = 1.0f - ly; + float hx = 1.0f - lx; + array ws = {hy * hx, hy * lx, ly * hx, ly * lx}; + + auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]); + auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]); + return pooling(v01, v23); +} + +template +MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data, + const array& roi_starts, + const array& bin_size, + const array& idx, + const array& bin_grid_size, + const array& dims, + float roi_offset, + Op op) +{ + typename Iterator::value_type output_val = op.init(); + const int64_t count = bin_grid_size[0] * bin_grid_size[1]; + dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) { + array id = {iy, ix}; + array locs = + roi_starts + idx * bin_size + bin_size * (id + 0.5f) / bin_grid_size + roi_offset; + + auto val = bilinear_interpolate(data, dims, locs, op); + output_val = op(output_val, val); + }); + return op.final(output_val, count); +} + +template +struct roalign_settings +{ + T1 roi_offset{}; + T2 is_avg_pooling{}; + T3 sampling_ratio{}; + T4 spatial_scale{}; +}; + +template +constexpr roalign_settings make_roalign_settings(Ts... xs) +{ + return {xs...}; +} + +template +__device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, Settings s) +{ + auto index = make_index(); + const auto x = x_t.begin(); + const auto rois = rois_t.begin(); + const auto ind = ind_t.begin(); + + // input shape + auto x_lens = x_t.get_shape().lens; + auto channel_num = x_lens[1]; + // input dims of height and width, in all 2-dim arrays, the first dim + // is for height and second dim is for width + array in_dims = {x_lens[2], x_lens[3]}; + + const auto stride = index.nglobal(); + auto out_s = y_t.get_shape(); + auto roi_column_num = rois_t.get_shape().lens[1]; + + // output dims of height and width, in all 2-dim arrays, the first dim + // is for height and second dim is for width + const auto& out_lens = out_s.lens; + array out_dims = {out_lens[2], out_lens[3]}; + + for(index_int i = index.global; i < out_s.elements(); i += stride) + { + auto idx = out_s.multi(i); + int n = idx[0]; + int c = idx[1]; + int ph = idx[2]; + int pw = idx[3]; + + const auto offset_rois = rois + (n * roi_column_num); + const int batch_ind = ind[n]; + + array roi_starts = {offset_rois[1] * s.spatial_scale, + offset_rois[0] * s.spatial_scale}; + array roi_ends = {offset_rois[3] * s.spatial_scale, + offset_rois[2] * s.spatial_scale}; + + array roi_size{}; + array bin_size{}; + array bin_grid_size{}; + + for(index_int ii = 0; ii < roi_size.size(); ++ii) + { + roi_size[ii] = roi_ends[ii] - roi_starts[ii]; + roi_size[ii] = migraphx::max(roi_size[ii], 1.0f); + + bin_size[ii] = roi_size[ii] / out_dims[ii]; + bin_grid_size[ii] = (s.sampling_ratio > 0) + ? s.sampling_ratio + : migraphx::ceil(roi_size[ii] / out_dims[ii]); + } + + const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]); + if constexpr(s.is_avg_pooling) + { + y_t[i] = calc_pooling(offset_x, + roi_starts, + bin_size, + {ph, pw}, + bin_grid_size, + in_dims, + s.roi_offset, + avg_pool{}); + } + else + { + y_t[i] = calc_pooling(offset_x, + roi_starts, + bin_size, + {ph, pw}, + bin_grid_size, + in_dims, + s.roi_offset, + max_pool{}); + } + } +} + +} // namespace migraphx +#endif diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/scatternd.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/scatternd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c89c382ec95fae4031f0e149bc2d5d19fc8d43f8 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/scatternd.hpp @@ -0,0 +1,64 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_SCATTERND_HPP +#define MIGRAPHX_GUARD_KERNELS_SCATTERND_HPP + +#include +#include + +namespace migraphx { + +struct assign_none +{ + template + MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const + { + x = y; + } +}; + +struct assign_add +{ + template + MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const + { + x += y; + } +}; + +struct assign_mul +{ + template + MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const + { + x *= y; + } +}; + +template +__device__ void scatternd(const T& indices_t, const U& updates_t, const V& output_t, F f) +{ + auto index = make_index(); + auto updates_shape = updates_t.get_shape(); + + index.global_stride(updates_shape.elements(), [&](auto i) { + auto output_shape = output_t.get_shape(); + + auto indices_shape = indices_t.get_shape(); + auto k = indices_shape.lens.back(); + auto q = indices_shape.lens.size(); + + auto updates_idx = updates_shape.multi(i); + auto indices_idx = indices_shape.multi(0); + copy(updates_idx.begin(), updates_idx.begin() + q - 1, indices_idx.begin()); + + auto index_start = indices_t.begin() + indices_shape.index(indices_idx); + auto index_end = index_start + k; + auto out_idx = output_shape.multi(0); + copy(index_start, index_end, out_idx.begin()); + copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k); + + f(output_t[out_idx], updates_t[i]); + }); +} + +} // namespace migraphx +#endif diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp new file mode 100644 index 0000000000000000000000000000000000000000..50cff7c6ef32ec19578640a69e7368fd63600c27 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp @@ -0,0 +1,125 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_SHAPE_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_SHAPE_HPP + +#include +#include + +namespace migraphx { + +template +struct shape +{ + using index_array = typename Lens::base_array; + Lens lens = {}; + Strides strides = {}; + + constexpr shape() = default; + + constexpr shape(Lens l, Strides s) : lens(l), strides(s) {} + + constexpr auto elements() const { return _c; } + + constexpr auto element_space() const { return _c; } + + constexpr auto packed() const { return elements() == element_space(); } + constexpr auto broadcasted() const { return _c; } + constexpr auto transposed() const + { + return return_c([] { + auto lstrides = Strides{}; + if(shape{}.broadcasted()) + { + index_array s{}; + index_int j = 0; + for(index_int i = 0; i < s.size(); i++) + { + if(lstrides[i] != 0) + { + s[j] = lstrides[i]; + j++; + } + } + return not is_sorted(s.begin(), s.begin() + j, greater{}); + } + else + { + return not is_sorted(lstrides.begin(), lstrides.end(), greater{}); + } + }); + } + + constexpr auto standard() const { return packed() and not transposed(); } + + constexpr index_int index(index_array x) const { return x.dot(strides); } + + constexpr index_int index(std::initializer_list x) const + { + index_int idx = 0; + for(index_int i = 0; i < x.size(); i++) + idx += *(x.begin() + i) * strides[i]; + return idx; + } + + constexpr index_int index(index_int i) const + { + if(this->standard()) + return i; + else + { + const auto rank = this->lens.size(); + index_int s = 1; + index_int result = 0; + for(index_int j = 0; j < rank; j++) + { + const index_int k = rank - j - 1; + const index_int stride = this->strides[k]; + const index_int len = this->lens[k]; + const index_int slen = s * len; + const index_int idx = (i % slen) / s; + result += stride * idx; + s = slen; + } + return result; + } + } + + /// Convert single index into a multi-index + constexpr index_array multi(index_int idx) const + { + index_array result; + index_int tidx = idx; + for(diff_int is = result.size() - 1; is > 0; is--) + { + result[is] = tidx % lens[is]; + tidx = tidx / lens[is]; + } + result[0] = tidx; + return result; + } + /// Convert multi-index into a single index + constexpr index_int single(index_array idx) const + { + if(idx.empty()) + return 0; + return inner_product(lens.begin() + 1, lens.end(), idx.begin(), idx.back()); + } + + constexpr shape get_shape() const { return *this; } + + template + friend constexpr const Stream& operator<<(const Stream& ss, const shape& s) + { + ss << "{" << s.lens << "}, {" << s.strides << "}"; + return ss; + } +}; + +template +constexpr shape make_shape(Lens lens, Strides strides) +{ + return {lens, strides}; +} + +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fb7901e91d8e0261931eb37a22d741c29ed54261 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp @@ -0,0 +1,83 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_TENSOR_VIEW_HPP +#define MIGRAPHX_GUARD_KERNELS_TENSOR_VIEW_HPP + +#include +#include +#include + +namespace migraphx { + +template +struct tensor_view_iterator_read +{ + T* view; + constexpr auto& operator()(index_int n) const + { + MIGRAPHX_ASSERT(view != nullptr); + return (*view)[n]; + } +}; + +template +struct tensor_view +{ + using type = T; + using shape_type = Shape; + using index_array = typename Shape::index_array; + using iterator = basic_iota_iterator, index_int>; + + constexpr Shape get_shape() const { return Shape{}; } + constexpr auto size() const { return get_shape().elements(); } + + struct index_to_offset + { + index_int offset; + template + constexpr index_to_offset(U i) : offset(Shape{}.index(i)) + { + } + }; + + constexpr T& operator[](MIGRAPHX_CAPTURE_SOURCE_LOCATION(index_to_offset) i) const + { + index_to_offset ito = i; + MIGRAPHX_WARN(ito.offset < get_shape().element_space(), + i, + "Out of bounds access at offset: ", + ito.offset); + return x[ito.offset]; + } + + constexpr T* data() const { return x; } + + constexpr auto begin() const { return iterator{0, {this}}; } + constexpr auto end() const { return iterator{this->size(), {this}}; } + + constexpr auto begin_at(index_array i) const + { + MIGRAPHX_ASSERT(get_shape().single(i) < get_shape().elements()); + MIGRAPHX_ASSERT(get_shape().index(i) < get_shape().element_space()); + return iterator{get_shape().single(i), {this}}; + } + + template + constexpr tensor_view with(U* y) const + { + static_assert(sizeof(T) == sizeof(U), "Not the same size"); + return {y}; + } + + T* x; +}; + +template +using get_shape_c = typename T::shape_type; + +template +constexpr tensor_view make_tensor_view(T* x, Shape) +{ + return {x}; +} + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_TENSOR_VIEW_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4d3d4bad527618e116dc36e8154d22c6f4916d24 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp @@ -0,0 +1,214 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPE_TRAITS_HPP + +#include +#include + +namespace migraphx { + +template +U private_declval(int); + +template +T private_declval(long); + +template +auto declval() noexcept -> decltype(private_declval(0)); + +template +struct type_identity +{ + using type = T; +}; + +template +struct enable_if +{ +}; + +template +struct enable_if +{ + using type = T; +}; + +template +using enable_if_t = typename enable_if::type; + +template +struct conditional +{ + using type = T; +}; + +template +struct conditional +{ + using type = F; +}; + +template +using conditional_t = typename conditional::type; + +// NOLINTNEXTLINE +#define MIGRAPHX_BUILTIN_TYPE_TRAIT1(name) \ + template \ + struct name : bool_constant<__##name(T)> \ + { \ + } + +// NOLINTNEXTLINE +#define MIGRAPHX_BUILTIN_TYPE_TRAIT2(name) \ + template \ + struct name : bool_constant<__##name(T, U)> \ + { \ + } + +// NOLINTNEXTLINE +#define MIGRAPHX_BUILTIN_TYPE_TRAITN(name) \ + template \ + struct name : bool_constant<__##name(Ts...)> \ + { \ + } + +// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_arithmetic); +// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_destructible); +// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_nothrow_destructible); +// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_pointer); +// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_scalar); +// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_signed); +// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_void); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_abstract); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_aggregate); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_array); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_class); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_compound); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_const); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_empty); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_enum); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_final); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_floating_point); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_function); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_fundamental); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_integral); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_literal_type); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_lvalue_reference); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_member_function_pointer); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_member_object_pointer); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_member_pointer); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_object); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_pod); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_polymorphic); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_reference); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_rvalue_reference); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_standard_layout); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_trivial); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_trivially_copyable); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_trivially_destructible); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_union); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_unsigned); +MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_volatile); +MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_assignable); +MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_base_of); +MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_convertible); +MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_nothrow_assignable); +MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_same); +MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_trivially_assignable); +MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible); +MIGRAPHX_BUILTIN_TYPE_TRAITN(is_nothrow_constructible); +MIGRAPHX_BUILTIN_TYPE_TRAITN(is_trivially_constructible); + +template +struct remove_reference +{ + using type = T; +}; +template +struct remove_reference +{ + using type = T; +}; +template +struct remove_reference +{ + using type = T; +}; + +template +using remove_reference_t = typename remove_reference::type; + +template +struct add_pointer : type_identity::type*> +{ +}; + +template +using add_pointer_t = typename add_pointer::type; + +template +struct common_type; + +template +struct common_type +{ + using type = T; +}; + +template +struct common_type +{ + using type = decltype(true ? declval() : declval()); +}; + +template +struct common_type +{ + using type = typename common_type::type, Us...>::type; +}; + +template +using common_type_t = typename common_type::type; + +constexpr unsigned long int_max(unsigned long n) { return (1u << (n * 8)) - 1; } + +template +constexpr T numeric_max() +{ + if constexpr(is_integral{}) + { + if constexpr(is_unsigned{}) + return int_max(sizeof(T)) * 2; + else + return int_max(sizeof(T)); + } + else if constexpr(is_same{}) + return __DBL_MAX__; + else if constexpr(is_same{}) + return __FLT_MAX__; + else if constexpr(is_same{}) + return __FLT16_MAX__; + else + return 0; +} + +template +constexpr T numeric_lowest() +{ + if constexpr(is_integral{}) + { + if constexpr(is_unsigned{}) + return 0; + else + return -numeric_max() - 1; + } + else + { + return -numeric_max(); + } +} + +#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__> + +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bc7da139b884305d0d87ab161375eb2963218430 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp @@ -0,0 +1,21 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP + +#include + +namespace migraphx { + +using index_int = std::uint32_t; +using diff_int = std::int32_t; + +#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT + +template +using vec = T __attribute__((ext_vector_type(N))); + +using half = _Float16; +using half2 = migraphx::vec; + +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c7837c536c91f4992f081f3e21c5e81566cbd41b --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp @@ -0,0 +1,164 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_VEC_HPP +#define MIGRAPHX_GUARD_KERNELS_VEC_HPP + +#include +#include +#include + +namespace migraphx { + +template +constexpr auto vec_size(vec) +{ + return index_constant{}; +} + +template +constexpr auto vec_size(T, ...) // NOLINT +{ + return index_constant<0>{}; +} + +template +constexpr auto vec_size() +{ + return decltype(vec_size(T{})){}; +} + +template +constexpr auto is_any_vec() +{ + if constexpr(sizeof...(Ts) == 0) + return false_type{}; + else + return bool_constant<((vec_size() + ...) > 0)>{}; +} + +template +constexpr auto vec_at(T x, I i) +{ + if constexpr(vec_size() == 0) + return x; + else + { + MIGRAPHX_ASSERT(i < vec_size()); + return x[i]; + } +} + +template +using vec_type = decltype(vec_at(T{}, 0)); + +template +constexpr auto common_vec_size() +{ + return fold([](auto x, auto y) { + if constexpr(x > y) + return x; + else + return y; + })(vec_size()...); +} + +// Bools can not be used as a vector type so convert it to uint8 +template +__device__ __host__ T* remove_bool(T* x) +{ + return x; +} + +inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast(x); } + +template +__device__ __host__ auto as_vec(T* x) +{ + if constexpr(N < 2) + return x; + else + return reinterpret_cast*>(x); +} + +template +using safe_vec = vec{}, uint8_t, T>, N>; + +template +constexpr auto vec_transform(Ts... xs) +{ + return [=](auto f) { + if constexpr(is_any_vec()) + { + using type = decltype(f(vec_at(xs, 0)...)); + constexpr auto size = common_vec_size(); + safe_vec result = {0}; + for(int i = 0; i < size; i++) + result[i] = f(vec_at(xs, i)...); + return result; + } + else + { + return f(xs...); + } + }; +} + +// Return a vector type of N from index i in another larger vector +// N will be 2 for half2 packing +template +constexpr vec, N> vec_packed_at(T x, I i) +{ + if constexpr(vec_size() == 0) + return vec{x}; + else + { + MIGRAPHX_ASSERT((i + N) < vec_size()); + vec, N> result = {0}; + for(int j = 0; j < N; j++) + { + result[j] = x[i + j]; + } + return result; + } +} + +template +constexpr auto vec_packed_transform(Ts... xs) +{ + return [=](auto f) { + if constexpr(is_any_vec()) + { + using type = vec_type(xs, 0)...))>; + constexpr auto size = common_vec_size(); + safe_vec result = {0}; + for(int i = 0; i < size / N; i++) + { + // Call the function with packed vectors + safe_vec r = f(vec_packed_at(xs, i * N)...); + // Copy the packed vectors to the result + for(int j = 0; j < N; j++) + result[i * N + j] = r[j]; + } + return result; + } + else + { + return f(xs...); + } + }; +} + +template +constexpr auto vec_reduce(T x, Op op) +{ + if constexpr(vec_size() < 2) + return x; + else + { + vec_type result = x[0]; + for(int i = 1; i < vec_size(); i++) + result = op(result, x[i]); + return result; + } +} + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp new file mode 100644 index 0000000000000000000000000000000000000000..54152f28ef4112d54a394604c1a1361524bed0d6 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp @@ -0,0 +1,240 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP +#define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP + +#include +#include + +namespace migraphx { + +template +constexpr auto tensor_vec_size() +{ + return vec_size(); +} + +template +constexpr auto tensor_vec_size(T) +{ + return tensor_vec_size(); +} + +template +constexpr auto shape_step(Shape s, Axis) +{ + static_assert(N > 0, "Vector size must be non-zero"); + return sequence(s.lens.size(), [&](auto... is) { + auto lens = transform(s.lens, index_ints{}, [&](auto i, auto j) { + constexpr auto axis = Axis::to(); + MIGRAPHX_ASSERT(i != 0); + MIGRAPHX_ASSERT(j != axis or i % N == 0); + if(j == axis) + return i / N; + else + return i; + }); + auto strides = transform(s.strides, index_ints{}, [&](auto i, auto j) { + constexpr auto axis = Axis::to(); + // If stride of the axis is zero then we dont need to adjust the other strides + if(Shape{}.strides[axis] == 0) + return i; + MIGRAPHX_ASSERT(j == axis or i % N == 0); + if(j == axis) + return i; + else + return i / N; + }); + MIGRAPHX_ASSERT(make_shape(lens, strides).elements() * N == s.elements()); + MIGRAPHX_ASSERT(strides[Axis{}] == 0 or + make_shape(lens, strides).element_space() * N == s.element_space()); + return make_shape(lens, strides); + }); +} + +template +__device__ __host__ auto as_vec(T x, Axis axis) +{ + if constexpr(N < 2) + return x; + else + return make_tensor_view(as_vec(remove_bool(x.data())), + shape_step(x.get_shape(), axis)); +} + +template +constexpr auto tensor_step(T x, Axis axis) +{ + if constexpr(N < 2) + { + return x; + } + else + { + constexpr auto s = decltype(x.get_shape()){}; + MIGRAPHX_ASSERT(s.strides[axis] == 0); + return make_tensor_view(x.data(), shape_step(s, axis)); + } +} + +template +__device__ __host__ auto as_vec(IntegralConstant ic, T&& x) +{ + return as_vec(x); +} + +template +constexpr index_int find_vector_axis_c(Shape s) +{ + // Find the fastest axis that is not broadcasted + index_int axis = 0; + for(index_int i = 1; i < s.lens.size(); i++) + { + if(s.strides[i] == 0) + continue; + if(s.strides[axis] == 0 or + pack_compare(less{}, pack(s.strides[i], s.lens[i]), pack(s.strides[axis], s.lens[axis]))) + axis = i; + } + return axis; +} + +template +constexpr index_int find_vector_axis_c(Shapes... ss) +{ + const bool all_broadcasted = (ss.broadcasted() and ...); + index_int axis = 0; + bool b = false; + by([&](auto s) { + if(b) + return; + // Skip broadcasted shapes if there are shapes not broadcasted + if(not all_broadcasted and s.broadcasted()) + return; + axis = find_vector_axis_c(s); + if(s.strides[axis] == 1) + b = true; + })(ss...); + if(not b) + return -1; + return axis; +} + +template +constexpr auto find_vector_axis(Shapes...) +{ + return _c; +} + +template +constexpr auto is_vectorizable_c(Axis axis, Shapes... ss) +{ + return ((axis < ss.lens.size() and ss.lens[axis] % N == 0 and + // Only vectorize broadcasted types with stride 0, since this causes issues in the + // preloader + ((not ss.broadcasted() and ss.strides[axis] == 1) or ss.strides[axis] == 0)) and + ...); +} + +template +constexpr auto is_vectorizable(Axis, Shapes...) +{ + return _c(Axis::to(), Shapes{}...)>; +} + +template +constexpr auto find_vectorize_size(P pred) +{ + if constexpr(decltype(pred(_c<4>)){}) + return _c<4>; + else if constexpr(decltype(pred(_c<2>)){}) + return _c<2>; + else + return _c<1>; +} + +template +__host__ __device__ auto auto_vectorize(T x) +{ + if constexpr(tensor_vec_size() == 0) + { + constexpr auto axis = find_vector_axis(x.get_shape()); + constexpr auto n = + find_vectorize_size([&](auto i) { return is_vectorizable(axis, x.get_shape()); }); + return as_vec(x, axis); + } + else + { + return x; + } +} + +template +inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs) +{ + // TODO: Just check there a single axis of 1 + constexpr bool packed_or_broadcasted = + ((xs.get_shape().packed() or xs.get_shape().broadcasted()) and ...); + if constexpr(packed_or_broadcasted) + { + constexpr auto axis = decltype(find_vector_axis(xs.get_shape()...)){}; + constexpr auto n = find_vectorize_size( + [&](auto i) { return is_vectorizable(axis, xs.get_shape()...); }); + by( + [&](auto x) { + constexpr auto s = decltype(x.get_shape()){}; + if constexpr(axis < s.strides.size()) + { + MIGRAPHX_ASSERT(s.strides[axis] == 0 or s.strides[axis] == 1); + MIGRAPHX_ASSERT(s.lens[axis] > 0); + MIGRAPHX_ASSERT(n == 1 or s.lens[axis] % n == 0); + if constexpr(s.strides[axis] == 0) + return tensor_step(x, axis); + else + return as_vec(x, axis); + } + else + { + return x; + } + }, + f)(xs...); + } + else + { + f(xs...); + } +} + +inline __device__ __host__ auto auto_vectorize() +{ + return make_transform([](auto f, auto... xs) { auto_vectorize_impl(f, xs...); }); +} + +template +__device__ __host__ auto vectorize_tensor(T x) +{ + constexpr auto shape = get_shape_c{}; + if constexpr(shape.lens[Axis] == 1) + return x; + else if constexpr(shape.strides[Axis] == 0) + return tensor_step(x, _c); + else + return as_vec(x, _c); +} + +template +__device__ __host__ auto vectorize() +{ + return make_transform([](auto f, auto... xs) { + if constexpr(N < 2) + { + f(xs...); + } + else + { + f(vectorize_tensor(xs)...); + } + }); +} + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP diff --git a/src/targets/gpu/leaky_relu.cpp b/src/targets/gpu/leaky_relu.cpp index 5e05a1721776cebf87512bad3a8cf804fc3f0bdd..b335cf761332833fd1d672331b0b0c3cfeee1d27 100644 --- a/src/targets/gpu/leaky_relu.cpp +++ b/src/targets/gpu/leaky_relu.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -31,6 +32,11 @@ argument miopen_leaky_relu::compute(context& ctx, return args[1]; } +void miopen_leaky_relu::finalize(context&, const shape&, const std::vector&) +{ + ad = make_leaky_relu(op.alpha); +} + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/logsoftmax.cpp b/src/targets/gpu/logsoftmax.cpp index 56f12fb61f806c4d6b96052da1055fc498057338..2b5041fd1fb96e8aaf04cec64c1537651c17f4c3 100644 --- a/src/targets/gpu/logsoftmax.cpp +++ b/src/targets/gpu/logsoftmax.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace migraphx { @@ -12,13 +13,15 @@ namespace gpu { shape hip_logsoftmax::compute_shape(const std::vector& inputs) const { check_shapes{inputs, *this}.has(2).standard(); - return op.compute_shape({inputs.at(0)}); + return op.normalize_compute_shape({inputs.at(0)}); } argument hip_logsoftmax::compute(context& ctx, const shape&, const std::vector& args) const { - device::logsoftmax(ctx.get_stream().get(), args.back(), args.front(), op.axis); + auto n_dim = args.front().get_shape().lens().size(); + auto tuned_axis = tune_axis(n_dim, op.axis, op.name()); + device::logsoftmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); return args.back(); } diff --git a/src/targets/gpu/loop.cpp b/src/targets/gpu/loop.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1bc85bdbd1b3b2f082593df93d51ce872c753f02 --- /dev/null +++ b/src/targets/gpu/loop.cpp @@ -0,0 +1,97 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +shape hip_loop::compute_shape(std::vector inputs, std::vector mods) const +{ + auto input_num = (inputs.size() - 2) / 2; + inputs.erase(inputs.begin() + input_num, inputs.end()); + return op.compute_shape(inputs, std::move(mods)); +} + +struct gpu_loop +{ + int64_t max_iterations = 0; + + template + void copy(context& ctx, const argument& src, T& dst) const + { + argument arg_dst{src.get_shape(), &dst}; + copy_from_gpu(ctx, src, arg_dst); + } + + template + void copy(context& ctx, T src, const argument& dst) const + { + argument arg_src{dst.get_shape(), &src}; + copy_to_gpu(ctx, arg_src, dst); + } + + void append(const std::vector&, const std::vector&, int) const {} + + void set_zero(context& ctx, const std::vector& concatenated_outputs, int iter) const + { + if(iter >= max_iterations) + return; + + auto elem_num = max_iterations - iter; + for(const auto& out : concatenated_outputs) + { + auto s = out.get_shape(); + auto size = s.bytes() / max_iterations; + auto lens = s.lens(); + lens[0] = elem_num; + shape ss{s.type(), lens}; + assert(ss.bytes() + iter * size <= out.get_shape().bytes()); + device::fill(ctx.get_stream().get(), argument(ss, out.data() + iter * size), 0); + } + } + + std::unordered_map get_output_params(const module& m) const + { + auto get_output_index = [](const std::string& name) { + std::string out_prefix = "#output_"; + auto loc = name.find(out_prefix); + if(loc != std::string::npos) + { + int index = std::stoi(name.substr(loc + out_prefix.size())); + return index; + } + + return -1; + }; + + const auto& param_names = m.get_parameter_names(); + std::unordered_map result; + for(const auto& name : param_names) + { + auto index = get_output_index(name); + if(index == -1) + continue; + result[name] = index; + } + + return result; + } +}; + +argument +hip_loop::compute(context& ctx, + const shape&, + const std::vector& args, + const std::vector& mods, + const std::function( + module_ref&, const std::unordered_map&)>& run) const +{ + return run_loop(gpu_loop{op.max_iterations}, ctx, args, mods, run); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 2ed3342a844b7f61ab144e17c1de5a01edaa38e5..2e9d1f781a238908cb043537d5367c51a88ea5bc 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -1,73 +1,52 @@ -#include +#include #include #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include #include #include -#include -#include -#include -#include -#include -#include +#include +#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -75,12 +54,15 @@ namespace gpu { struct miopen_apply { - program* prog = nullptr; + module* mod = nullptr; const lowering* pass = nullptr; std::unordered_map> apply_map{}; instruction_ref last{}; + bool offload_copy = false; + bool int8_x4_format = true; + bool compute_fp32 = false; - context& get_context() + context& get_context() const { assert(pass != nullptr); assert(pass->ctx != nullptr); @@ -94,112 +76,184 @@ struct miopen_apply (void)i; } + const std::unordered_set& get_rocblas_fp32_archs() + { + static std::unordered_set supported_archs{"gfx908", "gfx90a"}; + return supported_archs; + } + void init() { - assert(prog != nullptr); + assert(mod != nullptr); assert(pass != nullptr); - this->last = instruction::get_output_alias(std::prev(prog->end())); - - add_miopen_simple_op("abs", make_abs); - - add_miopen_extend_op("leaky_relu", make_leaky_relu); - add_miopen_extend_op("elu", make_elu); - - add_generic_op("add"); - add_generic_op("sub"); - add_generic_op("exp"); - add_generic_op("erf"); - add_generic_op("log"); - add_generic_op("sin"); - add_generic_op("cos"); - add_generic_op("tan"); - add_generic_op("sinh"); - add_generic_op("cosh"); - add_generic_op("tanh"); - add_generic_op("asin"); - add_generic_op("acos"); - add_generic_op("atan"); - add_generic_op("sqrt"); - add_generic_op("mul"); - add_generic_op("div"); - add_generic_op("max"); - add_generic_op("min"); - add_generic_op("rsqrt"); - add_generic_op("round"); - add_generic_op("pow"); - add_generic_op("sqdiff"); - add_generic_op("relu"); - add_generic_op("sign"); - add_generic_op("sigmoid"); - add_generic_op("ceil"); - add_generic_op("floor"); - - add_extend_op("contiguous"); - add_extend_op("concat"); - add_extend_op("softmax"); - add_extend_op("logsoftmax"); - add_extend_op("argmax"); - add_extend_op("argmin"); - add_extend_op("gather"); - add_extend_op("pad"); - add_extend_op("convert"); - add_extend_op("clip"); - add_extend_op("reduce_sum"); - add_extend_op("reduce_mean"); - add_extend_op("reduce_min"); - add_extend_op("reduce_max"); - add_gemm_op("dot"); - add_gemm_op("quant_dot"); - add_lrn_op(); +#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 + auto& ctx = get_context(); + const auto device_name = trim(split_string(get_device_name(), ':').front()); + if(contains(get_rocblas_fp32_archs(), device_name)) + compute_fp32 = true; + rocblas_gemm_flags flag; + rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag); + int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4); +#endif + + offload_copy = (mod->name() == "main") ? pass->offload_copy : false; + + add_generic_op("acos"); + add_generic_op("acosh"); + add_generic_op("add"); + add_generic_op("asin"); + add_generic_op("asinh"); + add_generic_op("atan"); + add_generic_op("atanh"); + add_generic_op("ceil"); + add_generic_op("contiguous"); + add_generic_op("cos"); + add_generic_op("cosh"); + add_generic_op("div"); + add_generic_op("equal"); + add_generic_op("erf"); + add_generic_op("exp"); + add_generic_op("floor"); + add_generic_op("greater"); + add_generic_op("less"); + add_generic_op("log"); + add_generic_op("logical_and"); + add_generic_op("logical_or"); + add_generic_op("logical_xor"); + add_generic_op("max"); + add_generic_op("min"); + add_generic_op("mul"); + add_generic_op("not"); + add_generic_op("pow"); + add_generic_op("prelu"); + add_generic_op("recip"); + add_generic_op("relu"); + add_generic_op("round"); + add_generic_op("rsqrt"); + add_generic_op("sigmoid"); + add_generic_op("sign"); + add_generic_op("sin"); + add_generic_op("sinh"); + add_generic_op("sqdiff"); + add_generic_op("sqrt"); + add_generic_op("sub"); + add_generic_op("tan"); + add_generic_op("tanh"); + add_generic_op("where"); + + add_extend_op("abs"); + add_extend_op("argmax"); + add_extend_op("argmin"); + add_extend_op("clip"); + add_extend_op("concat"); + add_extend_op("convert"); + add_extend_op("elu"); + add_extend_op("gather"); + add_extend_op("leaky_relu"); + add_extend_op("logsoftmax"); + add_extend_op("lrn"); + add_extend_op("multinomial"); + add_extend_op("nonzero"); + add_extend_op("pad"); + add_extend_op("pooling"); + add_extend_op("prefix_scan_sum"); + add_extend_op("reverse"); + add_extend_op("rnn_var_sl_last_output"); + add_extend_op("rnn_var_sl_shift_output"); + add_extend_op("rnn_var_sl_shift_sequence"); + add_extend_op("scatter_none"); + add_extend_op("softmax"); + add_extend_op("topk"); + + add_batch_norm_inference_op(); add_convolution_op(); + add_deconvolution_op(); + add_gemm_op("dot"); + add_gemm_op("quant_dot"); + add_if_op(); + add_loop_op(); + add_neg_op(); + add_nms_op(); add_quant_convolution_op(); - add_pooling_op(); - add_batch_norm_inference_op(); } - void copy_params() + void copy_params() const { - if(not pass->offload_copy) + if(not offload_copy) return; - for(auto ins : iterator_for(*prog)) + + for(auto ins : iterator_for(*mod)) { if(ins->name() != "@param") continue; + + // parameter no outputs, no need to insert copy to gpu + if(ins->outputs().empty()) + continue; + auto pos = std::next(ins); auto a = insert_allocation(pos, ins->get_shape()); - auto c = prog->insert_instruction(pos, hip_copy_to_gpu{}, ins, a); - prog->replace_instruction(ins, c); + auto c = mod->insert_instruction(pos, make_op("hip::copy_to_gpu"), ins, a); + mod->replace_instruction(ins, c); + } + + // return instruction + auto ret = std::prev(mod->end()); + if(ret->name() == "@return") + { + const auto& inputs = ret->inputs(); + + // each input of ret need to be copied from gpu to host, and replace + // output with copy output + for(const auto& in : inputs) + { + auto p_output = mod->insert_instruction(ret, make_op("hip::copy_from_gpu"), in); + instruction::replace_argument(ret, in, p_output); + } + } + // else branch to handle legacy program without the return instruction + else + { + mod->add_instruction(make_op("hip::copy_from_gpu"), ret); } - auto end = std::prev(prog->end()); - prog->add_instruction(hip_copy_from_gpu{}, end); } void apply() { init(); - for(auto it = prog->begin(); it != prog->end(); it++) + for(auto it = mod->begin(); it != mod->end(); it++) { auto s = it->get_shape(); if(apply_map.count(it->name()) > 0) { check_shape(s, apply_map.at(it->name())(it)); } + else if(has_compiler_for(it->name())) + { + check_shape(s, insert_precompile_op(it)); + } } + copy_params(); } - instruction_ref insert_allocation(instruction_ref ins, const shape& s, std::string tag = "") + instruction_ref insert_precompile_op(instruction_ref ins) const { - if(not pass->offload_copy and ins == last and tag.empty()) - { - return prog->add_parameter("output", s); - } - else - { - auto result = prog->insert_instruction(ins, hip_allocate{s, std::move(tag)}); - return result; - } + auto output = insert_allocation(ins, ins->get_shape()); + std::vector refs = ins->inputs(); + refs.push_back(output); + + return mod->replace_instruction( + ins, + make_op("gpu::precompile_op", {{"op", to_value(ins->get_operator())}}), + refs, + ins->module_inputs()); + } + + instruction_ref insert_allocation(instruction_ref ins, const shape& s) const + { + return mod->insert_instruction(ins, make_op("allocate", {{"shape", to_value(s)}})); } void add_convolution_op() @@ -208,46 +262,42 @@ struct miopen_apply auto&& op = any_cast(ins->get_operator()); auto conv = miopen_convolution{op, make_conv(op)}; + auto ws = conv.find(get_context(), ins->get_shape(), to_shapes(ins->inputs())); + + auto workspace = insert_allocation(ins, ws); + auto output = insert_allocation(ins, ins->get_shape()); + + return mod->replace_instruction( + ins, conv, ins->inputs().at(0), ins->inputs().at(1), workspace, output); + }); + } + + void add_deconvolution_op() + { + apply_map.emplace("deconvolution", [=](instruction_ref ins) { + auto&& op = any_cast(ins->get_operator()); + + auto conv = miopen_deconvolution{op, make_deconv(op)}; auto ws = conv.compile(get_context(), ins->get_shape(), to_shapes(ins->inputs())); - auto workspace = insert_allocation(ins, ws, "workspace"); + auto workspace = insert_allocation(ins, ws); auto output = insert_allocation(ins, ins->get_shape()); - return prog->replace_instruction( + return mod->replace_instruction( ins, conv, ins->inputs().at(0), ins->inputs().at(1), workspace, output); }); } - template - void add_gemm_op(std::string name) + template + void add_gemm_op(const std::string& name) { apply_map.emplace(name, [=](instruction_ref ins) { - auto&& op = any_cast(ins->get_operator()); - auto beta = op.beta; std::vector refs = ins->inputs(); - if(refs.size() == 2) - { - auto output = insert_allocation(ins, ins->get_shape()); - beta = 0; - refs.push_back(output); - } - else - { - auto c_alias = instruction::get_output_alias(refs.back()); - if(ins == last or refs.back()->outputs().size() > 1 or c_alias->inputs().empty()) - { - auto output = insert_allocation(ins, ins->get_shape()); - auto copy_out = prog->insert_instruction(ins, hip_copy{}, refs.back(), output); - refs.back() = copy_out; - refs.push_back(copy_out); - } - else - { - refs.push_back(refs.back()); - } - } - - return prog->replace_instruction(ins, rocblas_gemm{Op{op.alpha, beta}}, refs); + assert(refs.size() == 2); + auto output = insert_allocation(ins, ins->get_shape()); + refs.push_back(output); + return mod->replace_instruction( + ins, rocblas_gemm{Op{}, 1, 0, int8_x4_format, compute_fp32}, refs); }); } @@ -255,113 +305,183 @@ struct miopen_apply { apply_map.emplace("quant_convolution", [=](instruction_ref ins) { auto&& op = any_cast(ins->get_operator()); - auto conv = miopen_quant_convolution{op, make_conv(op)}; - auto ws = conv.compile(get_context(), ins->get_shape(), to_shapes(ins->inputs())); + shape ws; + miopen_quant_convolution conv; + auto compile_quant_conv_with_format = [&](bool format) { + conv = miopen_quant_convolution{op, format, make_conv(op)}; + ws = conv.compile(get_context(), ins->get_shape(), to_shapes(ins->inputs())); + }; + + try + { + compile_quant_conv_with_format(int8_x4_format); + } + catch(migraphx::exception&) + { + // In case no solver supports the default format, retry using the other format. + compile_quant_conv_with_format(!int8_x4_format); + } auto args = ins->inputs(); - auto workspace = insert_allocation(ins, ws, "workspace"); + auto workspace = insert_allocation(ins, ws); auto output = insert_allocation(ins, ins->get_shape()); - return prog->replace_instruction(ins, conv, args[0], args[1], workspace, output); + return mod->replace_instruction(ins, conv, args[0], args[1], workspace, output); }); } - void add_pooling_op() - { - apply_map.emplace("pooling", [=](instruction_ref ins) { - auto&& op = any_cast(ins->get_operator()); - auto pd = make_pooling(op); - auto output = insert_allocation(ins, ins->get_shape()); + // add_generic_op just constructs the operator with no fields whereas add_extend_op copies over + // the fields Since it doesn't have fields its default constructed - return prog->replace_instruction( - ins, miopen_pooling{op, std::move(pd)}, ins->inputs().at(0), output); - }); - } + void add_generic_op(const std::string& name) { add_generic_op(name, "gpu::" + name); } - void add_lrn_op() + void add_generic_op(const std::string& op_name, const std::string& gpu_name) { - apply_map.emplace("lrn", [=](instruction_ref ins) { - auto&& op = any_cast(ins->get_operator()); - auto ldesc = make_lrn(op); - auto output = insert_allocation(ins, ins->get_shape()); - return prog->replace_instruction( - ins, miopen_lrn{std::move(ldesc)}, ins->inputs().at(0), output); + apply_map.emplace(op_name, [=](instruction_ref ins) { + auto output = insert_allocation(ins, ins->get_shape()); + std::vector refs = ins->inputs(); + refs.push_back(output); + + return mod->replace_instruction(ins, make_op(gpu_name), refs); }); } - template - void add_generic_op(std::string name) + void add_extend_op(const std::string& name) { add_extend_op(name, "gpu::" + name); } + + void add_extend_op(const std::string& op_name, const std::string& gpu_name) { - apply_map.emplace(name, [=](instruction_ref ins) { + apply_map.emplace(op_name, [=](instruction_ref ins) { + auto&& op = ins->get_operator(); auto output = insert_allocation(ins, ins->get_shape()); std::vector refs = ins->inputs(); refs.push_back(output); - return prog->replace_instruction(ins, T{}, refs); + return mod->replace_instruction(ins, make_op(gpu_name, op.to_value()), refs); }); } - template - void add_extend_op(std::string name) + void add_batch_norm_inference_op() { - apply_map.emplace(name, [=](instruction_ref ins) { - auto&& op = any_cast(ins->get_operator()); - auto output = insert_allocation(ins, ins->get_shape()); - std::vector refs = ins->inputs(); - refs.push_back(output); + apply_map.emplace("batch_norm_inference", [=](instruction_ref ins) { + auto&& op = any_cast(ins->get_operator()); + auto output = insert_allocation(ins, ins->get_shape()); + shape old_shape = ins->inputs().at(1)->get_shape(); + auto input = ins->inputs()[0]; + auto input_lens = input->get_shape().lens(); + std::vector rsp_lens(input_lens.size(), 1); + // for per_activation case, also need to reshape input + if(op.bn_mode == op::batch_norm_inference::per_activation) + { + std::copy(input_lens.begin() + 1, input_lens.end(), rsp_lens.begin() + 1); + } + else + { + rsp_lens[1] = static_cast(old_shape.elements()); + } - return prog->replace_instruction(ins, T{op}, refs); + auto reshape_op = op::reshape{rsp_lens}; + std::vector reshapes; + std::transform(ins->inputs().begin() + 1, + ins->inputs().end(), + std::back_inserter(reshapes), + [&](auto i) { return mod->insert_instruction(ins, reshape_op, i); }); + + return mod->replace_instruction(ins, + miopen_batch_norm_inference{op}, + input, + reshapes[0], + reshapes[1], + reshapes[2], + reshapes[3], + output); }); } - template - void add_miopen_extend_op(std::string name, F f) + // use 0 - input to represent neg + void add_neg_op() { - apply_map.emplace(name, [=](instruction_ref ins) { - auto&& op = any_cast(ins->get_operator()); - auto ad = f(op.alpha); + apply_map.emplace("neg", [=](instruction_ref ins) { + auto s = ins->get_shape(); + std::vector zeros(s.elements(), 0.0f); + auto l0 = mod->add_literal(literal(s, zeros)); + auto output = insert_allocation(ins, s); + return mod->replace_instruction( + ins, make_op("gpu::sub"), l0, ins->inputs().front(), output); + }); + } - auto output = insert_allocation(ins, ins->get_shape()); - return prog->replace_instruction(ins, T{std::move(ad)}, ins->inputs().at(0), output); + // add input and output argument for the if operator + void add_if_op() + { + apply_map.emplace("if", [=](instruction_ref ins) { + std::vector inputs = ins->inputs(); + auto cpu_cond = + mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), inputs.front()); + auto sync_cond = mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_cond); + inputs.front() = sync_cond; + + return mod->replace_instruction(ins, ins->get_operator(), inputs, ins->module_inputs()); }); } - template - void add_miopen_simple_op(std::string name, F f) + // replace the loop operator with gpu_loop operator + void add_loop_op() { - apply_map.emplace(name, [=](instruction_ref ins) { - auto ad = f(); - auto output = insert_allocation(ins, ins->get_shape()); - return prog->replace_instruction(ins, T{std::move(ad)}, ins->inputs().at(0), output); + apply_map.emplace("loop", [=](instruction_ref ins) { + std::vector inputs = ins->inputs(); + // copy max_iter from gpu to cpu + auto cpu_max_iter = + mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), inputs.at(0)); + auto cpu_cond = + mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), inputs.at(1)); + auto synced_max_iter = + mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_max_iter, cpu_cond); + inputs.at(0) = synced_max_iter; + inputs.at(1) = cpu_cond; + auto copy_inputs = inputs; + std::transform(copy_inputs.begin(), + copy_inputs.end(), + std::back_inserter(inputs), + [&](auto in) { return insert_allocation(ins, in->get_shape()); }); + + auto mod_args = ins->module_inputs(); + auto output = insert_allocation(ins, ins->get_shape()); + + const auto* sub_mod = mod_args.front(); + auto cond_out = insert_allocation(ins, sub_mod->get_output_shapes().front()); + + // add cond and mod outputs to the argument list + inputs.push_back(cond_out); + inputs.push_back(output); + + return mod->replace_instruction( + ins, make_op("gpu::loop", ins->get_operator().to_value()), inputs, mod_args); }); } - void add_batch_norm_inference_op() + void add_nms_op() { - apply_map.emplace("batch_norm_inference", [=](instruction_ref ins) { - auto&& op = any_cast(ins->get_operator()); - auto output = insert_allocation(ins, ins->get_shape()); - shape old_shape = ins->inputs().at(1)->get_shape(); - std::vector new_shape{1, static_cast(old_shape.elements()), 1, 1}; - auto reshape_op = op::reshape{new_shape}; - std::vector reshapes; - std::transform(ins->inputs().begin() + 1, - ins->inputs().end(), - std::back_inserter(reshapes), - [&](auto i) { return prog->insert_instruction(ins, reshape_op, i); }); - return prog->replace_instruction(ins, - miopen_batch_norm_inference{op}, - ins->inputs().at(0), - reshapes[0], - reshapes[1], - reshapes[2], - reshapes[3], - output); + apply_map.emplace("nonmaxsuppression", [=](instruction_ref ins) { + auto s = ins->get_shape(); + auto output = insert_allocation(ins, s); + std::vector cpu_inputs; + auto inputs = ins->inputs(); + std::transform( + inputs.begin(), inputs.end(), std::back_inserter(cpu_inputs), [&](auto in) { + return mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), in); + }); + cpu_inputs.front() = + mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_inputs); + auto cpu_out = mod->insert_instruction(ins, ins->get_operator(), cpu_inputs); + auto gpu_out = + mod->insert_instruction(ins, make_op("hip::copy_to_gpu"), cpu_out, output); + return mod->replace_instruction(ins, gpu_out); }); } }; -void lowering::apply(program& p) const { miopen_apply{&p, this}.apply(); } +void lowering::apply(module& m) const { miopen_apply{&m, this}.apply(); } + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/lrn.cpp b/src/targets/gpu/lrn.cpp index 96427b91362adf431aa1d8795e05440e1c7040fe..c5505a5cef0b4590b968965cf29ca61773e8e43c 100644 --- a/src/targets/gpu/lrn.cpp +++ b/src/targets/gpu/lrn.cpp @@ -33,6 +33,11 @@ argument miopen_lrn::compute(context& ctx, return args[1]; } +void miopen_lrn::finalize(context&, const shape&, const std::vector&) +{ + ldesc = make_lrn(op); +} + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/mlir_conv.cpp b/src/targets/gpu/mlir_conv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6183976781a2344bf593dbe4fd02669c4ab4601e --- /dev/null +++ b/src/targets/gpu/mlir_conv.cpp @@ -0,0 +1,292 @@ +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#ifdef MIGRAPHX_MLIR_MIOPEN_SUPPORT +#include +#endif // MIGRAPHX_MLIR_MIOPEN_SUPPORT + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct mlir_apply +{ + module* mod = nullptr; + const mlir_conv* pass = nullptr; + + const char* mlir_kernel_name = "migraphx_conv2d"; + + std::unordered_map literal_map{}; + + struct execution_spec + { + migraphx::value::binary binary; + size_t global_size; + size_t local_size; + execution_spec(migraphx::value::binary&& binary_m, size_t global_s, size_t local_s) + : binary(std::move(binary_m)), global_size(global_s), local_size(local_s) + { + } + }; + + std::unordered_map> binary_map{}; + + context& get_context() const + { + assert(pass != nullptr); + assert(pass->ctx != nullptr); + return *pass->ctx; + } + + void init() const + { + assert(mod != nullptr); + assert(pass != nullptr); + } + + std::shared_ptr make_mlir_binary(instruction_ref op_r) + { + std::shared_ptr result; + +#ifdef MIGRAPHX_MLIR_MIOPEN_SUPPORT + auto conv = any_cast(op_r->get_operator()); + auto inp_t = op_r->inputs().at(0)->get_shape(); + auto flt_t = op_r->inputs().at(1)->get_shape(); + auto out_t = op_r->get_shape(); + + auto get_type_str = [](const shape& s) -> const char* { + switch(s.type()) + { + case shape::float_type: return "f32"; + case shape::half_type: return "f16"; + case shape::bool_type: + case shape::double_type: + case shape::uint8_type: + case shape::int8_type: + case shape::uint16_type: + case shape::int16_type: + case shape::int32_type: + case shape::int64_type: + case shape::uint32_type: + case shape::uint64_type: + case shape::tuple_type: break; + } + return nullptr; + }; + + const auto* inp_t_s = get_type_str(inp_t); + const auto* flt_t_s = get_type_str(flt_t); + const auto* out_t_s = get_type_str(out_t); + + if(out_t_s == nullptr || inp_t_s == nullptr || flt_t_s == nullptr) + return result; + + std::string mlir_options = "--kernel_name " + std::string(mlir_kernel_name); + + // platform spec + auto& device = get_context().get_current_device(); + char dev_name[64]; + sprintf(dev_name, "gfx%lu%02lu", device.get_device_major(), device.get_device_minor()); + mlir_options += " --arch " + std::string(dev_name) + " --num_cu " + + std::to_string(device.get_cu_count()); // ??? + + // Conv spec + mlir_options += + " --operation " + "conv2d" + " --batchsize " + + std::to_string(conv.group) + " --groupsize " + std::to_string(1) + " --padding_h " + + std::to_string(conv.padding[0]) + " --padding_w " + std::to_string(conv.padding[1]) + + " --conv_stride_h " + std::to_string(conv.stride[0]) + " --conv_stride_w " + + std::to_string(conv.stride[1]) + " --dilation_h " + std::to_string(conv.dilation[0]) + + " --dilation_w " + std::to_string(conv.dilation[1]); + + // Input spec + mlir_options += " --in_layout " + "NCHWG" + " --in_type " + + std::string(inp_t_s) + " --in_channels " + std::to_string(inp_t.lens()[1]) + + " --in_h " + std::to_string(inp_t.lens()[2]) + " --in_w " + + std::to_string(inp_t.lens()[3]); + + // Filter spec + mlir_options += " --fil_layout " + "NCHWG" + " --fil_type " + + std::string(flt_t_s) + " --fil_h " + std::to_string(flt_t.lens()[2]) + + " --fil_w " + std::to_string(flt_t.lens()[3]); + + // Output spec + mlir_options += " --out_layout " + "NCHWG" + " --out_type " + + std::string(out_t_s) + " --out_channels " + + std::to_string(out_t.lens()[1]) + " --out_h " + + std::to_string(out_t.lens()[2]) + " --out_w " + + std::to_string(out_t.lens()[3]); + + auto bin_i = binary_map.find(mlir_options); + if(bin_i == binary_map.end()) + { + size_t bin_size = 0; + + using mlir_handle = MIGRAPHX_MANAGE_PTR(MiirHandle, miirDestroyHandle); + auto handle = mlir_handle(miirCreateHandle(mlir_options.c_str())); + + if(miirLowerBin(handle.get()) == MIIR_SUCCESS && + miirBufferGet(handle.get(), nullptr, &bin_size) == MIIR_SUCCESS) + { + migraphx::value::binary bin(bin_size); + if(miirBufferGet(handle.get(), reinterpret_cast(bin.data()), &bin_size) == + MIIR_SUCCESS) + { + size_t global_size; + size_t block_size; + if(miirGetExecutionDims(handle.get(), &global_size, &block_size) == + MIIR_SUCCESS) + { + result = std::make_shared( + std::move(bin), global_size, block_size); + } + } + } + + binary_map[mlir_options] = result; + } + else + { + result = bin_i->second; + } +#else // MIGRAPHX_MLIR_MIOPEN_SUPPORT + (void)op_r; +#endif // MIGRAPHX_MLIR_MIOPEN_SUPPORT + return result; + } + + instruction_ref get_literal(uint64_t value) + { + auto fi = literal_map.find(value); + if(fi != literal_map.end()) + return fi->second; + auto lit = mod->add_literal(value); + literal_map.emplace(value, lit); + return lit; + } + + operation make_code_object_op(instruction_ref op_r, const std::shared_ptr& spec) + { + // each pointer is expanded out to a MemRefDescriptor + auto inp_t = op_r->inputs().at(0)->get_shape(); + auto flt_t = op_r->inputs().at(1)->get_shape(); + auto out_t = op_r->get_shape(); + + auto i64 = shape(shape::uint64_type); + + std::vector expected_inputs = { + flt_t, flt_t, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, inp_t, + inp_t, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, out_t, out_t, + i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, out_t}; + + return migraphx::make_op("gpu::code_object", + { + {"code_object", spec->binary}, + {"symbol_name", mlir_kernel_name}, + {"global", spec->global_size}, + {"local", spec->local_size}, + {"expected_inputs", migraphx::to_value(expected_inputs)}, + {"output", migraphx::to_value(out_t)}, + }); + } + + void add_memref_descriptor(std::vector& refs, instruction_ref inst) + { + const size_t offset = 0; + auto inst_t = inst->get_shape(); + refs.push_back(inst); + refs.push_back(inst); + refs.push_back(get_literal(offset)); // offset + + // dim sizes + std::transform(inst_t.lens().begin(), + inst_t.lens().end(), + std::back_inserter(refs), + [&](const auto& lval) { return get_literal(lval); }); + refs.push_back(get_literal(1)); // G + + // dim strides + std::transform(inst_t.strides().begin(), + inst_t.strides().end(), + std::back_inserter(refs), + [&](const auto& lval) { return get_literal(lval); }); + refs.push_back(get_literal(1)); // G + } + + instruction_ref insert_allocation(instruction_ref ins, const shape& s) const + { + return mod->insert_instruction(ins, hip_allocate{s}); + } + + void replace_conv_op(instruction_ref ins) + { + auto conv_bin = make_mlir_binary(ins); + if(conv_bin) + { + auto conv = make_code_object_op(ins, conv_bin); + + auto inp = ins->inputs().at(0); + auto flt = ins->inputs().at(1); + auto out = insert_allocation(ins, ins->get_shape()); + + std::vector refs; + refs.reserve(3 * 13 + 1); + add_memref_descriptor(refs, flt); + add_memref_descriptor(refs, inp); + add_memref_descriptor(refs, out); + refs.push_back(out); + + mod->replace_instruction(ins, conv, refs); + } + } + + void apply() + { + init(); + for(auto it : iterator_for(*mod)) + { + if(it->name() == "convolution") + { + replace_conv_op(it); + } + } + } +}; + +void mlir_conv::apply(module& m) const { mlir_apply{&m, this}.apply(); } + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/multinomial.cpp b/src/targets/gpu/multinomial.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2734201faba2247da015939332ebeffa8fbef129 --- /dev/null +++ b/src/targets/gpu/multinomial.cpp @@ -0,0 +1,27 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +shape hip_multinomial::compute_shape(std::vector inputs) const +{ + check_shapes{inputs, *this}.has(3).only_dims(2).standard(); + inputs.pop_back(); + return op.compute_shape(inputs); +} + +argument +hip_multinomial::compute(context& ctx, const shape&, const std::vector& args) const +{ + device::multinomial(ctx.get_stream().get(), args.back(), args.front(), args[1]); + return args.back(); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/nonzero.cpp b/src/targets/gpu/nonzero.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e4ca0f51e99545c5ee45145e1cf9e367e78f57d1 --- /dev/null +++ b/src/targets/gpu/nonzero.cpp @@ -0,0 +1,21 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +shape hip_nonzero::compute_shape(std::vector inputs) const +{ + return op.compute_shape({inputs.front()}); +} + +argument hip_nonzero::compute(context& ctx, const shape&, const std::vector& args) const +{ + return device::nonzero(ctx.get_stream().get(), args.back(), args.front()); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/pack_args.cpp b/src/targets/gpu/pack_args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..68e6860221e7efbffe4f4ee526de59fc35d9fd7c --- /dev/null +++ b/src/targets/gpu/pack_args.cpp @@ -0,0 +1,25 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +std::vector pack_args(const std::vector& args) +{ + std::vector kernargs; + for(auto&& arg : args) + { + std::size_t n = arg.size; + const auto* p = static_cast(arg.data); + // Insert padding + std::size_t padding = (arg.align - (kernargs.size() % arg.align)) % arg.align; + kernargs.insert(kernargs.end(), padding, 0); + kernargs.insert(kernargs.end(), p, p + n); + } + return kernargs; +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/pack_int8_args.cpp b/src/targets/gpu/pack_int8_args.cpp index 27782dc681f210703b6bfa1dba1503bc2cbecd1e..610963ea77c455b7fead8ba0b0b72042e3d24df4 100644 --- a/src/targets/gpu/pack_int8_args.cpp +++ b/src/targets/gpu/pack_int8_args.cpp @@ -1,54 +1,182 @@ +#include #include #include #include #include #include +#include #include #include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { -void pack_int8_args::apply(program& p) const +static instruction_ref pad_ins(module& m, instruction_ref ins, int offset) { - for(auto ins : iterator_for(p)) + auto s = ins->get_shape(); + auto lens = s.lens(); + auto k = lens[lens.size() + offset]; + auto pad_k = (k + 3) / 4 * 4; + auto pad_lens = lens; + pad_lens[lens.size() + offset] = pad_k; + auto ret_ins = ins; + if(pad_k != k) + { + std::vector pad_dims(lens.size() * 2, 0); + pad_dims[lens.size() + offset] = pad_k - k; + shape ps{s.type(), pad_lens}; + auto ins_out = + m.insert_instruction(ins, make_op("hip::allocate", {{"shape", to_value(ps)}})); + auto pad = make_op("pad", {{"pads", pad_dims}}); + ret_ins = + m.insert_instruction(std::next(ins), make_op("gpu::pad", pad.to_value()), ins, ins_out); + } + + return ret_ins; +} + +static std::vector pad_inputs(module& m, instruction_ref ins) +{ + std::vector ret_inputs; + auto inputs = ins->inputs(); + auto in0 = inputs.at(0); + auto sa = in0->get_shape(); + bool transa = sa.transposed(); + if(transa) + { + auto perm = find_permutation(sa); + auto val = in0->get_operator().to_value(); + if(val.contains("dims")) + { + int offset = static_cast(perm.back()) - static_cast(perm.size()); + auto t_in = in0->inputs().front(); + auto p_in = pad_ins(m, t_in, offset); + auto dims = val.at("dims").to_vector(); + auto r_in = + m.insert_instruction(ins, make_op("transpose", {{"permutation", dims}}), p_in); + ret_inputs.push_back(r_in); + } + else + { + shape cs{in0->get_shape().type(), in0->get_shape().lens()}; + auto con_out = + m.insert_instruction(ins, make_op("hip::allocate", {{"shape", to_value(cs)}})); + auto cin0 = m.insert_instruction(ins, make_op("gpu::contiguous"), in0, con_out); + ret_inputs.push_back(pad_ins(m, cin0, -1)); + } + } + else + { + ret_inputs.push_back(pad_ins(m, in0, -1)); + } + + auto in1 = inputs.at(1); + auto sb = in1->get_shape(); + bool transb = sb.transposed(); + if(transb) + { + auto perm = find_permutation(sb); + auto val = in1->get_operator().to_value(); + if(val.contains("dims")) + { + int offset = static_cast(perm[perm.size() - 2]) - static_cast(perm.size()); + auto t_in = in1->inputs().front(); + auto p_in = pad_ins(m, t_in, offset); + auto dims = val.at("dims").to_vector(); + auto r_in = + m.insert_instruction(ins, make_op("transpose", {{"permutation", dims}}), p_in); + ret_inputs.push_back(r_in); + } + else + { + shape cs{in1->get_shape().type(), in1->get_shape().lens()}; + auto con_out = + m.insert_instruction(ins, make_op("hip::allocate", {{"shape", to_value(cs)}})); + auto cin1 = m.insert_instruction(ins, make_op("gpu::contiguous"), in1, con_out); + ret_inputs.push_back(pad_ins(m, cin1, -2)); + } + } + else + { + ret_inputs.push_back(pad_ins(m, in1, -2)); + } + std::copy(inputs.begin() + 2, inputs.end(), std::back_inserter(ret_inputs)); + + return ret_inputs; +} + +void pack_int8_args::apply(module& m) const +{ + for(auto ins : iterator_for(m)) { if(ins->name() == "gpu::quant_gemm") { + auto val = ins->get_operator().to_value(); + assert(val.contains("int8_x4_format")); + if(not val.at("int8_x4_format").to()) + { + continue; + } auto inputs = ins->inputs(); + auto lens = inputs.at(0)->get_shape().lens(); + // gemm need the k to be multiple of 4, so need packing that dimension + auto old_inputs = inputs; + if((lens.back() % 4) != 0) + { + inputs = pad_inputs(m, ins); + } + bool transa = inputs[0]->get_shape().transposed(); bool transb = inputs[1]->get_shape().transposed(); - if(!transb) { - auto packed_b = p.insert_instruction(ins, hip_allocate{inputs[1]->get_shape()}); - auto output_b = - p.insert_instruction(ins, hip_int8_gemm_pack_a{}, {inputs[1], packed_b}); - instruction::replace_argument(ins, inputs[1], output_b); + auto packed_b = m.insert_instruction( + ins, make_op("hip::allocate", {{"shape", to_value(inputs[1]->get_shape())}})); + auto output_b = m.insert_instruction( + ins, make_op("gpu::int8_gemm_pack_a"), {inputs[1], packed_b}); + inputs[1] = output_b; } if(transa) { - auto packed_a = p.insert_instruction(ins, hip_allocate{inputs[0]->get_shape()}); - auto output_a = - p.insert_instruction(ins, hip_int8_gemm_pack_b{}, {inputs[0], packed_a}); - instruction::replace_argument(ins, inputs[0], output_a); + auto packed_a = m.insert_instruction( + ins, make_op("hip::allocate", {{"shape", to_value(inputs[0]->get_shape())}})); + auto output_a = m.insert_instruction( + ins, make_op("gpu::int8_gemm_pack_b"), {inputs[0], packed_a}); + inputs[0] = output_a; + } + + if(inputs != old_inputs) + { + m.replace_instruction(ins, ins->get_operator(), inputs); } } else if(ins->name() == "gpu::quant_convolution") { - auto inputs = ins->inputs(); - auto packed_x = - p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[0]->get_shape())}); + auto val = ins->get_operator().to_value(); + if(not val.at("int8_x4_format").to()) + { + continue; + } + + auto inputs = ins->inputs(); + auto packed_x = m.insert_instruction( + ins, + make_op("hip::allocate", + {{"shape", to_value(pack_int8_shape(inputs[0]->get_shape()))}})); auto output_x = - p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[0], packed_x}); + m.insert_instruction(ins, make_op("gpu::int8_conv_pack"), {inputs[0], packed_x}); instruction::replace_argument(ins, inputs[0], output_x); - auto packed_w = - p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[1]->get_shape())}); + auto packed_w = m.insert_instruction( + ins, + make_op("hip::allocate", + {{"shape", to_value(pack_int8_shape(inputs[1]->get_shape()))}})); auto output_w = - p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[1], packed_w}); + m.insert_instruction(ins, make_op("gpu::int8_conv_pack"), {inputs[1], packed_w}); instruction::replace_argument(ins, inputs[1], output_w); } } diff --git a/src/targets/gpu/pad.cpp b/src/targets/gpu/pad.cpp index 8448f936458b6cc0eb32f12742fa0217b22e0f3e..bb4892b09e1f3de7d4233532fd6dc17d4ba0021c 100644 --- a/src/targets/gpu/pad.cpp +++ b/src/targets/gpu/pad.cpp @@ -9,6 +9,7 @@ namespace gpu { shape hip_pad::compute_shape(std::vector inputs) const { inputs.pop_back(); + check_shapes{inputs, *this}.has(1).standard(); return op.compute_shape(inputs); } diff --git a/src/targets/gpu/pooling.cpp b/src/targets/gpu/pooling.cpp index bf2e1d7fefbad8500ded6c4c6d62b4b40e876838..b643b01bd5de3ea9d2e17f1613231906c68825fe 100644 --- a/src/targets/gpu/pooling.cpp +++ b/src/targets/gpu/pooling.cpp @@ -8,14 +8,35 @@ namespace gpu { shape miopen_pooling::compute_shape(const std::vector& inputs) const { check_shapes{inputs, *this}.has(2).standard(); - return op.compute_shape({inputs.at(0)}); + std::vector pooling_input = {inputs.at(0)}; + check_shapes{pooling_input, *this}.max_ndims(5); + return op.normalize_compute_shape(pooling_input); } + +inline void reshape_if_1d(shape& input) +{ + auto dims = input.lens(); + + if(dims.size() == 3) + { + std::vector new_dims = dims; + new_dims.insert(new_dims.begin() + 2, 1); + input = shape{input.type(), new_dims}; + } +} + argument miopen_pooling::compute(context& ctx, const shape& output_shape, const std::vector& args) const { - auto x_desc = make_tensor(args[0].get_shape()); - auto y_desc = make_tensor(output_shape); + shape x_shape = args[0].get_shape(); + shape y_shape = output_shape; + + reshape_if_1d(x_shape); + reshape_if_1d(y_shape); + + auto x_desc = make_tensor(x_shape); + auto y_desc = make_tensor(y_shape); float alpha = 1; float beta = 0; @@ -35,6 +56,12 @@ argument miopen_pooling::compute(context& ctx, return args[1]; } +void miopen_pooling::finalize(context&, const shape&, const std::vector&) +{ + if(pd == nullptr) + pd = make_pooling(op); +} + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/preallocate_param.cpp b/src/targets/gpu/preallocate_param.cpp deleted file mode 100644 index 7927d126213467e7aa4a69bf8e0cd1312625d527..0000000000000000000000000000000000000000 --- a/src/targets/gpu/preallocate_param.cpp +++ /dev/null @@ -1,31 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { -namespace gpu { - -void preallocate_param::apply(program& p) const -{ - for(auto ins : iterator_for(p)) - { - if(ins->name() != "@param") - continue; - std::string id = any_cast(ins->get_operator()).parameter; - if(id != param) - continue; - argument a = allocate_gpu(ins->get_shape()); - ctx->get_current_device().preallocations[id] = a; - auto r = p.insert_instruction(ins, hip_load_memory{a.get_shape(), id}); - p.replace_instruction(ins, r); - } -} - -} // namespace gpu -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e3962787a3b149425b8acddbc5a272bc8a3504a3 --- /dev/null +++ b/src/targets/gpu/prefuse_ops.cpp @@ -0,0 +1,76 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +namespace { +struct find_layernorm +{ + auto matcher() const { return match::layernorm(); } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto x_ins = r.instructions["x"]; + + if(not x_ins->get_shape().standard()) + x_ins = m.insert_instruction(ins, make_op("contiguous"), x_ins); + + auto relements = x_ins->get_shape().lens().back(); + + if(relements > 1024 or (relements % 4 != 0 and relements > 256)) + return; + + auto a = m.insert_instruction( + ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}})); + m.replace_instruction(ins, make_op("gpu::layernorm"), x_ins, a); + } +}; + +struct find_triaddlayernorm +{ + auto matcher() const + { + auto add1 = + match::name("add")(match::none_of(match::is_constant()), + match::args(match::any().bind("z1"), match::any().bind("z2"))); + auto add2 = match::name("add")(match::either_arg(0, 1)(add1, match::any().bind("z3"))); + return match::layernorm()(match::var("x")(add2)); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto x_ins = r.instructions["z1"]; + auto y_ins = r.instructions["z2"]; + auto z_ins = r.instructions["z3"]; + + for(auto* pins : {&x_ins, &y_ins, &z_ins}) + { + if(not(*pins)->get_shape().standard()) + *pins = m.insert_instruction(ins, make_op("contiguous"), *pins); + } + + auto relements = x_ins->get_shape().lens().back(); + + if(relements > 1024 or (relements % 4 != 0 and relements > 256)) + return; + + auto a = m.insert_instruction( + ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}})); + m.replace_instruction(ins, make_op("gpu::triadd_layernorm"), x_ins, y_ins, z_ins, a); + } +}; +} // namespace + +void prefuse_ops::apply(module& m) const +{ + match::find_matches(m, find_triaddlayernorm{}, find_layernorm{}); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/quant_convolution.cpp b/src/targets/gpu/quant_convolution.cpp old mode 100644 new mode 100755 index 38f7ce7063d7b965edb775ee7495a5bbbe9b5ee9..82787e9aca28c50e77e9f21ba35dd2683b580645 --- a/src/targets/gpu/quant_convolution.cpp +++ b/src/targets/gpu/quant_convolution.cpp @@ -10,14 +10,14 @@ namespace gpu { shape miopen_quant_convolution::compute_shape(const std::vector& inputs) const { check_shapes{inputs, *this}.has(4).standard(); - return op.compute_shape({inputs.at(0), inputs.at(1)}); + return op.normalize_compute_shape({inputs.at(0), inputs.at(1)}); } argument miopen_quant_convolution::compute(context& ctx, const shape& output_shape, const std::vector& args) const { - auto x_desc = make_tensor(args[0].get_shape(), true); - auto w_desc = make_tensor(args[1].get_shape(), true); + auto x_desc = make_tensor(args[0].get_shape(), int8_x4_format); + auto w_desc = make_tensor(args[1].get_shape(), int8_x4_format); auto y_desc = make_tensor(output_shape); float alpha = 1; @@ -49,8 +49,8 @@ shape miopen_quant_convolution::compile(context& ctx, std::vector inputs) { shape workspace_shape{}; - auto x_desc = make_tensor(inputs[0], true); - auto w_desc = make_tensor(inputs[1], true); + auto x_desc = make_tensor(inputs[0], int8_x4_format); + auto w_desc = make_tensor(inputs[1], int8_x4_format); auto y_desc = make_tensor(output_shape); std::size_t workspace_size = 0; @@ -62,8 +62,15 @@ shape miopen_quant_convolution::compile(context& ctx, &workspace_size); workspace_shape = shape{shape::int8_type, {workspace_size}}; - auto arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0]))); - auto arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1]))); + auto x_shape = inputs[0]; + auto w_shape = inputs[1]; + if(int8_x4_format) + { + x_shape = pack_int8_shape(x_shape); + w_shape = pack_int8_shape(w_shape); + } + auto arg_vec4_x = to_gpu(generate_argument(x_shape)); + auto arg_vec4_w = to_gpu(generate_argument(w_shape)); auto y = allocate_gpu(output_shape); auto workspace = allocate_gpu(workspace_shape); diff --git a/src/targets/gpu/reverse.cpp b/src/targets/gpu/reverse.cpp new file mode 100755 index 0000000000000000000000000000000000000000..ad3dcb30454db6b856819d885e7c57fffa71eb2b --- /dev/null +++ b/src/targets/gpu/reverse.cpp @@ -0,0 +1,22 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +shape hip_reverse::compute_shape(std::vector inputs) const +{ + inputs.pop_back(); + return op.normalize_compute_shape(inputs); +} + +argument hip_reverse::compute(context& ctx, const shape&, const std::vector& args) const +{ + return device::reverse(ctx.get_stream().get(), args.back(), args[0], op.axes); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/rnn_variable_seq_lens.cpp b/src/targets/gpu/rnn_variable_seq_lens.cpp new file mode 100644 index 0000000000000000000000000000000000000000..76deabff8bd9e8aca35966c1b6f79f65a5f75af4 --- /dev/null +++ b/src/targets/gpu/rnn_variable_seq_lens.cpp @@ -0,0 +1,61 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +shape hip_rnn_var_sl_shift_output::compute_shape(std::vector inputs) const +{ + inputs.pop_back(); + return op.compute_shape(inputs); +} + +argument hip_rnn_var_sl_shift_output::compute(context& ctx, + const shape&, + const std::vector& args) const +{ + device::rnn_var_sl_shift_output(ctx.get_stream().get(), + args.back(), + args.at(0), + args.at(1), + (op.direction == op::rnn_direction::reverse)); + return args.back(); +} + +shape hip_rnn_var_sl_shift_sequence::compute_shape(std::vector inputs) const +{ + inputs.pop_back(); + return op.compute_shape(inputs); +} + +argument hip_rnn_var_sl_shift_sequence::compute(context& ctx, + const shape&, + const std::vector& args) const +{ + device::rnn_var_sl_shift_sequence(ctx.get_stream().get(), args.back(), args.at(0), args.at(1)); + return args.back(); +} + +shape hip_rnn_var_sl_last_output::compute_shape(std::vector inputs) const +{ + inputs.pop_back(); + return op.compute_shape(inputs); +} + +argument hip_rnn_var_sl_last_output::compute(context& ctx, + const shape&, + const std::vector& args) const +{ + device::rnn_var_sl_last_output(ctx.get_stream().get(), + args.back(), + args.at(0), + args.at(1), + (op.direction == op::rnn_direction::reverse)); + return args.back(); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/scatter.cpp b/src/targets/gpu/scatter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fc85eea36a555a7815d7d524b1477656233fe389 --- /dev/null +++ b/src/targets/gpu/scatter.cpp @@ -0,0 +1,22 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +shape hip_scatter::compute_shape(std::vector inputs) const +{ + inputs.pop_back(); + return op.normalize_compute_shape(inputs); +} + +argument hip_scatter::compute(context& ctx, const shape&, const std::vector& args) const +{ + return device::scatter(ctx.get_stream().get(), args.back(), args[0], args[1], args[2], op.axis); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/schedule_model.cpp b/src/targets/gpu/schedule_model.cpp index 7011d6342554102ec6de25e83059c6d916a225b3..7fe0a58c85a89eab0ecb0e4f535eaeac7a44abde 100644 --- a/src/targets/gpu/schedule_model.cpp +++ b/src/targets/gpu/schedule_model.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -25,7 +26,7 @@ struct record_event return {}; } - void finalize(context& ctx, const shape&, const std::vector&) + void finalize(context& ctx, const shape&, const std::vector&) const { ctx.create_events(event); } @@ -65,37 +66,45 @@ struct set_stream ctx.set_stream(stream); return {}; } - void finalize(context& ctx, const shape&, const std::vector&) { ctx.set_stream(stream); } + void finalize(context& ctx, const shape&, const std::vector&) const + { + ctx.set_stream(stream); + } }; +MIGRAPHX_REGISTER_OP(record_event) +MIGRAPHX_REGISTER_OP(wait_event) +MIGRAPHX_REGISTER_OP(set_stream) + std::size_t schedule_model::concurrency() const { return streams; } -void schedule_model::sched(program& p, instruction_ref ins, std::size_t n) const +void schedule_model::sched(module& m, instruction_ref ins, std::size_t n) const { auto last_stream = std::find_if(std::make_reverse_iterator(ins), - std::make_reverse_iterator(p.begin()), + std::make_reverse_iterator(m.begin()), [&](auto&& i) { return i.name() == "gpu::set_stream"; }); - if(last_stream != std::make_reverse_iterator(p.begin())) + if(last_stream != std::make_reverse_iterator(m.begin())) { auto&& op = any_cast(last_stream->get_operator()); // If the same stream was set earlier then skip if(op.stream == n) return; } - p.insert_instruction(ins, set_stream{n}); + m.insert_instruction(ins, set_stream{n}); } -void schedule_model::wait(program& p, instruction_ref ins, std::size_t wait_id) const +void schedule_model::wait(module& m, instruction_ref ins, std::size_t wait_id) const { - p.insert_instruction(ins, wait_event{wait_id}); + m.insert_instruction(ins, wait_event{wait_id}); } -void schedule_model::record(program& p, instruction_ref ins, std::size_t wait_id) const +void schedule_model::record(module& m, instruction_ref ins, std::size_t wait_id) const { - p.insert_instruction(std::next(ins), record_event{wait_id}); + m.insert_instruction(std::next(ins), record_event{wait_id}); } static std::unordered_map create_weight_map() { return {{"hip::load_literal", 0}, + {"hip::hip_allocate_memory", 0}, {"hip::hip_load_memory", 0}, {"hip::allocate", 0}, {"gpu::convolution", 8}, @@ -106,7 +115,7 @@ static std::unordered_map create_weight_map() static const std::unordered_map& weight_map() { - static std::unordered_map m = create_weight_map(); + static const std::unordered_map m = create_weight_map(); return m; } diff --git a/src/targets/gpu/softmax.cpp b/src/targets/gpu/softmax.cpp index b6be67121301f64834fe3b854c446d975d830a1d..ac1f28fd003e26b2a85295b9285c40dc91af4254 100644 --- a/src/targets/gpu/softmax.cpp +++ b/src/targets/gpu/softmax.cpp @@ -1,6 +1,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -9,12 +10,14 @@ namespace gpu { shape hip_softmax::compute_shape(const std::vector& inputs) const { check_shapes{inputs, *this}.has(2).standard(); - return op.compute_shape({inputs.at(0)}); + return op.normalize_compute_shape({inputs.at(0)}); } argument hip_softmax::compute(context& ctx, const shape&, const std::vector& args) const { - device::softmax(ctx.get_stream().get(), args.back(), args.front(), op.axis); + auto n_dim = args.front().get_shape().lens().size(); + auto tuned_axis = tune_axis(n_dim, op.axis, op.name()); + device::softmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); return args.back(); } diff --git a/src/targets/gpu/sync_device.cpp b/src/targets/gpu/sync_device.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c1dd4e6affec1862ed20ac8a22ce5aa50e356cc7 --- /dev/null +++ b/src/targets/gpu/sync_device.cpp @@ -0,0 +1,32 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +void sync_device::apply(module& m) const +{ + auto last = std::prev(m.end()); + if(last->name() == "@return") + { + auto inputs = last->inputs(); + if(std::any_of(inputs.begin(), inputs.end(), [](auto i) { + return (i->name() == "hip::copy_from_gpu"); + })) + { + auto sync_in = m.insert_instruction(last, make_op("hip::sync_stream"), inputs); + if(not inputs.empty()) + { + m.replace_instruction(inputs.front(), sync_in); + } + } + } +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 3a47316571a3ca304868cc7c509ec36da2900b80..1d3a6b549ad16bdc83c0cd402a6f5af62841cfe3 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -1,90 +1,144 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include #include +#include #include -#include -#include -#include -#include +#include #include -#include -#include -#include #include +#include +#include #include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION) + +struct id_pass +{ + std::string name() const { return "id"; } + void apple(const module&) const {} +}; + +pass enable_pass(bool enabled, pass p) +{ + if(enabled) + return p; + return id_pass{}; +} std::vector target::get_passes(migraphx::context& gctx, const compile_options& options) const { auto& ctx = any_cast(gctx); + std::set unsupported_types(shape::types().begin(), shape::types().end()); + unsupported_types.erase(shape::type_t::float_type); + unsupported_types.erase(shape::type_t::half_type); + unsupported_types.erase(shape::type_t::bool_type); + unsupported_types.erase(shape::type_t::int8_type); + unsupported_types.erase(shape::type_t::uint8_type); + unsupported_types.erase(shape::type_t::tuple_type); // clang-format off return { + normalize_ops{}, dead_code_elimination{}, - simplify_reshapes{}, + simplify_qdq{}, + rewrite_quantization{}, dead_code_elimination{}, + eliminate_data_type{unsupported_types, shape::type_t::float_type}, + simplify_reshapes{}, eliminate_identity{}, eliminate_pad{}, dead_code_elimination{}, + insert_pad{}, + dead_code_elimination{}, rewrite_batchnorm{}, dead_code_elimination{}, rewrite_rnn{}, + dead_code_elimination{}, + inline_module{}, rewrite_pooling{}, dead_code_elimination{}, eliminate_common_subexpression{}, dead_code_elimination{}, simplify_algebra{}, + simplify_reshapes{}, + simplify_algebra{}, + prefuse_ops{}, dead_code_elimination{}, auto_contiguous{}, simplify_reshapes{}, - dead_code_elimination{}, propagate_constant{}, dead_code_elimination{}, + enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}), + dead_code_elimination{}, + mlir_conv{&ctx}, lowering{&ctx, options.offload_copy}, - eliminate_contiguous{}, + eliminate_contiguous{"gpu::contiguous"}, dead_code_elimination{}, - eliminate_concat{concat_gpu_optimization{}}, + replace_allocate{gpu_allocation_model{}, options.offload_copy}, dead_code_elimination{}, - adjust_allocation{}, + eliminate_concat{concat_gpu_optimization{}}, dead_code_elimination{}, pack_int8_args{}, dead_code_elimination{}, - fuse_ops{&ctx}, + adjust_allocation{gpu_allocation_model{}}, + dead_code_elimination{}, + fuse_ops{&ctx, options.fast_math}, + dead_code_elimination{}, + compile_ops{&ctx}, dead_code_elimination{}, write_literals{&ctx}, schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, not enabled(MIGRAPHX_DISABLE_SCHEDULE_PASS{})}, memory_coloring{"hip::allocate"}, - preallocate_param{"scratch", &ctx}, + sync_device{}, + preallocate_param{"scratch", gpu_allocation_model{}}, dead_code_elimination{}, eliminate_workspace{}, eliminate_allocation{"hip::allocate"}, check_context{}, + normalize_ops{}, dead_code_elimination{}, eliminate_identity{} }; // clang-format on } -std::string target::name() const { return "miopen"; } +std::string target::name() const { return "gpu"; } migraphx::context target::get_context() const { return context{}; } @@ -94,6 +148,8 @@ argument target::copy_from(const argument& arg) const { return gpu::from_gpu(arg argument target::allocate(const shape& s) const { return gpu::allocate_gpu(s); } +MIGRAPHX_REGISTER_TARGET(target); + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/topk.cpp b/src/targets/gpu/topk.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4d9eaaac34f3bef0bd6b00a57cee81a9d1c1d1a3 --- /dev/null +++ b/src/targets/gpu/topk.cpp @@ -0,0 +1,33 @@ +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +shape hip_topk::compute_shape(std::vector inputs) const +{ + return op.normalize_compute_shape({inputs.front()}); +} + +argument hip_topk::compute(context& ctx, const shape&, const std::vector& args) const +{ + auto outputs = args.back().get_sub_objects(); + return op.largest ? device::topk_largest(ctx.get_stream().get(), + outputs.front(), + outputs.back(), + args[0], + op.k, + op.axis) + : device::topk_smallest(ctx.get_stream().get(), + outputs.front(), + outputs.back(), + args[0], + op.k, + op.axis); +} + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/write_literals.cpp b/src/targets/gpu/write_literals.cpp index d24989888cfe8f9576ecb16f37eeb1b99ed691ba..39a5475100ed54ce8450b3a40073c644fe47bad0 100644 --- a/src/targets/gpu/write_literals.cpp +++ b/src/targets/gpu/write_literals.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include namespace migraphx { @@ -10,27 +11,25 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_COPY_LITERALS) -void write_literals::apply(program& p) const +void write_literals::apply(module& m) const { assert(ctx != nullptr); std::size_t n = 0; - for(auto ins : iterator_for(p)) + for(auto ins : iterator_for(m)) { if(ins->name() == "@literal") { if(enabled(MIGRAPHX_COPY_LITERALS{})) { literal l = ins->get_literal(); - auto pre = p.add_literal(l); - auto alloc = p.insert_instruction(std::next(pre), hip_allocate{l.get_shape()}); - p.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc); + auto pre = m.add_literal(l); + auto alloc = m.insert_instruction(std::next(pre), hip_allocate{l.get_shape()}); + m.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc); } else { - std::string id = "@literal:" + std::to_string(n); - argument a = to_gpu(ins->get_literal().get_argument()); - ctx->get_current_device().preallocations[id] = a; - p.replace_instruction(ins, hip_load_memory{a.get_shape(), id}); + std::string id = m.name() + ":@literal:" + std::to_string(n); + m.replace_instruction(ins, hip_copy_literal{ins->get_literal(), id}); n++; } } diff --git a/src/targets/ref/CMakeLists.txt b/src/targets/ref/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b59c3a983baa467bc505db77d2724eb9d89066a --- /dev/null +++ b/src/targets/ref/CMakeLists.txt @@ -0,0 +1,23 @@ + +add_library(migraphx_ref + target.cpp + lowering.cpp + gemm.cpp +) +set_target_properties(migraphx_ref PROPERTIES EXPORT_NAME ref) +rocm_set_soversion(migraphx_ref ${MIGRAPHX_SO_VERSION}) + +find_path(BLAZE_INCLUDE blaze/Blaze.h) +find_package(Threads) + +rocm_clang_tidy_check(migraphx_ref) +target_link_libraries(migraphx_ref migraphx Threads::Threads) +target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE}) +target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS) + +rocm_install_targets( + TARGETS migraphx_ref + INCLUDE + ${CMAKE_CURRENT_SOURCE_DIR}/include +) + diff --git a/src/targets/ref/gemm.cpp b/src/targets/ref/gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0716292425b4fbc7d34f59e39f2a7bbea295fb7b --- /dev/null +++ b/src/targets/ref/gemm.cpp @@ -0,0 +1,134 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace ref { + +template +using matrix = blaze::CustomMatrix; // NOLINT + +template +static auto make_mat(tensor_view x) +{ + const auto& s = x.get_shape(); + // assert(s.lens().size() == 2); + std::size_t n_dims = s.lens().size(); + std::size_t dim_0 = n_dims - 2; + std::size_t dim_1 = n_dims - 1; + if(s.transposed()) + return matrix{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]}; + return matrix{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]}; +} + +template +static void visit_mat(tensor_view x, F f) +{ + auto mat = make_mat(x); + if(x.get_shape().transposed()) + f(blaze::trans(mat)); + else + f(mat); +} + +template +struct is_fast_gemm_type : std::false_type +{ +}; + +template <> +struct is_fast_gemm_type : std::true_type +{ +}; + +template +void migemm_impl( + tensor_view cmat, tensor_view amat, tensor_view bmat, F alpha, F beta, std::true_type) +{ + visit_mat(amat, [&](const auto& a) { + visit_mat(bmat, [&](const auto& b) { + auto c = make_mat(cmat); + c = beta * c; + // This is a simple optimization to avoid + // compute A * B if alpha is 0.0 + if(alpha != 0.0) + { + c = c + alpha * a * b; + } + }); + }); +} + +template +void migemm_impl( + tensor_view cmat, tensor_view amat, tensor_view bmat, F alpha, F beta, std::false_type) +{ + std::size_t n_dims = cmat.get_shape().lens().size(); + std::size_t dim_0 = n_dims - 2; + std::size_t dim_1 = n_dims - 1; + auto k = amat.get_shape().lens()[dim_1]; + + assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]); + assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]); + assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]); + auto cs = cmat.get_shape(); + + par_for(cs.elements(), [&](auto i) { + auto c_idx = cs.multi(i); + auto a_idx = c_idx; + auto b_idx = c_idx; + double s = 0.0; + dfor(k)([&](auto kk) { + a_idx[dim_1] = b_idx[dim_0] = kk; + s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end()); + }); + cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta; + }); +} + +template +void migemm_impl(tensor_view cmat, tensor_view amat, tensor_view bmat, F alpha, F beta) +{ + auto lens = amat.get_shape().lens(); + bool batch_mul = + std::accumulate( + lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies()) == 1; + if(batch_mul) + { + migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type{}); + } + else + { + migemm_impl(cmat, amat, bmat, alpha, beta, std::false_type{}); + } +} + +template +void migemm_tpl( + const argument& c_arg, const argument& a_arg, const argument& b_arg, F alpha, F beta) +{ + visit_all(c_arg, a_arg, b_arg)( + [&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); }); +} + +void migemm( + const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta) +{ + migemm_tpl(c_arg, a_arg, b_arg, alpha, beta); +} + +void migemm(const argument& c_arg, + const argument& a_arg, + const argument& b_arg, + int32_t alpha, + int32_t beta) +{ + migemm_tpl(c_arg, a_arg, b_arg, alpha, beta); +} + +} // namespace ref +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/ref/include/migraphx/ref/context.hpp b/src/targets/ref/include/migraphx/ref/context.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f279c3b9302084ef88dd41ee56065f5aab8d3b56 --- /dev/null +++ b/src/targets/ref/include/migraphx/ref/context.hpp @@ -0,0 +1,19 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP +#define MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace ref { + +struct context +{ + void finish() const {} +}; + +} // namespace ref +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/cpu/include/migraphx/cpu/gemm.hpp b/src/targets/ref/include/migraphx/ref/gemm.hpp similarity index 94% rename from src/targets/cpu/include/migraphx/cpu/gemm.hpp rename to src/targets/ref/include/migraphx/ref/gemm.hpp index 156aa52982a51dec798f36843335d51633c64994..c045f21722aa16bddc27250b2c2e438a6842375d 100644 --- a/src/targets/cpu/include/migraphx/cpu/gemm.hpp +++ b/src/targets/ref/include/migraphx/ref/gemm.hpp @@ -6,7 +6,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -namespace cpu { +namespace ref { void migemm( const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta); @@ -16,7 +16,7 @@ void migemm(const argument& c_arg, int32_t alpha, int32_t beta); -} // namespace cpu +} // namespace ref } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/ref/include/migraphx/ref/lowering.hpp b/src/targets/ref/include/migraphx/ref/lowering.hpp new file mode 100644 index 0000000000000000000000000000000000000000..26c5bc844f81aad6c41a9048fde96e7d524864be --- /dev/null +++ b/src/targets/ref/include/migraphx/ref/lowering.hpp @@ -0,0 +1,21 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP +#define MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace ref { + +struct lowering +{ + std::string name() const { return "ref::lowering"; } + void apply(module& m) const; +}; + +} // namespace ref +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/ref/include/migraphx/ref/target.hpp b/src/targets/ref/include/migraphx/ref/target.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ea99b56994e11c6706ea8572def70e4a148d2b30 --- /dev/null +++ b/src/targets/ref/include/migraphx/ref/target.hpp @@ -0,0 +1,32 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_CPU_TARGET_HPP +#define MIGRAPHX_GUARD_MIGRAPHLIB_CPU_TARGET_HPP + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +struct pass; +namespace ref { + +struct target +{ + std::string name() const; + std::vector get_passes(migraphx::context& ctx, const compile_options&) const; + migraphx::context get_context() const { return context{}; } + + argument copy_to(const argument& arg) const { return arg; } + argument copy_from(const argument& arg) const { return arg; } + argument allocate(const shape& s) const; +}; + +MIGRAPHX_REGISTER_TARGET(target); + +} // namespace ref +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/ref/lowering.cpp b/src/targets/ref/lowering.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5aa367ceb7bb908dba02155a47b31c762ccf9db1 --- /dev/null +++ b/src/targets/ref/lowering.cpp @@ -0,0 +1,716 @@ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace ref { + +template +T zero(const T&) +{ + return T(0); +} + +template +typename std::conditional_t{}, std::make_signed, std::enable_if>:: + type + make_signed(T x) +{ + return x; +} + +// +// ref implemenataion of batch norm for inference +// +// inputs are: +// args[0] -> input data buffer +// args[1] -> mini batch mean +// args[2] -> mini batch variance +// args[3] -> gamma +// args[4] -> bias +// +// The equation to compute batch norm for inference is: +// +// output[i] = bias + gamma * (input[i] + mean) / sqrt(variance + epsilon) +// +// the input data format should be nchw +// +struct ref_batch_norm_inference +{ + op::batch_norm_inference op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "ref::batch_norm_inference"; } + + shape compute_shape(const std::vector& inputs) const { return op.compute_shape(inputs); } + + argument compute(context&, const shape& output_shape, std::vector args) const + { + argument output{output_shape}; + + double epsilon = op.epsilon; + auto input = args[0]; + auto arg_gamma = args[1]; + auto arg_bias = args[2]; + auto mini_batch_mean = args[3]; + auto mini_batch_variance = args[4]; + + if(op.bn_mode == op::batch_norm_inference::spatial) + { + visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)( + [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) { + par_for(output_shape.elements(), [&](auto i) { + auto idx = output_shape.multi(i); + auto c = idx[1]; + assert((variance[c] + epsilon) > 0); + result[i] = + gamma[c] * (buffer[i] - mean[c]) / std::sqrt(variance[c] + epsilon) + + bias[c]; + }); + }); + } + + if(op.bn_mode == op::batch_norm_inference::per_activation) + { + visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)( + [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) { + par_for(output_shape.elements(), [&](auto i) { + auto idx = output_shape.multi(i); + idx[0] = 0; + auto index = output_shape.index(idx); + + assert((variance[index] + epsilon) > 0); + result[i] = gamma[index] * (buffer[i] - mean[index]) / + std::sqrt(variance[index] + epsilon) + + bias[index]; + }); + }); + } + + return output; + } +}; +MIGRAPHX_REGISTER_OP(ref_batch_norm_inference) + +struct ref_lrn +{ + op::lrn op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "ref::lrn"; } + shape compute_shape(const std::vector& inputs) const { return op.compute_shape(inputs); } + argument compute(context&, shape output_shape, std::vector args) const + { + argument result{output_shape}; + visit_all(result, args[0])([&](auto output, auto input) { + int n_batch = output_shape.lens()[0]; + int channels = output_shape.lens()[1]; + int height = output_shape.lens()[2]; + int width = output_shape.lens()[3]; + float alphaoverarea = op.alpha / float(op.size); + int radius_lower = (op.size - 1) / 2; + int radius_upper = op.size / 2 + 1; + + par_dfor(n_batch, height, width)([&](int b, int h, int w) { + float scale = 0; + dfor(channels)([&](int c) { + auto start = (c - radius_lower) < 0 ? 0 : (c - radius_lower); + auto end = (c + radius_upper) > channels ? channels : (c + radius_upper); + for(auto k = start; k < end; ++k) + { + scale += std::pow(input(b, k, h, w), 2); + } + scale *= alphaoverarea; + scale += op.bias; + scale = std::pow(scale, -op.beta); + output(b, c, h, w) = input(b, c, h, w) * scale; + }); + }); + }); + return result; + } +}; +MIGRAPHX_REGISTER_OP(ref_lrn) + +template +void visit_quantize_impl(V&& v, T&& x, Ts&&... xs) +{ + x.visit([&](auto y) { visit_all(xs...)([&](auto... ys) { v(y, ys...); }); }); +} + +template +auto visit_quantize(T&& x, Ts&&... xs) +{ + return [&](auto v) { + // Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100 + visit_quantize_impl(v, x, xs...); + }; +} + +template +struct ref_convolution : auto_register_op> +{ + ref_convolution() = default; + + ref_convolution(Op pop) : op(std::move(pop)) {} + + Op op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "ref::" + op.name(); } + shape compute_shape(const std::vector& inputs) const + { + return op.normalize_compute_shape(inputs); + } + argument compute(context&, shape output_shape, std::vector args) const + { + argument result{output_shape}; + visit_quantize(result, args[0], args[1])([&](auto output, auto input, auto weights) { + auto in_lens = input.get_shape().lens(); + + auto wei_lens = weights.get_shape().lens(); + auto wei_n = wei_lens[0]; + auto wei_c = wei_lens[1]; + std::vector win_size(wei_lens.begin() + 1, wei_lens.end()); + + par_for(output_shape.elements(), [&](auto i) { + auto idx_o = output_shape.multi(i); + auto w = idx_o[1]; + auto n_dim = idx_o.size(); + + std::vector win_start; + for(std::size_t dim = 2; dim < n_dim; ++dim) + { + auto d_2 = dim - 2; + win_start.push_back(std::ptrdiff_t(idx_o[dim] * op.stride[d_2]) - + std::ptrdiff_t(op.padding[d_2])); + } + const auto group_id = w / (wei_n / op.group); + + shape win_shape{output_shape.type(), win_size}; + + double acc = 0.0; + shape_for_each(win_shape, [&](auto idx_win) { + auto k = idx_win[0]; + const auto in_ch = group_id * wei_c + k; + std::vector idx(idx_o.begin(), idx_o.end()); + idx[1] = in_ch; + std::transform(idx_win.begin() + 1, + idx_win.end(), + win_start.begin(), + idx.begin() + 2, + [](std::ptrdiff_t ii, std::ptrdiff_t jj) { return ii + jj; }); + std::vector idx_wei(idx_o.size()); + idx_wei[0] = w; + std::copy(idx_win.begin(), idx_win.end(), idx_wei.begin() + 1); + if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and + std::equal(idx.begin(), + idx.end(), + in_lens.begin(), + in_lens.end(), + std::less{})) + { + acc += + input(idx.begin(), idx.end()) * weights(idx_wei.begin(), idx_wei.end()); + } + }); + + output[i] = acc; + }); + }); + return result; + } +}; + +struct ref_im2col +{ + op::im2col op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + static std::string name() { return "ref::im2col"; } + shape compute_shape(const std::vector& inputs) const + { + return op.normalize_compute_shape(inputs); + } + + argument compute(context&, const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + auto input_shape = args[0].get_shape(); + auto weights_shape = args[1].get_shape(); + visit_all(result, args[0])([&](auto col, auto input) { + const std::size_t& height = input_shape.lens()[2]; + const std::size_t& width = input_shape.lens()[3]; + const std::size_t& channels = weights_shape.lens()[1]; + const std::size_t& kernel_h = weights_shape.lens()[2]; + const std::size_t& kernel_w = weights_shape.lens()[3]; + const std::size_t& pad_h = op.padding[0]; + const std::size_t& pad_w = op.padding[1]; + const std::size_t& stride_h = op.stride[0]; + const std::size_t& stride_w = op.stride[1]; + + long kdiv2_h = long(kernel_h) / 2; + long kdiv2_w = long(kernel_w) / 2; + // calculate output sizes + const std::size_t col_height = (height - kernel_h + 2 * pad_h) / stride_h + 1; + const std::size_t col_width = (width - kernel_w + 2 * pad_w) / stride_w + 1; + // account for padding for the starting position of the input pixels + long iinput = kdiv2_h - long(pad_h); + // loop over output pixels (ioutput, joutput) + for(std::size_t ioutput = 0; ioutput < col_height; ioutput++, iinput += stride_h) + { + long jinput = kdiv2_w - long(pad_w); + for(std::size_t joutput = 0; joutput < col_width; joutput++, jinput += stride_w) + { + // compute linear index for output + std::size_t ldx = ioutput * col_width + joutput; + std::size_t p = 0; + dfor(channels, + kernel_h, + kernel_w)([&](std::size_t c, std::size_t koffset, std::size_t loffset) { + auto idx = iinput + long(koffset) - kdiv2_h; + auto jdx = jinput + long(loffset) - kdiv2_w; + col(ldx, p) = ((idx >= 0) && (idx < height) && (jdx >= 0) && (jdx < width)) + ? input(0, c, idx, jdx) + : 0; + p++; + }); + } + } + }); + return result; + } +}; +MIGRAPHX_REGISTER_OP(ref_im2col) + +struct ref_op +{ + operation op = op::identity{}; + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + std::string name() const { return "ref::op"; } + shape compute_shape(const std::vector& inputs) const { return op.compute_shape(inputs); } + argument compute(context&, const shape& output_shape, const std::vector& args) const + { + return op.compute(output_shape, args); + } + value to_value() const + { + value v; + v["name"] = op.name(); + v["operator"] = op.to_value(); + return v; + } + void from_value(const value& v) + { + op = make_op(v.at("name").to(), v.at("operator")); + } + friend std::ostream& operator<<(std::ostream& os, const ref_op& x) + { + os << "ref::" << x.op; + return os; + } +}; +MIGRAPHX_REGISTER_OP(ref_op) + +struct ref_pad +{ + op::pad op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "ref::pad"; } + shape compute_shape(const std::vector& inputs) const { return op.compute_shape(inputs); } + argument compute(context&, const shape& output_shape, std::vector args) const + { + assert(output_shape.standard()); + argument result{output_shape}; + result.visit([&](auto output) { + using type = typename decltype(output)::value_type; + std::fill(output.begin(), output.end(), pad_clamp(op.value)); + }); + + visit_all(result, args[0])([&](auto output, auto input) { + shape_for_each(input.get_shape(), [&](const auto& idx) { + std::vector new_idx(idx.size()); + std::transform( + idx.begin(), idx.end(), op.pads.begin(), new_idx.begin(), [](auto i, auto j) { + return i + j; + }); + output(new_idx.begin(), new_idx.end()) = input(idx.begin(), idx.end()); + }); + }); + + return result; + } +}; +MIGRAPHX_REGISTER_OP(ref_pad) + +struct ref_gemm +{ + op::dot op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + std::string name() const { return "ref::dot"; } + shape compute_shape(const std::vector& inputs) const { return op.compute_shape(inputs); } + + argument compute(context&, const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + migemm(result, args[0], args[1], 1.0f, 0.0f); + + return result; + } +}; +MIGRAPHX_REGISTER_OP(ref_gemm) + +struct ref_quant_gemm +{ + op::quant_dot op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "ref::quant_dot"; } + shape compute_shape(const std::vector& inputs) const { return op.compute_shape(inputs); } + + argument compute(context&, const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + // first, convert the args[0] and args[1] from int8_t to int32_t + argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}}; + argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}}; + arg_0.visit([&](auto output) { + args.at(0).visit( + [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); }); + }); + + arg_1.visit([&](auto output) { + args.at(1).visit( + [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); }); + }); + + migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0}); + + return result; + } +}; +MIGRAPHX_REGISTER_OP(ref_gemm) + +struct leaky_relu_op +{ + op::leaky_relu op; + std::string name() const { return "ref::leaky_relu"; } + auto fcn() const + { + auto a = op.alpha; + return [a](auto x) { return x > 0 ? x : x * a; }; + } +}; + +struct elu_op +{ + op::elu op; + std::string name() const { return "ref::elu"; } + auto fcn() const + { + auto a = op.alpha; + return [a](auto x) { return x > 0 ? x : a * std::expm1(x); }; + } +}; + +template +struct ref_unary : auto_register_op> +{ + ref_unary() = default; + + template + ref_unary(T pop) : op(Op{std::move(pop)}) + { + } + + Op op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op.op, f); + } + std::string name() const { return op.name(); } + shape compute_shape(const std::vector& inputs) const + { + check_shapes{inputs, *this}.has(1); + const auto& s = inputs.at(0); + return {s.type(), s.lens()}; + } + + argument compute(context&, const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + visit_all(result, args[0])([&](auto output, auto input) { + assert(input.get_shape().standard()); + std::transform(input.begin(), input.end(), output.begin(), op.fcn()); + }); + + return result; + } +}; + +template +struct ref_softmax : auto_register_op> +{ + ref_softmax() = default; + + ref_softmax(Op pop) : op(std::move(pop)) {} + + Op op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "ref::" + op.name(); } + shape compute_shape(const std::vector& inputs) const + { + return op.normalize_compute_shape(inputs); + } + argument compute(context&, const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + auto batch_lens = output_shape.lens(); + int64_t tuned_axis = tune_axis(args[0].get_shape().lens().size(), op.axis, op.name()); + std::size_t n_dims = batch_lens[tuned_axis]; + batch_lens[tuned_axis] = 1; + shape batch_shape{shape::int32_type, batch_lens}; + + visit_all(result, args[0])([&](auto output, auto input) { + using value_type = accumulator_type; + std::vector batch_max(batch_shape.elements(), + std::numeric_limits::lowest()); + std::vector batch_sum(batch_shape.elements(), value_type(0)); + par_for(batch_shape.elements(), [&](auto i) { + auto idx = batch_shape.multi(i); + for(std::size_t j = 0; j < n_dims; ++j) + { + idx[tuned_axis] = j; + batch_max[i] = + std::max(batch_max[i], input(idx.begin(), idx.end())); + } + + for(std::size_t j = 0; j < n_dims; ++j) + { + idx[tuned_axis] = j; + std::size_t index = output_shape.index(idx); + output[index] = std::exp(input[index] - batch_max[i]); + } + + for(std::size_t j = 0; j < n_dims; ++j) + { + idx[tuned_axis] = j; + batch_sum[i] += output(idx.begin(), idx.end()); + } + + for(std::size_t j = 0; j < n_dims; ++j) + { + idx[tuned_axis] = j; + output(idx.begin(), idx.end()) = + op.output()(output(idx.begin(), idx.end()), batch_sum[i]); + } + }); + }); + + return result; + } +}; + +struct ref_rnn_var_sl_last_output +{ + op::rnn_var_sl_last_output op; + + template + static auto reflect(Self& self, F f) + { + return migraphx::reflect(self.op, f); + } + + std::string name() const { return "ref::rnn_var_sl_last_output"; } + + shape compute_shape(std::vector inputs) const + { + return op.compute_shape(std::move(inputs)); + } + + argument compute(const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + auto out_comp_lens = args[0].get_shape().lens(); + out_comp_lens[0] = 1; + shape out_comp_s{output_shape.type(), out_comp_lens}; + + visit_all(result, args[0])([&](auto output, auto input) { + args[1].visit([&](auto seq_lens) { + par_for(output_shape.elements(), [&](auto i) { + auto idx = out_comp_s.multi(i); + auto b = idx[2]; + if(op.direction == op::rnn_direction::reverse or idx[1] == 1) + { + idx[0] = 0; + } + else + { + idx[0] = seq_lens[b] - 1; + } + output[i] = input(idx.begin(), idx.end()); + }); + }); + }); + + return result; + } +}; +MIGRAPHX_REGISTER_OP(ref_rnn_var_sl_last_output) + +struct ref_apply +{ + module* mod; + std::unordered_map> apply_map{}; + + template + auto simple_op() + { + return [this](instruction_ref ins) { apply_simple_op(ins); }; + } + + template + auto extend_op() + { + return [this](instruction_ref ins) { apply_extend_op(ins); }; + } + + void init() + { + apply_map["batch_norm_inference"] = + extend_op(); + apply_map["convolution"] = extend_op, op::convolution>(); + apply_map["dot"] = extend_op(); + apply_map["quant_dot"] = extend_op(); + apply_map["quant_convolution"] = + extend_op, op::quant_convolution>(); + apply_map["elu"] = extend_op, op::elu>(); + apply_map["im2col"] = extend_op(); + apply_map["leaky_relu"] = extend_op, op::leaky_relu>(); + apply_map["logsoftmax"] = extend_op, op::logsoftmax>(); + apply_map["lrn"] = extend_op(); + apply_map["pad"] = extend_op(); + apply_map["softmax"] = extend_op, op::softmax>(); + apply_map["rnn_var_sl_last_output"] = + extend_op(); + } + + void apply() + { + init(); + for(auto it : iterator_for(*mod)) + { + if(apply_map.count(it->name()) > 0) + { + apply_map.at(it->name())(it); + } + else if(is_context_free(it->get_operator())) + { + apply_ref_op(it); + } + } + } + + void apply_ref_op(instruction_ref ins) const + { + mod->replace_instruction(ins, ref_op{ins->get_operator()}, ins->inputs()); + } + + template + void apply_simple_op(instruction_ref ins) + { + mod->replace_instruction(ins, T{}, ins->inputs()); + } + + template + void apply_extend_op(instruction_ref ins) + { + auto&& op = any_cast(ins->get_operator()); + mod->replace_instruction(ins, T{op}, ins->inputs()); + } +}; + +void lowering::apply(module& m) const { ref_apply{&m}.apply(); } + +} // namespace ref +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/ref/target.cpp b/src/targets/ref/target.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9ecdfb429a96e2542890b89557000e1e40497060 --- /dev/null +++ b/src/targets/ref/target.cpp @@ -0,0 +1,41 @@ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace ref { + +std::string target::name() const { return "ref"; } + +std::vector target::get_passes(migraphx::context&, const compile_options&) const +{ + return {normalize_ops{}, + eliminate_pad{}, + dead_code_elimination{}, + insert_pad{}, + dead_code_elimination{}, + rewrite_rnn{}, + dead_code_elimination{}, + auto_contiguous{}, + dead_code_elimination{}, + lowering{}, + dead_code_elimination{}}; +} + +argument target::allocate(const shape& s) const { return fill_argument(s, 0); } + +MIGRAPHX_REGISTER_TARGET(target); + +} // namespace ref +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/CMakeLists.txt b/src/tf/CMakeLists.txt index 7be5a9fa9189497a025f76b44166b1f0f891c06d..6d8f577dc78e67ef1bf9e1abf42eacbd29232741 100644 --- a/src/tf/CMakeLists.txt +++ b/src/tf/CMakeLists.txt @@ -19,11 +19,13 @@ target_compile_options(tf-proto PRIVATE -w) target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY}) set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On) -add_library(migraphx_tf tf.cpp) +file(GLOB TF_SRCS ${CONFIGURE_DEPENDS} *.cpp) +add_library(migraphx_tf ${TF_SRCS}) +target_include_directories(migraphx_tf PRIVATE include) set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf) rocm_set_soversion(migraphx_tf ${MIGRAPHX_SO_VERSION}) rocm_clang_tidy_check(migraphx_tf) -target_link_libraries(migraphx_tf PRIVATE tf-proto) +target_link_libraries(migraphx_tf PRIVATE tf-proto "-Wl,--exclude-libs,ALL") target_link_libraries(migraphx_tf PUBLIC migraphx) rocm_install_targets( diff --git a/src/tf/include/migraphx/tf/op_parser.hpp b/src/tf/include/migraphx/tf/op_parser.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9202e9016be7c513ed661188939923a2dded40ca --- /dev/null +++ b/src/tf/include/migraphx/tf/op_parser.hpp @@ -0,0 +1,79 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_TF_REGISTER_OP_PARSER_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_TF_REGISTER_OP_PARSER_HPP + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct op_desc +{ + std::string tf_name = ""; + std::string op_name = ""; +}; + +void register_op_parser(const std::string& name, tf_parser::op_func f); +tf_parser::op_func get_op_parser(const std::string& name); +std::vector get_op_parsers(); + +inline std::vector implicit_multi_op(std::vector inss) +{ + return inss; +} + +inline std::vector implicit_multi_op(instruction_ref ins) { return {ins}; } + +template +void register_op_parser() +{ + T parser; + for(auto&& opd : parser.operators()) + register_op_parser(opd.tf_name, + [opd, parser](auto&&... xs) { return parser.base_parse(opd, xs...); }); +} + +struct register_op_parser_action +{ + template + static void apply() + { + register_op_parser(); + } +}; + +template +struct op_parser : auto_register +{ + bool transpose() const { return false; } + std::vector base_parse(const op_desc& opd, + const tf_parser& parser, + tf_parser::node_info info, + const std::vector& args) const + { + std::vector result; + auto& self = static_cast(*this); + if(self.transpose()) + { + result = implicit_multi_op(self.parse(opd, parser, info, parser.to_nchw(args))); + std::transform(result.begin(), result.end(), result.begin(), [&](auto ins) { + return parser.to_nhwc(ins); + }); + } + else + { + result = implicit_multi_op(self.parse(opd, parser, info, args)); + } + return result; + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/tf/include/migraphx/tf/tf_parser.hpp b/src/tf/include/migraphx/tf/tf_parser.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d1324f2f7547e755803bc040e0021e35f3621467 --- /dev/null +++ b/src/tf/include/migraphx/tf/tf_parser.hpp @@ -0,0 +1,118 @@ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_TF_PARSER_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_TF_PARSER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +// namespace tf = tf_for_migraphx; + +struct tf_parser +{ + std::string filename; + std::string path = "."; + using attribute_map = std::unordered_map; + struct node_info + { + attribute_map attributes{}; + std::string name = ""; + module* mm = nullptr; + + instruction_ref make_contiguous(instruction_ref ins) const; + + instruction_ref add_broadcastable_binary_op(const std::string& op_name, + instruction_ref arg0, + instruction_ref arg1) const; + + instruction_ref add_common_op(const std::string& op_name, + std::vector inputs) const; + + template + instruction_ref add_common_op(const std::string& op_name, Ts... xs) const + { + return add_common_op(op_name, {xs...}); + } + + instruction_ref add_instruction(const operation& op, + const std::vector& args) const; + + template + instruction_ref add_instruction(const operation& op, Ts... xs) const + { + return add_instruction(op, {xs...}); + } + instruction_ref add_literal(literal l) const; + template + instruction_ref add_literal(Ts&&... xs) const + { + return add_literal(literal{std::forward(xs)...}); + } + }; + + using node_map = std::map; + using op_func = std::function( + const tf_parser&, const node_info&, std::vector)>; + node_map nodes; + std::vector input_nodes; + std::vector output_node_names; + std::unordered_map instructions; + program prog = program(); + module* mm = prog.get_main_module(); + bool is_nhwc = true; + unsigned int batch_size = 1; + std::size_t default_dim_value = 1; + std::unordered_map> map_input_dims; + + std::unordered_map ops; + + tf_parser(); + operation load(const std::string& name, const node_info& info) const; + bool should_transpose(instruction_ref ins) const; + instruction_ref to_nhwc(instruction_ref ins) const; + instruction_ref to_nchw(instruction_ref ins) const; + instruction_ref to_kcxy(instruction_ref ins) const; + std::vector to_nchw(const std::vector& args) const; + std::vector to_nhwc(const std::vector& args) const; + int64_t parse_axis(int64_t dim, size_t num_dims) const; + // tf stores certain attributes such as strides, dilations, as a 4D input. + // The first and last dims are equal to 1, and the relevant data is in dims 2 and 3. + // This helper function reorders the data to store for the respective operator member variables. + template + void reorder_data(std::vector& prev_data) const + { + std::vector new_data(prev_data.size()); + for(size_t i = 0; i < new_data.size(); i++) + { + auto new_idx = parse_axis(i, new_data.size()); + new_data.at(new_idx) = prev_data.at(i); + } + prev_data = new_data; + } + + void parse_undefined(module* mm, const std::string& name); + void parse_from(std::istream& is); + void parse_from(const void* data, std::size_t size); + void parse_graph(const tensorflow::GraphDef& graph); + void parse_node(const std::string& name); + literal parse_tensor(const tensorflow::TensorProto& t) const; + shape::type_t parse_type(tensorflow::DataType t) const; + std::vector find_outputs() const; +}; + +std::vector get_axes_from_mask(size_t num_axes, uint32_t mask); + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/tf/op_parser.cpp b/src/tf/op_parser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..95ce6db893fb04d7b0fa16c9ce05785361ceba36 --- /dev/null +++ b/src/tf/op_parser.cpp @@ -0,0 +1,31 @@ +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +std::unordered_map& op_parser_map() +{ + static std::unordered_map m; // NOLINT + return m; +} + +void register_op_parser(const std::string& name, tf_parser::op_func f) +{ + op_parser_map()[name] = std::move(f); +} +tf_parser::op_func get_op_parser(const std::string& name) { return op_parser_map().at(name); } +std::vector get_op_parsers() +{ + std::vector result; + std::transform(op_parser_map().begin(), + op_parser_map().end(), + std::back_inserter(result), + [&](auto&& p) { return p.first; }); + return result; +} + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_arg_op.cpp b/src/tf/parse_arg_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9b8f3f10eb9f177a1ee2e37cf32df3b87223f43c --- /dev/null +++ b/src/tf/parse_arg_op.cpp @@ -0,0 +1,28 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_arg_op : op_parser +{ + std::vector operators() const { return {{"ArgMax", "argmax"}, {"ArgMin", "argmin"}}; } + + instruction_ref parse(const op_desc& opd, + const tf_parser& /*parser*/, + const tf_parser::node_info& info, + const std::vector& args) const + { + int64_t axis = 0; + axis = args[1]->eval().at(); + auto ins = info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args.front()); + return info.add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_batchnorm.cpp b/src/tf/parse_batchnorm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6a4b3fd236ea2d537ec4d8fce8fa78b53c694a54 --- /dev/null +++ b/src/tf/parse_batchnorm.cpp @@ -0,0 +1,33 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_batchnorm : op_parser +{ + bool transpose() const { return true; } + std::vector operators() const { return {{"FusedBatchNorm"}, {"FusedBatchNormV3"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + tf_parser::node_info info, + const std::vector& args) const + { + float epsilon = 1e-5f; + float momentum = 0.9f; + if(contains(info.attributes, "epsilon")) + { + epsilon = info.attributes.at("epsilon").f(); + } + auto op = make_op("batch_norm_inference", {{"epsilon", epsilon}, {"momentum", momentum}}); + return info.add_instruction(op, args); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_biasadd.cpp b/src/tf/parse_biasadd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c0af38090784ece156fc6f64c455f5a63147795 --- /dev/null +++ b/src/tf/parse_biasadd.cpp @@ -0,0 +1,31 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_biasadd : op_parser +{ + bool transpose() const { return true; } + std::vector operators() const { return {{"BiasAdd"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + const tf_parser::node_info& info, + std::vector args) const + { + uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel) + + auto l0 = info.add_instruction( + make_op("broadcast", {{"axis", axis}, {"out_lens", args[0]->get_shape().lens()}}), + args[1]); + return info.add_instruction(make_op("add"), args[0], l0); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_binary_op.cpp b/src/tf/parse_binary_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a750950c5c5b2da88be06229d04b38a68fa8e88a --- /dev/null +++ b/src/tf/parse_binary_op.cpp @@ -0,0 +1,36 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_binary_op : op_parser +{ + bool transpose() const { return true; } + std::vector operators() const + { + return {{"Add", "add"}, + {"AddV2", "add"}, + {"Mul", "mul"}, + {"Pow", "pow"}, + {"SquaredDifference", "sqdiff"}, + {"Sub", "sub"}}; + } + + instruction_ref parse(const op_desc& opd, + const tf_parser& /*parser*/, + const tf_parser::node_info& info, + std::vector args) const + { + if(args.size() != 2) + MIGRAPHX_THROW("binary operators should have 2 operands"); + return info.add_broadcastable_binary_op(opd.op_name, args[0], args[1]); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_cast.cpp b/src/tf/parse_cast.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9cf9ed3269ad73f76e78ddd809e90a9f3d08bdb6 --- /dev/null +++ b/src/tf/parse_cast.cpp @@ -0,0 +1,27 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_cast : op_parser +{ + std::vector operators() const { return {{"Cast"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& parser, + tf_parser::node_info info, + const std::vector& args) const + { + shape::type_t type = parser.parse_type(info.attributes.at("DstT").type()); + return info.add_instruction(make_op("convert", {{"target_type", type}}), args); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_concat.cpp b/src/tf/parse_concat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7caa92e5e941560f5cd3e762c3ba3bf49438987c --- /dev/null +++ b/src/tf/parse_concat.cpp @@ -0,0 +1,32 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_concat : op_parser +{ + std::vector operators() const { return {{"ConcatV2"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + tf_parser::node_info info, + std::vector args) const + { + // get index for axis within args + size_t axis_idx = info.attributes.at("N").i(); + int64_t axis = args[axis_idx]->eval().at(); + auto op = make_op("concat", {{"axis", axis}}); + // return only first N arguments (assuming last index is the axis value) + return info.add_instruction( + op, std::vector(args.begin(), args.begin() + args.size() - 1)); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_constant.cpp b/src/tf/parse_constant.cpp new file mode 100644 index 0000000000000000000000000000000000000000..35f31d1fb38680669f723cbc8a9a0e76421b3bda --- /dev/null +++ b/src/tf/parse_constant.cpp @@ -0,0 +1,27 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_constant_op : op_parser +{ + bool transpose() const { return true; } + std::vector operators() const { return {{"Const"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& parser, + tf_parser::node_info info, + const std::vector& /*args*/) const + { + literal v = parser.parse_tensor(info.attributes.at("value").tensor()); + return info.add_literal(v); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_conv.cpp b/src/tf/parse_conv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..890c139c17dd6a705ff2e487cdbcb42154e23054 --- /dev/null +++ b/src/tf/parse_conv.cpp @@ -0,0 +1,94 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_conv : op_parser +{ + bool transpose() const { return true; } + std::vector operators() const { return {{"Conv2D"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& parser, + tf_parser::node_info info, + std::vector args) const + { + op::convolution op; + if(contains(info.attributes, "strides")) + { + std::vector stride; + copy(info.attributes.at("strides").list().i(), std::back_inserter(stride)); + parser.reorder_data(stride); + if(stride.size() != 4) + { + MIGRAPHX_THROW("strides should have 4 values"); + } + op.stride[0] = stride[2]; + op.stride[1] = stride[3]; + } + if(contains(info.attributes, "dilations")) + { + std::vector dilation; + copy(info.attributes.at("dilations").list().i(), std::back_inserter(dilation)); + parser.reorder_data(dilation); + if(dilation.size() != 4) + { + MIGRAPHX_THROW("dilation should have 4 values"); + } + op.dilation[0] = dilation[2]; + op.dilation[1] = dilation[3]; + } + + auto weights = parser.to_kcxy(args[1]); + auto l0 = args[0]; + if(contains(info.attributes, "padding")) + { + const std::string& pad_mode = info.attributes.at("padding").s(); + if(pad_mode.find("SAME") != std::string::npos) + { + op.padding_mode = op::padding_mode_t::same; + std::vector weight_dims = weights->get_shape().lens(); + size_t weight_h = weight_dims[2]; + size_t weight_w = weight_dims[3]; + + auto input_dims = l0->get_shape().lens(); + std::vector pads(input_dims.size()); + calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h); + calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w); + + op.padding = std::vector(pads.begin(), pads.end()); + } + else if(pad_mode.find("VALID") != std::string::npos) + { + op.padding_mode = op::padding_mode_t::valid; + } + else if(pad_mode.find("EXPLICIT") != std::string::npos) + { + std::vector padding; + copy(info.attributes.at("explicit_paddings").list().i(), + std::back_inserter(padding)); + if(padding.size() != 4) + { + MIGRAPHX_THROW("padding should have 4 values"); + } + if(padding[0] != padding[2] || padding[1] != padding[3]) + { + MIGRAPHX_THROW("migraphx does not support asymetric padding"); + } + op.padding[0] = padding[0]; + op.padding[1] = padding[1]; + } + } + return info.add_instruction(op, {l0, weights}); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_depthwiseconv.cpp b/src/tf/parse_depthwiseconv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..24cb1bfd4696d297804734707ff8221d89f74591 --- /dev/null +++ b/src/tf/parse_depthwiseconv.cpp @@ -0,0 +1,107 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_depthwiseconv : op_parser +{ + bool transpose() const { return true; } + std::vector operators() const { return {{"DepthwiseConv2dNative"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& parser, + tf_parser::node_info info, + std::vector args) const + { + op::convolution op; + size_t num_channels = args[0]->get_shape().lens()[1]; + op.group = num_channels; + + if(contains(info.attributes, "strides")) + { + std::vector stride; + copy(info.attributes.at("strides").list().i(), std::back_inserter(stride)); + parser.reorder_data(stride); + if(stride.size() != 4) + { + MIGRAPHX_THROW("strides should have 4 values"); + } + op.stride[0] = stride[2]; + op.stride[1] = stride[3]; + } + + auto weights = parser.to_kcxy(args[1]); + if(contains(info.attributes, "dilations")) + { + std::vector dilation; + copy(info.attributes.at("dilations").list().i(), std::back_inserter(dilation)); + parser.reorder_data(dilation); + if(dilation.size() != 4) + { + MIGRAPHX_THROW("dilation should have 4 values"); + } + op.dilation[0] = dilation[2]; + op.dilation[1] = dilation[3]; + } + + auto l0 = args[0]; + if(contains(info.attributes, "padding")) + { + const std::string& pad_mode = info.attributes.at("padding").s(); + + if(pad_mode.find("SAME") != std::string::npos) + { + op.padding_mode = op::padding_mode_t::same; + std::vector weight_dims = weights->get_shape().lens(); + size_t weight_h = weight_dims[2]; + size_t weight_w = weight_dims[3]; + + auto input_dims = l0->get_shape().lens(); + std::vector pads(input_dims.size()); + calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h); + calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w); + + if(pads[0] != pads[2] || pads[1] != pads[3]) + { + std::vector padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]}; + l0 = info.add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0); + } + else + { + op.padding[0] = pads[0]; + op.padding[1] = pads[1]; + } + } + else if(pad_mode.find("VALID") != std::string::npos) + { + op.padding_mode = op::padding_mode_t::valid; + } + } + + std::vector new_weights_shape; + copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape)); + + // weight format is (out_channels, in_channels, h, w), but in depthwise_conv, + // out_channels is equal to the multiplier. Adjust by inserting a reshape and + // setting in_channels to 1 + int64_t multiplier = new_weights_shape[0]; + int64_t out_channels = num_channels * multiplier; + new_weights_shape[0] = out_channels; + new_weights_shape[1] = 1; + // Make sure weights are contiguous before doing reshape + auto new_weights = info.add_instruction(make_op("reshape", {{"dims", new_weights_shape}}), + info.make_contiguous(weights)); + + return info.add_instruction(op, {l0, new_weights}); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_expanddims.cpp b/src/tf/parse_expanddims.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5f22459570cd783679c284d8f9fe9a92fbe78b53 --- /dev/null +++ b/src/tf/parse_expanddims.cpp @@ -0,0 +1,39 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_expanddims : op_parser +{ + std::vector operators() const { return {{"ExpandDims"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + const tf_parser::node_info& info, + std::vector args) const + { + std::vector input_dims = args[0]->get_shape().lens(); + std::vector new_dims(input_dims.begin(), input_dims.end()); + size_t num_dims = input_dims.size(); + int32_t dim = args[1]->eval().at(); + + if(dim < 0) + { + new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1); + } + else + { + new_dims.insert(new_dims.begin() + dim, 1); + } + return info.add_instruction(make_op("reshape", {{"dims", new_dims}}), args[0]); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_gather.cpp b/src/tf/parse_gather.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2766a9b272dcb552f673ed38d057b095e8a60b8d --- /dev/null +++ b/src/tf/parse_gather.cpp @@ -0,0 +1,27 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_gather : op_parser +{ + std::vector operators() const { return {{"GatherV2"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + const tf_parser::node_info& info, + std::vector args) const + { + int axis = args[2]->eval().at(); + return info.add_instruction(make_op("gather", {{"axis", axis}}), {args[0], args[1]}); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_generic_op.cpp b/src/tf/parse_generic_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..12c3278713ea7f9d9ac645658fb28f193f36e509 --- /dev/null +++ b/src/tf/parse_generic_op.cpp @@ -0,0 +1,35 @@ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_generic_op : op_parser +{ + bool transpose() const { return true; } + std::vector operators() const + { + return {{"All", "identity"}, + {"Identity", "identity"}, + {"LessEqual", "identity"}, + {"Relu", "relu"}, + {"Rsqrt", "rsqrt"}, + {"Tanh", "tanh"}, + {"StopGradient", "identity"}}; + } + + instruction_ref parse(const op_desc& opd, + const tf_parser& /*parser*/, + const tf_parser::node_info& info, + const std::vector& args) const + { + return info.add_instruction(make_op(opd.op_name), args); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_matmul.cpp b/src/tf/parse_matmul.cpp new file mode 100644 index 0000000000000000000000000000000000000000..959f4dc20a53ded7842964c27e94c6a986f09273 --- /dev/null +++ b/src/tf/parse_matmul.cpp @@ -0,0 +1,62 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_matmul : op_parser +{ + std::vector operators() const + { + return {{"BatchMatMul"}, {"BatchMatMulV2"}, {"MatMul"}}; + } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + tf_parser::node_info info, + std::vector args) const + { + bool transa = false; + bool transb = false; + + if(contains(info.attributes, "transpose_a")) + { + transa = info.attributes.at("transpose_a").b(); + } + if(contains(info.attributes, "transpose_b")) + { + transb = info.attributes.at("transpose_b").b(); + } + + if(contains(info.attributes, "adj_x")) + { + transa = info.attributes.at("adj_x").b(); + } + if(contains(info.attributes, "adj_y")) + { + transb = info.attributes.at("adj_y").b(); + } + + std::vector perm(args[0]->get_shape().lens().size()); + std::iota(perm.begin(), perm.end(), int64_t{0}); + // swap the last two elements + std::iter_swap(perm.end() - 1, perm.end() - 2); + + auto l1 = (transa) + ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0]) + : args[0]; + auto l2 = (transb) + ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1]) + : args[1]; + + return info.add_instruction(make_op("dot"), l1, l2); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_mean.cpp b/src/tf/parse_mean.cpp new file mode 100644 index 0000000000000000000000000000000000000000..88842cf9647a3bdc6a9ce57e6000e279d5b70e4e --- /dev/null +++ b/src/tf/parse_mean.cpp @@ -0,0 +1,32 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_mean : op_parser +{ + std::vector operators() const { return {{"Mean"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + tf_parser::node_info info, + std::vector args) const + { + bool keep_dims = info.attributes.at("keep_dims").b(); + auto axes = args[1]->eval().get().to_vector(); + + auto ins = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), args[0]); + if(not keep_dims) + ins = info.add_instruction(make_op("squeeze", {{"axes", axes}}), ins); + return ins; + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_onehot.cpp b/src/tf/parse_onehot.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5805d4b2f48aaf0ed41a03bf4bf131f81f88bd30 --- /dev/null +++ b/src/tf/parse_onehot.cpp @@ -0,0 +1,46 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_onehot : op_parser +{ + std::vector operators() const { return {{"OneHot"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + tf_parser::node_info info, + std::vector args) const + { + size_t depth = static_cast(args[1]->eval().at()); + + int64_t axis = -1; + float on_value = args[2]->eval().at(); + float off_value = args[3]->eval().at(); + + std::vector depth_input(depth * depth, off_value); + for(int i = 0; i < depth; i++) + { + depth_input[depth * i + i] = on_value; + } + + if(contains(info.attributes, "axis")) + axis = info.attributes.at("axis").i(); + if(axis == -1) + { + shape s{shape::float_type, {depth, depth}}; + auto l0 = info.add_literal({s, depth_input}); + return info.add_instruction(make_op("gather", {{"axis", 0}}), {l0, args[0]}); + } + MIGRAPHX_THROW("MIGraphX does not support axis != -1"); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_pack.cpp b/src/tf/parse_pack.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a0722ddd0b688595a55034dbdf7a073fe68d47f4 --- /dev/null +++ b/src/tf/parse_pack.cpp @@ -0,0 +1,47 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_pack : op_parser +{ + std::vector operators() const { return {{"Pack"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& parser, + tf_parser::node_info info, + std::vector args) const + { + // reinterpret as unsqueeze with concat + std::vector unsqueezed_args; + int64_t axis = 0; + if(contains(info.attributes, "axis")) + axis = info.attributes.at("axis").i(); + size_t input_size = args.front()->get_shape().lens().size(); + if(axis > input_size) + { + MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) + + " must be smaller than input size " + to_string(input_size)); + } + + std::transform( + args.begin(), + args.end(), + std::back_inserter(unsqueezed_args), + [&](instruction_ref arg) { + return info.add_instruction(make_op("unsqueeze", {{"axes", {axis}}}), arg); + }); + return parser.to_nhwc( + info.add_instruction(make_op("concat", {{"axis", axis}}), unsqueezed_args)); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_pad.cpp b/src/tf/parse_pad.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c3e88f7f00b73163818680d614179f096f3a6bf0 --- /dev/null +++ b/src/tf/parse_pad.cpp @@ -0,0 +1,46 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_pad : op_parser +{ + bool transpose() const { return true; } + std::vector operators() const { return {{"Pad"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& parser, + const tf_parser::node_info& info, + std::vector args) const + { + size_t ndims = args.front()->get_shape().lens().size(); + + // in tf, the paddings are arranged as a 2d shape (ndims, 2), + // the last dim contains the left padding and right padding respectively + std::vector> pad_per_dim(ndims); + auto tf_padding = args[1]->eval().get().to_vector(); + for(size_t i = 0; i < 2 * ndims; i += 2) + { + pad_per_dim[i / 2].first = tf_padding[i]; + pad_per_dim[i / 2].second = tf_padding[i + 1]; + } + parser.reorder_data(pad_per_dim); + + std::vector pads(ndims * 2); + for(size_t i = 0; i < ndims; i++) + { + pads[i] = pad_per_dim[i].first; + pads[i + ndims] = pad_per_dim[i].second; + } + return info.add_instruction(make_op("pad", {{"pads", pads}}), args.front()); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_pooling.cpp b/src/tf/parse_pooling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fd454ee9663584fdf1bf28530c8786e6233ee07e --- /dev/null +++ b/src/tf/parse_pooling.cpp @@ -0,0 +1,74 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_pooling : op_parser +{ + bool transpose() const { return true; } + std::vector operators() const { return {{"AvgPool"}, {"MaxPool"}}; } + + instruction_ref parse(const op_desc& opd, + const tf_parser& parser, + tf_parser::node_info info, + std::vector args) const + { + if(!starts_with(opd.tf_name, "Max") && !starts_with(opd.tf_name, "Av")) + { + MIGRAPHX_THROW("tf pooling mode must be Max or Average"); + } + op::pooling op{starts_with(opd.tf_name, "Max") ? op::pooling_mode::max + : op::pooling_mode::average}; + + if(contains(info.attributes, "strides")) + { + std::vector stride; + copy(info.attributes.at("strides").list().i(), std::back_inserter(stride)); + parser.reorder_data(stride); + if(stride.size() != 4) + { + MIGRAPHX_THROW("strides should have 4 values"); + } + op.stride[0] = stride[2]; + op.stride[1] = stride[3]; + } + if(contains(info.attributes, "ksize")) + { + std::vector ksize; + copy(info.attributes.at("ksize").list().i(), std::back_inserter(ksize)); + parser.reorder_data(ksize); + if(ksize.size() != 4) + { + MIGRAPHX_THROW("ksize should have 4 values"); + } + op.lengths[0] = ksize[2]; + op.lengths[1] = ksize[3]; + } + + auto l0 = args[0]; + if(contains(info.attributes, "padding")) + { + const std::string& pad_mode = info.attributes.at("padding").s(); + if(pad_mode.find("SAME") != std::string::npos) + { + auto input_dims = l0->get_shape().lens(); + std::vector pads(input_dims.size()); + calculate_padding(0, pads, input_dims[2], op.stride[0], 1, op.lengths[0]); + calculate_padding(1, pads, input_dims[3], op.stride[1], 1, op.lengths[1]); + + op.padding = std::vector(pads.begin(), pads.end()); + } + } + return info.add_instruction(op, l0); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_relu6.cpp b/src/tf/parse_relu6.cpp new file mode 100644 index 0000000000000000000000000000000000000000..54038603757b4adcc9b00465e130cc860708810e --- /dev/null +++ b/src/tf/parse_relu6.cpp @@ -0,0 +1,30 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_relu6 : op_parser +{ + bool transpose() const { return true; } + std::vector operators() const { return {{"Relu6"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + const tf_parser::node_info& info, + std::vector args) const + { + auto min_val = info.add_literal(0.0f); + auto max_val = info.add_literal(6.0f); + + return info.add_common_op("clip", args[0], min_val, max_val); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_reshape.cpp b/src/tf/parse_reshape.cpp new file mode 100644 index 0000000000000000000000000000000000000000..29fa89ea4005dd89866ba8bfff376b01ae97f888 --- /dev/null +++ b/src/tf/parse_reshape.cpp @@ -0,0 +1,32 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_reshape : op_parser +{ + std::vector operators() const { return {{"Reshape"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + const tf_parser::node_info& info, + std::vector args) const + { + if(args.size() != 2) + MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)"); + auto s = args[1]->eval(); + std::vector dims; + s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); + return info.add_instruction(make_op("reshape", {{"dims", dims}}), + info.make_contiguous(args[0])); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_shape.cpp b/src/tf/parse_shape.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d8f4deb24702b5b1098d353fd6b07c6bfabf59d1 --- /dev/null +++ b/src/tf/parse_shape.cpp @@ -0,0 +1,33 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_shape : op_parser +{ + std::vector operators() const { return {{"Shape"}}; } + + // Use a literal instruction to replace the shape since output of + // shape operator are literals in migraphx + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + const tf_parser::node_info& info, + std::vector args) const + { + std::vector arg_shape = args[0]->get_shape().lens(); + std::vector vec_shape(arg_shape.size()); + migraphx::shape s(migraphx::shape::int32_type, {arg_shape.size()}); + std::transform( + arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { return i; }); + return info.add_literal(migraphx::literal{s, vec_shape}); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_slice.cpp b/src/tf/parse_slice.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ad2ca0c2d062c97b4a72b5bc5c0326cf6699184a --- /dev/null +++ b/src/tf/parse_slice.cpp @@ -0,0 +1,46 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_slice : op_parser +{ + std::vector operators() const { return {{"Slice"}}; } + + // Use a literal instruction to replace the shape since output of + // shape operator are literals in migraphx + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + const tf_parser::node_info& info, + std::vector args) const + { + auto starts = args[1]->eval().get().to_vector(); + auto size = args[2]->eval().get().to_vector(); + auto axes = args[0]->get_shape().lens(); + size_t num_axes = axes.size(); + + std::vector axes_int64(axes.begin(), axes.end()); + std::vector starts_int64(starts.begin(), starts.end()); + std::vector ends(num_axes); + std::vector op_axes(num_axes); + std::iota(op_axes.begin(), op_axes.end(), 0); + for(size_t i = 0; i < num_axes; i++) + { + if(size[i] == -1) + ends[i] = axes_int64[i]; + else + ends[i] = starts_int64[i] + size[i]; + } + auto op = make_op("slice", {{"starts", starts_int64}, {"ends", ends}, {"axes", op_axes}}); + return info.add_instruction(op, info.make_contiguous(args[0])); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_softmax.cpp b/src/tf/parse_softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bf9aae605b94f6de00e3cc529ee23d81a87219fa --- /dev/null +++ b/src/tf/parse_softmax.cpp @@ -0,0 +1,37 @@ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_softmax : op_parser +{ + std::vector operators() const { return {{"Softmax"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + tf_parser::node_info info, + std::vector args) const + { + int axis = -1; + auto num_dims = args[0]->get_shape().lens().size(); + if(contains(info.attributes, "axis")) + { + axis = static_cast(info.attributes.at("axis").i()); + } + + axis = tune_axis(num_dims, axis, "tf_parse_softmax"); + + return info.add_instruction(make_op("softmax", {{"axis", axis}}), + info.make_contiguous(args[0])); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_split.cpp b/src/tf/parse_split.cpp new file mode 100644 index 0000000000000000000000000000000000000000..55c9e0c5796f8d7a4e735abdb5c052f2894dc190 --- /dev/null +++ b/src/tf/parse_split.cpp @@ -0,0 +1,98 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_split : op_parser +{ + std::vector operators() const { return {{"Split"}, {"SplitV"}}; } + + std::vector parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + tf_parser::node_info info, + std::vector args) const + { + bool vector_as_input = args.size() == 3; + int num_outputs = 1; + auto axis_arg = args[0]; + auto input_arg = args[1]; + if(vector_as_input) + { + input_arg = args[0]; + axis_arg = args[2]; + } + + if(contains(info.attributes, "num_split")) + num_outputs = info.attributes.at("num_split").i(); + + std::vector splits(num_outputs); + std::vector slice_pos{0}; + if(vector_as_input) + { + splits = args[1]->eval().get().to_vector(); + num_outputs = splits.size(); + } + + assert(num_outputs > 0); + + if(num_outputs == 1) + return std::vector{ + info.add_instruction(make_op("identity"), input_arg)}; + + auto lens = input_arg->get_shape().lens(); + auto num_dims = lens.size(); + int axis = axis_arg->eval().at(); + + // ensure split is made evenly if "num_split" is used + assert(vector_as_input or lens[axis] % num_outputs == 0); + + auto split_size = lens[axis] / num_outputs; + + // push back first end point of slice + if(vector_as_input) + { + slice_pos.push_back(splits[0]); + } + else + { + slice_pos.push_back(split_size); + } + + // calculate remaining end points for each slice + for(auto i = 1; i < num_outputs; i++) + { + if(vector_as_input) + { + splits[i] += splits[i - 1]; + slice_pos.push_back(splits[i]); + } + else + { + slice_pos.push_back((i + 1) * split_size); + } + } + std::vector result; + for(auto i = 0; i < num_outputs; i++) + { + std::vector axes(num_dims); + std::iota(axes.begin(), axes.end(), 0); + std::vector starts(num_dims, 0); + std::vector ends(lens.begin(), lens.end()); + + starts[axis] = slice_pos[i]; + ends[axis] = slice_pos[i + 1]; + auto op = make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}); + result.push_back(info.add_instruction(op, input_arg)); + } + return result; + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_squeeze.cpp b/src/tf/parse_squeeze.cpp new file mode 100644 index 0000000000000000000000000000000000000000..213e4642ac97122d163b27da3e5e2dc3ca14c628 --- /dev/null +++ b/src/tf/parse_squeeze.cpp @@ -0,0 +1,41 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_squeeze : op_parser +{ + std::vector operators() const { return {{"Squeeze"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + tf_parser::node_info info, + std::vector args) const + { + auto input_dims = args[0]->get_shape().lens(); + auto axes = info.attributes.at("squeeze_dims").list().i(); + std::vector op_axes(axes.begin(), axes.end()); + + if(op_axes.empty()) // no squeeze_dims provided, remove any dim that equals 1 + { + for(size_t i = 0; i < input_dims.size(); i++) + { + if(input_dims.at(i) == 1) + { + op_axes.push_back(i); + } + } + } + return info.add_instruction(make_op("squeeze", {{"axes", op_axes}}), + info.make_contiguous(args[0])); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_stridedslice.cpp b/src/tf/parse_stridedslice.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eac324d519c66ad8b28882b9554e763bece96fbc --- /dev/null +++ b/src/tf/parse_stridedslice.cpp @@ -0,0 +1,78 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_strideslice : op_parser +{ + std::vector operators() const { return {{"StridedSlice"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + tf_parser::node_info info, + std::vector args) const + { + auto starts = args[1]->eval().get().to_vector(); + auto ends = args[2]->eval().get().to_vector(); + auto l0 = args[0]; + size_t num_axes = l0->get_shape().lens().size(); + std::vector axes = l0->get_shape().lens(); + + std::vector op_starts(starts.begin(), starts.end()); + std::vector op_ends(ends.begin(), ends.end()); + std::vector op_axes(num_axes); + std::iota(op_axes.begin(), op_axes.end(), 0); + uint32_t begin_mask = 0; + uint32_t end_mask = 0; + uint32_t shrink_axis_mask = 0; + uint32_t bitwise_compare = 1; + std::vector squeeze_axes; + + if(contains(info.attributes, "begin_mask")) + begin_mask = static_cast(info.attributes.at("begin_mask").i()); + + if(contains(info.attributes, "end_mask")) + end_mask = static_cast(info.attributes.at("end_mask").i()); + + if(contains(info.attributes, "shrink_axis_mask")) + shrink_axis_mask = static_cast(info.attributes.at("shrink_axis_mask").i()); + + std::vector begin_axes = get_axes_from_mask(num_axes, begin_mask); + std::vector end_axes = get_axes_from_mask(num_axes, end_mask); + + for(size_t i = 0; i < num_axes; i++) + { + if(begin_axes.at(i) == 1) + { + op_starts.at(i) = 0; + } + if(end_axes.at(i) == 1) + { + op_ends.at(i) = axes.at(i); + } + } + + auto op = make_op("slice", {{"starts", op_starts}, {"ends", op_ends}, {"axes", op_axes}}); + auto l1 = info.add_instruction(op, l0); + if(shrink_axis_mask == 0) + return l1; + + for(size_t i = 0; i < num_axes; i++) + { + // the LSB corresponds to axis 0 when determining which axes to squeeze + if(((shrink_axis_mask >> i) & bitwise_compare) == 1) + squeeze_axes.push_back(i); + } + + return info.add_instruction(make_op("squeeze", {{"axes", squeeze_axes}}), l1); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/parse_transpose.cpp b/src/tf/parse_transpose.cpp new file mode 100644 index 0000000000000000000000000000000000000000..886966ee660f3196778a5210a0d320898416bfc2 --- /dev/null +++ b/src/tf/parse_transpose.cpp @@ -0,0 +1,29 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +struct parse_transpose : op_parser +{ + std::vector operators() const { return {{"Transpose"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const tf_parser& /*parser*/, + const tf_parser::node_info& info, + std::vector args) const + { + auto perm = args[1]->eval().get().to_vector(); + std::vector dims(perm.begin(), perm.end()); + + return info.add_instruction(make_op("transpose", {{"permutation", dims}}), args.front()); + } +}; + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tf/tf.cpp b/src/tf/tf.cpp index 38477cc21ee974cadc25632ed5f68382f226d429..8ee5c900d512ad980bebed1b7e1d418431e80c67 100644 --- a/src/tf/tf.cpp +++ b/src/tf/tf.cpp @@ -1,1353 +1,26 @@ -#include -#include -#include +#include #include #include #include -#include #include #include #include #include -#include #include -#include -#include -#include -#include #include -#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct tf_parser -{ - using attribute_map = std::unordered_map; - using node_map = std::map; - using op_func = - std::function(attribute_map, std::vector)>; - - node_map nodes; - std::vector input_nodes; - std::unordered_map instructions; - program prog = program(); - bool is_nhwc = true; - - std::unordered_map ops; - - bool should_transpose(instruction_ref ins) const - { - return is_nhwc and ins->get_shape().lens().size() == 4; - } - - instruction_ref to_nhwc(instruction_ref ins) - { - if(should_transpose(ins)) - return prog.add_instruction(op::transpose{{0, 2, 3, 1}}, ins); - return ins; - } - - instruction_ref to_nchw(instruction_ref ins) - { - if(should_transpose(ins)) - return prog.add_instruction(op::transpose{{0, 3, 1, 2}}, ins); - return ins; - } - - instruction_ref to_kcxy(instruction_ref ins) - { - if(should_transpose(ins)) - return prog.add_instruction(op::transpose{{3, 2, 0, 1}}, ins); - return ins; - } - - instruction_ref make_contiguous(instruction_ref ins) - { - if(ins->get_shape().standard()) - return ins; - else - return prog.add_instruction(op::contiguous{}, ins); - } - - std::vector to_nchw(const std::vector& args) - { - std::vector result(args.size()); - std::transform( - args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nchw(ins); }); - return result; - } - - std::vector to_nhwc(const std::vector& args) - { - std::vector result(args.size()); - std::transform( - args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nhwc(ins); }); - return result; - } - - std::vector - parse_axes(const attribute_map& attributes, const std::string& s, const size_t num_dims) const - { - auto attrs = attributes.at(s).list().i(); - std::vector axes; - copy(attrs.begin(), attrs.end(), std::back_inserter(axes)); - if(is_nhwc) - { - std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) { - return parse_axis(axis, num_dims); - }); - } - return axes; - } - - template - std::vector parse_axes(std::vector axes, const size_t num_dims) const - { - if(is_nhwc) - { - std::vector new_axes; - std::transform(axes.begin(), - axes.end(), - std::back_inserter(new_axes), - [&](size_t axis) { return parse_axis(axis, num_dims); }); - return new_axes; - } - return axes; - } - - // tf stores certain attributes such as strides, dilations, as a 4D input. - // The first and last dims are equal to 1, and the relevant data is in dims 2 and 3. - // This helper function reorders the data to store for the respective operator member variables. - template - void reorder_data(std::vector& prev_data) const - { - std::vector new_data(prev_data.size()); - for(size_t i = 0; i < new_data.size(); i++) - { - auto new_idx = parse_axis(i, new_data.size()); - new_data.at(new_idx) = prev_data.at(i); - } - prev_data = new_data; - } - - template - T parse_axis(const T& dim, const size_t num_dims) const - { - T new_dim = dim; - if(is_nhwc and num_dims >= 4) - { - switch(dim) - { - case 0: new_dim = 0; break; - case 1: new_dim = 2; break; - case 2: new_dim = 3; break; - case 3: new_dim = 1; break; - default: break; - } - } - return new_dim; - } - - std::vector get_axes(size_t num_axes) const - { - std::vector axes(num_axes); - std::iota(axes.begin(), axes.end(), 0); - return axes; - } - - std::vector get_axes_from_mask(const size_t num_axes, const uint32_t mask) - { - uint32_t bitwise_compare = 1; - std::vector axes; - for(size_t i = 0; i < num_axes; i++) - { - // the LSB corresponds to axis 0 when determining which axes to begin - if(((mask >> i) & bitwise_compare) == 1) - axes.push_back(1); - else - axes.push_back(0); - } - return axes; - } - - tf_parser() - { - add_generic_op("All", op::identity{}); - add_generic_op("Identity", op::identity{}); - add_generic_op("LessEqual", op::identity{}); - add_generic_op("Relu", op::relu{}); - add_generic_op("Relu6", op::clip{6.0, 0.0}); - add_generic_op("Rsqrt", op::rsqrt{}); - add_generic_op("Tanh", op::tanh{}); - add_generic_op("StopGradient", op::identity{}); - - add_binary_op("Add", op::add{}); - add_binary_op("Mul", op::mul{}); - add_binary_op("Pow", op::pow{}); - add_binary_op("SquaredDifference", op::sqdiff{}); - add_binary_op("Sub", op::sub{}); - - add_mem_op("AvgPool", &tf_parser::parse_pooling); - add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false); - add_mem_op("BatchMatMulV2", &tf_parser::parse_matmul, false); - add_mem_op("BiasAdd", &tf_parser::parse_biasadd); - add_mem_op("Cast", &tf_parser::parse_cast, false); - add_mem_op("ConcatV2", &tf_parser::parse_concat, false); - add_mem_op("Const", &tf_parser::parse_constant); - add_mem_op("Conv2D", &tf_parser::parse_conv); - add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv); - add_mem_op("ExpandDims", &tf_parser::parse_expanddims, false); - add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); - add_mem_op("GatherV2", &tf_parser::parse_gather, false); - add_mem_op("MatMul", &tf_parser::parse_matmul, false); - add_mem_op("MaxPool", &tf_parser::parse_pooling); - add_mem_op("Mean", &tf_parser::parse_mean, false); - add_mem_op("OneHot", &tf_parser::parse_onehot, false); - add_mem_op("Pack", &tf_parser::parse_pack, false); - add_mem_op("Pad", &tf_parser::parse_pad); - add_mem_op("Reshape", &tf_parser::parse_reshape, false); - add_mem_op("Slice", &tf_parser::parse_slice, false); - add_mem_op("Split", &tf_parser::parse_split, false); - add_mem_op("SplitV", &tf_parser::parse_split, false); - add_mem_op("Softmax", &tf_parser::parse_softmax, false); - add_mem_op("Squeeze", &tf_parser::parse_squeeze, false); - add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false); - add_mem_op("Transpose", &tf_parser::parse_transpose, false); - } - - template - void add_op(const std::string& name, F f, bool transpose = true) - { - if(transpose) - { - ops.emplace( - name, - op_func{ - [=](const attribute_map& attributes, const std::vector& args) { - return std::vector{to_nhwc(f(attributes, to_nchw(args)))}; - }}); - } - else - { - ops.emplace(name, - op_func{[=](const attribute_map& attributes, - const std::vector& args) { - return std::vector{f(attributes, args)}; - }}); - } - } - - template - void add_mem_op(std::string name, F f, bool transpose = true) - { - add_op(name, - [=](auto&&... xs) { - return std::mem_fn(f)(*this, name, std::forward(xs)...); - }, - transpose); - } - - template - void add_binary_op(std::string name, T x) - { - add_op(name, - [this, x](const attribute_map&, std::vector args) { - if(args.size() != 2) - MIGRAPHX_THROW("binary operators should have 2 operands"); - // TODO - // if(contains(attributes, "data_format")) - // { - // if(is_nhwc) - // { - // l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]); - // } - // } - return add_broadcastable_binary_op(args[0], args[1], x); - }, - false); - } - - template - instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x) - { - if(arg0->get_shape().lens() != arg1->get_shape().lens()) - { - // Example: - // s0 = (3,2,4,5) and s1 = (2,1,1) - // - // In this case we need to broadcast (:,1,1) portion of - // s1 plus broadcast the 1st dimension of s1 - // giving output_lens = (3,2,4,5) - // - // Another example: - // s0 = (3,2,1,5) and s1 = (2,7,5) - // In this case we need to broadcast the (:,:,1:,:) axis - // of s0 plus the 1st dimension of s1 giving - // output_lens = (3,2,7,5) - // - // Get lengths for both arguments - const std::vector* s0 = &arg0->get_shape().lens(); - const std::vector* s1 = &arg1->get_shape().lens(); - - // Make sure s0 is the smaller size - if(s0->size() > s1->size()) - std::swap(s0, s1); - - std::vector output_lens(*s1); - auto offset = s1->size() - s0->size(); - std::transform(s0->begin(), - s0->end(), - s1->begin() + offset, - output_lens.begin() + offset, - [](auto a, auto b) { return std::max(a, b); }); - - auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0); - auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1); - return to_nhwc(prog.add_instruction(x, to_nchw(l0), to_nchw(l1))); - } - else - { - return to_nhwc(prog.add_instruction(x, {to_nchw(arg0), to_nchw(arg1)})); - } - } - - template - void add_generic_op(std::string name, T x, bool transpose = true) - { - add_op(name, - [this, x](const attribute_map&, std::vector args) { - return prog.add_instruction(x, args); - }, - transpose); - } - - instruction_ref - parse_batchnorm(const std::string&, attribute_map attributes, std::vector args) - { - float epsilon = 1e-5f; - float momentum = 0.9f; - op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial; - if(contains(attributes, "epsilon")) - { - epsilon = attributes.at("epsilon").f(); - } - op::batch_norm_inference op{epsilon, momentum, bn_mode}; - return prog.add_instruction(op, std::move(args)); - } - - instruction_ref - parse_biasadd(const std::string&, const attribute_map&, std::vector args) - { - uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel) - auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]); - return prog.add_instruction(op::add{}, args[0], l0); - } - - instruction_ref - parse_cast(const std::string&, attribute_map attributes, std::vector args) - { - shape::type_t type = parse_type(attributes.at("DstT").type()); - return prog.add_instruction(op::convert{type}, std::move(args)); - } - - instruction_ref - parse_concat(const std::string&, attribute_map attributes, std::vector args) - { - // get index for axis within args - size_t axis_idx = attributes.at("N").i(); - int64_t axis = args[axis_idx]->eval().at(); - op::concat op{axis}; - // return only first N arguments (assuming last index is the axis value) - return prog.add_instruction( - op, std::vector(args.begin(), args.begin() + args.size() - 1)); - } - - instruction_ref parse_constant(const std::string&, - attribute_map attributes, - const std::vector&) - { - literal v = parse_tensor(attributes.at("value").tensor()); - return prog.add_literal(v); - } - - instruction_ref - parse_conv(const std::string&, attribute_map attributes, std::vector args) - { - op::convolution op; - if(contains(attributes, "strides")) - { - std::vector stride; - copy(attributes.at("strides").list().i(), std::back_inserter(stride)); - reorder_data(stride); - if(stride.size() != 4) - { - MIGRAPHX_THROW("strides should have 4 values"); - } - op.stride[0] = stride[2]; - op.stride[1] = stride[3]; - } - if(contains(attributes, "dilations")) - { - std::vector dilation; - copy(attributes.at("dilations").list().i(), std::back_inserter(dilation)); - reorder_data(dilation); - if(dilation.size() != 4) - { - MIGRAPHX_THROW("dilation should have 4 values"); - } - op.dilation[0] = dilation[2]; - op.dilation[1] = dilation[3]; - } - - auto weights = to_kcxy(args[1]); - auto l0 = args[0]; - if(contains(attributes, "padding")) - { - const std::string& pad_mode = attributes.at("padding").s(); - if(pad_mode.find("SAME") != std::string::npos) - { - op.padding_mode = op::padding_mode_t::same; - std::vector weight_dims = weights->get_shape().lens(); - size_t weight_h = weight_dims[2]; - size_t weight_w = weight_dims[3]; - - auto input_dims = l0->get_shape().lens(); - std::vector pads(input_dims.size()); - calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h); - calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w); - - if(pads[0] != pads[2] || pads[1] != pads[3]) - { - std::vector padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]}; - l0 = prog.add_instruction(migraphx::op::pad{padding}, l0); - } - else - { - op.padding[0] = pads[0]; - op.padding[1] = pads[1]; - } - } - else if(pad_mode.find("VALID") != std::string::npos) - { - op.padding_mode = op::padding_mode_t::valid; - } - else if(pad_mode.find("EXPLICIT") != std::string::npos) - { - std::vector padding; - copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding)); - if(padding.size() != 4) - { - MIGRAPHX_THROW("padding should have 4 values"); - } - if(padding[0] != padding[2] || padding[1] != padding[3]) - { - MIGRAPHX_THROW("migraphx does not support asymetric padding"); - } - op.padding[0] = padding[0]; - op.padding[1] = padding[1]; - } - } - return prog.add_instruction(op, {l0, to_kcxy(args[1])}); - } - - instruction_ref parse_depthwiseconv(const std::string&, - attribute_map attributes, - std::vector args) - { - op::convolution op; - size_t num_channels = args[0]->get_shape().lens()[1]; - op.group = num_channels; - - if(contains(attributes, "strides")) - { - std::vector stride; - copy(attributes.at("strides").list().i(), std::back_inserter(stride)); - reorder_data(stride); - if(stride.size() != 4) - { - MIGRAPHX_THROW("strides should have 4 values"); - } - op.stride[0] = stride[2]; - op.stride[1] = stride[3]; - } - - auto weights = to_kcxy(args[1]); - if(contains(attributes, "dilations")) - { - std::vector dilation; - copy(attributes.at("dilations").list().i(), std::back_inserter(dilation)); - reorder_data(dilation); - if(dilation.size() != 4) - { - MIGRAPHX_THROW("dilation should have 4 values"); - } - op.dilation[0] = dilation[2]; - op.dilation[1] = dilation[3]; - } - - auto l0 = args[0]; - if(contains(attributes, "padding")) - { - const std::string& pad_mode = attributes.at("padding").s(); - - if(pad_mode.find("SAME") != std::string::npos) - { - op.padding_mode = op::padding_mode_t::same; - std::vector weight_dims = weights->get_shape().lens(); - size_t weight_h = weight_dims[2]; - size_t weight_w = weight_dims[3]; - - auto input_dims = l0->get_shape().lens(); - std::vector pads(input_dims.size()); - calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h); - calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w); - - if(pads[0] != pads[2] || pads[1] != pads[3]) - { - std::vector padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]}; - l0 = prog.add_instruction(migraphx::op::pad{padding}, l0); - } - else - { - op.padding[0] = pads[0]; - op.padding[1] = pads[1]; - } - } - else if(pad_mode.find("VALID") != std::string::npos) - { - op.padding_mode = op::padding_mode_t::valid; - } - } - - std::vector new_weights_shape; - copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape)); - - // weight format is (out_channels, in_channels, h, w), but in depthwise_conv, - // out_channels is equal to the multiplier. Adjust by inserting a reshape and - // setting in_channels to 1 - int64_t multiplier = new_weights_shape[0]; - int64_t out_channels = num_channels * multiplier; - new_weights_shape[0] = out_channels; - new_weights_shape[1] = 1; - // Make sure weights are contiguous before doing reshape - auto new_weights = - prog.add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights)); - - return prog.add_instruction(op, {l0, new_weights}); - } - - instruction_ref - parse_expanddims(const std::string&, const attribute_map&, std::vector args) - { - std::vector input_dims = args[0]->get_shape().lens(); - std::vector new_dims(input_dims.begin(), input_dims.end()); - size_t num_dims = input_dims.size(); - int32_t dim = args[1]->eval().at(); - - if(dim < 0) - { - new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1); - } - else - { - new_dims.insert(new_dims.begin() + dim, 1); - } - return prog.add_instruction(op::reshape{new_dims}, args[0]); - } - - instruction_ref - parse_gather(const std::string&, const attribute_map&, std::vector args) - { - int axis = args[2]->eval().at(); - op::gather op{axis}; - return prog.add_instruction(op, {args[0], args[1]}); - } - - instruction_ref - parse_matmul(const std::string&, attribute_map attributes, std::vector args) - { - bool transa = false; - bool transb = false; - - if(contains(attributes, "transpose_a")) - { - transa = attributes.at("transpose_a").b(); - } - if(contains(attributes, "transpose_b")) - { - transb = attributes.at("transpose_b").b(); - } - - if(contains(attributes, "adj_x")) - { - transa = attributes.at("adj_x").b(); - } - if(contains(attributes, "adj_y")) - { - transb = attributes.at("adj_y").b(); - } - - std::vector perm(args[0]->get_shape().lens().size()); - std::iota(perm.begin(), perm.end(), int64_t{0}); - // swap the last two elements - std::iter_swap(perm.end() - 1, perm.end() - 2); - - auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0]; - auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1]; - - return prog.add_instruction(op::dot{}, l1, l2); - } - - instruction_ref - parse_mean(const std::string&, attribute_map attributes, std::vector args) - { - bool keep_dims = attributes.at("keep_dims").b(); - auto axes = args[1]->eval().get().to_vector(); - - if(keep_dims) - { - return prog.add_instruction(op::reduce_mean{axes}, args[0]); - } - else - { - auto ins = prog.add_instruction(op::reduce_mean{axes}, args[0]); - return prog.add_instruction(op::squeeze{axes}, ins); - } - } - - instruction_ref - parse_onehot(const std::string&, attribute_map attributes, std::vector args) - { - size_t depth = static_cast(args[1]->eval().at()); - - int64_t axis = -1; - float on_value = args[2]->eval().at(); - float off_value = args[3]->eval().at(); - - std::vector depth_input(depth * depth, off_value); - for(int i = 0; i < depth; i++) - { - depth_input[depth * i + i] = on_value; - } - - if(contains(attributes, "axis")) - axis = attributes.at("axis").i(); - if(axis == -1) - { - shape s{shape::float_type, {depth, depth}}; - auto l0 = prog.add_literal({s, depth_input}); - return prog.add_instruction(op::gather{0}, {l0, args[0]}); - } - MIGRAPHX_THROW("MIGraphX does not support axis != -1"); - } - - instruction_ref parse_pack(const std::string&, - const attribute_map& attributes, - std::vector args) - { - // reinterpret as unsqueeze with concat - std::vector unsqueezed_args; - int64_t axis = 0; - if(contains(attributes, "axis")) - axis = attributes.at("axis").i(); - size_t input_size = args.front()->get_shape().lens().size(); - if(axis > input_size) - { - MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) + - " must be smaller than input size " + to_string(input_size)); - } - - std::transform( - args.begin(), - args.end(), - std::back_inserter(unsqueezed_args), - [&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); }); - return to_nhwc(prog.add_instruction(op::concat{axis}, unsqueezed_args)); - } - - instruction_ref - parse_pad(const std::string&, const attribute_map&, std::vector args) - { - size_t ndims = args.front()->get_shape().lens().size(); - - // in tf, the paddings are arranged as a 2d shape (ndims, 2), - // the last dim contains the left padding and right padding respectively - std::vector> pad_per_dim(ndims); - auto tf_padding = args[1]->eval().get().to_vector(); - for(size_t i = 0; i < 2 * ndims; i += 2) - { - pad_per_dim[i / 2].first = tf_padding[i]; - pad_per_dim[i / 2].second = tf_padding[i + 1]; - } - reorder_data(pad_per_dim); - - op::pad op; - std::vector pads(ndims * 2); - for(size_t i = 0; i < ndims; i++) - { - pads[i] = pad_per_dim[i].first; - pads[i + ndims] = pad_per_dim[i].second; - } - op.pads = pads; - return prog.add_instruction(op, args.front()); - } - - instruction_ref parse_pooling(const std::string& name, - attribute_map attributes, - std::vector args) - { - op::pooling op{starts_with(name, "Max") ? "max" : "average"}; - - if(contains(attributes, "strides")) - { - std::vector stride; - copy(attributes.at("strides").list().i(), std::back_inserter(stride)); - reorder_data(stride); - if(stride.size() != 4) - { - MIGRAPHX_THROW("strides should have 4 values"); - } - op.stride[0] = stride[2]; - op.stride[1] = stride[3]; - } - if(contains(attributes, "ksize")) - { - std::vector ksize; - copy(attributes.at("ksize").list().i(), std::back_inserter(ksize)); - reorder_data(ksize); - if(ksize.size() != 4) - { - MIGRAPHX_THROW("ksize should have 4 values"); - } - op.lengths[0] = ksize[2]; - op.lengths[1] = ksize[3]; - } - - auto l0 = args[0]; - if(contains(attributes, "padding")) - { - const std::string& pad_mode = attributes.at("padding").s(); - if(pad_mode.find("SAME") != std::string::npos) - { - op.padding_mode = op::padding_mode_t::same; - auto input_dims = l0->get_shape().lens(); - std::vector pads(input_dims.size()); - calculate_padding(0, pads, input_dims[2], op.stride[0], 1, op.lengths[0]); - calculate_padding(1, pads, input_dims[3], op.stride[1], 1, op.lengths[1]); - - if(pads[0] != pads[2] || pads[1] != pads[3]) - { - std::vector padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]}; - l0 = prog.add_instruction( - migraphx::op::pad{padding, std::numeric_limits::lowest()}, l0); - } - else - { - op.padding[0] = pads[0]; - op.padding[1] = pads[1]; - } - } - else if(pad_mode.find("VALID") != std::string::npos) - { - op.padding_mode = op::padding_mode_t::valid; - } - } - return prog.add_instruction(op, l0); - } - - instruction_ref - parse_reshape(const std::string&, const attribute_map&, std::vector args) - { - op::reshape op; - if(args.size() != 2) - MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)"); - auto s = args[1]->eval(); - s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); - return prog.add_instruction(op, make_contiguous(args[0])); - } - - void parse_from(std::istream& is) - { - tensorflow::GraphDef graph; - if(graph.ParseFromIstream(&is)) - { - this->parse_graph(graph); - } - else - { - throw std::runtime_error("Failed reading tf file"); - } - } - - instruction_ref - parse_slice(const std::string&, const attribute_map&, std::vector args) - { - op::slice op; - auto starts = args[1]->eval().get().to_vector(); - auto size = args[2]->eval().get().to_vector(); - auto axes = args[0]->get_shape().lens(); - size_t num_axes = axes.size(); - - op.starts = std::vector(starts.begin(), starts.end()); - op.ends = std::vector(num_axes); - op.axes = std::vector(num_axes); - std::iota(op.axes.begin(), op.axes.end(), 0); - for(size_t i = 0; i < num_axes; i++) - { - if(size[i] == -1) - op.ends[i] = axes[i]; - else - op.ends[i] = starts[i] + size[i]; - } - return prog.add_instruction(op, make_contiguous(args[0])); - } - - // template to facilitate the logsoftmax later - template - instruction_ref parse_softmax(const std::string&, - const attribute_map& attributes, - std::vector args) - { - int axis = -1; - auto num_dims = args[0]->get_shape().lens().size(); - if(contains(attributes, "axis")) - { - axis = static_cast(attributes.at("axis").i()); - } - if(axis < 0) - { - axis += num_dims; - } - - return prog.add_instruction(Op{axis}, make_contiguous(args[0])); - } - - std::vector parse_split(const std::string&, - const attribute_map& attributes, - std::vector args) - { - bool vector_as_input = args.size() == 3; - int num_outputs = 1; - auto axis_arg = args[0]; - auto input_arg = args[1]; - if(vector_as_input) - { - input_arg = args[0]; - axis_arg = args[2]; - } - - if(contains(attributes, "num_split")) - num_outputs = attributes.at("num_split").i(); - - std::vector splits(num_outputs); - std::vector slice_pos{0}; - if(vector_as_input) - { - splits = args[1]->eval().get().to_vector(); - num_outputs = splits.size(); - } - - assert(num_outputs > 0); - - if(num_outputs == 1) - return std::vector{prog.add_instruction(op::identity{}, input_arg)}; - - auto lens = input_arg->get_shape().lens(); - auto num_dims = lens.size(); - int axis = axis_arg->eval().at(); - - // ensure split is made evenly if "num_split" is used - assert(vector_as_input or lens[axis] % num_outputs == 0); - - auto split_size = lens[axis] / num_outputs; - - // push back first end point of slice - if(vector_as_input) - { - slice_pos.push_back(splits[0]); - } - else - { - slice_pos.push_back(split_size); - } - - // calculate remaining end points for each slice - for(auto i = 1; i < num_outputs; i++) - { - if(vector_as_input) - { - splits[i] += splits[i - 1]; - slice_pos.push_back(splits[i]); - } - else - { - slice_pos.push_back((i + 1) * split_size); - } - } - std::vector result; - for(auto i = 0; i < num_outputs; i++) - { - op::slice op; - op.axes = std::vector(num_dims); - std::iota(op.axes.begin(), op.axes.end(), 0); - op.starts = std::vector(num_dims, 0); - op.ends = std::vector(lens.begin(), lens.end()); - - op.starts[axis] = slice_pos[i]; - op.ends[axis] = slice_pos[i + 1]; - result.push_back(prog.add_instruction(op, input_arg)); - } - return result; - } - - instruction_ref parse_squeeze(const std::string&, - const attribute_map& attributes, - std::vector args) - { - op::squeeze op; - auto input_dims = args[0]->get_shape().lens(); - auto axes = attributes.at("squeeze_dims").list().i(); - copy(axes, std::back_inserter(op.axes)); - - if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1 - { - for(size_t i = 0; i < input_dims.size(); i++) - { - if(input_dims.at(i) == 1) - { - op.axes.push_back(i); - } - } - } - return prog.add_instruction(op, make_contiguous(args[0])); - } - - instruction_ref parse_stridedslice(const std::string&, - const attribute_map& attributes, - std::vector args) - { - op::slice op; - auto starts = args[1]->eval().get().to_vector(); - auto ends = args[2]->eval().get().to_vector(); - auto l0 = args[0]; - size_t num_axes = l0->get_shape().lens().size(); - std::vector axes = l0->get_shape().lens(); - - op.starts = std::vector(starts.begin(), starts.end()); - op.ends = std::vector(ends.begin(), ends.end()); - op.axes = std::vector(num_axes); - std::iota(op.axes.begin(), op.axes.end(), 0); - uint32_t begin_mask = 0; - uint32_t end_mask = 0; - uint32_t shrink_axis_mask = 0; - uint32_t bitwise_compare = 1; - std::vector squeeze_axes; - - if(contains(attributes, "begin_mask")) - begin_mask = static_cast(attributes.at("begin_mask").i()); - - if(contains(attributes, "end_mask")) - end_mask = static_cast(attributes.at("end_mask").i()); - - if(contains(attributes, "shrink_axis_mask")) - shrink_axis_mask = static_cast(attributes.at("shrink_axis_mask").i()); - - std::vector begin_axes = get_axes_from_mask(num_axes, begin_mask); - std::vector end_axes = get_axes_from_mask(num_axes, end_mask); - - for(size_t i = 0; i < num_axes; i++) - { - if(begin_axes.at(i) == 1) - { - op.starts.at(i) = 0; - } - if(end_axes.at(i) == 1) - { - op.ends.at(i) = axes.at(i); - } - } - - auto l1 = prog.add_instruction(op, l0); - if(shrink_axis_mask == 0) - return l1; - - for(size_t i = 0; i < num_axes; i++) - { - // the LSB corresponds to axis 0 when determining which axes to squeeze - if(((shrink_axis_mask >> i) & bitwise_compare) == 1) - squeeze_axes.push_back(i); - } - - return prog.add_instruction(op::squeeze{squeeze_axes}, l1); - } - - instruction_ref - parse_transpose(const std::string&, const attribute_map&, std::vector args) - { - auto perm = args[1]->eval().get().to_vector(); - op::transpose op; - op.dims = std::vector(perm.begin(), perm.end()); - - return prog.add_instruction(op, args.front()); - } - - void parse_graph(const tensorflow::GraphDef& graph) - { - nodes = get_nodes(graph, input_nodes); - for(auto&& input : input_nodes) - { - const std::string& name = input.name(); - attribute_map input_attrs = get_attributes(input); - shape::type_t shape_type = parse_type(input_attrs.at("dtype").type()); - std::vector dims = parse_dims(input_attrs.at("shape").shape()); - if(is_nhwc and dims.size() >= 4) - { - reorder_data(dims); - } - shape s = shape{shape_type, dims}; - instructions[name] = to_nhwc(prog.add_parameter(name, s)); - } - for(auto&& p : nodes) - { - this->parse_node(p.first); - } - } - - void parse_node(const std::string& name) - { - if(instructions.count(name) == 0) - { - auto&& node = nodes.at(name); - // assert ops ignored - if(node.op() == "Assert" or contains(name, "Assert")) - return; - std::vector args; - - for(auto&& input : node.input()) - { - // control dependencies (signified by ^ before the name) are ignored - if(contains(input, "^")) - continue; - if(nodes.count(input) > 0) - { - std::string iname; - // input was from a node with multiple outputs - if(contains(input, ':')) - { - iname = input.substr(0, input.find(':')); - } - else - { - iname = get_name(nodes.at(input)); - } - assert(name != iname); - this->parse_node(iname); - args.push_back(instructions.at(input)); - } - else - { - args.push_back(instructions.at(input)); - } - } - - std::vector result; - if(ops.count(node.op()) == 0) - { - result.push_back(prog.add_instruction(op::unknown{node.op()}, args)); - } - else - { - result = ops[node.op()](get_attributes(node), args); - } - - assert(!result.empty()); - // First output has no ":" delimiter - instructions[name] = result.front(); - for(size_t i = 1; i < result.size(); i++) - { - instructions[name + ":" + std::to_string(i)] = result.at(i); - } - } - } - - static attribute_map get_attributes(const tensorflow::NodeDef& node) - { - attribute_map result; - for(auto&& attr : node.attr()) - { - result[attr.first] = attr.second; - } - return result; - } - - static std::string get_name(const tensorflow::NodeDef& node) { return node.name(); } - - static node_map get_nodes(const tensorflow::GraphDef& graph, - std::vector& input_nodes) - { - node_map result; - for(auto&& node : graph.node()) - { - auto node_name = get_name(node); - // assume each node in graph has an associated name - if(node_name.empty()) - MIGRAPHX_THROW("tf node with no name found"); - result[node_name] = node; - if(node.op() == "Placeholder") - { - input_nodes.push_back(node); - } - } - return result; - } - - static shape::type_t parse_type(const tensorflow::DataType t) - { - shape::type_t shape_type{}; - switch(t) - { - case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break; - case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break; - case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break; - case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break; - case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break; - case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break; - case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break; - case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break; - case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break; - case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break; - - case tensorflow::DataType::DT_INVALID: - case tensorflow::DataType::DT_UINT8: - case tensorflow::DataType::DT_STRING: - case tensorflow::DataType::DT_COMPLEX64: - case tensorflow::DataType::DT_BOOL: - case tensorflow::DataType::DT_QINT8: - case tensorflow::DataType::DT_QUINT8: - case tensorflow::DataType::DT_QINT32: - case tensorflow::DataType::DT_BFLOAT16: - case tensorflow::DataType::DT_QINT16: - case tensorflow::DataType::DT_QUINT16: - case tensorflow::DataType::DT_COMPLEX128: - case tensorflow::DataType::DT_RESOURCE: - case tensorflow::DataType::DT_VARIANT: - // tf pb should not use these types - case tensorflow::DataType::DT_FLOAT_REF: - case tensorflow::DataType::DT_DOUBLE_REF: - case tensorflow::DataType::DT_INT32_REF: - case tensorflow::DataType::DT_UINT8_REF: - case tensorflow::DataType::DT_INT16_REF: - case tensorflow::DataType::DT_INT8_REF: - case tensorflow::DataType::DT_STRING_REF: - case tensorflow::DataType::DT_COMPLEX64_REF: - case tensorflow::DataType::DT_INT64_REF: - case tensorflow::DataType::DT_BOOL_REF: - case tensorflow::DataType::DT_QINT8_REF: - case tensorflow::DataType::DT_QUINT8_REF: - case tensorflow::DataType::DT_QINT32_REF: - case tensorflow::DataType::DT_BFLOAT16_REF: - case tensorflow::DataType::DT_QINT16_REF: - case tensorflow::DataType::DT_QUINT16_REF: - case tensorflow::DataType::DT_UINT16_REF: - case tensorflow::DataType::DT_COMPLEX128_REF: - case tensorflow::DataType::DT_HALF_REF: - case tensorflow::DataType::DT_RESOURCE_REF: - case tensorflow::DataType::DT_VARIANT_REF: - case tensorflow::DataType::DT_UINT32_REF: - case tensorflow::DataType::DT_UINT64_REF: - case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: - case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break; - } - return shape_type; - } - - static literal parse_tensor(const tensorflow::TensorProto& t) - { - std::vector dims = parse_dims(t.tensor_shape()); - size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()); - if(!t.tensor_content().empty()) // has raw data - { - const std::string& s = t.tensor_content(); - switch(t.dtype()) - { - case tensorflow::DataType::DT_FLOAT: - return literal{{shape::float_type, dims}, s.data()}; - case tensorflow::DataType::DT_BOOL: - case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()}; - case tensorflow::DataType::DT_UINT16: - case tensorflow::DataType::DT_INT16: - return literal{{shape::int16_type, dims}, s.data()}; - case tensorflow::DataType::DT_INT32: - return literal{{shape::int32_type, dims}, s.data()}; - case tensorflow::DataType::DT_INT64: - return literal{{shape::int64_type, dims}, s.data()}; - case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()}; - case tensorflow::DataType::DT_DOUBLE: - return literal{{shape::double_type, dims}, s.data()}; - case tensorflow::DataType::DT_INVALID: - case tensorflow::DataType::DT_UINT8: - case tensorflow::DataType::DT_STRING: - case tensorflow::DataType::DT_UINT32: - case tensorflow::DataType::DT_UINT64: - case tensorflow::DataType::DT_COMPLEX64: - case tensorflow::DataType::DT_COMPLEX128: - case tensorflow::DataType::DT_QINT8: - case tensorflow::DataType::DT_QUINT8: - case tensorflow::DataType::DT_QINT32: - case tensorflow::DataType::DT_BFLOAT16: - case tensorflow::DataType::DT_QINT16: - case tensorflow::DataType::DT_QUINT16: - case tensorflow::DataType::DT_RESOURCE: - case tensorflow::DataType::DT_VARIANT: - case tensorflow::DataType::DT_FLOAT_REF: - case tensorflow::DataType::DT_DOUBLE_REF: - case tensorflow::DataType::DT_INT32_REF: - case tensorflow::DataType::DT_UINT8_REF: - case tensorflow::DataType::DT_INT16_REF: - case tensorflow::DataType::DT_INT8_REF: - case tensorflow::DataType::DT_STRING_REF: - case tensorflow::DataType::DT_COMPLEX64_REF: - case tensorflow::DataType::DT_INT64_REF: - case tensorflow::DataType::DT_BOOL_REF: - case tensorflow::DataType::DT_QINT8_REF: - case tensorflow::DataType::DT_QUINT8_REF: - case tensorflow::DataType::DT_QINT32_REF: - case tensorflow::DataType::DT_BFLOAT16_REF: - case tensorflow::DataType::DT_QINT16_REF: - case tensorflow::DataType::DT_QUINT16_REF: - case tensorflow::DataType::DT_UINT16_REF: - case tensorflow::DataType::DT_COMPLEX128_REF: - case tensorflow::DataType::DT_HALF_REF: - case tensorflow::DataType::DT_RESOURCE_REF: - case tensorflow::DataType::DT_VARIANT_REF: - case tensorflow::DataType::DT_UINT32_REF: - case tensorflow::DataType::DT_UINT64_REF: - case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: - case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: - throw std::runtime_error(""); - } - MIGRAPHX_THROW("Invalid tensor type"); - } - switch(t.dtype()) - { - case tensorflow::DataType::DT_FLOAT: - return create_literal( - shape::float_type, dims, get_data_vals(t.float_val(), shape_size)); - case tensorflow::DataType::DT_INT8: - return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size)); - case tensorflow::DataType::DT_UINT16: - return create_literal(shape::uint16_type, dims, get_data_vals(t.int_val(), shape_size)); - case tensorflow::DataType::DT_INT16: - return create_literal(shape::int16_type, dims, get_data_vals(t.int_val(), shape_size)); - case tensorflow::DataType::DT_INT32: - return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size)); - case tensorflow::DataType::DT_INT64: - return create_literal( - shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size)); - case tensorflow::DataType::DT_BOOL: - return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size)); - case tensorflow::DataType::DT_HALF: - { - std::vector data_int32 = get_data_vals(t.half_val(), shape_size); - std::vector data_uint16(data_int32.begin(), data_int32.end()); - std::vector data_half; - std::transform(data_uint16.begin(), - data_uint16.end(), - std::back_inserter(data_half), - [](uint16_t raw_val) { return *reinterpret_cast(&raw_val); }); - return create_literal(shape::half_type, dims, data_half); - } - case tensorflow::DataType::DT_DOUBLE: - return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)}; - case tensorflow::DataType::DT_INVALID: - case tensorflow::DataType::DT_UINT8: - case tensorflow::DataType::DT_STRING: - case tensorflow::DataType::DT_UINT32: - case tensorflow::DataType::DT_UINT64: - case tensorflow::DataType::DT_COMPLEX64: - case tensorflow::DataType::DT_COMPLEX128: - case tensorflow::DataType::DT_QINT8: - case tensorflow::DataType::DT_QUINT8: - case tensorflow::DataType::DT_QINT32: - case tensorflow::DataType::DT_BFLOAT16: - case tensorflow::DataType::DT_QINT16: - case tensorflow::DataType::DT_QUINT16: - case tensorflow::DataType::DT_RESOURCE: - case tensorflow::DataType::DT_VARIANT: - case tensorflow::DataType::DT_FLOAT_REF: - case tensorflow::DataType::DT_DOUBLE_REF: - case tensorflow::DataType::DT_INT32_REF: - case tensorflow::DataType::DT_UINT8_REF: - case tensorflow::DataType::DT_INT16_REF: - case tensorflow::DataType::DT_INT8_REF: - case tensorflow::DataType::DT_STRING_REF: - case tensorflow::DataType::DT_COMPLEX64_REF: - case tensorflow::DataType::DT_INT64_REF: - case tensorflow::DataType::DT_BOOL_REF: - case tensorflow::DataType::DT_QINT8_REF: - case tensorflow::DataType::DT_QUINT8_REF: - case tensorflow::DataType::DT_QINT32_REF: - case tensorflow::DataType::DT_BFLOAT16_REF: - case tensorflow::DataType::DT_QINT16_REF: - case tensorflow::DataType::DT_QUINT16_REF: - case tensorflow::DataType::DT_UINT16_REF: - case tensorflow::DataType::DT_COMPLEX128_REF: - case tensorflow::DataType::DT_HALF_REF: - case tensorflow::DataType::DT_RESOURCE_REF: - case tensorflow::DataType::DT_VARIANT_REF: - case tensorflow::DataType::DT_UINT32_REF: - case tensorflow::DataType::DT_UINT64_REF: - case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: - case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: - throw std::runtime_error(""); - } - MIGRAPHX_THROW("Invalid tensor type"); - } - - template - static std::vector get_data_vals(const google::protobuf::RepeatedField& data, - const size_t& shape_size) - { - std::vector data_vals(shape_size); - // check if shape has enough data values given existing fields - if(data.size() == 1) - { - std::fill(data_vals.begin(), data_vals.end(), data[0]); - } - else - copy(data.begin(), data.end(), std::back_inserter(data_vals)); - return data_vals; - } - - static std::vector parse_dims(const tensorflow::TensorShapeProto& s) - { - std::vector dims; - auto input_dims = s.dim(); - std::transform(input_dims.begin(), - input_dims.end(), - std::back_inserter(dims), - [](const tensorflow::TensorShapeProto_Dim& dim) { return dim.size(); }); - return dims; - } - - template - static literal - create_literal(shape::type_t shape_type, const std::vector& dims, std::vector data) - { - // assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar - if(dims.empty() or (dims.size() == 1 and dims.front() == 1)) - return literal{{shape_type}, data}; - return literal{{shape_type, dims}, data}; - } -}; - -program parse_tf(const std::string& name, bool is_nhwc) +program parse_tf(const std::string& name, const tf_options& options) { std::fstream input(name.c_str(), std::ios::in | std::ios::binary); - tf_parser parser; - parser.is_nhwc = is_nhwc; + tf::tf_parser parser; + parser.is_nhwc = options.is_nhwc; + parser.batch_size = options.batch_size; + parser.map_input_dims = options.map_input_dims; + parser.output_node_names = options.output_node_names; #ifndef NDEBUG // Log the program when it can't be parsed @@ -1363,7 +36,6 @@ program parse_tf(const std::string& name, bool is_nhwc) #else parser.parse_from(input); #endif - parser.to_nchw(std::prev(parser.prog.end())); return std::move(parser.prog); } diff --git a/src/tf/tf_parser.cpp b/src/tf/tf_parser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d8ffae2eaf867d84f0670d3908912bceda863d02 --- /dev/null +++ b/src/tf/tf_parser.cpp @@ -0,0 +1,566 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace tf { + +bool tf_parser::should_transpose(instruction_ref ins) const +{ + return is_nhwc and ins->get_shape().lens().size() == 4; +} + +instruction_ref tf_parser::to_nhwc(instruction_ref ins) const +{ + if(should_transpose(ins)) + return mm->add_instruction(make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), ins); + return ins; +} + +instruction_ref tf_parser::to_nchw(instruction_ref ins) const +{ + if(should_transpose(ins)) + return mm->add_instruction(make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), ins); + return ins; +} + +instruction_ref tf_parser::to_kcxy(instruction_ref ins) const +{ + return mm->add_instruction(make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), ins); +} + +std::vector tf_parser::to_nchw(const std::vector& args) const +{ + std::vector result(args.size()); + std::transform( + args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nchw(ins); }); + return result; +} + +std::vector tf_parser::to_nhwc(const std::vector& args) const +{ + std::vector result(args.size()); + std::transform( + args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nhwc(ins); }); + return result; +} + +instruction_ref tf_parser::node_info::make_contiguous(instruction_ref ins) const +{ + if(ins->get_shape().standard()) + return ins; + else + return mm->add_instruction(make_op("contiguous"), ins); +} + +instruction_ref tf_parser::node_info::add_broadcastable_binary_op(const std::string& op_name, + instruction_ref arg0, + instruction_ref arg1) const +{ + return this->add_common_op(op_name, arg0, arg1); +} + +instruction_ref tf_parser::node_info::add_common_op(const std::string& op_name, + std::vector inputs) const +{ + return migraphx::add_common_op(*mm, make_op(op_name), std::move(inputs)); +} + +int64_t tf_parser::parse_axis(const int64_t dim, const size_t num_dims) const +{ + int64_t new_dim = dim; + if(is_nhwc and num_dims >= 4) + { + switch(dim) + { + case 0: new_dim = 0; break; + case 1: new_dim = 2; break; + case 2: new_dim = 3; break; + case 3: new_dim = 1; break; + default: break; + } + } + return new_dim; +} + +instruction_ref +tf_parser::node_info::add_instruction(const operation& op, + const std::vector& args) const +{ + return mm->add_instruction(op, args); +} + +instruction_ref tf_parser::node_info::add_literal(literal l) const +{ + return mm->add_literal(std::move(l)); +} + +std::vector get_axes_from_mask(const size_t num_axes, const uint32_t mask) +{ + uint32_t bitwise_compare = 1; + std::vector axes; + for(size_t i = 0; i < num_axes; i++) + { + // the LSB corresponds to axis 0 when determining which axes to begin + if(((mask >> i) & bitwise_compare) == 1) + axes.push_back(1); + else + axes.push_back(0); + } + return axes; +} + +tf_parser::tf_parser() +{ + // Add all registered op parsers + for(auto&& name : get_op_parsers()) + ops.emplace(name, get_op_parser(name)); +} + +static std::string get_name(const tensorflow::NodeDef& node) { return node.name(); } + +static tf_parser::node_map get_nodes(const tensorflow::GraphDef& graph, + std::vector& input_nodes) +{ + tf_parser::node_map result; + for(auto&& node : graph.node()) + { + auto node_name = get_name(node); + // assume each node in graph has an associated name + if(node_name.empty()) + MIGRAPHX_THROW("tf node with no name found"); + result[node_name] = node; + if(node.op() == "Placeholder") + { + input_nodes.push_back(node); + } + } + return result; +} + +static tf_parser::attribute_map get_attributes(const tensorflow::NodeDef& node) +{ + tf_parser::attribute_map result; + for(auto&& attr : node.attr()) + { + result[attr.first] = attr.second; + } + + return result; +} + +static std::vector parse_dims(const tensorflow::TensorShapeProto& s) +{ + std::vector dims; + auto input_dims = s.dim(); + std::transform(input_dims.begin(), + input_dims.end(), + std::back_inserter(dims), + [](const tensorflow::TensorShapeProto_Dim& dim) { return dim.size(); }); + return dims; +} + +template +static std::vector get_data_vals(const google::protobuf::RepeatedField& data, + const size_t& shape_size) +{ + std::vector data_vals(shape_size); + // check if shape has enough data values given existing fields + if(data.size() == 1) + { + std::fill(data_vals.begin(), data_vals.end(), data[0]); + } + else + copy(data.begin(), data.end(), std::back_inserter(data_vals)); + return data_vals; +} + +template +static literal +create_literal(shape::type_t shape_type, const std::vector& dims, std::vector data) +{ + // assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar + if(dims.empty() or (dims.size() == 1 and dims.front() == 1)) + return literal{{shape_type}, data}; + return literal{{shape_type, dims}, data}; +} + +static bool is_valid_op(const tensorflow::NodeDef& node) +{ + std::vector ignored{"NoOp", "Assert"}; + return none_of(ignored, [&](const auto& op) { + const auto& name = get_name(node); + return node.op() == op or contains(name, op); + }); +} + +std::vector tf_parser::find_outputs() const +{ + std::unordered_set inputs; + for(auto&& p : nodes) + { + auto&& node = p.second; + std::copy(node.input().begin(), node.input().end(), std::inserter(inputs, inputs.end())); + } + std::vector outputs; + for(auto&& p : nodes) + { + const auto& name = p.first; + const auto& node = p.second; + if(not is_valid_op(node)) + continue; + // control flow related, ignore this node + if(contains(name, "^")) + continue; + // literals are valid ops, but they are not outputs unless specified + if(node.op() == "Const") + continue; + if(inputs.count(name) == 0) + outputs.push_back(name); + } + return outputs; +} + +void tf_parser::parse_graph(const tensorflow::GraphDef& graph) +{ + nodes = get_nodes(graph, input_nodes); + for(auto&& input : input_nodes) + { + const std::string& name = input.name(); + attribute_map input_attrs = get_attributes(input); + shape::type_t shape_type = parse_type(input_attrs.at("dtype").type()); + std::vector dims = parse_dims(input_attrs.at("shape").shape()); + + if(contains(map_input_dims, name)) + { + dims = map_input_dims.at(name); + } + else + { + if(is_nhwc and dims.size() >= 4) + { + this->reorder_data(dims); + } + std::transform(dims.begin(), dims.end(), dims.begin(), [&](auto dim) { + return static_cast(dim) <= 0 ? batch_size : dim; + }); + } + + shape s = shape{shape_type, dims}; + instructions[name] = to_nhwc(mm->add_parameter(name, s)); + } + for(auto&& p : nodes) + { + this->parse_node(p.first); + } + auto last_ins = std::prev(mm->end()); + if(last_ins != mm->end()) + { + // Needs to add a ret instruction at the end of + // the program + if(output_node_names.empty()) + { + output_node_names = find_outputs(); + } + + std::vector output_ins; + std::transform(output_node_names.begin(), + output_node_names.end(), + std::back_inserter(output_ins), + [&](auto output_name) { + if(not contains(instructions, output_name)) + MIGRAPHX_THROW("PARSE_TF: output name " + output_name + + " not found in graph!"); + return this->to_nchw(instructions[output_name]); + }); + mm->add_return(output_ins); + } +} + +void tf_parser::parse_node(const std::string& name) +{ + if(instructions.count(name) == 0) + { + auto&& node = nodes.at(name); + if(not is_valid_op(node)) + return; + + std::vector args; + + for(auto&& input : node.input()) + { + // control dependencies (signified by ^ before the name) are ignored + if(contains(input, "^")) + continue; + if(nodes.count(input) > 0) + { + std::string iname; + // input was from a node with multiple outputs + if(contains(input, ':')) + { + iname = input.substr(0, input.find(':')); + } + else + { + iname = get_name(nodes.at(input)); + } + assert(name != iname); + this->parse_node(iname); + args.push_back(instructions.at(input)); + } + else + { + args.push_back(instructions.at(input)); + } + } + std::vector result; + if(ops.count(node.op()) == 0) + { + result.push_back(mm->add_instruction(op::unknown{node.op()}, args)); + } + else + { + result = ops[node.op()](*this, {get_attributes(node), node.op(), mm}, args); + } + assert(!result.empty()); + // First output has no ":" delimiter + instructions[name] = result.front(); + for(size_t i = 1; i < result.size(); i++) + { + instructions[name + ":" + std::to_string(i)] = result.at(i); + } + } +} + +void tf_parser::parse_from(std::istream& is) +{ + tensorflow::GraphDef graph; + if(graph.ParseFromIstream(&is)) + { + this->parse_graph(graph); + } + else + { + throw std::runtime_error("Failed reading tf file"); + } +} + +shape::type_t tf_parser::parse_type(const tensorflow::DataType t) const +{ + shape::type_t shape_type{}; + switch(t) + { + case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break; + case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break; + case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break; + case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break; + case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break; + case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break; + case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break; + case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break; + case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break; + case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break; + + case tensorflow::DataType::DT_INVALID: + case tensorflow::DataType::DT_UINT8: + case tensorflow::DataType::DT_STRING: + case tensorflow::DataType::DT_COMPLEX64: + case tensorflow::DataType::DT_BOOL: + case tensorflow::DataType::DT_QINT8: + case tensorflow::DataType::DT_QUINT8: + case tensorflow::DataType::DT_QINT32: + case tensorflow::DataType::DT_BFLOAT16: + case tensorflow::DataType::DT_QINT16: + case tensorflow::DataType::DT_QUINT16: + case tensorflow::DataType::DT_COMPLEX128: + case tensorflow::DataType::DT_RESOURCE: + case tensorflow::DataType::DT_VARIANT: + // tf pb should not use these types + case tensorflow::DataType::DT_FLOAT_REF: + case tensorflow::DataType::DT_DOUBLE_REF: + case tensorflow::DataType::DT_INT32_REF: + case tensorflow::DataType::DT_UINT8_REF: + case tensorflow::DataType::DT_INT16_REF: + case tensorflow::DataType::DT_INT8_REF: + case tensorflow::DataType::DT_STRING_REF: + case tensorflow::DataType::DT_COMPLEX64_REF: + case tensorflow::DataType::DT_INT64_REF: + case tensorflow::DataType::DT_BOOL_REF: + case tensorflow::DataType::DT_QINT8_REF: + case tensorflow::DataType::DT_QUINT8_REF: + case tensorflow::DataType::DT_QINT32_REF: + case tensorflow::DataType::DT_BFLOAT16_REF: + case tensorflow::DataType::DT_QINT16_REF: + case tensorflow::DataType::DT_QUINT16_REF: + case tensorflow::DataType::DT_UINT16_REF: + case tensorflow::DataType::DT_COMPLEX128_REF: + case tensorflow::DataType::DT_HALF_REF: + case tensorflow::DataType::DT_RESOURCE_REF: + case tensorflow::DataType::DT_VARIANT_REF: + case tensorflow::DataType::DT_UINT32_REF: + case tensorflow::DataType::DT_UINT64_REF: + case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: + case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break; + } + return shape_type; +} + +literal tf_parser::parse_tensor(const tensorflow::TensorProto& t) const +{ + std::vector dims = parse_dims(t.tensor_shape()); + size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()); + if(!t.tensor_content().empty()) // has raw data + { + const std::string& s = t.tensor_content(); + switch(t.dtype()) + { + case tensorflow::DataType::DT_FLOAT: return literal{{shape::float_type, dims}, s.data()}; + case tensorflow::DataType::DT_BOOL: + case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()}; + case tensorflow::DataType::DT_UINT16: + case tensorflow::DataType::DT_INT16: return literal{{shape::int16_type, dims}, s.data()}; + case tensorflow::DataType::DT_INT32: return literal{{shape::int32_type, dims}, s.data()}; + case tensorflow::DataType::DT_INT64: return literal{{shape::int64_type, dims}, s.data()}; + case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()}; + case tensorflow::DataType::DT_DOUBLE: return literal{{shape::double_type, dims}, s.data()}; + case tensorflow::DataType::DT_INVALID: + case tensorflow::DataType::DT_UINT8: + case tensorflow::DataType::DT_STRING: + case tensorflow::DataType::DT_UINT32: + case tensorflow::DataType::DT_UINT64: + case tensorflow::DataType::DT_COMPLEX64: + case tensorflow::DataType::DT_COMPLEX128: + case tensorflow::DataType::DT_QINT8: + case tensorflow::DataType::DT_QUINT8: + case tensorflow::DataType::DT_QINT32: + case tensorflow::DataType::DT_BFLOAT16: + case tensorflow::DataType::DT_QINT16: + case tensorflow::DataType::DT_QUINT16: + case tensorflow::DataType::DT_RESOURCE: + case tensorflow::DataType::DT_VARIANT: + case tensorflow::DataType::DT_FLOAT_REF: + case tensorflow::DataType::DT_DOUBLE_REF: + case tensorflow::DataType::DT_INT32_REF: + case tensorflow::DataType::DT_UINT8_REF: + case tensorflow::DataType::DT_INT16_REF: + case tensorflow::DataType::DT_INT8_REF: + case tensorflow::DataType::DT_STRING_REF: + case tensorflow::DataType::DT_COMPLEX64_REF: + case tensorflow::DataType::DT_INT64_REF: + case tensorflow::DataType::DT_BOOL_REF: + case tensorflow::DataType::DT_QINT8_REF: + case tensorflow::DataType::DT_QUINT8_REF: + case tensorflow::DataType::DT_QINT32_REF: + case tensorflow::DataType::DT_BFLOAT16_REF: + case tensorflow::DataType::DT_QINT16_REF: + case tensorflow::DataType::DT_QUINT16_REF: + case tensorflow::DataType::DT_UINT16_REF: + case tensorflow::DataType::DT_COMPLEX128_REF: + case tensorflow::DataType::DT_HALF_REF: + case tensorflow::DataType::DT_RESOURCE_REF: + case tensorflow::DataType::DT_VARIANT_REF: + case tensorflow::DataType::DT_UINT32_REF: + case tensorflow::DataType::DT_UINT64_REF: + case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: + case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: + throw std::runtime_error(""); + } + MIGRAPHX_THROW("Invalid tensor type"); + } + switch(t.dtype()) + { + case tensorflow::DataType::DT_FLOAT: + return create_literal(shape::float_type, dims, get_data_vals(t.float_val(), shape_size)); + case tensorflow::DataType::DT_INT8: + return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size)); + case tensorflow::DataType::DT_UINT16: + return create_literal(shape::uint16_type, dims, get_data_vals(t.int_val(), shape_size)); + case tensorflow::DataType::DT_INT16: + return create_literal(shape::int16_type, dims, get_data_vals(t.int_val(), shape_size)); + case tensorflow::DataType::DT_INT32: + return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size)); + case tensorflow::DataType::DT_INT64: + return create_literal(shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size)); + case tensorflow::DataType::DT_BOOL: + return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size)); + case tensorflow::DataType::DT_HALF: { + std::vector data_int32 = get_data_vals(t.half_val(), shape_size); + std::vector data_uint16(data_int32.begin(), data_int32.end()); + std::vector data_half; + std::transform(data_uint16.begin(), + data_uint16.end(), + std::back_inserter(data_half), + [](uint16_t raw_val) { return *reinterpret_cast(&raw_val); }); + return create_literal(shape::half_type, dims, data_half); + } + case tensorflow::DataType::DT_DOUBLE: + return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)}; + case tensorflow::DataType::DT_INVALID: + case tensorflow::DataType::DT_UINT8: + case tensorflow::DataType::DT_STRING: + case tensorflow::DataType::DT_UINT32: + case tensorflow::DataType::DT_UINT64: + case tensorflow::DataType::DT_COMPLEX64: + case tensorflow::DataType::DT_COMPLEX128: + case tensorflow::DataType::DT_QINT8: + case tensorflow::DataType::DT_QUINT8: + case tensorflow::DataType::DT_QINT32: + case tensorflow::DataType::DT_BFLOAT16: + case tensorflow::DataType::DT_QINT16: + case tensorflow::DataType::DT_QUINT16: + case tensorflow::DataType::DT_RESOURCE: + case tensorflow::DataType::DT_VARIANT: + case tensorflow::DataType::DT_FLOAT_REF: + case tensorflow::DataType::DT_DOUBLE_REF: + case tensorflow::DataType::DT_INT32_REF: + case tensorflow::DataType::DT_UINT8_REF: + case tensorflow::DataType::DT_INT16_REF: + case tensorflow::DataType::DT_INT8_REF: + case tensorflow::DataType::DT_STRING_REF: + case tensorflow::DataType::DT_COMPLEX64_REF: + case tensorflow::DataType::DT_INT64_REF: + case tensorflow::DataType::DT_BOOL_REF: + case tensorflow::DataType::DT_QINT8_REF: + case tensorflow::DataType::DT_QUINT8_REF: + case tensorflow::DataType::DT_QINT32_REF: + case tensorflow::DataType::DT_BFLOAT16_REF: + case tensorflow::DataType::DT_QINT16_REF: + case tensorflow::DataType::DT_QUINT16_REF: + case tensorflow::DataType::DT_UINT16_REF: + case tensorflow::DataType::DT_COMPLEX128_REF: + case tensorflow::DataType::DT_HALF_REF: + case tensorflow::DataType::DT_RESOURCE_REF: + case tensorflow::DataType::DT_VARIANT_REF: + case tensorflow::DataType::DT_UINT32_REF: + case tensorflow::DataType::DT_UINT64_REF: + case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: + case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: throw std::runtime_error(""); + } + MIGRAPHX_THROW("Invalid tensor type"); +} + +} // namespace tf +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/tmp_dir.cpp b/src/tmp_dir.cpp new file mode 100755 index 0000000000000000000000000000000000000000..20bf0b6601bac149278919a066279e92c43666d0 --- /dev/null +++ b/src/tmp_dir.cpp @@ -0,0 +1,65 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_SAVE_TEMP_DIR) + +std::string random_string(std::string::size_type length) +{ + static const std::string& chars = "0123456789" + "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + + std::mt19937 rg{std::random_device{}()}; + std::uniform_int_distribution pick(0, chars.length() - 1); + + std::string str(length, 0); + std::generate(str.begin(), str.end(), [&] { return chars[pick(rg)]; }); + + return str; +} + +std::string unique_string(const std::string& prefix) +{ + auto pid = getpid(); + auto tid = std::this_thread::get_id(); + auto clk = std::chrono::steady_clock::now().time_since_epoch().count(); + std::stringstream ss; + ss << std::hex << prefix << "-" << pid << "-" << tid << "-" << clk << "-" << random_string(16); + return ss.str(); +} + +tmp_dir::tmp_dir(const std::string& prefix) + : path(fs::temp_directory_path() / + unique_string(prefix.empty() ? "migraphx" : "migraphx-" + prefix)) +{ + fs::create_directories(this->path); +} + +void tmp_dir::execute(const std::string& exe, const std::string& args) const +{ + process{exe + " " + args}.cwd(this->path).exec(); +} + +tmp_dir::~tmp_dir() +{ + if(!enabled(MIGRAPHX_DEBUG_SAVE_TEMP_DIR{})) + { + fs::remove_all(this->path); + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/value.cpp b/src/value.cpp new file mode 100644 index 0000000000000000000000000000000000000000..89907bf09fad2e25fad0e249fca1de39ab81894b --- /dev/null +++ b/src/value.cpp @@ -0,0 +1,524 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct value_base_impl : cloneable +{ + virtual value::type_t get_type() { return value::null_type; } +#define MIGRAPHX_VALUE_GENERATE_BASE_FUNCTIONS(vt, cpp_type) \ + virtual const cpp_type* if_##vt() const { return nullptr; } + MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_BASE_FUNCTIONS) + virtual std::vector* if_array() { return nullptr; } + virtual std::unordered_map* if_object() { return nullptr; } + virtual value_base_impl* if_value() const { return nullptr; } + value_base_impl() = default; + value_base_impl(const value_base_impl&) = default; + value_base_impl& operator=(const value_base_impl&) = default; + virtual ~value_base_impl() override {} +}; + +#define MIGRAPHX_VALUE_GENERATE_BASE_TYPE(vt, cpp_type) \ + struct vt##_value_holder : value_base_impl::share \ + { \ + vt##_value_holder(cpp_type d) : data(std::move(d)) {} \ + virtual value::type_t get_type() override { return value::vt##_type; } \ + virtual const cpp_type* if_##vt() const override { return &data; } \ + cpp_type data; \ + }; +MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_BASE_TYPE) + +struct array_value_holder : value_base_impl::derive +{ + array_value_holder() {} + array_value_holder(std::vector d) : data(std::move(d)) {} + virtual value::type_t get_type() override { return value::array_type; } + virtual std::vector* if_array() override { return &data; } + std::vector data; +}; + +struct object_value_holder : value_base_impl::derive +{ + object_value_holder() {} + object_value_holder(std::vector d, std::unordered_map l) + : data(std::move(d)), lookup(std::move(l)) + { + } + virtual value::type_t get_type() override { return value::object_type; } + virtual std::vector* if_array() override { return &data; } + virtual std::unordered_map* if_object() override { return &lookup; } + std::vector data; + std::unordered_map lookup; +}; + +value::value(const value& rhs) : x(rhs.x ? rhs.x->clone() : nullptr), key(rhs.key) {} +value& value::operator=(value rhs) +{ + std::swap(rhs.x, x); + if(not rhs.key.empty()) + std::swap(rhs.key, key); + return *this; +} + +void set_vector(std::shared_ptr& x, + const std::vector& v, + bool array_on_empty = true) +{ + if(v.empty()) + { + if(array_on_empty) + x = std::make_shared(); + else + x = std::make_shared(); + return; + } + if(v.front().get_key().empty()) + { + x = std::make_shared(v); + } + else + { + std::unordered_map lookup; + std::size_t i = 0; + for(auto&& e : v) + { + lookup[e.get_key()] = i; + i++; + } + x = std::make_shared(v, lookup); + } +} + +value::value(const std::initializer_list& i) : x(nullptr) +{ + if(i.size() == 2 and i.begin()->is_string() and i.begin()->get_key().empty()) + { + key = i.begin()->get_string(); + auto r = (i.begin() + 1)->x; + x = r ? r->clone() : nullptr; + return; + } + set_vector(x, std::vector(i.begin(), i.end())); +} + +value::value(const std::vector& v, bool array_on_empty) : x(nullptr) +{ + set_vector(x, v, array_on_empty); +} + +value::value(const std::unordered_map& m) + : value(std::vector(m.begin(), m.end()), false) +{ +} + +value::value(const std::string& pkey, const std::vector& v, bool array_on_empty) + : x(nullptr), key(pkey) +{ + set_vector(x, v, array_on_empty); +} + +value::value(const std::string& pkey, const std::unordered_map& m) + : value(pkey, std::vector(m.begin(), m.end()), false) +{ +} + +value::value(const std::string& pkey, std::nullptr_t) : x(nullptr), key(pkey) {} + +value::value(std::nullptr_t) : x(nullptr) {} + +value::value(const std::string& pkey, const value& rhs) + : x(rhs.x ? rhs.x->clone() : nullptr), key(pkey) +{ +} + +value::value(const std::string& pkey, const char* i) : value(pkey, std::string(i)) {} +value::value(const char* i) : value(std::string(i)) {} + +#define MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS(vt, cpp_type) \ + value::value(cpp_type i) : x(std::make_shared(std::move(i))) {} \ + value::value(const std::string& pkey, cpp_type i) \ + : x(std::make_shared(std::move(i))), key(pkey) \ + { \ + } \ + value& value::operator=(cpp_type rhs) \ + { \ + x = std::make_shared(std::move(rhs)); \ + return *this; \ + } \ + bool value::is_##vt() const { return x ? x->get_type() == vt##_type : false; } \ + const cpp_type& value::get_##vt() const \ + { \ + auto* r = this->if_##vt(); \ + assert(r); \ + return *r; \ + } \ + const cpp_type* value::if_##vt() const { return x ? x->if_##vt() : nullptr; } +MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS) + +value& value::operator=(const char* c) +{ + *this = std::string{c}; + return *this; +} + +value& value::operator=(std::nullptr_t) +{ + x = nullptr; + return *this; +} + +value& value::operator=(const std::initializer_list& i) +{ + value rhs = i; + std::swap(rhs.x, x); + return *this; +} + +bool value::is_array() const { return x ? x->get_type() == array_type : false; } +const std::vector& value::value::get_array() const +{ + const auto* r = this->if_array(); + assert(r); + return *r; +} +const std::vector* value::if_array() const { return x ? x->if_array() : nullptr; } + +bool value::is_object() const { return x ? x->get_type() == object_type : false; } +const std::vector& value::get_object() const +{ + const auto* r = this->if_object(); + assert(r); + return *r; +} +const std::vector* value::if_object() const +{ + auto* r = x ? x->if_array() : nullptr; + assert(r == nullptr or + std::none_of(r->begin(), r->end(), [](auto&& v) { return v.get_key().empty(); })); + return r; +} + +bool value::is_null() const { return x == nullptr; } + +const std::string& value::get_key() const { return key; } + +std::vector* if_array_impl(const std::shared_ptr& x) +{ + if(x == nullptr) + return nullptr; + return x->if_array(); +} + +std::vector& get_array_impl(const std::shared_ptr& x) +{ + auto* a = if_array_impl(x); + assert(a); + return *a; +} + +std::vector& get_array_throw(const std::shared_ptr& x) +{ + auto* a = if_array_impl(x); + if(a == nullptr) + MIGRAPHX_THROW("Expected an array or object"); + return *a; +} + +template +T* find_impl(const std::shared_ptr& x, const std::string& key, T* end) +{ + auto* a = if_array_impl(x); + if(a == nullptr) + return end; + auto* lookup = x->if_object(); + if(lookup == nullptr) + return end; + auto it = lookup->find(key); + if(it == lookup->end()) + return end; + return std::addressof((*a)[it->second]); +} + +value* value::find(const std::string& pkey) { return find_impl(x, pkey, this->end()); } + +const value* value::find(const std::string& pkey) const { return find_impl(x, pkey, this->end()); } +bool value::contains(const std::string& pkey) const +{ + const auto* it = find(pkey); + if(it == nullptr) + return false; + if(it == end()) + return false; + return true; +} +std::size_t value::size() const +{ + auto* a = if_array_impl(x); + if(a == nullptr) + return 0; + return a->size(); +} +bool value::empty() const { return size() == 0; } +const value* value::data() const +{ + auto* a = if_array_impl(x); + if(a == nullptr) + return nullptr; + return a->data(); +} +value* value::data() +{ + auto* a = if_array_impl(x); + if(a == nullptr) + return nullptr; + return a->data(); +} +value* value::begin() +{ + // cppcheck-suppress assertWithSideEffect + assert(data() or empty()); + return data(); +} +const value* value::begin() const +{ + assert(data() or empty()); + return data(); +} +value* value::end() { return begin() + size(); } +const value* value::end() const { return begin() + size(); } + +value& value::front() +{ + assert(this->size() > 0); + return *begin(); +} +const value& value::front() const +{ + assert(this->size() > 0); + return *begin(); +} +value& value::back() +{ + assert(this->size() > 0); + return *std::prev(end()); +} +const value& value::back() const +{ + assert(this->size() > 0); + return *std::prev(end()); +} +value& value::at(std::size_t i) +{ + auto* a = if_array_impl(x); + if(a == nullptr) + MIGRAPHX_THROW("Not an array"); + return a->at(i); +} +const value& value::at(std::size_t i) const +{ + auto* a = if_array_impl(x); + if(a == nullptr) + MIGRAPHX_THROW("Not an array"); + return a->at(i); +} +value& value::at(const std::string& pkey) +{ + auto* r = find(pkey); + if(r == nullptr) + MIGRAPHX_THROW("Not an object"); + if(r == end()) + MIGRAPHX_THROW("Key not found: " + pkey); + return *r; +} +const value& value::at(const std::string& pkey) const +{ + const auto* r = find(pkey); + if(r == nullptr) + MIGRAPHX_THROW("Not an object for field: " + pkey); + if(r == end()) + MIGRAPHX_THROW("Key not found: " + pkey); + return *r; +} +value& value::operator[](std::size_t i) +{ + assert(i < this->size()); + return *(begin() + i); +} +const value& value::operator[](std::size_t i) const +{ + assert(i < this->size()); + return *(begin() + i); +} +value& value::operator[](const std::string& pkey) { return *emplace(pkey, nullptr).first; } + +void value::clear() { get_array_throw(x).clear(); } +void value::resize(std::size_t n) +{ + if(not is_array()) + MIGRAPHX_THROW("Expected an array."); + get_array_impl(x).resize(n); +} +void value::resize(std::size_t n, const value& v) +{ + if(not is_array()) + MIGRAPHX_THROW("Expected an array."); + get_array_impl(x).resize(n, v); +} + +std::pair value::insert(const value& v) +{ + if(v.key.empty()) + { + if(!x) + x = std::make_shared(); + get_array_impl(x).push_back(v); + assert(this->if_array()); + return std::make_pair(&back(), true); + } + else + { + if(!x) + x = std::make_shared(); + auto p = x->if_object()->emplace(v.key, get_array_impl(x).size()); + if(p.second) + get_array_impl(x).push_back(v); + assert(this->if_object()); + return std::make_pair(&get_array_impl(x)[p.first->second], p.second); + } +} +value* value::insert(const value* pos, const value& v) +{ + assert(v.key.empty()); + if(!x) + x = std::make_shared(); + auto&& a = get_array_impl(x); + auto it = a.insert(a.begin() + (pos - begin()), v); + return std::addressof(*it); +} + +value value::without_key() const +{ + value result = *this; + result.key = ""; + return result; +} + +value value::with_key(const std::string& pkey) const +{ + value result = *this; + result.key = pkey; + return result; +} + +template +const T& compare_decay(const T& x) +{ + return x; +} +int compare_decay(std::nullptr_t) { return 0; } + +template +bool compare(const value& x, const value& y, F f) +{ + bool result = false; + x.visit_value([&](auto&& a) { + y.visit_value([&](auto&& b) { + if constexpr(std::is_same{}) + result = f(std::forward_as_tuple(x.get_key(), compare_decay(a)), + std::forward_as_tuple(y.get_key(), compare_decay(b))); + else + assert(false); // NOLINT + }); + }); + return result; +} + +value::type_t value::get_type() const +{ + if(!x) + return null_type; + return x->get_type(); +} + +bool operator==(const value& x, const value& y) +{ + if(x.get_type() != y.get_type()) + return false; + return compare(x, y, std::equal_to<>{}); +} +bool operator!=(const value& x, const value& y) { return not(x == y); } +bool operator<(const value& x, const value& y) +{ + if(x.get_type() != y.get_type()) + return x.get_type() < y.get_type(); + return compare(x, y, std::less<>{}); +} +bool operator<=(const value& x, const value& y) { return not(x > y); } +bool operator>(const value& x, const value& y) { return y < x; } +bool operator>=(const value& x, const value& y) { return not(x < y); } + +void print_value(std::ostream& os, std::nullptr_t) { os << "null"; } + +template +void print_value(std::ostream& os, const T& x) +{ + os << x; +} + +template +void print_value(std::ostream& os, const std::pair& x) +{ + os << x.first; + os << ": "; + print_value(os, x.second); +} + +void print_value(std::ostream& os, const std::vector& x) +{ + os << "{"; + os << to_string_range(x); + os << "}"; +} + +void print_value(std::ostream& os, const value::binary& x) +{ + // Convert binary to integers + std::vector v(x.begin(), x.end()); + os << "{"; + os << to_string_range(v); + os << "}"; +} + +std::ostream& operator<<(std::ostream& os, const value& d) +{ + d.visit([&](auto&& y) { print_value(os, y); }); + return os; +} + +void value::debug_print(bool show_type) const +{ + if(show_type) + { + switch(get_type()) + { +#define MIGRAPHX_VALUE_GENERATE_TYPE_STRING_CASE(vt, cpp_type) \ + case vt##_type: std::cout << #vt << ": "; break; + MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_TYPE_STRING_CASE) + MIGRAPHX_VALUE_GENERATE_TYPE_STRING_CASE(null, ) + MIGRAPHX_VALUE_GENERATE_TYPE_STRING_CASE(array, ) + MIGRAPHX_VALUE_GENERATE_TYPE_STRING_CASE(object, ) + } + } + std::cout << *this << std::endl; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/verify_args.cpp b/src/verify_args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..159f0a986ff03bdd0186f5e3c650fdd4476f8762 --- /dev/null +++ b/src/verify_args.cpp @@ -0,0 +1,84 @@ + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +bool verify_args(const std::string& name, + const argument& ref_arg, + const argument& target_arg, + double tolerance) +{ + bool passed = true; + visit_all(ref_arg, target_arg)([&](auto ref, auto target) { + double error; + passed = verify_range(ref, target, tolerance, &error); + if(not passed) + { + // TODO: Check for nans + std::cout << "FAILED: " << name << std::endl; + std::cout << "error: " << error << std::endl; + if(ref.size() < 32) + std::cout << "ref:" << ref << std::endl; + if(target.size() < 32) + std::cout << "target:" << target << std::endl; + if(range_zero(ref)) + std::cout << "Ref data is all zeros" << std::endl; + if(range_zero(target)) + std::cout << "Target data is all zeros" << std::endl; + + auto mxdiff = max_diff(ref, target); + std::cout << "Max diff: " << mxdiff << std::endl; + + auto idx = mismatch_idx(ref, target, float_equal); + if(idx < range_distance(ref)) + { + std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx] + << std::endl; + } + + auto ref_nan_idx = find_idx(ref, not_finite); + if(ref_nan_idx >= 0) + std::cout << "Non finite number found in ref at " << ref_nan_idx << ": " + << ref[ref_nan_idx] << std::endl; + + auto target_nan_idx = find_idx(target, not_finite); + if(target_nan_idx >= 0) + std::cout << "Non finite number found in target at " << target_nan_idx << ": " + << target[target_nan_idx] << std::endl; + std::cout << std::endl; + } + else + { + if(range_zero(ref)) + std::cout << "Ref data is all zeros" << std::endl; + if(range_zero(target)) + std::cout << "Target data is all zeros" << std::endl; + + // auto mxdiff = max_diff(ref, target); + // std::cout << "Max diff: " << mxdiff << std::endl; + + // auto idx = mismatch_idx(ref, target, float_equal); + // if(idx < range_distance(ref)) + // { + // std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx] + // << std::endl; + // } + + auto ref_nan_idx = find_idx(ref, not_finite); + if(ref_nan_idx >= 0) + std::cout << "Non finite number found in ref at " << ref_nan_idx << ": " + << ref[ref_nan_idx] << std::endl; + + auto target_nan_idx = find_idx(target, not_finite); + if(target_nan_idx >= 0) + std::cout << "Non finite number found in target at " << target_nan_idx << ": " + << target[target_nan_idx] << std::endl; + // std::cout << std::endl; + } + }); + return passed; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/version.h.in b/src/version.h.in new file mode 100644 index 0000000000000000000000000000000000000000..5f2038d45b023d44896ccc5829842f25e0f5466b --- /dev/null +++ b/src/version.h.in @@ -0,0 +1,4 @@ +// clang-format off +#define MIGRAPHX_VERSION_MAJOR @PROJECT_VERSION_MAJOR@ +#define MIGRAPHX_VERSION_MINOR @PROJECT_VERSION_MINOR@ +// clang-format on diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4650acbe2fb0f0e547eab1c964ec1f392234012e..526e53e3f7aa00bc359962ef95d733c2a0b60490 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -7,7 +7,7 @@ find_package(Threads REQUIRED) include(ProcessorCount) ProcessorCount(N) set(CTEST_PARALLEL_LEVEL ${N} CACHE STRING "CTest parallel level") -add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -j ${CTEST_PARALLEL_LEVEL} -C ${CMAKE_CFG_INTDIR} --timeout 1500) +add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -j ${CTEST_PARALLEL_LEVEL} -C ${CMAKE_CFG_INTDIR} --timeout 5000) add_custom_target(tests) find_program(MIGRAPHX_GDB gdb) @@ -44,6 +44,9 @@ function(add_test_command NAME EXE) # --args $ ${ARGN}) set(TEST_DIR ${CMAKE_CURRENT_BINARY_DIR}/gdb/test_${NAME}) file(MAKE_DIRECTORY ${TEST_DIR}) + if (NOT EXISTS ${TEST_DIR}) + message(FATAL_ERROR "Failed to create test directory: ${TEST_DIR}") + endif() file(GENERATE OUTPUT "${TEST_DIR}/run.cmake" CONTENT " # Remove previous core dump @@ -64,6 +67,7 @@ function(add_test_command NAME EXE) add_test(NAME ${NAME} COMMAND ${EXE} ${ARGN}) endif() endif() + set_tests_properties(${NAME} PROPERTIES FAIL_REGULAR_EXPRESSION "FAILED") endfunction() function(add_test_executable TEST_NAME) @@ -82,12 +86,11 @@ function(add_test_executable TEST_NAME) add_test_command(${TEST_NAME} ${TEST_COMMAND}) add_dependencies(tests ${TEST_NAME}) add_dependencies(check ${TEST_NAME}) - set_tests_properties(${TEST_NAME} PROPERTIES FAIL_REGULAR_EXPRESSION "FAILED") - target_link_libraries(${TEST_NAME} migraphx migraphx_cpu migraphx_onnx) + target_link_libraries(${TEST_NAME} migraphx migraphx_ref migraphx_onnx) target_include_directories(${TEST_NAME} PUBLIC include) endfunction(add_test_executable) -file(GLOB TESTS *.cpp) +file(GLOB TESTS ${CONFIGURE_DEPENDS} *.cpp) foreach(TEST ${TESTS}) get_filename_component(BASE_NAME ${TEST} NAME_WE) @@ -97,7 +100,7 @@ endforeach() if(MIGRAPHX_ENABLE_GPU) # gpu tests - file(GLOB GPU_TESTS gpu/*.cpp) + file(GLOB GPU_TESTS ${CONFIGURE_DEPENDS} gpu/*.cpp) foreach(TEST ${GPU_TESTS}) get_filename_component(BASE_NAME ${TEST} NAME_WE) @@ -117,24 +120,27 @@ file (GLOB ONNX_TESTS ${TEST_ONNX_DIR}/*.cpp) foreach(ONNX_TEST ${ONNX_TESTS}) get_filename_component(BASE_NAME ${ONNX_TEST} NAME_WE) set(TEST_NAME test_${BASE_NAME}) - add_executable(${TEST_NAME} ${TES_ONNX_DIR}/${ONNX_TEST}) + add_executable(${TEST_NAME} ${ONNX_TEST}) rocm_clang_tidy_check(${TEST_NAME}) - target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_cpu) + target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref) target_include_directories(${TEST_NAME} PUBLIC include) - add_test(NAME ${TEST_NAME} COMMAND $ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/onnx) + add_test(NAME ${TEST_NAME} COMMAND $ WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_dependencies(tests ${TEST_NAME}) add_dependencies(check ${TEST_NAME}) endforeach() # tf test +set(TEST_TF_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tf) add_executable(test_tf tf/tf_test.cpp) rocm_clang_tidy_check(test_tf) -target_link_libraries(test_tf migraphx_tf migraphx_cpu) +target_link_libraries(test_tf migraphx_tf migraphx_ref) target_include_directories(test_tf PUBLIC include) -add_test(NAME test_tf COMMAND $ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/tf) +add_test(NAME test_tf COMMAND $ WORKING_DIRECTORY ${TEST_TF_DIR}) add_dependencies(tests test_tf) add_dependencies(check test_tf) +add_subdirectory(api) +add_subdirectory(verify) if(MIGRAPHX_ENABLE_PYTHON) add_subdirectory(py) endif() @@ -154,7 +160,7 @@ function(test_header NAME HEADER) endfunction() function(test_headers PREFIX) - file(GLOB HEADERS ${ARGN}) + file(GLOB HEADERS ${CONFIGURE_DEPENDS} ${ARGN}) foreach(HEADER ${HEADERS}) file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER}) @@ -168,7 +174,7 @@ function(test_headers PREFIX) endfunction() test_headers(migraphx ${CMAKE_SOURCE_DIR}/src/include/migraphx/*.hpp) -test_headers(migraphx/cpu ${CMAKE_SOURCE_DIR}/src/targets/cpu/include/migraphx/cpu/*.hpp) +test_headers(migraphx/ref ${CMAKE_SOURCE_DIR}/src/targets/ref/include/migraphx/ref/*.hpp) if(MIGRAPHX_ENABLE_GPU) test_headers(migraphx/gpu ${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/*.hpp) endif() diff --git a/test/analyze_streams.cpp b/test/analyze_streams.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cbd5cb69669d9d521c1f107a5d727d967496d403 --- /dev/null +++ b/test/analyze_streams.cpp @@ -0,0 +1,523 @@ +#include +#include +#include +#include +#include "test.hpp" +#include "basic_ops.hpp" + +struct record_event +{ + std::size_t event = 0; + template + static auto reflect(Self& self, F f) + { + return migraphx::pack(f(self.event, "event")); + } + std::string name() const { return "record_event"; } + migraphx::shape compute_shape(const std::vector&) const { return {}; } + + migraphx::argument compute(const migraphx::shape&, const std::vector&) const + { + return {}; + } +}; + +struct wait_event +{ + std::size_t event = 0; + template + static auto reflect(Self& self, F f) + { + return migraphx::pack(f(self.event, "event")); + } + std::string name() const { return "wait_event"; } + migraphx::shape compute_shape(const std::vector&) const { return {}; } + + migraphx::argument compute(const migraphx::shape&, const std::vector&) const + { + return {}; + } +}; + +struct set_stream +{ + std::size_t stream = 0; + template + static auto reflect(Self& self, F f) + { + return migraphx::pack(f(self.stream, "stream")); + } + std::string name() const { return "set_stream"; } + migraphx::shape compute_shape(const std::vector&) const { return {}; } + + migraphx::argument compute(const migraphx::shape&, const std::vector&) const + { + return {}; + } +}; + +struct test_stream_model +{ + std::size_t max_stream = 0; + std::unordered_map ins2stream{}; + std::size_t get_nstream() const { return max_stream + 1; } + std::size_t get_stream(migraphx::instruction_ref ins) const { return ins2stream.at(ins); } + std::size_t get_event_id(migraphx::instruction_ref ins) const + { + auto v = ins->get_operator().to_value(); + return v["event"].to(); + } + bool has_stream(migraphx::instruction_ref ins) const { return ins2stream.count(ins) > 0; } + bool is_record(migraphx::instruction_ref ins) const { return ins->name() == "record_event"; } + bool is_wait(migraphx::instruction_ref ins) const { return ins->name() == "wait_event"; } +}; + +struct program_model +{ + migraphx::program p; + migraphx::module* mm = p.get_main_module(); + std::unordered_map ins2stream{}; + std::size_t max_stream = 0; + + template + migraphx::instruction_ref add_literal(Ts... xs) + { + return mm->add_literal(xs...); + } + + template + migraphx::instruction_ref add_instruction(Ts... xs) + { + return mm->add_instruction(xs...); + } + + template + migraphx::instruction_ref add_instruction_stream(std::size_t n, Ts... xs) + { + max_stream = std::max(max_stream, n); + auto ins = mm->add_instruction(xs...); + ins2stream[ins] = n; + return ins; + } + + template + migraphx::instruction_ref add_return(Ts... xs) + { + return mm->add_return({xs...}); + } + + template + migraphx::instruction_ref add_return_stream(std::size_t n, Ts... xs) + { + max_stream = std::max(max_stream, n); + auto ins = mm->add_return({xs...}); + ins2stream[ins] = n; + return ins; + } + + test_stream_model get_stream_model() const { return {max_stream, ins2stream}; } + + std::vector analyze() const + { + return migraphx::analyze_streams(*p.get_main_module(), get_stream_model()); + } + + void debug_print() const { p.debug_print(); } + + void debug_print(const std::vector& races) const + { + for(auto&& race : races) + { + std::cout << "Race:\n"; + mm->debug_print(race.ins); + mm->debug_print(race.before); + } + } +}; + +TEST_CASE(simple_race1) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + auto pass3 = pm.add_instruction_stream(0, pass_op{}, pass1, pass2); + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == pass3}); + EXPECT(bool{races.front().before == pass2}); +} + +TEST_CASE(simple_race2) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + auto pass21 = pm.add_instruction_stream(1, pass_op{}, pass2); + auto pass3 = pm.add_instruction_stream(0, pass_op{}, pass1, pass21); + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == pass3}); + EXPECT(bool{races.front().before == pass21}); +} + +TEST_CASE(simple_race3) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass11 = pm.add_instruction_stream(0, pass_op{}, pass1); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + auto pass3 = pm.add_instruction_stream(0, pass_op{}, pass11, pass2); + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == pass3}); + EXPECT(bool{races.front().before == pass2}); +} + +TEST_CASE(simple_race4) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass11 = pm.add_instruction_stream(0, pass_op{}, pass1); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + auto pass21 = pm.add_instruction_stream(1, pass_op{}, pass2); + auto pass3 = pm.add_instruction_stream(0, pass_op{}, pass11, pass21); + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == pass3}); + EXPECT(bool{races.front().before == pass21}); +} + +TEST_CASE(simple_race5) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + auto pass11 = pm.add_instruction_stream(0, pass_op{}, pass1); + auto pass21 = pm.add_instruction_stream(1, pass_op{}, pass2); + auto pass3 = pm.add_instruction_stream(0, pass_op{}, pass11, pass21); + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == pass3}); + EXPECT(bool{races.front().before == pass21}); +} + +TEST_CASE(simple_race_record_wait_wrong_stream) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + pm.add_instruction_stream(0, record_event{1}); + pm.add_instruction_stream(1, wait_event{1}); + auto pass3 = pm.add_instruction_stream(0, pass_op{}, pass1, pass2); + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == pass3}); + EXPECT(bool{races.front().before == pass2}); +} + +TEST_CASE(simple_race_record_wait_same_stream1) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + pm.add_instruction_stream(1, record_event{1}); + pm.add_instruction_stream(1, wait_event{1}); + auto pass3 = pm.add_instruction_stream(0, pass_op{}, pass1, pass2); + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == pass3}); + EXPECT(bool{races.front().before == pass2}); +} + +TEST_CASE(simple_race_record_wait_same_stream2) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + pm.add_instruction_stream(0, record_event{1}); + pm.add_instruction_stream(0, wait_event{1}); + auto pass3 = pm.add_instruction_stream(0, pass_op{}, pass1, pass2); + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == pass3}); + EXPECT(bool{races.front().before == pass2}); +} + +TEST_CASE(simple_race_sync) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + pm.add_instruction_stream(1, record_event{1}); + pm.add_instruction_stream(0, wait_event{1}); + pm.add_instruction_stream(0, pass_op{}, pass1, pass2); + auto races = pm.analyze(); + + EXPECT(races.empty()); +} + +TEST_CASE(race_return) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + auto r = pm.add_return_stream(0, pass1, pass2); + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == r}); + EXPECT(bool{races.front().before == pass2}); +} + +TEST_CASE(race_return_sync) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + pm.add_instruction_stream(1, record_event{1}); + pm.add_instruction_stream(0, wait_event{1}); + pm.add_return_stream(0, pass1, pass2); + auto races = pm.analyze(); + + EXPECT(races.empty()); +} + +TEST_CASE(race_double_wait1) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + pm.add_instruction_stream(1, record_event{1}); + auto pass3 = pm.add_instruction_stream(2, pass_op{}, one); + pm.add_instruction_stream(2, wait_event{1}); + auto pass4 = pm.add_instruction_stream(2, pass_op{}, pass3, pass2); + pm.add_instruction_stream(2, record_event{2}); + auto pass5 = pm.add_instruction_stream(0, pass_op{}, pass1, pass2); + pm.add_instruction_stream(0, record_event{3}); + pm.add_instruction_stream(1, wait_event{3}); + pm.add_instruction_stream(1, wait_event{2}); + pm.add_instruction_stream(1, pass_op{}, pass4, pass5); + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == pass5}); + EXPECT(bool{races.front().before == pass2}); +} + +TEST_CASE(race_double_wait2) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + pm.add_instruction_stream(1, record_event{1}); + auto pass3 = pm.add_instruction_stream(2, pass_op{}, one); + auto pass4 = pm.add_instruction_stream(2, pass_op{}, pass3, pass2); + pm.add_instruction_stream(2, record_event{2}); + pm.add_instruction_stream(0, wait_event{1}); + auto pass5 = pm.add_instruction_stream(0, pass_op{}, pass1, pass2); + pm.add_instruction_stream(0, record_event{3}); + pm.add_instruction_stream(1, wait_event{3}); + pm.add_instruction_stream(1, wait_event{2}); + pm.add_instruction_stream(1, pass_op{}, pass4, pass5); + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == pass4}); + EXPECT(bool{races.front().before == pass2}); +} + +TEST_CASE(race_double_wait3) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + pm.add_instruction_stream(1, record_event{1}); + auto pass3 = pm.add_instruction_stream(2, pass_op{}, one); + pm.add_instruction_stream(2, wait_event{1}); + auto pass4 = pm.add_instruction_stream(2, pass_op{}, pass3, pass2); + pm.add_instruction_stream(2, record_event{2}); + pm.add_instruction_stream(0, wait_event{1}); + auto pass5 = pm.add_instruction_stream(0, pass_op{}, pass1, pass2); + pm.add_instruction_stream(1, wait_event{2}); + auto pass6 = pm.add_instruction_stream(1, pass_op{}, pass4, pass5); + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == pass6}); + EXPECT(bool{races.front().before == pass5}); +} + +TEST_CASE(race_double_wait4) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + pm.add_instruction_stream(1, record_event{1}); + auto pass3 = pm.add_instruction_stream(2, pass_op{}, one); + pm.add_instruction_stream(2, wait_event{1}); + auto pass4 = pm.add_instruction_stream(2, pass_op{}, pass3, pass2); + pm.add_instruction_stream(2, record_event{2}); + pm.add_instruction_stream(0, wait_event{1}); + auto pass5 = pm.add_instruction_stream(0, pass_op{}, pass1, pass2); + pm.add_instruction_stream(0, record_event{3}); + pm.add_instruction_stream(1, wait_event{3}); + auto pass6 = pm.add_instruction_stream(1, pass_op{}, pass4, pass5); + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == pass6}); + EXPECT(bool{races.front().before == pass4}); +} + +TEST_CASE(race_double_wait_sync) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + pm.add_instruction_stream(1, record_event{1}); + auto pass3 = pm.add_instruction_stream(2, pass_op{}, one); + pm.add_instruction_stream(2, wait_event{1}); + auto pass4 = pm.add_instruction_stream(2, pass_op{}, pass3, pass2); + pm.add_instruction_stream(2, record_event{2}); + pm.add_instruction_stream(0, wait_event{1}); + auto pass5 = pm.add_instruction_stream(0, pass_op{}, pass1, pass2); + pm.add_instruction_stream(0, record_event{3}); + pm.add_instruction_stream(1, wait_event{3}); + pm.add_instruction_stream(1, wait_event{2}); + pm.add_instruction_stream(1, pass_op{}, pass4, pass5); + auto races = pm.analyze(); + + EXPECT(races.empty()); +} + +TEST_CASE(race_multi_wait1) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + pm.add_instruction_stream(0, record_event{5}); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + pm.add_instruction_stream(1, record_event{1}); + pm.add_instruction_stream(2, wait_event{1}); + auto pass3 = pm.add_instruction_stream(2, pass_op{}, one, pass2); + pm.add_instruction_stream(2, record_event{2}); + pm.add_instruction_stream(3, wait_event{5}); + auto pass4 = pm.add_instruction_stream(3, pass_op{}, one, pass1); + pm.add_instruction_stream(3, record_event{3}); + pm.add_instruction_stream(0, wait_event{2}); + auto pass5 = pm.add_instruction_stream(0, pass_op{}, pass3, pass1); + pm.add_instruction_stream(0, record_event{4}); + pm.add_instruction_stream(1, wait_event{3}); + auto pass6 = pm.add_instruction_stream(1, pass_op{}, pass4, pass5); + + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == pass6}); + EXPECT(bool{races.front().before == pass5}); +} + +TEST_CASE(race_multi_wait2) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + pm.add_instruction_stream(0, record_event{5}); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + pm.add_instruction_stream(1, record_event{1}); + pm.add_instruction_stream(2, wait_event{1}); + auto pass3 = pm.add_instruction_stream(2, pass_op{}, one, pass2); + pm.add_instruction_stream(2, record_event{2}); + pm.add_instruction_stream(3, wait_event{5}); + auto pass4 = pm.add_instruction_stream(3, pass_op{}, one, pass1); + pm.add_instruction_stream(3, record_event{3}); + pm.add_instruction_stream(0, wait_event{2}); + auto pass5 = pm.add_instruction_stream(0, pass_op{}, pass3, pass1); + pm.add_instruction_stream(0, record_event{4}); + pm.add_instruction_stream(1, wait_event{4}); + auto pass6 = pm.add_instruction_stream(1, pass_op{}, pass4, pass5); + + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == pass6}); + EXPECT(bool{races.front().before == pass4}); +} + +TEST_CASE(race_multi_wait3) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + pm.add_instruction_stream(1, record_event{1}); + pm.add_instruction_stream(2, wait_event{1}); + auto pass3 = pm.add_instruction_stream(2, pass_op{}, one, pass2); + pm.add_instruction_stream(2, record_event{2}); + auto pass4 = pm.add_instruction_stream(3, pass_op{}, one, pass1); + pm.add_instruction_stream(3, record_event{3}); + pm.add_instruction_stream(0, wait_event{2}); + auto pass5 = pm.add_instruction_stream(0, pass_op{}, pass3, pass1); + pm.add_instruction_stream(0, record_event{4}); + pm.add_instruction_stream(1, wait_event{3}); + pm.add_instruction_stream(1, wait_event{4}); + pm.add_instruction_stream(1, pass_op{}, pass4, pass5); + + auto races = pm.analyze(); + + EXPECT(races.size() == 1); + EXPECT(bool{races.front().ins == pass4}); + EXPECT(bool{races.front().before == pass1}); +} + +TEST_CASE(race_multi_wait_sync) +{ + program_model pm; + auto one = pm.add_literal(1); + auto pass1 = pm.add_instruction_stream(0, pass_op{}, one); + pm.add_instruction_stream(0, record_event{5}); + auto pass2 = pm.add_instruction_stream(1, pass_op{}, one); + pm.add_instruction_stream(1, record_event{1}); + pm.add_instruction_stream(2, wait_event{1}); + auto pass3 = pm.add_instruction_stream(2, pass_op{}, one, pass2); + pm.add_instruction_stream(2, record_event{2}); + pm.add_instruction_stream(3, wait_event{5}); + auto pass4 = pm.add_instruction_stream(3, pass_op{}, one, pass1); + pm.add_instruction_stream(3, record_event{3}); + pm.add_instruction_stream(0, wait_event{2}); + auto pass5 = pm.add_instruction_stream(0, pass_op{}, pass3, pass1); + pm.add_instruction_stream(0, record_event{4}); + pm.add_instruction_stream(1, wait_event{3}); + pm.add_instruction_stream(1, wait_event{4}); + pm.add_instruction_stream(1, pass_op{}, pass4, pass5); + + auto races = pm.analyze(); + + EXPECT(races.empty()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/any_ptr.cpp b/test/any_ptr.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d9bc5b0b26555fbbf259e2c971cae8c94dbb2b05 --- /dev/null +++ b/test/any_ptr.cpp @@ -0,0 +1,28 @@ +#include +#include + +TEST_CASE(test_int_id) +{ + int i = 1; + migraphx::any_ptr p = &i; + EXPECT(p.get() == &i); + EXPECT(p.get(migraphx::get_type_name(i)) == &i); + EXPECT(p.unsafe_get() == &i); + EXPECT(test::throws([&] { p.get(); })); + EXPECT(test::throws([&] { p.get(migraphx::get_type_name(&i)); })); +} + +TEST_CASE(test_int_name) +{ + int i = 1; + void* vp = &i; + migraphx::any_ptr p{vp, migraphx::get_type_name(i)}; + EXPECT(p.get() == &i); + EXPECT(p.get(migraphx::get_type_name(i)) == &i); + EXPECT(p.unsafe_get() == &i); + EXPECT(test::throws([&] { p.get(); })); + EXPECT(test::throws([&] { p.get(migraphx::get_type_name(&i)); })); + EXPECT(test::throws([&] { p.get(migraphx::get_type_name(float{})); })); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/api/CMakeLists.txt b/test/api/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..55b51d0614e6da38cf85043ad36ab78f683f83c2 --- /dev/null +++ b/test/api/CMakeLists.txt @@ -0,0 +1,26 @@ +function(add_api_test TEST_NAME TEST_SRC TEST_DIR) + set(NAME test_api_${TEST_NAME}) + add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC}) + rocm_clang_tidy_check(${NAME}) + target_link_libraries(${NAME} migraphx_c migraphx) + target_include_directories(${NAME} PUBLIC ../include) + add_test(NAME ${NAME} COMMAND $ WORKING_DIRECTORY ${TEST_DIR}) + add_dependencies(tests ${NAME}) + add_dependencies(check ${NAME}) +endfunction() + +add_api_test(array_base test_array_base.cpp ${TEST_ONNX_DIR}) +add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR}) +add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR}) +add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) +add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR}) +add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR}) +add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) +add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) +add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) +add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR}) +# GPU-based tests +if(MIGRAPHX_ENABLE_GPU) +add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR}) +target_link_libraries(test_api_gpu migraphx_gpu) +endif() diff --git a/test/api/test_array_base.cpp b/test/api/test_array_base.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e1633188adefb06bd2b8f67f9f072022cb790412 --- /dev/null +++ b/test/api/test_array_base.cpp @@ -0,0 +1,32 @@ +#include +#include "test.hpp" + +struct array2 : migraphx::array_base +{ + std::vector v; + array2() = default; + array2(std::initializer_list x) : v(x) {} + std::size_t size() const { return v.size(); } + int operator[](std::size_t i) const { return v[i]; } +}; + +TEST_CASE(iterators) +{ + array2 a = {1, 2, 3}; + EXPECT(bool{std::equal(a.begin(), a.end(), a.v.begin())}); +} + +TEST_CASE(front_back) +{ + array2 a = {1, 2, 3}; + EXPECT(a.front() == 1); + EXPECT(a.back() == 3); +} + +TEST_CASE(empty) +{ + array2 a = {1, 2, 3}; + EXPECT(not a.empty()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/api/test_assign.cpp b/test/api/test_assign.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f2c4398613ffe21017719991c878a368fd34416a --- /dev/null +++ b/test/api/test_assign.cpp @@ -0,0 +1,25 @@ +#include +#include +#include "test.hpp" + +TEST_CASE(shape_assign) +{ + auto s1_cpp = migraphx::shape{migraphx_shape_float_type, {1, 3}}; + std::vector lens{2, 3}; + + // handle ptr is const, workaround to construct shape using C API + migraphx_shape_t s2; + migraphx_shape_create(&s2, migraphx_shape_float_type, lens.data(), lens.size()); + auto s2_cpp = migraphx::shape(s2, migraphx::own{}); + CHECK(bool{s1_cpp != s2_cpp}); + // use C++ API for assignment + s1_cpp.assign_to_handle(s2); + CHECK(bool{s1_cpp == s2_cpp}); + + auto s3_cpp = migraphx::shape{migraphx_shape_float_type, lens}; + // use C API for assignment + migraphx_shape_assign_to(s2, s3_cpp.get_handle_ptr()); + CHECK(bool{s2_cpp == s3_cpp}); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/api/test_compile_options.cpp b/test/api/test_compile_options.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ef63bf0b3ceb6fb80f08fb72c735dac121824e97 --- /dev/null +++ b/test/api/test_compile_options.cpp @@ -0,0 +1,17 @@ +#include +#include +#include +#include "test.hpp" + +TEST_CASE(compile_options_api_test) +{ + migraphx::api::compile_options options; + options.set_offload_copy(false); + options.set_fast_math(false); + const auto* s_options = reinterpret_cast( + options.get_handle_ptr()); + CHECK(s_options->fast_math == false); + CHECK(s_options->offload_copy == false); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/api/test_cpu.cpp b/test/api/test_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d50c049b0add4337f93412f611236a04a99f20a8 --- /dev/null +++ b/test/api/test_cpu.cpp @@ -0,0 +1,178 @@ +#include +#include +#include "test.hpp" + +TEST_CASE(load_and_run) +{ + auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); + auto shapes_before = p.get_output_shapes(); + p.compile(migraphx::target("ref")); + auto shapes_after = p.get_output_shapes(); + CHECK(shapes_before.size() == 1); + CHECK(shapes_before.size() == shapes_after.size()); + CHECK(bool{shapes_before.front() == shapes_after.front()}); + migraphx::program_parameters pp; + auto param_shapes = p.get_parameter_shapes(); + for(auto&& name : param_shapes.names()) + { + pp.add(name, migraphx::argument::generate(param_shapes[name])); + } + auto outputs = p.eval(pp); + CHECK(shapes_before.size() == outputs.size()); + CHECK(bool{shapes_before.front() == outputs.front().get_shape()}); +} + +TEST_CASE(load_and_run_init_list) +{ + auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); + auto shapes_before = p.get_output_shapes(); + p.compile(migraphx::target("ref")); + auto shapes_after = p.get_output_shapes(); + CHECK(shapes_before.size() == 1); + CHECK(shapes_before.size() == shapes_after.size()); + CHECK(bool{shapes_before.front() == shapes_after.front()}); + auto param_shapes = p.get_parameter_shapes(); + EXPECT(param_shapes.size() == 3); + auto names = param_shapes.names(); + auto outputs = p.eval({{names[0], migraphx::argument::generate(param_shapes[names[0]])}, + {names[1], migraphx::argument::generate(param_shapes[names[1]])}, + {names[2], migraphx::argument::generate(param_shapes[names[2]])}}); + CHECK(shapes_before.size() == outputs.size()); + CHECK(bool{shapes_before.front() == outputs.front().get_shape()}); +} + +TEST_CASE(quantize_fp16) +{ + auto p1 = migraphx::parse_onnx("gemm_ex_test.onnx"); + const auto& p2 = p1; + const auto& p3 = p1; + migraphx::quantize_fp16(p1); + + migraphx::quantize_op_names names; + migraphx::quantize_fp16(p2, names); + CHECK(bool{p1 == p2}); + + names.add("dot"); + migraphx::quantize_fp16(p3, names); + CHECK(bool{p1 == p3}); +} + +TEST_CASE(quantize_int8) +{ + auto p1 = migraphx::parse_onnx("gemm_ex_test.onnx"); + const auto& p2 = p1; + auto t = migraphx::target("ref"); + migraphx::quantize_int8_options options; + migraphx::quantize_int8(p1, t, options); + + migraphx::program_parameters pp; + auto param_shapes = p1.get_parameter_shapes(); + for(auto&& name : param_shapes.names()) + { + pp.add(name, migraphx::argument::generate(param_shapes[name])); + } + options.add_calibration_data(pp); + options.add_op_name("dot"); + + migraphx::quantize_int8(p2, t, options); + CHECK(bool{p1 == p2}); +} + +TEST_CASE(load_and_run_user_input_shape) +{ + migraphx::onnx_options options; + options.set_input_parameter_shape("0", {2, 3, 64, 64}); + auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx", options); + auto shapes_before = p.get_output_shapes(); + p.compile(migraphx::target("ref")); + auto shapes_after = p.get_output_shapes(); + CHECK(shapes_before.size() == 1); + CHECK(shapes_before.size() == shapes_after.size()); + CHECK(bool{shapes_before.front() == shapes_after.front()}); + migraphx::program_parameters pp; + auto param_shapes = p.get_parameter_shapes(); + for(auto&& name : param_shapes.names()) + { + pp.add(name, migraphx::argument::generate(param_shapes[name])); + } + auto outputs = p.eval(pp); + CHECK(shapes_before.size() == outputs.size()); + CHECK(bool{shapes_before.front() == outputs.front().get_shape()}); +} + +TEST_CASE(zero_parameter) +{ + auto p = migraphx::parse_onnx("constant_fill_test.onnx"); + auto shapes_before = p.get_output_shapes(); + p.compile(migraphx::target("ref")); + auto shapes_after = p.get_output_shapes(); + CHECK(shapes_before.size() == 1); + CHECK(shapes_before.size() == shapes_after.size()); + CHECK(bool{shapes_before.front() == shapes_after.front()}); + migraphx::program_parameters pp; + auto param_shapes = p.get_parameter_shapes(); + for(auto&& name : param_shapes.names()) + { + pp.add(name, migraphx::argument::generate(param_shapes[name])); + } + auto outputs = p.eval(pp); + CHECK(shapes_before.size() == outputs.size()); + CHECK(bool{shapes_before.front() == outputs.front().get_shape()}); +} + +TEST_CASE(set_scalar_parameter) +{ + auto p1 = migraphx::parse_onnx("add_bcast_test.onnx"); + migraphx::shape s1(migraphx_shape_float_type, {3, 4}); + auto param_shapes = p1.get_parameter_shapes(); + auto s1_orig = param_shapes["1"]; + CHECK(bool{s1 == s1_orig}); + + migraphx::onnx_options option; + option.set_input_parameter_shape("1", {}); + auto p2 = migraphx::parse_onnx("add_bcast_test.onnx", option); + migraphx::shape s_scalar(migraphx_shape_float_type); + auto param_shapes_1 = p2.get_parameter_shapes(); + auto s_scalar_after = param_shapes_1["1"]; + CHECK(bool{s_scalar == s_scalar_after}); +} + +TEST_CASE(scalar_shape) +{ + auto s = migraphx::shape(migraphx_shape_float_type); + EXPECT(s.lengths().size() == 1); + EXPECT(s.strides().size() == 1); + EXPECT(s.lengths().front() == 1); + EXPECT(s.strides().front() == 0); +} + +TEST_CASE(strided_shape) +{ + std::vector lens = {2, 2}; + std::vector strides = {1, 2}; + auto s = migraphx::shape(migraphx_shape_float_type, lens, strides); + EXPECT(s.lengths() == lens); + EXPECT(s.strides() == strides); +} + +TEST_CASE(get_main_module) +{ + auto p = migraphx::parse_onnx("constant_fill_test.onnx"); + migraphx::module mm = p.get_main_module(); + mm.print(); + p.print(); +} + +TEST_CASE(set_loop_default_iter_num) +{ + migraphx::onnx_options option; + option.set_default_loop_iterations(15); + auto p = migraphx::parse_onnx("loop_default_test.onnx", option); + auto out_shapes = p.get_output_shapes(); + std::vector out_lens0 = {1}; + EXPECT(out_shapes[0].lengths() == out_lens0); + std::vector out_lens1 = {15, 1}; + EXPECT(out_shapes[1].lengths() == out_lens1); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/api/test_custom_op.cpp b/test/api/test_custom_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..50472d6f38b264be89ad1d28cfcc3f6381c33983 --- /dev/null +++ b/test/api/test_custom_op.cpp @@ -0,0 +1,23 @@ +#include +#include +#include "test.hpp" + +struct simple_custom_op final : migraphx::experimental_custom_op_base +{ + virtual std::string name() const override { return "simple_custom_op"; } + virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override + { + return inputs.front(); + } +}; + +TEST_CASE(register_custom_op) +{ + simple_custom_op simple_op; + migraphx::register_experimental_custom_op(simple_op); + + auto op = migraphx::operation("simple_custom_op"); + EXPECT(op.name() == "simple_custom_op"); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/api/test_gpu.cpp b/test/api/test_gpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3fe4af85060cbb34d72363b951037680e8003c14 --- /dev/null +++ b/test/api/test_gpu.cpp @@ -0,0 +1,154 @@ +#include +#include +#include +#include +#include "test.hpp" + +TEST_CASE(load_and_run) +{ + auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); + auto shapes_before = p.get_output_shapes(); + migraphx::compile_options options; + options.set_offload_copy(); + p.compile(migraphx::target("gpu"), options); + auto shapes_after = p.get_output_shapes(); + CHECK(shapes_before.size() == 1); + CHECK(shapes_before.size() == shapes_after.size()); + CHECK(bool{shapes_before.front() == shapes_after.front()}); + migraphx::program_parameters pp; + auto param_shapes = p.get_parameter_shapes(); + for(auto&& name : param_shapes.names()) + { + pp.add(name, migraphx::argument::generate(param_shapes[name])); + } + auto outputs = p.eval(pp); + CHECK(shapes_before.size() == outputs.size()); + CHECK(bool{shapes_before.front() == outputs.front().get_shape()}); +} + +TEST_CASE(load_and_run_ctx) +{ + auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); + migraphx::compile_options options; + options.set_offload_copy(); + p.compile(migraphx::target("gpu"), options); + migraphx::program_parameters pp; + auto param_shapes = p.get_parameter_shapes(); + for(auto&& name : param_shapes.names()) + { + pp.add(name, migraphx::argument::generate(param_shapes[name])); + } + auto ctx = p.experimental_get_context(); + EXPECT(ctx.get_queue() != nullptr); + p.eval(pp); + ctx.finish(); +} + +TEST_CASE(if_pl_test) +{ + auto run_prog = [&](auto cond) { + auto p = migraphx::parse_onnx("if_pl_test.onnx"); + auto shapes_before = p.get_output_shapes(); + migraphx::compile_options options; + options.set_offload_copy(); + p.compile(migraphx::target("gpu"), options); + auto shapes_after = p.get_output_shapes(); + CHECK(shapes_before.size() == 1); + CHECK(bool{shapes_before.front() == shapes_after.front()}); + + migraphx::program_parameters pp; + auto param_shapes = p.get_parameter_shapes(); + auto xs = param_shapes["x"]; + std::vector xd(xs.bytes() / sizeof(float), 1.0); + pp.add("x", migraphx::argument(xs, xd.data())); + auto ys = param_shapes["y"]; + std::vector yd(ys.bytes() / sizeof(float), 2.0); + pp.add("y", migraphx::argument(ys, yd.data())); + char ccond = cond; + pp.add("cond", migraphx::argument(param_shapes["cond"], &ccond)); + + auto outputs = p.eval(pp); + auto output = outputs[0]; + auto lens = output.get_shape().lengths(); + auto elem_num = + std::accumulate(lens.begin(), lens.end(), 1, std::multiplies()); + float* data_ptr = reinterpret_cast(output.data()); + std::vector ret(data_ptr, data_ptr + elem_num); + + return ret; + }; + + // then branch + { + auto result_vector = run_prog(true); + std::vector gold = {2, 3, 4, 5, 6, 7}; + EXPECT(result_vector == gold); + } + + // else branch + { + auto result_vector = run_prog(false); + std::vector gold = {1, 2, 3, 4, 5, 6}; + EXPECT(result_vector == gold); + } +} + +TEST_CASE(loop_test) +{ + auto run_prog = [&](int64_t max_iter_num) { + migraphx::onnx_options parse_options; + parse_options.set_default_loop_iterations(max_iter_num); + auto p = migraphx::parse_onnx("loop_default_test.onnx", parse_options); + auto shapes_before = p.get_output_shapes(); + migraphx::compile_options options; + options.set_offload_copy(); + p.compile(migraphx::target("gpu"), options); + auto shapes_after = p.get_output_shapes(); + CHECK(shapes_before.size() == 2); + CHECK(bool{shapes_before.front() == shapes_after.front()}); + + migraphx::program_parameters pp; + auto param_shapes = p.get_parameter_shapes(); + auto aas = param_shapes["a"]; + std::vector xd = {1.0f}; + pp.add("a", migraphx::argument(aas, xd.data())); + auto bbs = param_shapes["b"]; + std::vector yd = {2.0}; + pp.add("b", migraphx::argument(bbs, yd.data())); + + auto outputs = p.eval(pp); + auto output = outputs[0]; + auto lens = output.get_shape().lengths(); + auto elem_num = + std::accumulate(lens.begin(), lens.end(), 1, std::multiplies()); + float* data_ptr = reinterpret_cast(output.data()); + std::vector> ret; + ret.push_back({data_ptr, data_ptr + elem_num}); + + output = outputs[1]; + lens = output.get_shape().lengths(); + elem_num = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies()); + data_ptr = reinterpret_cast(output.data()); + ret.push_back({data_ptr, data_ptr + elem_num}); + + return ret; + }; + + { + auto result_vector = run_prog(10); + std::vector gold0 = {2.0f}; + EXPECT(result_vector.at(0) == gold0); + std::vector gold1 = {-2, 4, 0, 0, 0, 0, 0, 0, 0, 0}; + EXPECT(result_vector.at(1) == gold1); + } + + { + auto result_vector = run_prog(15); + std::vector gold0 = {2.0f}; + EXPECT(result_vector.at(0) == gold0); + std::vector gold1 = {-2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + EXPECT(result_vector.at(1) == gold1); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/api/test_lookup.cpp b/test/api/test_lookup.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c7d110ba3b0586989d715a68431655eb912b293 --- /dev/null +++ b/test/api/test_lookup.cpp @@ -0,0 +1,30 @@ +#include +#include "test.hpp" + +template +std::false_type has_handle(migraphx::rank<0>, T) +{ + return {}; +} + +template +auto has_handle(migraphx::rank<1>, T*) -> decltype(migraphx::as_handle{}, std::true_type{}) +{ + return {}; +} + +TEST_CASE(shape) +{ + static_assert(std::is_same, migraphx::shape>{}, "Failed"); + static_assert(std::is_same, migraphx::shape>{}, "Failed"); + static_assert(std::is_same, migraphx::shape>{}, + "Failed"); +} +TEST_CASE(non_handle) +{ + int i = 0; + EXPECT(bool{has_handle(migraphx::rank<1>{}, migraphx_shape_t{})}); + EXPECT(bool{not has_handle(migraphx::rank<1>{}, &i)}); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/api/test_module_construct.cpp b/test/api/test_module_construct.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d66e22f0ff5ebba3efac102bc8555356e60483f --- /dev/null +++ b/test/api/test_module_construct.cpp @@ -0,0 +1,73 @@ +#include +#include +#include +#include "test.hpp" + +TEST_CASE(add_literals) +{ + migraphx::program p; + migraphx::module m = p.get_main_module(); + migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}}; + std::vector x_values(9, 1); + auto x = m.add_literal(param_shape, x_values.data()); + std::vector y_values(9, -1); + auto y = m.add_literal(param_shape, y_values.data()); + auto add_op = migraphx::operation("add"); + auto r = m.add_instruction(add_op, {x, y}); + m.add_return({r}); + // run on ref target + p.compile(migraphx::target("ref")); + migraphx::program_parameters pp; + auto outputs = p.eval(pp); + auto output = outputs[0]; + std::vector expected(9, 0); + CHECK(bool(output == migraphx::argument(param_shape, expected.data()))); +} + +TEST_CASE(if_then_else_op) +{ + migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}}; + migraphx::shape cond_s{migraphx_shape_bool_type}; + auto create_program = [&]() { + migraphx::program p; + auto mm = p.get_main_module(); + auto cond = mm.add_parameter("cond", cond_s); + auto x = mm.add_parameter("x", param_shape); + auto y = mm.add_parameter("y", param_shape); + auto then_mod = p.create_module("If_0_if"); + auto x_identity = then_mod.add_instruction(migraphx::operation("identity"), {x}); + then_mod.add_return({x_identity}); + + auto else_mod = p.create_module("If_0_else"); + auto y_identity = else_mod.add_instruction(migraphx::operation("identity"), {y}); + else_mod.add_return({y_identity}); + + auto if_ins = mm.add_instruction(migraphx::operation("if"), {cond}, {then_mod, else_mod}); + auto get_tuple_op = migraphx::operation("get_tuple_elem", "{index: 0}"); + auto ret = mm.add_instruction(get_tuple_op, {if_ins}); + mm.add_return({ret}); + return p; + }; + + std::vector x_data(9, 1); + std::vector y_data(9, -1); + auto x_arg = migraphx::argument(param_shape, x_data.data()); + auto y_arg = migraphx::argument(param_shape, y_data.data()); + auto run_prog = [&](bool cond) { + auto p = create_program(); + p.compile(migraphx::target("ref")); + auto outputs = + p.eval({{"cond", migraphx::argument(cond_s, &cond)}, {"x", x_arg}, {"y", y_arg}}); + return outputs[0]; + }; + + // then branch + auto then_res = run_prog(true); + CHECK(bool{then_res == x_arg}); + + // else branch + auto else_res = run_prog(false); + CHECK(bool{else_res == y_arg}); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/api/test_op_construct.cpp b/test/api/test_op_construct.cpp new file mode 100755 index 0000000000000000000000000000000000000000..80ee5df4b948e9febdf2c83d72904ff796135a91 --- /dev/null +++ b/test/api/test_op_construct.cpp @@ -0,0 +1,29 @@ +#include +#include +#include "test.hpp" + +TEST_CASE(add_op) +{ + auto add_op = migraphx::operation("add"); + EXPECT(add_op.name() == "add"); +} + +TEST_CASE(reduce_mean_without_quotes) +{ + auto rm = migraphx::operation("reduce_mean", "{axes : [1, 2, 3, 4]}"); + EXPECT(rm.name() == "reduce_mean"); +} + +TEST_CASE(reduce_mean) +{ + auto rm = migraphx::operation("reduce_mean", "{\"axes\" : [1, 2, 3, 4]}"); + EXPECT(rm.name() == "reduce_mean"); +} + +TEST_CASE(reduce_mean_with_format) +{ + auto rm = migraphx::operation("reduce_mean", "{axes : [%i, %i, %i, %i]}", 1, 2, 3, 4); + EXPECT(rm.name() == "reduce_mean"); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/api/test_save_load.cpp b/test/api/test_save_load.cpp new file mode 100644 index 0000000000000000000000000000000000000000..62063da961777581ae217d1b7aa006cb04efbd5a --- /dev/null +++ b/test/api/test_save_load.cpp @@ -0,0 +1,37 @@ +#include +#include +#include "test.hpp" + +TEST_CASE(load_save_default) +{ + std::string filename = "migraphx_api_load_save.mxr"; + auto p1 = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); + auto s1 = p1.get_output_shapes(); + + migraphx::save(p1, filename.c_str()); + auto p2 = migraphx::load(filename.c_str()); + auto s2 = p2.get_output_shapes(); + EXPECT(s1.size() == s2.size()); + EXPECT(bool{s1.front() == s2.front()}); + EXPECT(bool{p1.sort() == p2.sort()}); + std::remove(filename.c_str()); +} + +TEST_CASE(load_save_json) +{ + std::string filename = "migraphx_api_load_save.json"; + auto p1 = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); + auto s1 = p1.get_output_shapes(); + migraphx::file_options options; + options.set_file_format("json"); + + migraphx::save(p1, filename.c_str(), options); + auto p2 = migraphx::load(filename.c_str(), options); + auto s2 = p2.get_output_shapes(); + EXPECT(s1.size() == s2.size()); + EXPECT(bool{s1.front() == s2.front()}); + EXPECT(bool{p1.sort() == p2.sort()}); + std::remove(filename.c_str()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/api/test_tf_parser.cpp b/test/api/test_tf_parser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72dc2b19ff0f76b6b66bd968ecf1a330ee4bb044 --- /dev/null +++ b/test/api/test_tf_parser.cpp @@ -0,0 +1,45 @@ +#include +#include +#include "test.hpp" + +TEST_CASE(load_tf) +{ + auto p = migraphx::parse_tf("add_test.pb"); + auto shapes = p.get_output_shapes(); + CHECK(shapes.size() == 1); +} + +TEST_CASE(load_tf_default_dim) +{ + migraphx::tf_options tf_options; + size_t batch = 2; + tf_options.set_default_dim_value(batch); + tf_options.set_nhwc(); + auto p = migraphx::parse_tf("conv_batch_test.pb", tf_options); + auto shapes = p.get_output_shapes(); + CHECK(shapes.size() == 1); + CHECK(shapes.front().lengths().front() == batch); +} + +TEST_CASE(load_tf_param_shape) +{ + migraphx::tf_options tf_options; + std::vector new_shape{1, 3}; + tf_options.set_input_parameter_shape("0", new_shape); + tf_options.set_input_parameter_shape("1", new_shape); + auto p = migraphx::parse_tf("add_test.pb", tf_options); + auto shapes = p.get_output_shapes(); + CHECK(shapes.size() == 1); + CHECK(shapes.front().lengths() == new_shape); +} + +TEST_CASE(load_tf_multi_outputs) +{ + migraphx::tf_options tf_options; + tf_options.set_output_names({"relu", "tanh"}); + auto p = migraphx::parse_tf("multi_output_test.pb", tf_options); + auto shapes = p.get_output_shapes(); + CHECK(shapes.size() == 2); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/argument_test.cpp b/test/argument_test.cpp new file mode 100755 index 0000000000000000000000000000000000000000..814f4704013d035d6733ca9c5af1fe645cc76709 --- /dev/null +++ b/test/argument_test.cpp @@ -0,0 +1,202 @@ +#include +#include +#include +#include +#include +#include "test.hpp" + +migraphx::argument as_argument(migraphx::argument a) { return a; } +template +migraphx::argument as_argument(T x) +{ + return migraphx::literal{x}.get_argument(); +} +template +migraphx::argument make_tuple(Ts... xs) +{ + return migraphx::argument{{as_argument(xs)...}}; +} + +TEST_CASE(copy_eq) +{ + auto a1 = as_argument(3); + auto a2 = as_argument(3); + auto a3 = as_argument(1); + auto a4 = a1; // NOLINT + + EXPECT(a1 == a2); + EXPECT(a2 != a3); + EXPECT(a1 == a4); + EXPECT(a4 != a3); + + EXPECT(a1.get_sub_objects().empty()); + EXPECT(a2.get_sub_objects().empty()); + EXPECT(a3.get_sub_objects().empty()); + EXPECT(a4.get_sub_objects().empty()); +} + +TEST_CASE(default_construct) +{ + migraphx::argument a1{}; + migraphx::argument a2{}; + + EXPECT(a1.empty()); + EXPECT(a2.empty()); + EXPECT(a1 == a2); + + EXPECT(a1.to_string().empty()); + EXPECT(a2.to_string().empty()); + + EXPECT(a1.get_sub_objects().empty()); + EXPECT(a2.get_sub_objects().empty()); +} + +TEST_CASE(string_elems) +{ + migraphx::shape s{migraphx::shape::int64_type, {3}}; + migraphx::literal l{s, {1, 2, 3}}; + auto a = l.get_argument(); + + EXPECT(a.to_string() == "1, 2, 3"); +} + +TEST_CASE(tuple) +{ + auto a1 = make_tuple(3, 3.0); + + EXPECT(a1.get_shape().type() == migraphx::shape::tuple_type); + EXPECT(a1.get_sub_objects().size() == 2); + EXPECT(a1.get_sub_objects()[0] == as_argument(3)); + EXPECT(a1.get_sub_objects()[1] == as_argument(3.0)); + + auto a2 = make_tuple(3, 3.0); + + EXPECT(a1 == a2); + EXPECT(a1.to_string() == a2.to_string()); + + auto a3 = make_tuple(3, 4.0); + EXPECT(a1 != a3); + EXPECT(a1.to_string() != a3.to_string()); +} + +TEST_CASE(nested_tuple) +{ + auto a1 = make_tuple(3, make_tuple(5, 4)); + + EXPECT(a1.get_shape().type() == migraphx::shape::tuple_type); + EXPECT(a1.get_sub_objects().size() == 2); + EXPECT(a1.get_sub_objects()[0] == as_argument(3)); + EXPECT(a1.get_sub_objects()[1] == make_tuple(5, 4)); + + auto a2 = make_tuple(3, make_tuple(5, 4)); + + EXPECT(a1 == a2); + EXPECT(a1.to_string() == a2.to_string()); + + auto a3 = make_tuple(3, make_tuple(5, 6)); + EXPECT(a1 != a3); + EXPECT(a1.to_string() != a3.to_string()); +} + +TEST_CASE(tuple_construct) +{ + migraphx::shape s{{migraphx::shape{migraphx::shape::float_type, {4}}, + migraphx::shape{migraphx::shape::int8_type, {3}}}}; + migraphx::argument a{s}; + EXPECT(a.get_sub_objects().size() == 2); + EXPECT(a.get_shape() == s); + + auto b = a; // NOLINT + EXPECT(a.get_shape() == b.get_shape()); + EXPECT(a.get_sub_objects().size() == 2); + EXPECT(a.get_sub_objects()[0] == b.get_sub_objects()[0]); + EXPECT(a.get_sub_objects()[1] == b.get_sub_objects()[1]); + EXPECT(a == b); +} + +TEST_CASE(tuple_visit) +{ + auto a1 = make_tuple(3, 3.0); + EXPECT(test::throws([&] { a1.visit([](auto&&) {}); })); + EXPECT(test::throws([&] { a1.at(); })); + + bool reaches = false; + a1.visit([&](auto&&) { EXPECT(false); }, + [&](auto&& xs) { + reaches = true; + EXPECT(xs.size() == 2); + EXPECT(xs[0] == as_argument(3)); + EXPECT(xs[1] == as_argument(3.0)); + }); + EXPECT(reaches); +} + +TEST_CASE(tuple_visit_all) +{ + auto a1 = make_tuple(3, 3.0); + auto a2 = make_tuple(1, 2, 3); + + EXPECT(test::throws([&] { visit_all(a1, a2)([](auto&&, auto&&) {}); })); + bool reaches = false; + visit_all(a1, a2)([&](auto&&, auto&&) { EXPECT(false); }, + [&](auto&& xs, auto&& ys) { + reaches = true; + EXPECT(xs.size() == 2); + EXPECT(xs[0] == as_argument(3)); + EXPECT(xs[1] == as_argument(3.0)); + + EXPECT(ys.size() == 3); + EXPECT(ys[0] == as_argument(1)); + EXPECT(ys[1] == as_argument(2)); + EXPECT(ys[2] == as_argument(3)); + }); + EXPECT(reaches); +} + +TEST_CASE(value_argument) +{ + migraphx::shape s{migraphx::shape::int64_type, {3}}; + migraphx::literal l1{s, {1, 2, 3}}; + auto a1 = l1.get_argument(); + auto v1 = migraphx::to_value(a1); + migraphx::literal l2{1}; + auto a2 = l2.get_argument(); + auto v2 = migraphx::to_value(a2); + EXPECT(v1 != v2); + + auto a3 = migraphx::from_value(v1); + EXPECT(a3 == a1); + auto a4 = migraphx::from_value(v2); + EXPECT(a4 == a2); +} + +TEST_CASE(value_tuple) +{ + auto a1 = make_tuple(3, 3.0, make_tuple(3, 4)); + auto a2 = make_tuple(1, 2, 3); + + auto v1 = migraphx::to_value(a1); + auto v2 = migraphx::to_value(a2); + EXPECT(v1 != v2); + + auto a3 = migraphx::from_value(v1); + EXPECT(a3 == a1); + auto a4 = migraphx::from_value(v2); + EXPECT(a4 == a2); +} + +TEST_CASE(argument_share) +{ + migraphx::shape s{migraphx::shape::int64_type, {3}}; + std::vector buffer(s.bytes()); + migraphx::argument a1(s, [=]() mutable { return buffer.data(); }); + auto a2 = a1; // NOLINT + EXPECT(a1.data() != a2.data()); + + auto a3 = a1.share(); + EXPECT(a1.data() != a3.data()); + auto a4 = a3; // NOLINT + EXPECT(a4.data() == a3.data()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/auto_contiguous_test.cpp b/test/auto_contiguous_test.cpp index 6910e488211acf8e9ab85f866baa3a20105404e9..f58070fdd46ffff93b1a65114d32b9f22d5bc746 100644 --- a/test/auto_contiguous_test.cpp +++ b/test/auto_contiguous_test.cpp @@ -1,96 +1,186 @@ #include -#include -#include #include #include #include +#include + #include -void run_pass(migraphx::program& p) { migraphx::run_passes(p, {migraphx::auto_contiguous{}}); } +void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::auto_contiguous{}}); } // TODO: Add this test case void literal_broadcast() { - migraphx::program p; - p.add_literal(get_2_broadcasted()); - EXPECT(not p.get_shape().standard()); - EXPECT(p.get_shape().broadcasted()); - run_pass(p); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().broadcasted()); + migraphx::module m; + + m.add_literal(get_2_broadcasted()); + EXPECT(not m.get_output_shapes().back().standard()); + EXPECT(m.get_output_shapes().back().broadcasted()); + run_pass(m); + EXPECT(m.get_output_shapes().back().standard()); + EXPECT(not m.get_output_shapes().back().broadcasted()); } TEST_CASE(literal_transpose) { - migraphx::program p; - p.add_literal(get_2x2_transposed()); - EXPECT(not p.get_shape().standard()); - EXPECT(p.get_shape().transposed()); - run_pass(p); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().transposed()); + migraphx::module m; + + m.add_literal(get_2x2_transposed()); + EXPECT(not m.get_output_shapes().back().standard()); + EXPECT(m.get_output_shapes().back().transposed()); + run_pass(m); + EXPECT(m.get_output_shapes().back().standard()); + EXPECT(not m.get_output_shapes().back().transposed()); } TEST_CASE(after_literal_transpose) { - migraphx::program p; - auto l = p.add_literal(get_2x2()); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().transposed()); - auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); - p.add_instruction(pass_op{}, t); - EXPECT(not p.get_shape().standard()); - EXPECT(p.get_shape().transposed()); - run_pass(p); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().transposed()); + migraphx::module m; + + auto l = m.add_literal(get_2x2()); + EXPECT(m.get_output_shapes().back().standard()); + EXPECT(not m.get_output_shapes().back().transposed()); + auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + m.add_instruction(pass_op{}, t); + EXPECT(not m.get_output_shapes().back().standard()); + EXPECT(m.get_output_shapes().back().transposed()); + run_pass(m); + EXPECT(m.get_output_shapes().back().standard()); + EXPECT(not m.get_output_shapes().back().transposed()); } TEST_CASE(after_literal_broadcast) { - migraphx::program p; - auto l1 = p.add_literal(get_2x2()); - auto l2 = p.add_literal(get_2()); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().broadcasted()); - auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2); - p.add_instruction(pass_op{}, b); - EXPECT(not p.get_shape().standard()); - EXPECT(p.get_shape().broadcasted()); - run_pass(p); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().broadcasted()); + migraphx::module m; + + auto l1 = m.add_literal(get_2x2()); + auto l2 = m.add_literal(get_2()); + EXPECT(m.get_output_shapes().back().standard()); + EXPECT(not m.get_output_shapes().back().broadcasted()); + auto b = m.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", l1->get_shape().lens()}}), l2); + m.add_instruction(pass_op{}, b); + EXPECT(not m.get_output_shapes().back().standard()); + EXPECT(m.get_output_shapes().back().broadcasted()); + run_pass(m); + EXPECT(m.get_output_shapes().back().standard()); + EXPECT(not m.get_output_shapes().back().broadcasted()); } TEST_CASE(after_param_transpose) { - migraphx::program p; - auto l = p.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}}); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().transposed()); - auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); - p.add_instruction(pass_op{}, t); - EXPECT(not p.get_shape().standard()); - EXPECT(p.get_shape().transposed()); - run_pass(p); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().transposed()); + migraphx::module m; + + auto l = m.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}}); + EXPECT(m.get_output_shapes().back().standard()); + EXPECT(not m.get_output_shapes().back().transposed()); + auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + m.add_instruction(pass_op{}, t); + EXPECT(not m.get_output_shapes().back().standard()); + EXPECT(m.get_output_shapes().back().transposed()); + run_pass(m); + EXPECT(m.get_output_shapes().back().standard()); + EXPECT(not m.get_output_shapes().back().transposed()); } TEST_CASE(after_param_broadcast) { - migraphx::program p; - auto l1 = p.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}}); - auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {2}}); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().broadcasted()); - auto b = p.add_instruction(migraphx::op::broadcast{0, l1->get_shape().lens()}, l2); - p.add_instruction(pass_op{}, b); - EXPECT(not p.get_shape().standard()); - EXPECT(p.get_shape().broadcasted()); - run_pass(p); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().broadcasted()); + migraphx::module m; + + auto l1 = m.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}}); + auto l2 = m.add_parameter("2", {migraphx::shape::float_type, {2}}); + EXPECT(m.get_output_shapes().back().standard()); + EXPECT(not m.get_output_shapes().back().broadcasted()); + auto b = m.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", l1->get_shape().lens()}}), l2); + m.add_instruction(pass_op{}, b); + EXPECT(not m.get_output_shapes().back().standard()); + EXPECT(m.get_output_shapes().back().broadcasted()); + run_pass(m); + EXPECT(m.get_output_shapes().back().standard()); + EXPECT(not m.get_output_shapes().back().broadcasted()); +} + +TEST_CASE(two_transpose_gather) +{ + migraphx::module m1; + { + auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}}); + auto ind = m1.add_parameter("ind", {migraphx::shape::float_type, {2, 3}}); + auto td = m1.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data); + auto sd = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), td); + auto bd = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), sd); + auto r = m1.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), bd, ind); + m1.add_return({r}); + } + run_pass(m1); + + migraphx::module m2; + { + auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}}); + auto ind = m2.add_parameter("ind", {migraphx::shape::float_type, {2, 3}}); + auto td = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data); + auto ctd = m2.add_instruction(migraphx::make_op("contiguous"), td); + auto sd = m2.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), ctd); + auto bd = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), sd); + auto cbd = m2.add_instruction(migraphx::make_op("contiguous"), bd); + auto r = m2.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), cbd, ind); + m2.add_return({r}); + } + + EXPECT(m1 == m2); +} + +TEST_CASE(standard_reshape) +{ + migraphx::module m1; + { + auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}}); + auto add = m1.add_instruction(migraphx::make_op("add"), data, data); + auto r = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), add); + m1.add_return({r}); + } + run_pass(m1); + + migraphx::module m2; + { + auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}}); + auto add = m2.add_instruction(migraphx::make_op("add"), data, data); + auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add); + auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), ca); + m2.add_return({r}); + } + + EXPECT(m1 == m2); +} + +TEST_CASE(dead_instruction) +{ + migraphx::module m1; + { + auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}}); + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}), data); + auto r = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}), + data); + m1.add_return({r}); + } + run_pass(m1); + + migraphx::module m2; + { + auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}}); + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}), data); + auto r = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}), + data); + auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r); + m2.add_return({cr}); + } + + EXPECT(m1 == m2); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/const_eval_test.cpp b/test/const_eval_test.cpp index 02d1e9b546a126b7dc029a3ae2e2de4cead930b6..c14a807e99c611168004dbf7b544e38e2104c606 100644 --- a/test/const_eval_test.cpp +++ b/test/const_eval_test.cpp @@ -52,42 +52,52 @@ struct test_context TEST_CASE(literal_test) { migraphx::program p; - auto lit = p.add_literal(1); + + auto* mm = p.get_main_module(); + auto lit = mm->add_literal(1); CHECK(lit->eval() == migraphx::literal{1}); } TEST_CASE(param_test) { migraphx::program p; - auto lit = p.add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}}); + + auto* mm = p.get_main_module(); + auto lit = mm->add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}}); CHECK(lit->eval().empty()); } TEST_CASE(op_test1) { migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_cf_op{}, one, two); + + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto sum = mm->add_instruction(sum_cf_op{}, one, two); CHECK(sum->eval() == migraphx::literal{3}); } TEST_CASE(op_test2) { migraphx::program p; - auto x = p.add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}}); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_cf_op{}, x, two); + + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("param", migraphx::shape{migraphx::shape::float_type, {1}}); + auto two = mm->add_literal(2); + auto sum = mm->add_instruction(sum_cf_op{}, x, two); CHECK(sum->eval().empty()); } TEST_CASE(op_test3) { migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum1 = p.add_instruction(sum_op{}, one, two); - auto sum2 = p.add_instruction(sum_cf_op{}, sum1, two); + + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto sum1 = mm->add_instruction(sum_op{}, one, two); + auto sum2 = mm->add_instruction(sum_cf_op{}, sum1, two); CHECK(sum2->eval().empty()); } diff --git a/test/context_test.cpp b/test/context_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..476df1237b1c81178dd8c1dd52c50f275a731bd3 --- /dev/null +++ b/test/context_test.cpp @@ -0,0 +1,17 @@ +#include +#include +#include +#include +#include + +TEST_CASE(context) +{ + migraphx::context ctx = migraphx::ref::context{}; + migraphx::value v = ctx.to_value(); + EXPECT(v.empty()); + + migraphx::context cpu_ctx = migraphx::ref::context{}; + cpu_ctx.from_value(v); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/convert_to_json.cpp b/test/convert_to_json.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a3fac6b27c845f2f3907c4eacce13d4827c905ca --- /dev/null +++ b/test/convert_to_json.cpp @@ -0,0 +1,123 @@ +#include +#include + +TEST_CASE(key_int) +{ + std::string str = "{abc:{key:1}}"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\":{\"key\":1}}"); +} + +TEST_CASE(key_negative_int) +{ + std::string str = "{abc:{key:-1}}"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\":{\"key\":-1}}"); +} + +TEST_CASE(key_float) +{ + std::string str = "{abc:{key:1.0}}"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\":{\"key\":1.0}}"); +} + +TEST_CASE(key_negative_float) +{ + std::string str = "{abc:{key:-1.0}}"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\":{\"key\":-1.0}}"); +} + +TEST_CASE(key_exp) +{ + std::string str = "{abc:{key:1e+10}}"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\":{\"key\":1e+10}}"); +} + +TEST_CASE(key_exp_1) +{ + std::string str = "{abc:{key:1E-10}}"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\":{\"key\":1E-10}}"); +} + +TEST_CASE(key_null) +{ + std::string str = "{abc:{key:null}}"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\":{\"key\":null}}"); +} + +TEST_CASE(key_inf) +{ + std::string str = "{abc:{key:inf}}"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\":{\"key\":inf}}"); +} + +TEST_CASE(key_neg_inf) +{ + std::string str = "{abc:{key:-inf}}"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\":{\"key\":-inf}}"); +} + +TEST_CASE(key_true) +{ + std::string str = "{abc:{key:true}}"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\":{\"key\":true}}"); +} + +TEST_CASE(key_false) +{ + std::string str = "{abc:{key:false}}"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\":{\"key\":false}}"); +} + +TEST_CASE(key_nan) +{ + std::string str = "{abc:{key:nan}}"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\":{\"key\":nan}}"); +} + +TEST_CASE(quote_key_num) +{ + std::string str = R"({"abc":{"key":1}})"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\":{\"key\":1}}"); +} + +TEST_CASE(quote_with_space_key_num) +{ + std::string str = R"({"abc key":{"key":1}})"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc key\":{\"key\":1}}"); +} + +TEST_CASE(key_value_num_space) +{ + std::string str = "{abc : { key : 1}}"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\" : { \"key\" : 1}}"); +} + +TEST_CASE(key_value_str) +{ + std::string str = "{abc : {key : value}}"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\" : {\"key\" : \"value\"}}"); +} + +TEST_CASE(key_space_value) +{ + std::string str = "{abc : [key, value]}"; + auto jstr = migraphx::convert_to_json(str); + EXPECT(jstr == "{\"abc\" : [\"key\", \"value\"]}"); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/cpu_ops_test.cpp b/test/cpu_ops_test.cpp deleted file mode 100644 index 86f292bf61324fe5f06feca9683c931554b74cf6..0000000000000000000000000000000000000000 --- a/test/cpu_ops_test.cpp +++ /dev/null @@ -1,2288 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "test.hpp" -#include - -float sigmoid(float x) { return 1 / (1 + expf(-x)); } - -float elu(float a, float x) { return x > 0 ? x : a * std::expm1(x); } - -TEST_CASE(slice_test) -{ - { - migraphx::program p; - std::vector data(2 * 2 * 3); - std::iota(data.begin(), data.end(), 0); - migraphx::shape s{migraphx::shape::int32_type, {2, 2, 3}}; - auto l0 = p.add_literal(migraphx::literal{s, data}); - p.add_instruction(migraphx::op::slice{{2}, {1}, {3}}, l0); - migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}; - EXPECT(p.get_shape() == s2); - p.compile(migraphx::cpu::target{}); - migraphx::shape sresult{migraphx::shape::int32_type, {2, 2, 2}, {4, 2, 1}}; - auto result = p.eval({}); - std::vector gold = {1, 2, 4, 5, 7, 8, 10, 11}; - std::vector results_vector(2 * 2 * 2); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, gold)); - EXPECT(result.get_shape() == sresult); - } - { - migraphx::program p; - std::vector data(2 * 2 * 3); - std::iota(data.begin(), data.end(), 0); - migraphx::shape s{migraphx::shape::int32_type, {2, 2, 3}}; - auto l0 = p.add_literal(migraphx::literal{s, data}); - p.add_instruction(migraphx::op::slice{{0, 1, 2}, {0, 0, 0}, {2, 2, 2}}, l0); - migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}; - EXPECT(p.get_shape() == s2); - p.compile(migraphx::cpu::target{}); - migraphx::shape sresult{migraphx::shape::int32_type, {2, 2, 2}, {4, 2, 1}}; - auto result = p.eval({}); - std::vector gold = {0, 1, 3, 4, 6, 7, 9, 10}; - std::vector results_vector(2 * 2 * 2); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, gold)); - EXPECT(result.get_shape() == sresult); - } -} - -TEST_CASE(concat_test) -{ - { - migraphx::program p; - int axis = 1; - std::vector data0 = {0, 1, 5, 6}; - std::vector data1 = {2, 3, 4, 7, 8, 9}; - std::vector data2 = {10, 20}; - migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; - migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; - migraphx::shape s2{migraphx::shape::int32_type, {2, 1}}; - auto l0 = p.add_literal(migraphx::literal{s0, data0}); - auto l1 = p.add_literal(migraphx::literal{s1, data1}); - auto l2 = p.add_literal(migraphx::literal{s2, data2}); - p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector gold = {0, 1, 2, 3, 4, 10, 5, 6, 7, 8, 9, 20}; - std::vector results_vector(2 * 6); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, gold)); - EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector({2, 6}))); - EXPECT( - migraphx::verify_range(result.get_shape().strides(), std::vector({6, 1}))); - } - - { - migraphx::program p; - int axis = -1; - std::vector data0 = {0, 1, 5, 6}; - std::vector data1 = {2, 3, 4, 7, 8, 9}; - std::vector data2 = {10, 20}; - migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; - migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; - migraphx::shape s2{migraphx::shape::int32_type, {2, 1}}; - auto l0 = p.add_literal(migraphx::literal{s0, data0}); - auto l1 = p.add_literal(migraphx::literal{s1, data1}); - auto l2 = p.add_literal(migraphx::literal{s2, data2}); - p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector gold = {0, 1, 2, 3, 4, 10, 5, 6, 7, 8, 9, 20}; - std::vector results_vector(2 * 6); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, gold)); - EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector({2, 6}))); - EXPECT( - migraphx::verify_range(result.get_shape().strides(), std::vector({6, 1}))); - } - - { - migraphx::program p; - int axis = 0; - std::vector data0 = {0, 1, 2, 3}; - std::vector data1 = {4, 5, 6, 7, 8, 9}; - std::vector data2 = {10, 11}; - migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; - migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; - migraphx::shape s2{migraphx::shape::int32_type, {1, 2}}; - auto l0 = p.add_literal(migraphx::literal{s0, data0}); - auto l1 = p.add_literal(migraphx::literal{s1, data1}); - auto l2 = p.add_literal(migraphx::literal{s2, data2}); - p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - std::vector results_vector(6 * 2); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, gold)); - EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector({6, 2}))); - EXPECT( - migraphx::verify_range(result.get_shape().strides(), std::vector({2, 1}))); - } - - { - migraphx::program p; - int axis = -2; - std::vector data0 = {0, 1, 2, 3}; - std::vector data1 = {4, 5, 6, 7, 8, 9}; - std::vector data2 = {10, 11}; - migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; - migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; - migraphx::shape s2{migraphx::shape::int32_type, {1, 2}}; - auto l0 = p.add_literal(migraphx::literal{s0, data0}); - auto l1 = p.add_literal(migraphx::literal{s1, data1}); - auto l2 = p.add_literal(migraphx::literal{s2, data2}); - p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - std::vector results_vector(6 * 2); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, gold)); - EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector({6, 2}))); - EXPECT( - migraphx::verify_range(result.get_shape().strides(), std::vector({2, 1}))); - } -} - -TEST_CASE(gather_test) -{ - { - migraphx::program p; - - std::vector data(3 * 3); - std::iota(data.begin(), data.end(), 0.5); - migraphx::shape s{migraphx::shape::float_type, {3, 3}}; - auto a0 = p.add_literal(migraphx::literal{s, data}); - migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; - std::vector indices{0, 2}; - auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); - int axis = 0; - p.add_instruction(migraphx::op::gather{axis}, a0, a1); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector res_data(4 * 5); - std::vector golden = {0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f}; - result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(res_data, golden)); - } - - { - migraphx::program p; - - std::vector data(3 * 3); - std::iota(data.begin(), data.end(), 0.5); - migraphx::shape s{migraphx::shape::float_type, {3, 3}}; - auto a0 = p.add_literal(migraphx::literal{s, data}); - migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; - std::vector indices{-3, -1}; - auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); - int axis = 0; - p.add_instruction(migraphx::op::gather{axis}, a0, a1); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector res_data(4 * 5); - std::vector golden = {0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f}; - result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(res_data, golden)); - } - - { - migraphx::program p; - - std::vector data(3 * 3); - std::iota(data.begin(), data.end(), 0.5); - migraphx::shape s{migraphx::shape::float_type, {3, 3}}; - auto a0 = p.add_literal(migraphx::literal{s, data}); - migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; - std::vector indices{0, 2}; - auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); - int axis = 1; - p.add_instruction(migraphx::op::gather{axis}, a0, a1); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector res_data(4 * 5); - std::vector golden = {0.5f, 2.5f, 3.5f, 5.5f, 6.5f, 8.5f}; - result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(res_data, golden)); - } - - { - migraphx::program p; - - std::vector data(3 * 3); - std::iota(data.begin(), data.end(), 0.5); - migraphx::shape s{migraphx::shape::float_type, {3, 3}}; - auto a0 = p.add_literal(migraphx::literal{s, data}); - migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; - std::vector indices{0, 2}; - auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); - int axis = -1; - p.add_instruction(migraphx::op::gather{axis}, a0, a1); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector res_data(4 * 5); - std::vector golden = {0.5f, 2.5f, 3.5f, 5.5f, 6.5f, 8.5f}; - result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(res_data, golden)); - } - - { - migraphx::program p; - - std::vector data(3 * 3); - std::iota(data.begin(), data.end(), 0.5); - migraphx::shape s{migraphx::shape::float_type, {3, 3}}; - auto a0 = p.add_literal(migraphx::literal{s, data}); - // scalar index - migraphx::shape s_indices{migraphx::shape::int32_type}; - std::vector indices{0}; - auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); - int axis = -1; - p.add_instruction(migraphx::op::gather{axis}, a0, a1); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector res_data{}; - std::vector golden = {0.5f, 3.5f, 6.5f}; - result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(res_data, golden)); - } - - { - migraphx::program p; - - std::vector data(3 * 3); - std::iota(data.begin(), data.end(), 0.5); - migraphx::shape s{migraphx::shape::float_type, {3, 3}}; - auto a0 = p.add_literal(migraphx::literal{s, data}); - // scalar index - migraphx::shape s_indices{migraphx::shape::int32_type}; - std::vector indices{-3}; - auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); - int axis = -1; - p.add_instruction(migraphx::op::gather{axis}, a0, a1); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector res_data{}; - std::vector golden = {0.5f, 3.5f, 6.5f}; - result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(res_data, golden)); - } - - { - migraphx::program p; - - std::vector data(3); - std::iota(data.begin(), data.end(), 0.5); - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto a0 = p.add_literal(migraphx::literal{s, data}); - // scalar index - migraphx::shape s_indices{migraphx::shape::int32_type}; - std::vector indices{0}; - auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); - int axis = -1; - p.add_instruction(migraphx::op::gather{axis}, a0, a1); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector res_data{}; - std::vector golden = {0.5f}; - result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(res_data, golden)); - } -} - -TEST_CASE(squeeze_test) -{ - { - migraphx::program p; - std::vector data(4 * 3 * 3); - migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; - migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; - auto l0 = p.add_literal(migraphx::literal{s1, data}); - p.add_instruction(migraphx::op::squeeze{{1}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - EXPECT(result.get_shape() == s2); - } - { - migraphx::program p; - std::vector data(4 * 3 * 3); - migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; - migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}}; - auto l0 = p.add_literal(migraphx::literal{s1, data}); - p.add_instruction(migraphx::op::squeeze{{3}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - EXPECT(result.get_shape() == s2); - } - { - migraphx::program p; - std::vector data(4 * 3 * 3); - migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; - migraphx::shape s2{migraphx::shape::float_type, {4, 3, 3}}; - auto l0 = p.add_literal(migraphx::literal{s1, data}); - p.add_instruction(migraphx::op::squeeze{}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - EXPECT(result.get_shape() == s2); - } -} - -TEST_CASE(unsqueeze_test) -{ - { - migraphx::program p; - std::vector data(4 * 3 * 3); - migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; - migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}}; - auto l0 = p.add_literal(migraphx::literal{s1, data}); - p.add_instruction(migraphx::op::unsqueeze{{1}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - EXPECT(result.get_shape() == s2); - } - { - migraphx::program p; - std::vector data(4 * 3 * 3); - migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; - migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; - auto l0 = p.add_literal(migraphx::literal{s1, data}); - p.add_instruction(migraphx::op::unsqueeze{{2}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - EXPECT(result.get_shape() == s2); - } -} - -TEST_CASE(globalavgpool_test) -{ - migraphx::program p; - auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}}; - auto op = migraphx::op::pooling{"average"}; - auto lens = s.lens(); - op.lengths = {lens[2], lens[3]}; - - std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; - auto l0 = p.add_literal(migraphx::literal{s, data}); - p.add_instruction(op, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{0.25, 0.575, 0.375}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(globalmaxpool_test) -{ - migraphx::program p; - auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}}; - auto op = migraphx::op::pooling{"max"}; - auto lens = s.lens(); - op.lengths = {lens[2], lens[3]}; - - std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; - auto l0 = p.add_literal(migraphx::literal{s, data}); - p.add_instruction(op, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{0.4, 0.9, 0.7}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(im2col_3x3_no_pad_identity_test) -{ - std::size_t f[2] = {3, 3}; - std::size_t size[2] = {3, 3}; - std::array padding{{0, 0}}; - std::array stride{{1, 1}}; - std::array dilation{{1, 1}}; - std::size_t channels = 1; - - std::vector weights(channels * f[0] * f[1]); - std::vector input(channels * size[0] * size[1]); - std::iota(input.begin(), input.end(), 0); - - migraphx::program p; - migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}}; - migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; - auto l_image = p.add_literal(migraphx::literal{s_image, input}); - auto l_weights = p.add_literal(migraphx::literal{s_weights, weights}); - p.add_instruction(migraphx::op::im2col{padding, stride, dilation}, l_image, l_weights); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - - std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1; - std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; - std::vector results_vector(channels * f[0] * f[1] * col_height * col_width); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, input)); -} - -TEST_CASE(im2col_3x3_no_pad_test) -{ - std::size_t f[2] = {3, 3}; - std::size_t size[2] = {4, 4}; - std::array padding{{0, 0}}; - std::array stride{{1, 1}}; - std::array dilation{{1, 1}}; - std::size_t channels = 1; - - std::vector weights(channels * f[0] * f[1]); - std::vector input(channels * size[0] * size[1]); - std::iota(input.begin(), input.end(), 0); - - migraphx::program p; - migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}}; - migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; - auto l_image = p.add_literal(migraphx::literal{s_image, input}); - auto l_weights = p.add_literal(migraphx::literal{s_weights, weights}); - p.add_instruction(migraphx::op::im2col{padding, stride, dilation}, l_image, l_weights); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - - std::vector correct = {0, 1, 2, 4, 5, 6, 8, 9, 10, 1, 2, 3, 5, 6, 7, 9, 10, 11, - 4, 5, 6, 8, 9, 10, 12, 13, 14, 5, 6, 7, 9, 10, 11, 13, 14, 15}; - - std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1; - std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; - std::vector results_vector(channels * f[0] * f[1] * col_height * col_width); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, correct)); -} - -TEST_CASE(im2col_3x3_stride_2_no_pad_test) -{ - std::size_t f[2] = {3, 3}; - std::size_t size[2] = {6, 6}; - std::array padding{{0, 0}}; - std::array stride{{2, 2}}; - std::array dilation{{1, 1}}; - std::size_t channels = 1; - - std::vector weights(channels * f[0] * f[1]); - std::vector input(channels * size[0] * size[1]); - std::iota(input.begin(), input.end(), 0); - - migraphx::program p; - migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}}; - migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; - auto l_image = p.add_literal(migraphx::literal{s_image, input}); - auto l_weights = p.add_literal(migraphx::literal{s_weights, weights}); - p.add_instruction(migraphx::op::im2col{padding, stride, dilation}, l_image, l_weights); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - - std::vector correct = {0, 1, 2, 6, 7, 8, 12, 13, 14, 2, 3, 4, - 8, 9, 10, 14, 15, 16, 12, 13, 14, 18, 19, 20, - 24, 25, 26, 14, 15, 16, 20, 21, 22, 26, 27, 28}; - - std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1; - std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; - std::vector results_vector(channels * f[0] * f[1] * col_height * col_width); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, correct)); -} - -TEST_CASE(im2col_3x3_with_padding_test) -{ - std::size_t f[2] = {3, 3}; - std::size_t size[2] = {2, 2}; - std::array padding{{1, 1}}; - std::array stride{{1, 1}}; - std::array dilation{{1, 1}}; - std::size_t channels = 1; - - std::vector weights(channels * f[0] * f[1]); - std::vector input(channels * size[0] * size[1]); - std::iota(input.begin(), input.end(), 0); - - migraphx::program p; - migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}}; - migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; - auto l_image = p.add_literal(migraphx::literal{s_image, input}); - auto l_weights = p.add_literal(migraphx::literal{s_weights, weights}); - p.add_instruction(migraphx::op::im2col{padding, stride, dilation}, l_image, l_weights); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - - std::vector correct = {0, 0, 0, 0, 0, 1, 0, 2, 3, 0, 0, 0, 0, 1, 0, 2, 3, 0, - 0, 0, 1, 0, 2, 3, 0, 0, 0, 0, 1, 0, 2, 3, 0, 0, 0, 0}; - - std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1; - std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; - std::vector results_vector(channels * f[0] * f[1] * col_height * col_width); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, correct)); -} - -TEST_CASE(batch_norm_inference_test) -{ - migraphx::program p; - const size_t width = 2; - const size_t height = 2; - const size_t channels = 4; - const size_t batches = 2; - const float x_val = 8.0; - const float mean_val = 2.0; - const float variance_val = 4.0; - const float scale_val = 2.0f; - const float bias_val = 1.0f; - const float output_val = scale_val * (x_val - mean_val) / (std::sqrt(variance_val)) + bias_val; - - migraphx::shape s{migraphx::shape::float_type, {batches, channels, height, width}}; - migraphx::shape vars{migraphx::shape::float_type, {channels}}; - std::vector x_data(width * height * channels * batches); - std::vector scale_data(channels); - std::vector bias_data(channels); - std::vector mean_data(channels); - std::vector variance_data(channels); - - std::fill(x_data.begin(), x_data.end(), x_val); - std::fill(mean_data.begin(), mean_data.end(), mean_val); - std::fill(variance_data.begin(), variance_data.end(), variance_val); - std::fill(scale_data.begin(), scale_data.end(), scale_val); - std::fill(bias_data.begin(), bias_data.end(), bias_val); - - auto x = p.add_literal(migraphx::literal{s, x_data}); - auto scale = p.add_literal(migraphx::literal{vars, scale_data}); - auto bias = p.add_literal(migraphx::literal{vars, bias_data}); - auto mean = p.add_literal(migraphx::literal{vars, mean_data}); - auto variance = p.add_literal(migraphx::literal{vars, variance_data}); - - p.add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - - std::vector result_vector(width * height * channels * batches); - std::vector gold(width * height * channels * batches); - std::fill(gold.begin(), gold.end(), output_val); - result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); - - EXPECT(migraphx::verify_range(result_vector, gold)); -} - -TEST_CASE(im2col_3x3_with_channels_identity_test) -{ - std::size_t f[2] = {3, 3}; - std::size_t size[2] = {3, 3}; - std::array padding{{0, 0}}; - std::array stride{{1, 1}}; - std::array dilation{{1, 1}}; - std::size_t channels = 2; - - std::vector weights(channels * f[0] * f[1]); - std::vector input(channels * size[0] * size[1]); - std::iota(input.begin(), input.end(), 0); - - migraphx::program p; - migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}}; - migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; - auto l_image = p.add_literal(migraphx::literal{s_image, input}); - auto l_weights = p.add_literal(migraphx::literal{s_weights, weights}); - p.add_instruction(migraphx::op::im2col{padding, stride, dilation}, l_image, l_weights); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - - std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1; - std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; - std::vector results_vector(channels * f[0] * f[1] * col_height * col_width); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, input)); -} - -TEST_CASE(exp_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l = p.add_literal(migraphx::literal{s, {-1, 0, 1}}); - p.add_instruction(migraphx::op::exp{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {0.36787944f, 1.f, 2.71828183f}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(erf_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {4}}; - auto l = - p.add_literal(migraphx::literal{s, {0.73785057, 1.58165966, -0.43597795, -0.01677432}}); - p.add_instruction(migraphx::op::erf{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {0.70327317, 0.97470088, -0.46247893, -0.01892602}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(sqrt_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {5}}; - auto l = p.add_literal( - migraphx::literal{s, {1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184}}); - p.add_instruction(migraphx::op::sqrt{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {1.01233218, 0.92543537, 0.18450265, 0.96328566, 0.32510282}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(sign_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {5}}; - auto l = p.add_literal( - migraphx::literal{s, {1.02481645, 0.85643062, -0.03404123, -0.92791926, 0.0}}); - p.add_instruction(migraphx::op::sign{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {1.0, 1.0, -1.0, -1.0, 0.0}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(log_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l = p.add_literal(migraphx::literal{s, {1, 2, 3}}); - p.add_instruction(migraphx::op::log{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {0.0f, 0.6931471806f, 1.0986122887f}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(pow_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto b = p.add_literal(migraphx::literal{s, {1, 2, 3}}); - auto e = p.add_literal(migraphx::literal{s, {1, 2, 3}}); - p.add_instruction(migraphx::op::pow{}, b, e); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {1.0f, 4.0f, 27.0f}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(sin_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l = p.add_literal(migraphx::literal{s, {-1, 0, 1}}); - p.add_instruction(migraphx::op::sin{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {-0.84147098f, 0.f, 0.84147098f}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(cos_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l = p.add_literal(migraphx::literal{s, {-1, 0, 1}}); - p.add_instruction(migraphx::op::cos{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {0.54030231f, 1.f, 0.54030231f}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(tan_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l = p.add_literal(migraphx::literal{s, {-1, 0, 1}}); - p.add_instruction(migraphx::op::tan{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {-1.55740772f, 0.0f, 1.55740772f}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(asin_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - std::vector data{-0.5f, 0.0f, 0.9f}; - auto l = p.add_literal(migraphx::literal{s, data}); - p.add_instruction(migraphx::op::asin{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {-0.5235987756f, 0.f, 1.119769515}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(acos_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::double_type, {3}}; - std::vector data{-0.8f, 0.0f, 1.0f}; - auto l = p.add_literal(migraphx::literal{s, data}); - p.add_instruction(migraphx::op::acos{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {2.4980915448f, 1.5707963268f, 0.0f}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(atan_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::double_type, {3}}; - auto l = p.add_literal(migraphx::literal{s, {-1, 0, 1}}); - p.add_instruction(migraphx::op::atan{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {-0.7853981634f, 0.0f, 0.7853981634f}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(add_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l1 = p.add_literal(migraphx::literal{s, {-1, 0, 1}}); - auto l2 = p.add_literal(migraphx::literal{s, {1, 2, 3}}); - p.add_instruction(migraphx::op::add{}, l1, l2); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {0, 2, 4}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(broadcast_test) -{ - migraphx::program p; - migraphx::shape a_shape{migraphx::shape::int32_type, {2, 2}}; - std::vector a_data{0, 0, 0, 0}; - migraphx::shape b_shape{migraphx::shape::int32_type, {2}}; - std::vector b_data{-2, -3}; - uint64_t axis = 0; - auto l1 = p.add_literal(migraphx::literal{a_shape, a_data}); - auto l2 = p.add_literal(migraphx::literal{b_shape, b_data}); - p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape().lens()}, l2); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - auto output = result.get(); - EXPECT(output(0, 0) == -2); - EXPECT(output(0, 1) == -2); - EXPECT(output(1, 0) == -3); - EXPECT(output(1, 1) == -3); -} -TEST_CASE(add_broadcast_test) -{ - { - migraphx::program p; - migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3}}; - std::vector a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - migraphx::shape b_shape{migraphx::shape::float_type, {2, 2}}; - std::vector b_data{0, -1, -2, -3}; - uint64_t axis = 0; - auto l1 = p.add_literal(migraphx::literal{a_shape, a_data}); - auto l2 = p.add_literal(migraphx::literal{b_shape, b_data}); - auto l3 = p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape().lens()}, l2); - p.add_instruction(migraphx::op::add{}, l1, l3); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - EXPECT(result.get_shape().packed()); - std::vector results_vector(12); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8}; - EXPECT(migraphx::verify_range(results_vector, gold)); - } - { - migraphx::program p; - migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3}}; - std::vector a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - migraphx::shape b_shape{migraphx::shape::float_type, {2, 2, 1}}; - std::vector b_data{0, -1, -2, -3}; - auto l1 = p.add_literal(migraphx::literal{a_shape, a_data}); - auto l2 = p.add_literal(migraphx::literal{b_shape, b_data}); - auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 3}}, l1); - auto l4 = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 3}}, l2); - p.add_instruction(migraphx::op::add{}, l3, l4); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - EXPECT(result.get_shape().packed()); - std::vector results_vector(12); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8}; - EXPECT(migraphx::verify_range(results_vector, gold)); - } -} - -TEST_CASE(sub_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l1 = p.add_literal(migraphx::literal{s, {-1, 0, 1}}); - auto l2 = p.add_literal(migraphx::literal{s, {1, 2, 3}}); - p.add_instruction(migraphx::op::sub{}, l1, l2); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {-2, -2, -2}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(mul_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l1 = p.add_literal(migraphx::literal{s, {-1, 0, 1}}); - auto l2 = p.add_literal(migraphx::literal{s, {1, 2, 3}}); - p.add_instruction(migraphx::op::mul{}, l1, l2); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {-1, 0, 3}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(div_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l1 = p.add_literal(migraphx::literal{s, {-1.0f, 0.5f, 1.0f}}); - auto l2 = p.add_literal(migraphx::literal{s, {1.0f, 2.0f, 4.0f}}); - p.add_instruction(migraphx::op::div{}, l1, l2); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {-1.f, 0.25f, 0.25f}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(relu_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l = p.add_literal(migraphx::literal{s, {-1.f, 0.f, 1.f}}); - p.add_instruction(migraphx::op::relu{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {0.f, 0.f, 1.f}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(leaky_relu_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l = p.add_literal(migraphx::literal{s, {-1.f, 0.f, 1.f}}); - p.add_instruction(migraphx::op::leaky_relu{0.01}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {-0.01f, 0.f, 1.f}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(lrn_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {1, 5, 1, 1}}; - auto l = p.add_literal(migraphx::literal{s, {-2.0f, 1.0f, 0.f, 1.0f, 2.0f}}); - p.add_instruction(migraphx::op::lrn{0.0001, 0.75, 1, 5}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(5); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {-2 / 1.000075, 1 / 1.00009, 0 / 1.000145, 1 / 1.00009, 2 / 1.000075}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(imagescaler_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {1, 3, 2, 2}}; - auto img = p.add_literal(migraphx::literal{s, - {0.2, - 0.3, - 0.5, - 0.4, - - 0.7, - 0.8, - 0.1, - 0.9, - - 0.15, - 0.25, - 0.35, - 0.45}}); - auto scale_val = p.add_literal(2.f); - auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val); - auto img_scaled = p.add_instruction(migraphx::op::mul{}, img, scaled_tensor); - auto bias_vals = p.add_literal( - migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); - auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals); - p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(12); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {0.41, - 0.61, - 1.01, - 0.81, - - 1.42, - 1.62, - 0.22, - 1.82, - - 0.33, - 0.53, - 0.73, - 0.93}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(reshape_test) -{ - migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}}; - std::vector data(24); - std::iota(data.begin(), data.end(), -3); - { - migraphx::program p; - auto l = p.add_literal(migraphx::literal{a_shape, data}); - std::vector new_shape = {8, 3, 1, 1}; - p.add_instruction(migraphx::op::reshape{new_shape}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, data)); - } - { - migraphx::program p; - auto l = p.add_literal(migraphx::literal{a_shape, data}); - std::vector new_shape = {1, 3, 4, 2}; - p.add_instruction(migraphx::op::reshape{new_shape}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, data)); - } - { - migraphx::program p; - auto l = p.add_literal(migraphx::literal{a_shape, data}); - std::vector new_shape = {1, 3, 4, 2}; - p.add_instruction(migraphx::op::reshape{new_shape}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, data)); - } -} - -TEST_CASE(maxpool_test) -{ - migraphx::program p; - std::vector a = { - -2.1314404, -1.63041711, 1.54562736, 1.04625261, -1.42931843, -0.48703974, 0.4065806, - -0.1524526, 1.30775225, 0.45538983, -0.06631992, -1.75332725, 1.33493888, 0.47327688, - 0.36873096, 1.18358743, -0.34640595, 1.22098756, 0.01946825, -0.20238149, 0.43348005, - -0.67991608, -0.83041084, 0.93537551, 0.70241445, -0.5654031, -1.30899191, -0.26735824, - -0.52444768, 1.99097753, 1.86504853, -0.26506025, 0.26236168, 0.43763575, 0.95300823, - -1.02733946, -0.74655169, -0.5374338, -0.28901565, -0.59789604, 0.5310151, 0.99125904, - 0.40609556, -1.57175648, 0.22031412, 1.45862222, 0.53217483, 1.39087725, 1.00170159, - -0.87175864, -1.7204628, -1.72008383, -0.38656762, -0.01443311, 1.46645272, -1.39995027, - 0.22505587, -0.43461126, -0.05511411, -0.79950953, -0.01439556, 0.08795211, 1.18943918, - -0.84079367, -1.73383629, -0.55662078, -0.30626822, -0.67339015, 0.44179603, 0.54316711, - 0.40899998, -0.27831686, -1.11900508, -0.0881724, 0.35483059, 2.36277103, -0.04765317, - -0.36865309, 0.73814237, 1.47151589, 1.36546791, -0.32649881, -1.0517807, 2.24768877, - 0.68883753, 0.58646208, -0.91017133, -0.50462508, -0.4013325, -0.72348958, -0.47368807, - 0.35285577, -1.01817429, -0.5152272, 0.60321307, 0.43521205, -0.23733577, 0.66427642, - 0.82949388, 0.82443929, 0.71550399, 0.34561086, 0.68570769, -0.40718508, -1.20350206, - 0.15793853, -2.31013632, -0.07934658, -0.09348056, 0.36576006, 2.46601582, 0.11090943, - 0.9144392, 0.56759721, -0.22112127, -0.21955389, 0.72474903, -1.28448462, 1.53285873, - 0.37437943, 0.31409341, 1.95433736, 0.91620457, 0.86205518, 1.24365854, 0.19248386, - 0.22526583, 0.13462132, -0.27561715, -2.06446075, -0.02306402, -1.38278747, 1.1411345, - 1.31293464, -1.86041689, 1.06763375, -0.26541466, 1.4545635, 1.11430049, -0.66491818, - 0.87101674, 0.67768967, -1.02062869, -1.05031872, -2.2764678, -2.0200038, 0.37592548, - -0.26701379, -0.83388507, 0.19403623, 1.00968623, 0.11020003, 1.16736257, -1.1160326, - 0.47346735, 0.6126079, -0.19135755, 1.33624589, -0.29802522, -0.57873946, -1.06555879, - -0.20686582, 1.36892557, -0.19937795, 0.8649236, -1.40126073, 1.53441942, 0.34682792, - -1.31724346, -1.32898355, 2.40126371, 0.07845283, 1.35732043, -0.63678312, 0.39429256, - -1.36487007, -0.31026676, -0.44981545, -0.28994772, -0.14657612, -1.75206447, -0.70612341, - 1.20071781, -1.64647579, -0.7133292, 0.88494766, 0.52119428, -2.77387547, 2.07681108, - -0.90133125, 0.2847338, 0.6174528, -0.20616426, -0.64263535, -1.08496261, 0.54275119, - -0.88503587, 0.6629802, 1.47319221, -1.05829155, -0.97027361, -0.93187737, -1.39954746, - -0.52359426, -0.14743951, 1.51522756, 0.2078452, -1.28156149, -1.19363916, -0.78680223, - -0.89094824, 1.30212069, -0.77974445, -0.58411664, 0.48764706, -0.67132682}; - std::vector c = {1.33493888, 1.54562736, 1.22098756, 1.33493888, 1.18358743, 1.99097753, - 1.00170159, 1.45862222, 1.39087725, 1.46645272, 1.18943918, -0.01443311, - 1.47151589, 2.36277103, 2.24768877, 0.68883753, 0.82949388, 0.71550399, - 1.95433736, 2.46601582, 1.53285873, 1.95433736, 1.06763375, 1.4545635, - 1.33624589, 1.16736257, 0.6126079, 1.36892557, 2.40126371, 1.53441942, - 0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802}; - migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 6, 6}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{3, 2}}}, al); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(36); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, c)); -} - -TEST_CASE(softmax_simple_test) -{ - migraphx::program p; - std::vector a = {0.25, 0.75}; - std::vector s = {0.377541, 0.622459}; - migraphx::shape a_shape{migraphx::shape::float_type, {1, 2}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - p.add_instruction(migraphx::op::softmax{1}, al); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(2); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, s)); -} - -TEST_CASE(softmax_test) -{ - migraphx::program p; - std::vector a = { - -5.61869681e-01, 9.07827199e-01, 1.29255986e+00, 3.18533443e-02, -1.22183852e-03, - -2.83830553e-01, -1.03245842e+00, -9.28322077e-01, -8.82696748e-01, 1.11327164e-01, - -9.20038462e-01, 8.47388089e-01, 2.51734018e-01, 1.50563884e+00, 2.23056650e+00, - -6.17576987e-02, -1.00264274e-01, -6.10369384e-01, 1.17537189e+00, -2.51560897e-01, - -8.50333512e-01, -8.03578615e-01, -6.51194930e-01, -2.58137047e-01, 4.65528190e-01, - 3.23284641e-02, -1.54700470e+00, 1.38096774e+00, 5.39869189e-01, -7.56884992e-01, - 1.81503093e+00, -2.11269641e+00, 1.92466557e+00, 1.77230799e+00, 2.21660900e+00, - 1.56777036e+00, -2.08995026e-03, 3.50566894e-01, -1.15042710e+00, -1.18577778e+00, - 8.90633047e-01, -6.63949102e-02, 1.44661188e+00, 1.59215283e+00, -2.56262213e-01, - 9.39079225e-01, 4.07298543e-02, 3.86590779e-01, 6.09607756e-01, 8.22331488e-01, - -2.82126725e-01, -9.49052632e-01, -4.24012303e-01, -5.32990396e-01, -3.18386006e+00, - 3.27092171e-01, -1.33315325e+00, 3.62459183e-01, 3.74710828e-01, -1.30302286e+00, - 1.79680198e-01, -4.51832324e-01, 4.34282750e-01, -7.09520102e-01, 6.20333970e-01, - -1.28712380e+00, 2.04130828e-01, -7.70607769e-01, 1.61889160e+00, -1.50951004e+00, - -4.10505563e-01, -3.56566496e-02, -1.29747534e+00, -1.49967879e-01, 7.77626812e-01, - -8.28408226e-02, 2.73412596e-02, 5.79780899e-03, 9.87900198e-02, -7.95276761e-01, - -1.38536084e+00, -6.63573861e-01, 3.89783204e-01, -1.30670881e+00, -7.62425125e-01, - -4.04883057e-01, 6.24344349e-01, 3.68128955e-01, -1.01577950e+00, -3.06715906e-01, - 5.67961395e-01, 2.98198581e-01, -1.63613629e+00, -3.75131965e-01, -6.75393403e-01, - 2.59172034e+00, 6.75538957e-01, 9.07939598e-02, 1.92257717e-01, -1.21592450e+00, - -2.73682117e-01, 1.25232983e+00, -1.39969170e+00, -1.91483587e-01, 2.57732719e-01, - 3.10056299e-01, 1.41833842e+00, -1.81386679e-01, 3.92868072e-01, -8.14771175e-01, - 2.02392387e+00, -9.42091495e-02, -3.77683818e-01, 2.05638766e+00, 2.93796062e-01, - -6.02131486e-01, 2.70461679e-01, -8.92358482e-01, 1.04388881e+00, 2.66154885e-01}; - - std::vector s = { - 0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741, - 0.18750034, 0.21905553, 0.27000085, 0.0547399, 0.56318235, 0.47422904, 0.78964758, - 0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684, 0.18766427, 0.15753111, - 0.07844277, 0.05120674, 0.36648798, 0.14637007, 0.13152322, 0.01560997, 0.29065287, - 0.49196178, 0.10550152, 0.81890774, 0.06369215, 0.62972021, 0.74931765, 0.67285055, - 0.35034987, 0.28612873, 0.31931475, 0.04220394, 0.16093165, 0.22390974, 0.11915915, - 0.3115395, 0.35899726, 0.22190949, 0.57518375, 0.13888834, 0.7753762, 0.4642328, - 0.57055861, 0.21954368, 0.34515455, 0.09486015, 0.40631217, 0.01842281, 0.48770609, - 0.06652815, 0.36023033, 0.42343026, 0.24226256, 0.17348589, 0.44066274, 0.6865865, - 0.17296699, 0.46923906, 0.06921105, 0.3570261, 0.4125829, 0.73165393, 0.15302512, - 0.29499072, 0.33932695, 0.30852377, 0.40762195, 0.40170741, 0.36259529, 0.60848355, - 0.42618036, 0.31721094, 0.02960522, 0.28256637, 0.24389413, 0.2725659, 0.10663581, - 0.27622163, 0.28264219, 0.53652936, 0.09476089, 0.40890986, 0.34848392, 0.32572666, - 0.53076893, 0.11529481, 0.29117745, 0.14625968, 0.8756339, 0.49818122, 0.10656087, - 0.1813329, 0.17664003, 0.21410346, 0.80408043, 0.02315119, 0.27155462, 0.32804728, - 0.13268511, 0.61795473, 0.49703068, 0.41696799, 0.10175809, 0.71028161, 0.29929739, - 0.17377149, 0.76075399, 0.20071237, 0.32632929, 0.36892858, 0.09416146, 0.26656723, - 0.42914796}; - - migraphx::shape a_shape{migraphx::shape::float_type, {5, 3, 4, 2}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - p.add_instruction(migraphx::op::softmax{}, al); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(120); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, s)); -} - -TEST_CASE(logsoftmax_test_axis_0) -{ - migraphx::program p; - std::vector a = { - 1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336, - -1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592, - -0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611, - -0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996, - 1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923, - -0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234, - -1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535, - -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; - - std::vector s = { - -0.135261, -2.843968, -0.659995, -0.488413, -1.051857, -2.812936, -0.250956, -0.353985, - -1.155980, -0.603651, -0.211969, -0.175371, -1.336552, -3.885010, -1.871544, -0.837083, - -0.887745, -0.433338, -1.158864, -4.911197, -1.147972, -0.666711, -0.996874, -0.981418, - -0.851145, -0.853988, -0.858112, -2.067420, -0.059956, -0.727436, -0.950881, -0.429689, - -0.061906, -1.505332, -1.210277, -0.377970, -0.791448, -1.655428, -1.827253, -0.304828, - -0.020762, -0.167101, -0.567346, -0.530319, -1.045094, -0.376648, -0.007391, -0.381670, - -0.720302, -0.460499, -0.469651, -0.556740, -0.554628, -0.551582}; - - migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - int axis = 0; - p.add_instruction(migraphx::op::logsoftmax{axis}, al); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, s)); -} - -TEST_CASE(logsoftmax_test_axis_1) -{ - migraphx::program p; - std::vector a = { - 1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336, - -1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592, - -0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611, - -0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996, - 1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923, - -0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234, - -1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535, - -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; - - std::vector s = { - -0.550468, -2.132973, -1.549746, -0.650533, -1.051529, -2.248570, -0.141017, -2.028357, - -1.947730, -1.511324, -0.166597, -0.379726, -1.965689, -1.172109, -1.475721, -2.700831, - -1.537011, -0.658754, -1.596017, -3.353137, -2.266743, -1.084197, -1.076214, -0.406712, - -2.743019, -0.425526, -1.079083, -2.139486, -1.270584, -1.024088, -1.154231, -3.201762, - -0.888957, -0.532855, -3.103583, -1.221339, -1.355980, -3.531678, -1.438510, -0.975194, - -0.080261, -1.162697, -1.568557, -1.398519, -1.322129, -0.470660, -0.370953, -0.907343, - -1.179017, -3.312239, -1.286363, -1.586076, -0.345100, -0.824173}; - - migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - int axis = 1; - p.add_instruction(migraphx::op::logsoftmax{axis}, al); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, s)); -} - -TEST_CASE(logsoftmax_test_axis_2) -{ - migraphx::program p; - std::vector a = { - 1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336, - -1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592, - -0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611, - -0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996, - 1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923, - -0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234, - -1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535, - -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; - - std::vector s = { - -0.495957, -1.031212, -0.245531, -2.013726, -1.339125, -2.465619, -1.356652, -0.964037, - -2.019250, -0.214522, -0.289569, -0.234392, -2.086591, -2.684439, -2.851651, -2.674176, - -1.697424, -1.889155, -0.401029, -3.064586, -1.173030, -1.306912, -2.177020, -0.834262, - -2.818177, -0.174415, -1.361105, -1.024571, -0.106766, -1.167645, -1.072650, -2.576522, - -0.569261, -1.207483, -3.679894, -2.095913, -0.504264, -3.039291, -1.290559, -1.156812, - -0.126453, -0.551493, -2.506384, -2.646261, -1.905195, -0.206994, -0.191369, -0.959754, - -1.948685, -3.671233, -0.875521, -3.111952, -1.905644, -1.6076011}; - - migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - int axis = 2; - p.add_instruction(migraphx::op::logsoftmax{axis}, al); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, s)); -} - -TEST_CASE(logsoftmax_test_axis_3) -{ - migraphx::program p; - std::vector a = { - 1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336, - -1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592, - -0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611, - -0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996, - 1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923, - -0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234, - -1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535, - -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; - - std::vector s = { - -0.336904, -3.475825, -1.366154, -0.279366, -2.208430, -2.010934, -0.225511, -2.436562, - -2.167785, -1.572415, -1.784104, -0.470789, -1.067459, -1.801948, -0.711023, -2.307197, - -1.467087, -0.400681, -0.426983, -3.740518, -1.127681, -1.078919, -2.599005, -0.534965, - -2.561400, -0.567617, -1.033025, -2.097713, -0.520463, -1.262245, -1.763230, -2.607658, - -0.281299, -0.814243, -2.627210, -0.724131, -0.655704, -2.123055, -1.018163, -2.480634, - -0.382599, -1.451479, -1.843102, -0.915303, -0.818078, -1.316929, -0.508875, -2.033541, - -1.487672, -2.417791, -0.378360, -2.568531, -0.569794, -1.028032}; - - migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - int axis = 3; - p.add_instruction(migraphx::op::logsoftmax{axis}, al); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, s)); -} - -TEST_CASE(argmax_test_0) -{ - migraphx::program p; - std::vector data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, - -1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706, - 0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497}; - std::vector res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1}; - migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; - auto dl = p.add_literal(migraphx::literal{data_shape, data}); - p.add_instruction(migraphx::op::argmax{0}, dl); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector result_vec; - result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); - - EXPECT(migraphx::verify_range(result_vec, res_gold)); -} - -TEST_CASE(argmax_test_1) -{ - migraphx::program p; - std::vector data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, - -1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706, - 0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497}; - std::vector res_gold = {0, 0, 2, 1, 2, 0, 0, 2}; - migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; - auto dl = p.add_literal(migraphx::literal{data_shape, data}); - p.add_instruction(migraphx::op::argmax{1}, dl); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector result_vec; - result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); - - EXPECT(migraphx::verify_range(result_vec, res_gold)); -} - -TEST_CASE(argmax_test_2) -{ - migraphx::program p; - std::vector data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, - -1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706, - 0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497}; - std::vector res_gold = {1, 3, 2, 2, 2, 3}; - migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; - auto dl = p.add_literal(migraphx::literal{data_shape, data}); - p.add_instruction(migraphx::op::argmax{2}, dl); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector result_vec; - result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); - - EXPECT(migraphx::verify_range(result_vec, res_gold)); -} - -TEST_CASE(argmin_test_0) -{ - migraphx::program p; - std::vector data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, - -1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706, - 0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497}; - std::vector res_gold = {1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0}; - migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; - auto dl = p.add_literal(migraphx::literal{data_shape, data}); - p.add_instruction(migraphx::op::argmin{0}, dl); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector result_vec; - result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); - - EXPECT(migraphx::verify_range(result_vec, res_gold)); -} - -TEST_CASE(argmin_test_1) -{ - migraphx::program p; - std::vector data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, - -1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706, - 0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497}; - std::vector res_gold = {2, 2, 0, 2, 0, 1, 2, 0}; - migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; - auto dl = p.add_literal(migraphx::literal{data_shape, data}); - p.add_instruction(migraphx::op::argmin{1}, dl); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector result_vec; - result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); - - EXPECT(migraphx::verify_range(result_vec, res_gold)); -} - -TEST_CASE(argmin_test_2) -{ - migraphx::program p; - std::vector data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, - -1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706, - 0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497}; - std::vector res_gold = {2, 1, 0, 3, 3, 2}; - migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; - auto dl = p.add_literal(migraphx::literal{data_shape, data}); - p.add_instruction(migraphx::op::argmin{2}, dl); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector result_vec; - result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); - - EXPECT(migraphx::verify_range(result_vec, res_gold)); -} - -TEST_CASE(conv2d_test) -{ - migraphx::program p; - std::vector a = { - 2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712, - -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606, - 0.02345525, -0.33579525, 0.38901961, 1.05473483, -1.31188095, 1.8963089, -0.07265259, - 0.947339, 0.41949373, -0.70814759, 0.25892952, 1.07311416, 1.2571274, -0.62318051, - -0.19951548, -0.94232577, -0.29393643, 0.42292568, -0.80230367, 1.40909171, 0.63617158, - 0.13900366, 1.09253144, -0.15265895, 1.54781747, 0.72780299, 1.09189606, -0.38068101, - 0.97057933, -0.58958799, 1.56188643, 0.21474874, 0.58725154, -1.27097559, -0.03024297, - 1.09437096, -0.4897908, 0.34838957, -1.31042492, -1.69069934, 0.86956722, -0.40457946, - 0.46691212, 1.29273605, 0.26464137, 0.22073045, -1.02178168, 0.22163901, -1.84387338, - 0.75522131, -0.45775682, -0.42241111, -1.50944722, 1.07256448, -1.95876884, -0.28106022, - 0.3341668, 2.13129425, -1.14728117, -1.06555498, -0.298444, -0.88322699, -0.65866792, - -2.06007552, 0.01374334, 0.45612028, 0.52715492, 1.01914406, -1.72659791, 0.80650896, - 0.16860051, 2.24112225, -0.78620857, 0.36566174, -0.07020134, -0.47976932, -0.68230027, - -0.94711417, -0.54506505, 1.66504931, -0.71860826, 0.61132306}; - - std::vector c = { - 2.82721668e-02, 6.44195229e-02, 1.53499246e-02, 1.72468081e-01, -6.33238107e-02, - 9.49496776e-02, 1.40258059e-01, -7.92879611e-02, -1.29301161e-01, 3.11307609e-03, - -1.90624535e-01, 1.13238767e-01, -2.80647576e-02, 3.12882811e-02, -3.52091640e-02, - 3.33581865e-02, 6.43158704e-02, 7.40238279e-02, -1.00106120e-01, -9.56912562e-02, - 1.44342467e-01, 9.40258950e-02, 6.36333972e-02, 1.66158378e-03, -8.91554281e-02, - 2.58734226e-02, 1.70919895e-02, 1.78214177e-01, 8.84564668e-02, 8.98126513e-02, - -1.63809001e-01, 1.37802169e-01, 1.66439757e-01, -1.45631135e-02, 1.88469887e-04, - 4.76950556e-02, -1.91969007e-01, -1.76233292e-01, -7.70473927e-02, 1.14828631e-01, - 1.76608220e-01, -1.50728196e-01, 1.99946314e-02, -5.88052124e-02, 1.31612435e-01, - 1.61106288e-02, -1.35080189e-01, 1.49512306e-01, 3.86456847e-02, 1.29330024e-01, - -3.22975963e-02, -5.60784787e-02, -5.41997552e-02, 4.78562862e-02}; - - std::vector s = {0.27039781, - 0.19105849, - -0.06339942, - -0.65087199, - 0.40867025, - 0.05063812, - -0.14907975, - 0.49018705, - -0.49197209, - 0.33236548, - -0.39374301, - 0.16012701, - 0.06574871, - 0.71606487, - -0.55201721, - -0.46427044}; - migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 4, 4}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - - migraphx::shape c_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; - auto cl = p.add_literal(migraphx::literal{c_shape, c}); - - p.add_instruction(migraphx::op::convolution{}, al, cl); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - - std::vector results_vector(16); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, s)); -} - -TEST_CASE(conv2d_padding_test) -{ - migraphx::program p; - std::vector a = { - 2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712, - -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606, - 0.02345525, -0.33579525, 0.38901961, 1.05473483, -1.31188095, 1.8963089, -0.07265259, - 0.947339, 0.41949373, -0.70814759, 0.25892952, 1.07311416, 1.2571274, -0.62318051, - -0.19951548, -0.94232577, -0.29393643, 0.42292568, -0.80230367, 1.40909171, 0.63617158, - 0.13900366, 1.09253144, -0.15265895, 1.54781747, 0.72780299, 1.09189606, -0.38068101, - 0.97057933, -0.58958799, 1.56188643, 0.21474874, 0.58725154, -1.27097559, -0.03024297, - 1.09437096, -0.4897908, 0.34838957, -1.31042492, -1.69069934, 0.86956722, -0.40457946, - 0.46691212, 1.29273605, 0.26464137, 0.22073045, -1.02178168, 0.22163901, -1.84387338, - 0.75522131, -0.45775682, -0.42241111, -1.50944722, 1.07256448, -1.95876884, -0.28106022, - 0.3341668, 2.13129425, -1.14728117, -1.06555498, -0.298444, -0.88322699, -0.65866792, - -2.06007552, 0.01374334, 0.45612028, 0.52715492, 1.01914406, -1.72659791, 0.80650896, - 0.16860051, 2.24112225, -0.78620857, 0.36566174, -0.07020134, -0.47976932, -0.68230027, - -0.94711417, -0.54506505, 1.66504931, -0.71860826, 0.61132306}; - - std::vector c = { - -0.16115488, -0.09800646, -0.05412646, 0.10475694, 0.00555485, -0.12667653, 0.0458357, - -0.02656217, -0.16338061, 0.15037455, 0.0102711, 0.01303349, 0.05242859, 0.02034754, - 0.04751867, -0.17038961, -0.1434752, -0.10770349, 0.05676742, -0.15838449, 0.10128359, - -0.18958683, 0.11954515, 0.10758857, -0.01058291, -0.12797487, 0.08971019, 0.18793164, - -0.00881396, -0.06588994, -0.13321903, -0.03300409, 0.01439607, 0.07618178, -0.11556662, - 0.00764295, 0.12956454, -0.08937147, -0.12763587, 0.04674943, 0.05765297, 0.11336918, - 0.14747436, -0.06199479, -0.01166052, -0.12432006, -0.04494537, -0.17581205, 0.09475745, - 0.1149437, -0.1014564, 0.0274073, -0.01323579, -0.11092556}; - - std::vector s = { - -0.0201216, 0.40407312, -0.39005592, -0.0631946, 0.37963012, -0.64611685, 0.1349397, - -0.54113752, 0.28533003, 0.27667275, -0.16442731, -0.181494, 0.30564839, 0.58744538, - 0.32015014, 0.24969585, -0.27367792, -0.53308117, 0.41236052, 0.26136363, -0.01489828, - 0.57652152, -0.38506854, 0.119615, 0.0437076, 0.04779706, 0.57887721, 0.23126155, - 0.05695833, -0.68200272, 0.02063358, -0.10267162, 0.8062973, -0.38149622, -0.40134856, - -0.03353126, 0.38991132, -0.3478111, 0.03661491, 0.25783631, 0.62772679, -0.1961118, - 0.76423508, -0.36241418, -0.20994355, -0.12368261, -0.9406727, 0.02340185, -0.08793129, - -0.02471633, -0.58163726, -0.02211772, -0.42014724, 0.77525634, 0.504951, -0.20537445, - -0.20369984, -0.83037728, -1.40423918, -0.46160448, -0.22944322, 0.36074194, 0.49579027, - 0.46527559}; - - migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 4, 4}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - - migraphx::shape c_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; - auto cl = p.add_literal(migraphx::literal{c_shape, c}); - - p.add_instruction(migraphx::op::convolution{{{1, 1}}, {{1, 1}}}, al, cl); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - - std::vector results_vector(64); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, s)); -} - -TEST_CASE(conv2d_padding_stride_test) -{ - migraphx::program p; - std::vector a = { - 2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712, - -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606, - 0.02345525, -0.33579525, 0.38901961, 1.05473483, -1.31188095, 1.8963089, -0.07265259, - 0.947339, 0.41949373, -0.70814759, 0.25892952, 1.07311416, 1.2571274, -0.62318051, - -0.19951548, -0.94232577, -0.29393643, 0.42292568, -0.80230367, 1.40909171, 0.63617158, - 0.13900366, 1.09253144, -0.15265895, 1.54781747, 0.72780299, 1.09189606, -0.38068101, - 0.97057933, -0.58958799, 1.56188643, 0.21474874, 0.58725154, -1.27097559, -0.03024297, - 1.09437096, -0.4897908, 0.34838957, -1.31042492, -1.69069934, 0.86956722, -0.40457946, - 0.46691212, 1.29273605, 0.26464137, 0.22073045, -1.02178168, 0.22163901, -1.84387338, - 0.75522131, -0.45775682, -0.42241111, -1.50944722, 1.07256448, -1.95876884, -0.28106022, - 0.3341668, 2.13129425, -1.14728117, -1.06555498, -0.298444, -0.88322699, -0.65866792, - -2.06007552, 0.01374334, 0.45612028, 0.52715492, 1.01914406, -1.72659791, 0.80650896, - 0.16860051, 2.24112225, -0.78620857, 0.36566174, -0.07020134, -0.47976932, -0.68230027, - -0.94711417, -0.54506505, 1.66504931, -0.71860826, 0.61132306}; - - std::vector c = { - -0.14601797, -0.13000923, 0.06521662, 0.06178288, -0.11083675, 0.10154136, 0.09990512, - 0.06030385, -0.11374587, -0.17523311, -0.14344215, 0.17802463, 0.06300922, -0.15325832, - 0.07066704, 0.05166031, 0.00615084, -0.02606523, 0.08083995, -0.17913306, 0.0624622, - 0.0735731, -0.04198661, -0.0164391, -0.06374192, 0.16569914, 0.10681538, 0.07370754, - 0.02802075, 0.00282027, 0.15104802, -0.11084409, -0.00197773, 0.07924436, 0.03528272, - 0.04765259, -0.15896152, 0.07917164, 0.12125669, -0.1154705, -0.11999125, 0.12749968, - -0.06269585, 0.18658121, -0.03944227, 0.0111798, -0.17731084, 0.11789055, -0.09982193, - 0.08142821, 0.0729029, 0.11303909, 0.12735154, 0.03885292}; - - std::vector s = {-0.20817225, - 0.87965256, - 0.14958936, - -1.24887264, - -0.06540672, - 0.20778663, - 0.40456355, - -0.99900877, - 0.4917807, - 0.1994698, - 0.64205718, - 0.37798831, - -0.25315839, - 0.44276932, - -0.16138598, - 0.79344082}; - - migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 4, 4}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - - migraphx::shape c_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; - auto cl = p.add_literal(migraphx::literal{c_shape, c}); - - p.add_instruction(migraphx::op::convolution{{{1, 1}}, {{2, 2}}}, al, cl); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - - std::vector results_vector(16); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, s)); -} - -TEST_CASE(quant_conv2d_test) -{ - migraphx::program p; - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; - std::vector a(2 * 3 * 4 * 4); - std::iota(a.begin(), a.end(), 0); - auto al = p.add_literal(migraphx::literal{a_shape, a}); - - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; - std::vector c(2 * 3 * 3 * 3); - std::iota(c.begin(), c.end(), 0); - auto cl = p.add_literal(migraphx::literal{c_shape, c}); - - p.add_instruction(migraphx::op::quant_convolution{}, al, cl); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - - std::vector s = {10197, - 10548, - 11601, - 11952, - 25506, - 26586, - 29826, - 30906, - 27045, - 27396, - 28449, - 28800, - 77346, - 78426, - 81666, - 82746}; - - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, s)); -} - -TEST_CASE(quant_conv2d_padding_test) -{ - migraphx::program p; - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; - std::vector a(2 * 3 * 4 * 4); - std::iota(a.begin(), a.end(), 0); - auto al = p.add_literal(migraphx::literal{a_shape, a}); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; - std::vector c(2 * 3 * 3 * 3); - std::iota(c.begin(), c.end(), 0); - auto cl = p.add_literal(migraphx::literal{c_shape, c}); - p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{1, 1}}}, al, cl); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector s = { - 4521, 6753, 7014, 4635, 6858, 10197, 10548, 6939, 7830, 11601, 11952, 7839, 5007, - 7383, 7590, 4953, 10515, 15987, 16734, 11277, 16821, 25506, 26586, 17874, 19737, 29826, - 30906, 20718, 13593, 20505, 21198, 14187, 13161, 19281, 19542, 12699, 18522, 27045, 27396, - 17739, 19494, 28449, 28800, 18639, 11919, 17319, 17526, 11289, 34707, 51843, 52590, 34893, - 51813, 77346, 78426, 52002, 54729, 81666, 82746, 54846, 36057, 53769, 54462, 36075}; - - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, s)); -} - -TEST_CASE(quant_conv2d_padding_stride_test) -{ - migraphx::program p; - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; - std::vector a(2 * 3 * 4 * 4); - std::iota(a.begin(), a.end(), 0); - auto al = p.add_literal(migraphx::literal{a_shape, a}); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; - std::vector c(2 * 3 * 3 * 3); - std::iota(c.begin(), c.end(), 0); - auto cl = p.add_literal(migraphx::literal{c_shape, c}); - p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{2, 2}}}, al, cl); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - - std::vector s = {4521, - 7014, - 7830, - 11952, - 10515, - 16734, - 19737, - 30906, - 13161, - 19542, - 19494, - 28800, - 34707, - 52590, - 54729, - 82746}; - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(results_vector, s)); -} - -TEST_CASE(transpose_test) -{ - migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 2, 3}}; - std::vector data(12); - std::iota(data.begin(), data.end(), 0); - - { - migraphx::program p; - auto l = p.add_literal(migraphx::literal{a_shape, data}); - std::vector perm = {0, 3, 1, 2}; - p.add_instruction(migraphx::op::transpose{perm}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - - result.visit([&](auto output) { - std::vector new_lens = {1, 3, 2, 2}; - EXPECT(bool{output.get_shape().lens() == new_lens}); - }); - } - { - migraphx::program p; - auto l = p.add_literal(migraphx::literal{a_shape, data}); - std::vector perm = {0, 3, 1, 2}; - auto result = p.add_instruction(migraphx::op::transpose{perm}, l); - p.add_instruction(migraphx::op::contiguous{}, result); - p.compile(migraphx::cpu::target{}); - auto result2 = p.eval({}); - - std::vector results_vector(12); - result2.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11}; - EXPECT(migraphx::verify_range(results_vector, gold)); - } -} - -TEST_CASE(contiguous_test) -{ - migraphx::shape a_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}}; - std::vector data(12); - std::iota(data.begin(), data.end(), 0); - - migraphx::program p; - auto l = p.add_literal(migraphx::literal{a_shape, data}); - p.add_instruction(migraphx::op::contiguous{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - - std::vector results_vector(12); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector new_lens = {1, 3, 2, 2}; - std::vector new_strides = {12, 1, 6, 3}; - EXPECT(migraphx::verify_range(results_vector, data)); -} - -TEST_CASE(identity_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 2}}; - std::vector data{1, 2, 3, 4}; - auto l = p.add_literal(migraphx::literal{s, data}); - p.add_instruction(migraphx::op::identity{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(4); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - EXPECT(std::equal(data.begin(), data.end(), results_vector.begin())); -} - -TEST_CASE(abs_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 2}}; - auto l = p.add_literal(migraphx::literal{s, {-1, 2, -3, 4}}); - p.add_instruction(migraphx::op::abs{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(4); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{1, 2, 3, 4}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(sigmoid_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 2}}; - auto l = p.add_literal(migraphx::literal{s, {-1, 2, -3, 4}}); - p.add_instruction(migraphx::op::sigmoid{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(4); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{sigmoid(-1), sigmoid(2), sigmoid(-3), sigmoid(4)}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(sinh_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 2}}; - auto l = p.add_literal(migraphx::literal{s, {-1.0, 2.0, -3.0, 4.0}}); - p.add_instruction(migraphx::op::sinh{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(4); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{sinhf(-1), sinhf(2), sinhf(-3), sinhf(4)}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(cosh_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 2}}; - auto l = p.add_literal(migraphx::literal{s, {-1.0, 2.0, -3.0, 4.0}}); - p.add_instruction(migraphx::op::cosh{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(4); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{coshf(-1), coshf(2), coshf(-3), coshf(4)}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(tanh_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 2}}; - auto l = p.add_literal(migraphx::literal{s, {-1.0, 2.0, -3.0, 4.0}}); - p.add_instruction(migraphx::op::tanh{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(4); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{tanhf(-1), tanhf(2), tanhf(-3), tanhf(4)}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(elu_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 2}}; - auto l = p.add_literal(migraphx::literal{s, {-1.0, 2.0, -3.0, 4.0}}); - float alpha = 0.5; - p.add_instruction(migraphx::op::elu{alpha}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(4); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{elu(alpha, -1), elu(alpha, 2), elu(alpha, -3), elu(alpha, 4)}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(max_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l0 = p.add_literal(migraphx::literal{s, {1, 4, 3}}); - auto l1 = p.add_literal(migraphx::literal{s, {2, 8, 6}}); - auto l2 = p.add_literal(migraphx::literal{s, {7, 5, 9}}); - auto curr_max = p.add_instruction(migraphx::op::max{}, l0, l1); - p.add_instruction(migraphx::op::max{}, curr_max, l2); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(4); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{7, 8, 9}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(min_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l0 = p.add_literal(migraphx::literal{s, {1, 4, 3}}); - auto l1 = p.add_literal(migraphx::literal{s, {2, 8, 6}}); - auto l2 = p.add_literal(migraphx::literal{s, {7, 5, 9}}); - auto curr_min = p.add_instruction(migraphx::op::min{}, l0, l1); - p.add_instruction(migraphx::op::min{}, curr_min, l2); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(4); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{1, 4, 3}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(pad_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 2}}; - auto l0 = p.add_literal(migraphx::literal{s, {1, 2, 3, 4}}); - p.add_instruction(migraphx::op::pad{{1, 1, 1, 1}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(16); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(fp16_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::half_type, {1}}; - migraphx::half a{1.5}; - migraphx::half b{2.5}; - migraphx::half c{4.0}; - auto l0 = p.add_literal(migraphx::literal{s, {a}}); - auto l1 = p.add_literal(migraphx::literal{s, {b}}); - p.add_instruction(migraphx::op::add{}, l0, l1); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(1); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{c}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(fp32_fp16_test) -{ - auto create_program = [] { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - std::vector data(2 * 3); - std::iota(data.begin(), data.end(), 1.0f); - auto l1 = p.add_literal(migraphx::literal(s, data)); - auto l2 = p.add_literal(migraphx::literal(s, data)); - p.add_instruction(migraphx::op::add{}, l1, l2); - return p; - }; - - auto test_case = [&](std::vector&& op_names) { - std::vector gold_res = {2.0, 4.0, 6.0, 8.0, 10.0, 12.0}; - auto p = create_program(); - migraphx::quantize_fp16(p, op_names); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector res; - result.visit([&](auto output) { res.assign(output.begin(), output.end()); }); - EXPECT(migraphx::verify_range(res, gold_res)); - }; - - test_case({"all"}); - test_case({"add"}); -} - -TEST_CASE(clip_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l = p.add_literal(migraphx::literal{s, {-1.0, 0.0, 10.0}}); - migraphx::op::clip op; - op.max_val = 6.0; - op.min_val = 0.0; - p.add_instruction(op, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {0.0, 0.0, 6.0}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(reduce_sum_axis0) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_sum{{0}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{15, 18, 21, 24}; - EXPECT(results_vector == gold); -} - -TEST_CASE(reduce_sum_axis1) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_sum{{1}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{4, 6, 12, 14, 20, 22}; - EXPECT(results_vector == gold); -} - -TEST_CASE(reduce_sum_axis2) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_sum{{2}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{3, 7, 11, 15, 19, 23}; - EXPECT(results_vector == gold); -} - -TEST_CASE(reduce_sum_axis02) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_sum{{0, 2}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{33, 45}; - EXPECT(results_vector == gold); -} - -TEST_CASE(reduce_sum_axis12) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_sum{{1, 2}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{10, 26, 42}; - EXPECT(results_vector == gold); -} - -TEST_CASE(rsqrt_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l = p.add_literal(migraphx::literal{s, {4.0, 16.0, 64.0}}); - p.add_instruction(migraphx::op::rsqrt{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {0.5, 0.25, 0.125}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(reduce_mean_axis1) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_mean{{1}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{2, 3, 6, 7, 10, 11}; - EXPECT(results_vector == gold); -} - -TEST_CASE(reduce_mean_axis2) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_mean{{2}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{1.5f, 3.5f, 5.5f, 7.5f, 9.5f, 11.5f}; - EXPECT(results_vector == gold); -} - -TEST_CASE(reduce_mean_axis02) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_mean{{0, 2}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{5.5, 7.5}; - EXPECT(results_vector == gold); -} - -TEST_CASE(reduce_mean_axis12) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{2.5f, 6.5f, 10.5f}; - EXPECT(results_vector == gold); -} - -TEST_CASE(reduce_mean_int) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::int32_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{2, 6, 10}; - EXPECT(results_vector == gold); -} - -TEST_CASE(reduce_min_axis1) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_min{{1}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{1, 2, 5, 6, 9, 10}; - EXPECT(results_vector == gold); -} - -TEST_CASE(reduce_min_axis02) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_min{{0, 2}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{1, 3}; - EXPECT(results_vector == gold); -} - -TEST_CASE(reduce_min_axis12) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_min{{1, 2}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{1, 5, 9}; - EXPECT(results_vector == gold); -} - -TEST_CASE(reduce_max_axis0) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_max{{0}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{9, 10, 11, 12}; - EXPECT(results_vector == gold); -} - -TEST_CASE(reduce_max_axis01) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_max{{0, 1}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{11, 12}; - EXPECT(results_vector == gold); -} - -TEST_CASE(reduce_max_axis02) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; - auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; - auto l0 = p.add_literal(input); - p.add_instruction(migraphx::op::reduce_max{{0, 2}}, l0); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{10, 12}; - EXPECT(results_vector == gold); -} - -TEST_CASE(sqdiff_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto l1 = p.add_literal(migraphx::literal{s, {-1, 0, 1}}); - auto l2 = p.add_literal(migraphx::literal{s, {1, 2, 3}}); - p.add_instruction(migraphx::op::sqdiff{}, l1, l2); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector(3); - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {4, 4, 4}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(round_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {9}}; - auto l = p.add_literal(migraphx::literal{s, {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0}}); - p.add_instruction(migraphx::op::round{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {1.0, 2.0, 2.0, -1.0, -2.0, -2.0, 0.0, 2.0, -2.0}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(ceil_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {9}}; - auto l = p.add_literal(migraphx::literal{s, {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0}}); - p.add_instruction(migraphx::op::ceil{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {2.0, 2.0, 2.0, -1.0, -1.0, -1.0, 0.0, 2.0, -2.0}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(floor_test) -{ - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {9}}; - auto l = p.add_literal(migraphx::literal{s, {1.1, 1.5, 0.6, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}}); - p.add_instruction(migraphx::op::floor{}, l); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); - std::vector results_vector; - result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold = {1.0, 1.0, 0.0, -2.0, -2.0, -1.0, -0.0, 2.0, -2.0}; - EXPECT(migraphx::verify_range(results_vector, gold)); -} - -TEST_CASE(op_capture) -{ - migraphx::program p; - migraphx::shape s1{migraphx::shape::float_type, {3, 3}}; - migraphx::shape s2{migraphx::shape::float_type, {3, 6}}; - std::vector d1(s1.elements()); - std::vector d2(s2.elements()); - std::iota(d1.begin(), d1.end(), 0.0f); - std::iota(d2.begin(), d2.end(), 0.0f); - - auto p1 = p.add_literal(s1, d1); - auto p2 = p.add_literal(s1, d1); - auto pb = p.add_literal(s2, d2); - auto pc = p.add_literal(s2, d2); - auto pa = p.add_instruction(migraphx::op::add{}, p1, p2); - auto ps = p.add_instruction(migraphx::op::dot{}, pa, pb, pc); - p.add_instruction(migraphx::op::dot{}, pa, ps); - - migraphx::program capture_p = p; - migraphx::target t = migraphx::cpu::target{}; - migraphx::capture_arguments(capture_p, t, {"dot"}); - - p.compile(migraphx::cpu::target{}); - capture_p.compile(migraphx::cpu::target{}); - - auto cap_res = capture_p.eval({}); - auto res = p.eval({}); - - std::vector vec; - std::vector cap_vec; - cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); }); - res.visit([&](auto output) { vec.assign(output.begin(), output.end()); }); - - EXPECT(migraphx::verify_range(vec, cap_vec)); -} - -int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/cpu_rnn_ops_test.cpp b/test/cpu_rnn_ops_test.cpp deleted file mode 100644 index 3875a0e4bdf53205ebecbac41ec934cbe8f5cce4..0000000000000000000000000000000000000000 --- a/test/cpu_rnn_ops_test.cpp +++ /dev/null @@ -1,3366 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "test.hpp" - -TEST_CASE(rnn_forward) -{ - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - std::vector w_data{0.4691, - 0.3185, - -0.2227, - 0.4423, - -0.0609, - -0.2803, - 0.1744, - 0.3146, - 0.4049, - -0.3973, - -0.0890, - -0.1636}; - - std::vector r_data{-0.0456, - 0.1061, - 0.1574, - -0.4928, - -0.4300, - -0.1909, - -0.0225, - -0.2668, - 0.1840, - -0.4453, - -0.4896, - 0.1302, - -0.0929, - 0.3545, - -0.4981, - 0.0616}; - - std::vector bias_data{ - -0.4938, 0.4355, -0.3186, 0.2094, 0.1037, -0.1071, 0.4504, -0.3990}; - std::vector ih_data(num_dirct * batch_size * hidden_size, 0); - std::vector input(seq_len * batch_size * input_size, 0); - input[0] = input[1] = 1.0; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - float clip = 0.0f; - // concatenation of hidden states as program output - { - - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r, - bias, - und, - ih); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{0.37780784, - 0.61055139, - 0.55168478, - -0.5888475, - -0.37144644, - 0.31708236, - 0.13104209, - -0.18736027, - 0.03445704, - 0.19167931, - -0.3946827, - -0.30889652, - -0.22276389, - 0.44193283, - -0.16477929, - -0.11893477}; - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // rnn last output as program output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = p.add_instruction( - migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, - seq, - w, - r, - bias, - und, - ih); - - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.compile(migraphx::cpu::target{}); - - auto last_output = p.eval({}); - std::vector last_output_data; - last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); - - std::vector last_output_data_gold{0.03445704, - 0.19167931, - -0.3946827, - -0.30889652, - -0.22276389, - 0.44193283, - -0.16477929, - -0.11893477}; - EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); - } - - // multiple rnn_last_output operators - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = p.add_instruction( - migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.compile(migraphx::cpu::target{}); - - auto last_output = p.eval({}); - std::vector last_output_data; - last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); - - std::vector last_output_data_gold{0.03445704, - 0.19167931, - -0.3946827, - -0.30889652, - -0.22276389, - 0.44193283, - -0.16477929, - -0.11893477}; - EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); - } - - // 3 args - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - - auto out_hs = p.add_instruction( - migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, - seq, - w, - r); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.compile(migraphx::cpu::target{}); - - auto last_output = p.eval({}); - std::vector last_output_data; - last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); - - std::vector last_output_data_gold{ - 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.}; - EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); - } - - // seq_len = 1 - { - seq_len = 1; - std::vector input_1(seq_len * batch_size * input_size, 0); - input_1[0] = input_1[1] = 1.0; - migraphx::shape in_shape_1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape_1, input_1}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r, - bias, - und, - ih); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{0.37780784, - 0.61055139, - 0.55168478, - -0.5888475, - -0.37144644, - 0.31708236, - 0.13104209, - -0.18736027}; - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } -} - -TEST_CASE(rnn_reverse) -{ - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - std::vector w_data{-0.0296, - -0.1341, - 0.1761, - -0.2325, - -0.0717, - 0.1852, - 0.2720, - 0.1471, - -0.1097, - 0.3363, - -0.0587, - -0.2302}; - std::vector r_data{0.2528, - -0.2333, - 0.3973, - 0.1593, - -0.0388, - 0.1702, - 0.3829, - -0.0712, - -0.1668, - 0.3074, - -0.2854, - 0.4049, - -0.3737, - -0.1051, - 0.4482, - -0.2841}; - std::vector bias_data{-0.3188, 0.1341, -0.4446, 0.1389, 0.3117, 0.3664, 0.2352, 0.2552}; - std::vector input(seq_len * batch_size * input_size, 0); - input[0] = input[1] = 1.0; - std::vector ih_data(num_dirct * batch_size * hidden_size, 0); - float clip = 0.0f; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - // concatenation of hidden states as program output - { - - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction( - migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip}, - seq, - w, - r, - bias, - und, - ih); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.29385301, - 0.16796815, - 0.51075965, - 0.40258689, - -0.13818839, - 0.44124447, - 0.14365635, - 0.14803654, - -0.0070999, - 0.46251031, - -0.20639211, - 0.37488942, - -0.0070999, - 0.46251031, - -0.20639211, - 0.37488942}; - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // rnn last output as program output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = p.add_instruction( - migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip}, - seq, - w, - r, - bias, - und, - ih); - - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.compile(migraphx::cpu::target{}); - - auto last_output = p.eval({}); - std::vector last_output_data; - last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); - - std::vector last_output_data_gold{-0.29385301, - 0.16796815, - 0.51075965, - 0.40258689, - -0.13818839, - 0.44124447, - 0.14365635, - 0.14803654}; - EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); - } -} - -TEST_CASE(rnn_bidirectional) -{ - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 2; - std::vector w_data{0.4691, 0.3185, -0.2227, 0.4423, -0.0609, -0.2803, - 0.1744, 0.3146, 0.4049, -0.3973, -0.0890, -0.1636, - -0.0296, -0.1341, 0.1761, -0.2325, -0.0717, 0.1852, - 0.2720, 0.1471, -0.1097, 0.3363, -0.0587, -0.2302}; - - std::vector r_data{-0.0456, 0.1061, 0.1574, -0.4928, -0.4300, -0.1909, -0.0225, - -0.2668, 0.1840, -0.4453, -0.4896, 0.1302, -0.0929, 0.3545, - -0.4981, 0.0616, 0.2528, -0.2333, 0.3973, 0.1593, -0.0388, - 0.1702, 0.3829, -0.0712, -0.1668, 0.3074, -0.2854, 0.4049, - -0.3737, -0.1051, 0.4482, -0.2841}; - - std::vector bias_data{-0.4938, - 0.4355, - -0.3186, - 0.2094, - 0.1037, - -0.1071, - 0.4504, - -0.3990, - -0.3188, - 0.1341, - -0.4446, - 0.1389, - 0.3117, - 0.3664, - 0.2352, - 0.2552}; - - std::vector input(seq_len * batch_size * input_size, 0); - input[0] = input[1] = 1.0; - std::vector ih_data(num_dirct * batch_size * hidden_size, 0); - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - float clip = 0.0f; - // concatenation of hidden state for program output - { - - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - p.add_instruction( - migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip}, - seq, - w, - r, - bias, - und, - ih); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{ - 0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236, - 0.13104209, -0.18736027, -0.29385301, 0.16796815, 0.51075965, 0.40258689, - -0.13818839, 0.44124447, 0.14365635, 0.14803654, 0.03445704, 0.19167931, - -0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477, - -0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031, - -0.20639211, 0.37488942}; - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // last rnn output for program output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = - p.add_instruction(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - und, - ih); - - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.compile(migraphx::cpu::target{}); - - auto last_output = p.eval({}); - std::vector last_output_data; - last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); - - std::vector last_output_data_gold{0.03445704, - 0.19167931, - -0.3946827, - -0.30889652, - -0.22276389, - 0.44193283, - -0.16477929, - -0.11893477, - -0.29385301, - 0.16796815, - 0.51075965, - 0.40258689, - -0.13818839, - 0.44124447, - 0.14365635, - 0.14803654}; - - EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); - } - - // 4 args - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - - auto out_hs = - p.add_instruction(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias); - - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.compile(migraphx::cpu::target{}); - - auto last_output = p.eval({}); - std::vector last_output_data; - last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); - - std::vector last_output_data_gold{0.03445704, - 0.19167931, - -0.3946827, - -0.30889652, - -0.22276389, - 0.44193283, - -0.16477929, - -0.11893477, - -0.29385301, - 0.16796815, - 0.51075965, - 0.40258689, - -0.13818839, - 0.44124447, - 0.14365635, - 0.14803654}; - - EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); - } - - // 3 args - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - - p.add_instruction(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r); - p.compile(migraphx::cpu::target{}); - - auto last_output = p.eval({}); - std::vector last_output_data; - last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); - - std::vector last_output_data_gold{ - 0.6570473, 0.36392266, 0.45342238, -0.45127486, 0., 0., 0., 0., - -0.16225325, -0.29515147, 0.39617197, 0.27068236, 0., 0., 0., 0., - 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0.}; - - EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); - } - - // concatenation of hidden state for program output - { - seq_len = 1; - std::vector input_1(seq_len * batch_size * input_size, 0); - input_1[0] = input_1[1] = 1.0; - migraphx::shape in_shape_1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape_1, input_1}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - p.add_instruction( - migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip}, - seq, - w, - r, - bias, - und, - ih); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{0.37780784, - 0.61055139, - 0.55168478, - -0.5888475, - -0.37144644, - 0.31708236, - 0.13104209, - -0.18736027, - -0.16915828, - 0.1938169, - 0.20667936, - 0.58609703, - -0.0070999, - 0.46251031, - -0.20639211, - 0.37488942}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } -} - -TEST_CASE(gru_forward) -{ - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; - std::vector w_data{ - 0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418, - 0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640, - -0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498, - 0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331, - 0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905}; - - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; - std::vector r_data{ - 0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529, - -0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131, - 0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721, - -0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179, - -0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706, - -0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801, - 0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934, - 0.3645, -0.4310, -0.3480, 0.0702, -0.1558}; - - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - std::vector bias_data{ - 0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946, - -0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494, - 0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607}; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - std::vector input{-0.8432, - -0.9887, - 1.3041, - -2.6430, - -0.3306, - -0.8504, - -0.3933, - 0.5151, - -0.2951, - 0.0093, - -1.1948, - -0.1239, - 0.0373, - 1.3211, - 0.7854, - -0.4838, - -1.0536, - -0.2529}; - - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - std::vector ih_data{ - -0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; - float clip = 0.0f; - // concatenation of hidden states for output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - 1}, - seq, - w, - r, - bias, - und, - ih); - - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{ - -0.27298412, 0.42363745, -0.09368783, 0.4823072, -0.02183238, -0.6873896, - 0.16144305, 0.31932795, 0.6104771, 0.79759157, -0.31791314, 0.5249062, - 0.08800987, 0.46404213, -0.11872687, -0.26210734, 0.34448293, -0.0176422, - 0.48523626, 0.60002893, -0.3969709, 0.43360898, 0.35775262, 0.23280787, - -0.52179873, -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // last output for output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto concat_hs = - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - 1}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.3969709, - 0.43360898, - 0.35775262, - 0.23280787, - -0.52179873, - -0.21944991, - 0.4535257, - -0.13735442, - 0.51757574, - 0.50380427}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // two rnn_last_output operators after gru - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto concat_hs = - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - 1}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); - p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.3969709, - 0.43360898, - 0.35775262, - 0.23280787, - -0.52179873, - -0.21944991, - 0.4535257, - -0.13735442, - 0.51757574, - 0.50380427}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // last output for output, linear_before_reset = 0 - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto concat_hs = - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - 0}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.53291196, - 0.50160867, - 0.39010462, - 0.39292926, - -0.5960838, - -0.38451535, - 0.454239, - -0.10620412, - 0.6014447, - 0.43445644}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } -} - -TEST_CASE(gru_forward_args) -{ - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; - std::vector w_data{ - 0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418, - 0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640, - -0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498, - 0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331, - 0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905}; - - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; - std::vector r_data{ - 0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529, - -0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131, - 0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721, - -0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179, - -0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706, - -0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801, - 0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934, - 0.3645, -0.4310, -0.3480, 0.0702, -0.1558}; - - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - std::vector bias_data{ - 0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946, - -0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494, - 0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607}; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - std::vector input{-0.8432, - -0.9887, - 1.3041, - -2.6430, - -0.3306, - -0.8504, - -0.3933, - 0.5151, - -0.2951, - 0.0093, - -1.1948, - -0.1239, - 0.0373, - 1.3211, - 0.7854, - -0.4838, - -1.0536, - -0.2529}; - - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - std::vector ih_data{ - -0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; - float clip = 0.0f; - - // 3 args - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - 1}, - seq, - w, - r); - - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.114674, -0.129581, -0.218156, -0.140788, -0.114242, - -0.346569, 0.321367, -0.0838253, 0.102097, 0.00232137, - -0.149055, 0.0590743, -0.0533094, -0.0446122, -0.112588, - 0.0153261, 0.168883, -0.326836, 0.0843562, 0.160872, - -0.232523, 0.00214573, 0.231693, -0.160475, -0.518952, - 0.0467166, 0.12327, -0.374162, 0.137778, 0.251976}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // 4 args (bias is used) - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - 1}, - seq, - w, - r, - bias); - - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.273619, 0.0931375, -0.104717, 0.0203752, -0.0797887, - -0.493948, 0.472118, -0.0336318, 0.332706, 0.0182268, - -0.341684, 0.38063, 0.0589275, 0.2644, -0.115737, - -0.152324, 0.442277, -0.201626, 0.408909, 0.12905, - -0.416866, 0.377186, 0.32922, 0.162214, -0.519973, - -0.140072, 0.465076, -0.229563, 0.500164, 0.195166}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // 4 args (ih is used) - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - 1}, - seq, - w, - r, - und, - und, - ih); - - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.0801064, 0.27025, -0.20704, 0.333579, -0.0452438, - -0.56265, 0.061061, 0.262172, 0.405193, 0.775226, - -0.100683, 0.258729, -0.0187297, 0.215815, -0.108936, - -0.0941018, 0.129665, -0.159421, 0.190636, 0.597412, - -0.197, 0.0885705, 0.269396, -0.0414511, -0.515137, - -0.03075, 0.158326, -0.296488, 0.177983, 0.519498}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } -} - -TEST_CASE(gru_forward_actv_funcs) -{ - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; - std::vector w_data{ - 0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418, - 0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640, - -0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498, - 0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331, - 0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905}; - - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; - std::vector r_data{ - 0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529, - -0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131, - 0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721, - -0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179, - -0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706, - -0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801, - 0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934, - 0.3645, -0.4310, -0.3480, 0.0702, -0.1558}; - - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - std::vector bias_data{ - 0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946, - -0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494, - 0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607}; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - std::vector input{-0.8432, - -0.9887, - 1.3041, - -2.6430, - -0.3306, - -0.8504, - -0.3933, - 0.5151, - -0.2951, - 0.0093, - -1.1948, - -0.1239, - 0.0373, - 1.3211, - 0.7854, - -0.4838, - -1.0536, - -0.2529}; - - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - std::vector ih_data{ - -0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; - float clip = 0.0f; - - // no activation function specified, so default is used. - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto concat_hs = p.add_instruction( - migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip, 1}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.3969709, - 0.43360898, - 0.35775262, - 0.23280787, - -0.52179873, - -0.21944991, - 0.4535257, - -0.13735442, - 0.51757574, - 0.50380427}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // 1 activation function (sigmoid) specified - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::forward, - clip, - 1}, - seq, - w, - r, - bias, - und, - ih); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{0.26905832, 0.5669211, 0.20464146, 0.67195725, 0.24752215, - 0.11411376, 0.12353572, 0.4245067, 0.73908687, 0.8644615, - 0.34754312, 0.61424744, 0.36769435, 0.6499579, 0.3168031, - 0.3296533, 0.3055136, 0.42514813, 0.6851256, 0.7967266, - 0.35652235, 0.6033026, 0.52634895, 0.5815402, 0.3001663, - 0.39814138, 0.4354002, 0.4310627, 0.6708563, 0.7509278}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // 1 activation function (tanh) specified - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto concat_hs = p.add_instruction( - migraphx::op::gru{ - hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 1}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.49333298, - -0.06104589, - 0.5629142, - -0.97955984, - -0.9314696, - -0.03033514, - 0.5280315, - -0.27354342, - 0.65615714, - 0.53612584}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // seq length of 1 - { - migraphx::program p; - seq_len = 1; - migraphx::shape in_shape_one{migraphx::shape::float_type, - {seq_len, batch_size, input_size}}; - std::vector input_one{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504}; - auto seq = p.add_literal(migraphx::literal{in_shape_one, input_one}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - 1}, - seq, - w, - r, - bias, - und, - ih); - - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.27298412, - 0.42363745, - -0.09368783, - 0.4823072, - -0.02183238, - -0.6873896, - 0.16144305, - 0.31932795, - 0.6104771, - 0.79759157}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } -} - -TEST_CASE(gru_reverse) -{ - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; - std::vector w_data{ - 0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418, - 0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640, - -0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498, - 0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331, - 0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905}; - - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; - std::vector r_data{ - 0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529, - -0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131, - 0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721, - -0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179, - -0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706, - -0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801, - 0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934, - 0.3645, -0.4310, -0.3480, 0.0702, -0.1558}; - - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - std::vector bias_data{ - 0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946, - -0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494, - 0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607}; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - std::vector input{-0.8432, - -0.9887, - 1.3041, - -2.6430, - -0.3306, - -0.8504, - -0.3933, - 0.5151, - -0.2951, - 0.0093, - -1.1948, - -0.1239, - 0.0373, - 1.3211, - 0.7854, - -0.4838, - -1.0536, - -0.2529}; - - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - std::vector ih_data{ - -0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; - float clip = 0.0f; - - // concatenation of hidden states for output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip, - 1}, - seq, - w, - r, - bias, - und, - ih); - - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.263403, 0.317655, -0.00634162, 0.200443, -0.349125, - -0.600874, 0.542386, -0.0856531, 0.55703, 0.54711, - -0.276245, 0.521348, 0.302874, 0.394353, -0.334369, - -0.187861, 0.213553, -0.0708377, 0.545435, 0.654301, - -0.329512, 0.476095, 0.284044, 0.392077, -0.369226, - -0.3275, -0.027301, 0.143774, 0.655686, 0.782831}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // last output for output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto concat_hs = - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip, - 1}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.263403, - 0.317655, - -0.00634162, - 0.200443, - -0.349125, - -0.600874, - 0.542386, - -0.0856531, - 0.55703, - 0.54711}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // last output for output, linear_before_reset = 0 - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto concat_hs = - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip, - 0}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.388654, - 0.384975, - 0.0179455, - 0.350101, - -0.456872, - -0.690085, - 0.534512, - -0.0558191, - 0.646604, - 0.463943}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // no activation function specified, so default is used. - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - p.add_instruction( - migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip, 1}, - seq, - w, - r, - bias, - und, - ih); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.263403, 0.317655, -0.00634162, 0.200443, -0.349125, - -0.600874, 0.542386, -0.0856531, 0.55703, 0.54711, - -0.276245, 0.521348, 0.302874, 0.394353, -0.334369, - -0.187861, 0.213553, -0.0708377, 0.545435, 0.654301, - -0.329512, 0.476095, 0.284044, 0.392077, -0.369226, - -0.3275, -0.027301, 0.143774, 0.655686, 0.782831}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // seq length of 1 - { - migraphx::program p; - seq_len = 1; - migraphx::shape in_shape_one{migraphx::shape::float_type, - {seq_len, batch_size, input_size}}; - std::vector input_one{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504}; - auto seq = p.add_literal(migraphx::literal{in_shape_one, input_one}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip, - 1}, - seq, - w, - r, - bias, - und, - ih); - - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.272984, - 0.423637, - -0.0936878, - 0.482307, - -0.0218324, - -0.68739, - 0.161443, - 0.319328, - 0.610477, - 0.797592}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } -} - -TEST_CASE(gru_bidirectional) -{ - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 3; - std::size_t num_dirct = 2; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; - std::vector w_data{ - 0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418, - 0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109, - 0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732, - 0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294, - 0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778, - - -0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353, - 0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408, - -0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714, - -0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996, - -0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279}; - - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; - std::vector r_data{ - -0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063, - 0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194, - -0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082, - 0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609, - 0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339, - -0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534, - -0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305, - 0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440, - 0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074, - 0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677, - -0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618, - -0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997, - 0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027, - 0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955, - -0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599}; - - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - std::vector bias_data{ - -0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363, - -0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817, - -0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416, - 0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317, - 0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377, - 0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124}; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - std::vector input{-0.8432, - -0.9887, - 1.3041, - -2.6430, - -0.3306, - -0.8504, - -0.3933, - 0.5151, - -0.2951, - 0.0093, - -1.1948, - -0.1239, - 0.0373, - 1.3211, - 0.7854, - -0.4838, - -1.0536, - -0.2529}; - - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - std::vector ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, - 0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340, - 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; - - float clip = 0.0f; - - // concatenation of hidden states for output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 1}, - seq, - w, - r, - bias, - und, - ih); - - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{ - 0.0352243, 0.0146756, 0.00570925, 0.152446, 0.208683, 0.214342, -0.0454273, - -0.135177, -0.0800739, 0.903659, 0.0248217, 0.435231, -0.144448, 0.101531, - -0.111305, 0.381317, 0.468983, 0.230557, 0.348021, 0.180229, -0.0930435, - 0.174108, -0.063834, 0.0909285, 0.22759, -0.221983, -0.139656, -0.0938906, - -0.247681, 0.69647, -0.159396, 0.299061, -0.116652, 0.238649, 0.109945, - 0.192866, 0.307073, 0.191113, 0.658287, -0.0340374, -0.0959787, 0.0794681, - 0.241526, 0.321104, 0.00693533, -0.311839, -0.12802, -0.16643, -0.393849, - 0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665, - 0.079043, 0.322652, 0.752701, 0.243775}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // last output for output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto concat_hs = - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 1}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.0959787, 0.0794681, 0.241526, 0.321104, 0.00693533, - -0.311839, -0.12802, -0.16643, -0.393849, 0.648851, - 0.0248217, 0.435231, -0.144448, 0.101531, -0.111305, - 0.381317, 0.468983, 0.230557, 0.348021, 0.180229}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // last output for output, linear_before_reset = 0 - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto concat_hs = - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 0}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{ - -0.09280921, 0.18506107, 0.32247013, 0.17034212, -0.00115255, -0.29865006, -0.04513004, - -0.10688055, -0.4767866, 0.6317833, 0.00286336, 0.53692746, -0.00617076, 0.04564289, - -0.18030001, 0.39584228, 0.53879917, 0.384983, 0.2759448, 0.11611474}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } -} - -TEST_CASE(gru_bidirectional_args) -{ - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 3; - std::size_t num_dirct = 2; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; - std::vector w_data{ - 0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418, - 0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109, - 0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732, - 0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294, - 0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778, - - -0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353, - 0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408, - -0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714, - -0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996, - -0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279}; - - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; - std::vector r_data{ - -0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063, - 0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194, - -0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082, - 0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609, - 0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339, - -0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534, - -0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305, - 0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440, - 0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074, - 0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677, - -0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618, - -0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997, - 0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027, - 0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955, - -0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599}; - - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - std::vector bias_data{ - -0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363, - -0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817, - -0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416, - 0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317, - 0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377, - 0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124}; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - std::vector input{-0.8432, - -0.9887, - 1.3041, - -2.6430, - -0.3306, - -0.8504, - -0.3933, - 0.5151, - -0.2951, - 0.0093, - -1.1948, - -0.1239, - 0.0373, - 1.3211, - 0.7854, - -0.4838, - -1.0536, - -0.2529}; - - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - std::vector ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, - 0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340, - 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; - - float clip = 0.0f; - - // 3 args - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 0}, - seq, - w, - r); - - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{ - 0.0863793, -0.227845, 0.0283059, -0.258645, 0.14187, 0.43541, 0.190748, - -0.530196, -0.440444, 0.293767, 0.0402142, 0.0788687, -0.013, -0.233298, - -0.0739615, 0.467104, 0.446285, 0.306097, 0.125636, 0.272524, 0.0949838, - 0.0522264, -0.0872712, -0.084203, 0.140013, 0.12739, -0.0111171, -0.431119, - -0.468382, 0.388067, -0.109174, -0.119064, -0.0242958, -0.180555, 0.118983, - 0.341578, 0.275472, 0.0853083, 0.332205, -0.0498387, 0.140338, 0.0319435, - 0.247019, 0.275848, -0.158223, 0.0495464, -0.0681034, -0.418158, -0.523234, - 0.469122, -0.306578, -0.221095, -0.106449, -0.248934, -0.00682121, 0.288407, - 0.198708, 0.0695644, 0.211621, 0.00246037}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // 4 args (bias is used) - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 1}, - seq, - w, - r, - bias); - - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{ - -0.156667, -0.248473, 0.0255282, -0.24566, 0.211589, 0.192707, 0.253025, - -0.515283, -0.414174, 0.227127, 0.124773, 0.284532, -0.203929, -0.120517, - -0.2794, 0.547635, 0.518549, 0.0447674, 0.258461, 0.0502881, -0.219516, - 0.0927382, -0.0760062, -0.0906231, 0.237615, -0.215638, 0.0128074, -0.425813, - -0.433378, 0.375383, -0.0381738, 0.117793, -0.180851, -0.0841245, -0.116649, - 0.419469, 0.393515, -0.076395, 0.427436, -0.264071, -0.185829, 0.0483585, - 0.242955, 0.25233, 0.0148512, -0.304127, -0.0616653, -0.411568, -0.491748, - 0.476508, -0.313413, -0.0361821, -0.173037, -0.235731, -0.163113, 0.349008, - 0.248674, -0.0295413, 0.291437, -0.165005}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // 4 args (ih is used) - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 1}, - seq, - w, - r, - und, - und, - ih); - - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{ - 0.248571, 0.0982155, 0.00808877, 0.0986508, 0.0969705, 0.434692, -0.141696, - -0.164271, -0.121157, 0.863222, -0.0718357, 0.137711, 0.109221, -0.00207995, - 0.0331223, 0.262705, 0.346587, 0.457158, 0.240744, 0.404261, 0.222779, - 0.179757, -0.0845316, 0.0690347, 0.10204, 0.100155, -0.190286, -0.122062, - -0.274379, 0.547281, -0.226753, -0.0397069, 0.120404, 0.171299, 0.259989, - 0.0864604, 0.111322, 0.331784, 0.604653, 0.181017, 0.237426, 0.0911999, - 0.233106, 0.32996, -0.17175, 0.0190231, -0.154805, -0.205631, -0.405354, - 0.519054, -0.380409, -0.0350301, -0.00633752, 0.403791, 0.181883, -0.0977917, - -0.0339407, 0.413089, 0.721238, 0.431879}; - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } -} - -TEST_CASE(gru_bidirectional_actv_funcs) -{ - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 3; - std::size_t num_dirct = 2; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; - std::vector w_data{ - 0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418, - 0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109, - 0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732, - 0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294, - 0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778, - - -0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353, - 0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408, - -0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714, - -0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996, - -0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279}; - - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; - std::vector r_data{ - -0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063, - 0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194, - -0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082, - 0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609, - 0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339, - -0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534, - -0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305, - 0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440, - 0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074, - 0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677, - -0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618, - -0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997, - 0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027, - 0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955, - -0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599}; - - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - std::vector bias_data{ - -0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363, - -0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817, - -0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416, - 0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317, - 0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377, - 0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124}; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - std::vector input{-0.8432, - -0.9887, - 1.3041, - -2.6430, - -0.3306, - -0.8504, - -0.3933, - 0.5151, - -0.2951, - 0.0093, - -1.1948, - -0.1239, - 0.0373, - 1.3211, - 0.7854, - -0.4838, - -1.0536, - -0.2529}; - - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - std::vector ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, - 0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340, - 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; - - float clip = 0.0f; - - // no activation function specified, so default is used. - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto concat_hs = p.add_instruction( - migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip, 1}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{-0.0959787, 0.0794681, 0.241526, 0.321104, 0.00693533, - -0.311839, -0.12802, -0.16643, -0.393849, 0.648851, - 0.0248217, 0.435231, -0.144448, 0.101531, -0.111305, - 0.381317, 0.468983, 0.230557, 0.348021, 0.180229}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // 1 activation function (sigmoid) specified - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 0}, - seq, - w, - r, - bias, - und, - ih); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{ - 0.325495, 0.469214, 0.164517, 0.585327, 0.328398, 0.457928, 0.065011, 0.35986, - 0.545029, 0.859425, 0.427923, 0.667133, 0.41591, 0.540971, 0.365475, 0.482058, - 0.565495, 0.556993, 0.607649, 0.543627, 0.428915, 0.537405, 0.306046, 0.518399, - 0.403561, 0.410694, 0.301163, 0.407397, 0.471334, 0.726446, 0.309389, 0.612072, - 0.360619, 0.590861, 0.366545, 0.367001, 0.433829, 0.501275, 0.72481, 0.512745, - 0.463795, 0.539649, 0.487682, 0.554471, 0.395916, 0.430744, 0.415923, 0.424275, - 0.409655, 0.698256, 0.126883, 0.554374, 0.216137, 0.671491, 0.263833, 0.0678646, - 0.132732, 0.477083, 0.802206, 0.626802}; - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // 1 activation function (tanh) specified - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 1}, - seq, - w, - r, - bias, - und, - ih); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{ - 0.0919632, -0.398302, -0.0267752, -0.326771, 0.401983, 0.949841, 0.557779, - -0.745259, -1.52726, 0.946066, 0.330446, 0.301982, -0.443763, -0.0655817, - -0.326473, 0.861394, 0.560799, -0.101768, 0.145142, 0.128956, -0.329758, - 0.458253, -0.339208, 0.289109, 0.36728, -1.09574, -0.181394, -0.575781, - -0.823083, 0.804262, -0.0965933, 0.20405, -0.430215, 0.00884668, 0.0716857, - 0.844222, 0.516472, -0.191571, 0.596968, -0.545405, -0.336693, -0.0280516, - 0.339058, 1.00367, 0.12655, -0.0984504, -0.174945, -0.5365, 0.183188, - 0.66716, -0.704461, -0.393346, -0.627123, 0.210395, 0.0563026, 0.31419, - 0.759629, 0.000258222, 0.350835, -0.682684}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // 3 activation functions specified - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto concat_hs = p.add_instruction( - migraphx::op::gru{hidden_size, - {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 1}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{0.351019, 0.474363, 0.570719, 0.717703, 0.468843, - 1.15142, 0.457633, 0.300962, 0.361245, 0.666199, - 0.330446, 0.301982, -0.443763, -0.0655817, -0.326473, - 0.861394, 0.560799, -0.101768, 0.145142, 0.128956}; - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // 4 activation functions all specified - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 1}, - seq, - w, - r, - bias, - und, - ih); - - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{ - 0.0352243, 0.0146756, 0.00570925, 0.152446, 0.208683, 0.214342, -0.0454273, - -0.135177, -0.0800739, 0.903659, 0.0248217, 0.435231, -0.144448, 0.101531, - -0.111305, 0.381317, 0.468983, 0.230557, 0.348021, 0.180229, -0.0930435, - 0.174108, -0.063834, 0.0909285, 0.22759, -0.221983, -0.139656, -0.0938906, - -0.247681, 0.69647, -0.159396, 0.299061, -0.116652, 0.238649, 0.109945, - 0.192866, 0.307073, 0.191113, 0.658287, -0.0340374, -0.0959787, 0.0794681, - 0.241526, 0.321104, 0.00693533, -0.311839, -0.12802, -0.16643, -0.393849, - 0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665, - 0.079043, 0.322652, 0.752701, 0.243775}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // seq length of 1 - { - migraphx::program p; - seq_len = 1; - migraphx::shape in_shape_one{migraphx::shape::float_type, - {seq_len, batch_size, input_size}}; - std::vector input_one{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504}; - auto seq = p.add_literal(migraphx::literal{in_shape_one, input_one}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 1}, - seq, - w, - r, - bias, - und, - ih); - - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{0.0352243, 0.0146756, 0.00570925, 0.152446, 0.208683, - 0.214342, -0.0454273, -0.135177, -0.0800739, 0.903659, - -0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078, - -0.144492, -0.0115366, 0.409153, 0.487015, 0.550755}; - - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } -} - -TEST_CASE(lstm_forward) -{ - std::size_t batch_size = 3; - std::size_t seq_len = 4; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - std::vector w_data{ - 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, - 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, - -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, - -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, - -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285}; - - std::vector r_data{ - 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, - -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, - -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, - -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, - 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, - 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, - -0.2169, -0.1344, 0.3468, -0.2260}; - - std::vector bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, - -0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182, - 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807, - 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, - -0.3025, 0.3637, -0.3181, -0.4655}; - - std::vector input_data{ - -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, - -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, - 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, - 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; - - std::vector ih_data{1.9104, - -1.9004, - 0.3337, - 0.5741, - 0.5671, - 0.0458, - 0.4514, - -0.8968, - -0.9201, - 0.1962, - 0.5771, - -0.5332}; - - std::vector ic_data{0.9569, - -0.5981, - 1.1312, - 1.0945, - 1.1055, - -0.1212, - -0.9097, - 0.7831, - -1.6991, - -1.9498, - -1.2567, - -0.4114}; - - std::vector pph_data{1.84369764, - 0.68413646, - -0.44892886, - -1.50904413, - 0.3860796, - -0.52186625, - 1.08474445, - -1.80867321, - 1.32594529, - 0.4336262, - -0.83699064, - 0.49162736}; - - float clip = 0.0f; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; - migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; - - // forward, hidden state concatenation as output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - 0}, - seq, - w, - r, - bias, - und, - ih, - ic, - und); - p.compile(migraphx::cpu::target{}); - - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{ - 0.0417273, -0.272355, 0.206765, 0.223879, 0.138193, -0.0322939, -0.0891815, - 0.15773, 0.19139, -0.127708, -0.409371, -0.136186, 0.0742487, -0.0800085, - 0.259897, 0.0670196, 0.184266, 0.0610048, -0.138041, 0.0963885, 0.0213755, - -0.146027, -0.0324509, -0.0620429, -0.00532985, 0.0440265, 0.29654, -0.0463156, - 0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434, - 0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113, - 0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607}; - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // forward, last_output as program output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto hs = p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - 0}, - seq, - w, - r, - bias, - und, - ih, - ic, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, hs); - p.compile(migraphx::cpu::target{}); - - auto last_hs = p.eval({}); - std::vector output_data; - last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - - std::vector output_data_gold{-0.0847427, - 0.0874114, - 0.304256, - -0.0585745, - -0.0223018, - 0.131113, - 0.135643, - -0.0566208, - 0.142701, - 0.0342236, - -0.198664, - 0.0702607}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - - // forward, last_cell_output as program output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto hs = p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - 0}, - seq, - w, - r, - bias, - und, - ih, - ic, - und); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs); - p.compile(migraphx::cpu::target{}); - - auto last_hs = p.eval({}); - std::vector output_data; - last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - - std::vector output_data_gold{-0.111454, - 0.247794, - 0.471087, - -0.220574, - -0.048196, - 0.263184, - 0.283258, - -0.14882, - 0.605585, - 0.078598, - -0.64457, - 0.119811}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } -} - -TEST_CASE(lstm_forward_more) -{ - std::size_t batch_size = 3; - std::size_t seq_len = 4; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - std::vector w_data{ - 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, - 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, - -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, - -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, - -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285}; - - std::vector r_data{ - 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, - -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, - -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, - -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, - 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, - 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, - -0.2169, -0.1344, 0.3468, -0.2260}; - - std::vector bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, - -0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182, - 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807, - 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, - -0.3025, 0.3637, -0.3181, -0.4655}; - - std::vector input_data{ - -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, - -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, - 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, - 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; - - std::vector ih_data{1.9104, - -1.9004, - 0.3337, - 0.5741, - 0.5671, - 0.0458, - 0.4514, - -0.8968, - -0.9201, - 0.1962, - 0.5771, - -0.5332}; - - std::vector ic_data{0.9569, - -0.5981, - 1.1312, - 1.0945, - 1.1055, - -0.1212, - -0.9097, - 0.7831, - -1.6991, - -1.9498, - -1.2567, - -0.4114}; - - std::vector pph_data{1.84369764, - 0.68413646, - -0.44892886, - -1.50904413, - 0.3860796, - -0.52186625, - 1.08474445, - -1.80867321, - 1.32594529, - 0.4336262, - -0.83699064, - 0.49162736}; - - float clip = 0.0f; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; - migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; - - // forward, 3 args - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - 0}, - seq, - w, - r); - p.compile(migraphx::cpu::target{}); - - auto last_hs = p.eval({}); - std::vector output_data; - last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - - std::vector output_data_gold{ - -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361, - 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.0786602, -0.0613048, - 0.179592, -0.071286, 0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, - -0.0552699, 0.0252681, -0.0562072, -0.102509, -0.0372696, 0.252296, -0.144544, - 0.00496085, 0.0662588, -0.048577, -0.187329, 0.0855831, -0.0171894, -0.140202, - 0.0828391, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, - 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - - // forward, 8 args - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); - auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - 0}, - seq, - w, - r, - bias, - und, - ih, - ic, - pph); - p.compile(migraphx::cpu::target{}); - - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{ - 0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337, - 0.104168, 0.022618, -0.121195, -0.4065, -0.252054, 0.186991, -0.0624168, - 0.205513, 0.0836373, 0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, - -0.0890598, -0.135266, -0.0413375, 0.0459032, 0.0414126, 0.272303, 0.0393149, - 0.218258, 0.0944405, 0.0431211, -0.132394, 0.103489, 0.0142918, -0.123408, - 0.0401075, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, - 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723}; - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } - - // seq_len = 1 - { - seq_len = 1; - migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - std::vector input_data1{ - -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); - auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto hs = p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - 0}, - seq, - w, - r, - bias, - und, - ih, - ic, - pph); - p.add_instruction(migraphx::op::rnn_last_output{}, hs); - p.compile(migraphx::cpu::target{}); - - auto hs_concat = p.eval({}); - std::vector hs_data; - hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); - - std::vector hs_data_gold{0.079753, - -0.289854, - 0.160043, - 0.115056, - 0.294074, - -0.0319677, - -0.0955337, - 0.104168, - 0.022618, - -0.121195, - -0.4065, - -0.252054}; - EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); - } -} - -TEST_CASE(lstm_reverse) -{ - std::size_t batch_size = 3; - std::size_t seq_len = 4; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - std::vector w_data{ - -0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, - -0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, - -0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, - 0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, - -0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; - - std::vector r_data{ - -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707, - 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430, - -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365, - 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360, - 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291, - -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910, - 0.3987, -0.1687, -0.0032, -0.1038}; - - std::vector bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, - -0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, - 0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470, - -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055, - -0.4386, 0.4208, 0.0717, 0.3789}; - - std::vector input_data{ - -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, - -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, - 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, - 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; - - std::vector ih_data{1.5289, - 1.0986, - 0.6091, - 1.6462, - 0.8720, - 0.5349, - -0.1962, - -1.7416, - -0.9912, - 1.2831, - 1.0896, - -0.6959}; - - std::vector ic_data{-0.8323, - 0.3998, - 0.1831, - 0.5938, - 2.7096, - -0.1790, - 0.0022, - -0.8040, - 0.1578, - 0.0567, - 0.8069, - -0.5141}; - - std::vector pph_data{-0.8271, - -0.5683, - 0.4562, - -1.2545, - 1.2729, - -0.4082, - -0.4392, - -0.9406, - 0.7794, - 1.8194, - -0.5811, - 0.2166}; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; - float clip = 0.0f; - // reverse, concatenation of hidden states as program output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip, - 0}, - seq, - w, - r, - bias, - und, - ih, - ic, - pph); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{ - -0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, - 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549, - 0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456, - 0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485, - 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353, - 0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, - 0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - // reverse, 3 args, last cell output as program output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto hs = p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip, - 0}, - seq, - w, - r); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs); - - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{-0.443077, - -0.325425, - -0.249367, - -0.270812, - 0.122913, - 0.118537, - 0.0370199, - -0.0164687, - -0.00754759, - 0.141613, - 0.348002, - 0.667298}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - - // reverse, 3 args, 0 actv function - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto hs = p.add_instruction( - migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip, 0}, - seq, - w, - r); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs); - - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{-0.443077, - -0.325425, - -0.249367, - -0.270812, - 0.122913, - 0.118537, - 0.0370199, - -0.0164687, - -0.00754759, - 0.141613, - 0.348002, - 0.667298}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - - // reverse, 3 args, 1 actv function - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - p.add_instruction(migraphx::op::lstm{hidden_size, - {migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::reverse, - clip, - 0}, - seq, - w, - r); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{ - 0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934, - 0.405483, 0.445586, 0.515814, 0.473186, 0.301937, 0.264893, 0.254353, 0.269231, - 0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213, - 0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216, - 0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634, - 0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - - // reverse, 3 args, 2 actv functions - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto hs = - p.add_instruction(migraphx::op::lstm{hidden_size, - {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::reverse, - clip, - 0}, - seq, - w, - r); - p.add_instruction(migraphx::op::rnn_last_output{}, hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{-0.132123, - -0.37531, - -0.12943, - -0.00798307, - -0.133882, - -0.0251383, - 0.0486486, - -0.0220606, - 0.292495, - 0.233866, - 0.48646, - 0.481844}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - - // reverse, 3 args, seq_len = 1, concatenation of hidden states as program output - { - seq_len = 1; - std::vector input_data1{ - -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; - migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1}); - - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip, - 0}, - seq, - w, - r); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{-0.104351, - -0.0471426, - -0.0905753, - 0.01506, - 0.059797, - 0.104239, - -0.0266768, - 0.0727547, - -0.146298, - 0.070535, - 0.327809, - 0.407388}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } -} - -TEST_CASE(lstm_bidirectional) -{ - std::size_t batch_size = 3; - std::size_t seq_len = 4; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 2; - std::vector w_data{ - 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, - 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, - -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, - -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, - -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715, - -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351, - 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734, - -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346, - 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729, - 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; - - std::vector r_data{ - 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, - -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, - -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, - -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, - 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, - 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, - -0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, - 0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, - 0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, - 0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, - 0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, - -0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, - 0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038}; - - std::vector bias_data{ - 0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274, - -0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, - 0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637, - -0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823, - 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474, - -0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055, - -0.4386, 0.4208, 0.0717, 0.3789}; - - std::vector input_data{ - -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, - -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, - 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, - 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; - - std::vector ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458, - 0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332, - 1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349, - -0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959}; - - std::vector ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212, - -0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114, - -0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790, - 0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141}; - - std::vector pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796, - -0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262, - -0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562, - -1.2545, 1.2729, -0.4082, -0.4392, -0.9406, - 0.7794, 1.8194, -0.5811, 0.2166}; - float clip = 0.0f; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; - - // concatenation of hidden states as program output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 0}, - seq, - w, - r, - bias, - und, - ih, - ic, - pph); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{ - 0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337, - 0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.120174, 0.043157, - 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, - 0.32421, 0.344048, 0.271694, 0.186991, -0.0624168, 0.205513, 0.0836373, - 0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266, - -0.0413375, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685, - 0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, 0.0459032, - 0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394, - 0.103489, 0.0142918, -0.123408, 0.0401075, -0.182201, -0.0232277, 0.235501, - -0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, - 0.146353, 0.0789186, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, - 0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723, - -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508, - -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - - // last hidden state as program output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto hs = p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 0}, - seq, - w, - r, - bias, - und, - ih, - ic, - pph); - p.add_instruction(migraphx::op::rnn_last_output{}, hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{ - -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549, - 0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188, - 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - - // last cell output as program output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); - auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); - auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto hs = p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 0}, - seq, - w, - r, - bias, - und, - ih, - ic, - pph); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{ - -0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934, - 0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334, - 1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - - // 3 args, concatenation of hidden states as program output - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 0}, - seq, - w, - r); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{ - -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361, - 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647, - -0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328, - 0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286, - 0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681, - -0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636, - 0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509, - -0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329, - 0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065, - -0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432, - 0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, - -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, - -0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348, - -0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - - // sequence length is 1, contenation of hidden state as program output - { - migraphx::program p; - seq_len = 1; - migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - std::vector input_data1{ - -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; - auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 0}, - seq, - w, - r); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{ - -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, - -0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, - -0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239, - -0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } -} - -TEST_CASE(lstm_bidirectional_actv_func) -{ - std::size_t batch_size = 3; - std::size_t seq_len = 4; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 2; - std::vector w_data{ - 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, - 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, - -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, - -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, - -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715, - -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351, - 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734, - -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346, - 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729, - 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; - - std::vector r_data{ - 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, - -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, - -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, - -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, - 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, - 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, - -0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, - 0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, - 0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, - 0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, - 0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, - -0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, - 0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038}; - - std::vector input_data{ - -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, - -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, - 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, - 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; - - float clip = 0.0f; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; - // 3 args, 0 actv func - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - p.add_instruction( - migraphx::op::lstm{ - hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip, 0}, - seq, - w, - r); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{ - -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361, - 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647, - -0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328, - 0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286, - 0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681, - -0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636, - 0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509, - -0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329, - 0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065, - -0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432, - 0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, - -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, - -0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348, - -0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - - // 3 args, 1 actv func - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - p.add_instruction(migraphx::op::lstm{hidden_size, - {migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 0}, - seq, - w, - r); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{ - 0.227861, 0.328562, 0.277867, 0.272945, 0.204389, 0.296123, 0.223834, 0.311113, - 0.424666, 0.173974, 0.40628, 0.286631, 0.246078, 0.199709, 0.303753, 0.301178, - 0.264634, 0.304661, 0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186, - 0.339438, 0.29655, 0.331832, 0.242338, 0.409384, 0.236272, 0.306045, 0.26269, - 0.261246, 0.334357, 0.23622, 0.245288, 0.301937, 0.264893, 0.254353, 0.269231, - 0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213, - 0.374123, 0.283167, 0.377129, 0.245726, 0.444712, 0.203168, 0.411446, 0.269965, - 0.172792, 0.296224, 0.17319, 0.352547, 0.310306, 0.262902, 0.276964, 0.295002, - 0.373802, 0.366785, 0.419791, 0.393216, 0.262827, 0.371441, 0.369022, 0.298262, - 0.450186, 0.263538, 0.402895, 0.216177, 0.267257, 0.342535, 0.257797, 0.268563, - 0.193043, 0.275645, 0.167678, 0.350889, 0.334143, 0.309444, 0.174822, 0.251634, - 0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - - // 3 args, 2 actv func - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto hs = - p.add_instruction(migraphx::op::lstm{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 0}, - seq, - w, - r); - p.add_instruction(migraphx::op::rnn_last_output{}, hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{ - -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, - 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, - -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, - 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - - // 3 args, 4 actv func - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto hs = p.add_instruction(migraphx::op::lstm{hidden_size, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 0}, - seq, - w, - r); - p.add_instruction(migraphx::op::rnn_last_output{}, hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{ - -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, - 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, - 0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, - 0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - - // 3 args, 5 actv func - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto hs = p.add_instruction(migraphx::op::lstm{hidden_size, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 0}, - seq, - w, - r); - p.add_instruction(migraphx::op::rnn_last_output{}, hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{ - -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, - 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, - -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, - 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } - - // 3 args, 6 actv func - { - migraphx::program p; - auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); - auto w = p.add_literal(migraphx::literal{w_shape, w_data}); - auto r = p.add_literal(migraphx::literal{r_shape, r_data}); - auto hs = p.add_instruction(migraphx::op::lstm{hidden_size, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - 0}, - seq, - w, - r); - p.add_instruction(migraphx::op::rnn_last_output{}, hs); - p.compile(migraphx::cpu::target{}); - auto hs_concat = p.eval({}); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{ - -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, - 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, - -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, - 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805}; - EXPECT(migraphx::verify_range(output_data, output_data_gold)); - } -} - -int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/dead_code_elimination_test.cpp b/test/dead_code_elimination_test.cpp index 0ee678e486f20b9a43cb9c5e9454dcc61bc21789..bfee9887bb9f26fd432dd8e21a351e840a7d26be 100644 --- a/test/dead_code_elimination_test.cpp +++ b/test/dead_code_elimination_test.cpp @@ -1,9 +1,10 @@ #include #include +#include #include -#include -#include -#include +#include +#include + #include void run_pass(migraphx::program& p) @@ -14,14 +15,14 @@ void run_pass(migraphx::program& p) TEST_CASE(simple_test) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(sum_op{}, one, two); - auto count = std::distance(p.begin(), p.end()); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(sum_op{}, one, two); + auto count = std::distance(mm->begin(), mm->end()); run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) == count); - auto result = p.eval({}); + EXPECT(std::distance(mm->begin(), mm->end()) == count); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{3}); EXPECT(result != migraphx::literal{4}); } @@ -29,15 +30,15 @@ TEST_CASE(simple_test) TEST_CASE(simple_test_nop) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(nop{}); - p.add_instruction(sum_op{}, one, two); - auto count = std::distance(p.begin(), p.end()); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(nop{}); + mm->add_instruction(sum_op{}, one, two); + auto count = std::distance(mm->begin(), mm->end()); run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) == count); - auto result = p.eval({}); + EXPECT(std::distance(mm->begin(), mm->end()) == count); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{3}); EXPECT(result != migraphx::literal{4}); } @@ -45,15 +46,15 @@ TEST_CASE(simple_test_nop) TEST_CASE(simple_test_nop2) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(nop{}); - p.add_instruction(sum_op{}, one, two); - p.add_instruction(nop{}); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(nop{}); + mm->add_instruction(sum_op{}, one, two); + mm->add_instruction(nop{}); run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) == 2); - auto result = p.eval({}); + EXPECT(std::distance(mm->begin(), mm->end()) == 2); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{}); EXPECT(result != migraphx::literal{4}); } @@ -61,15 +62,15 @@ TEST_CASE(simple_test_nop2) TEST_CASE(duplicate_test1) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(sum_op{}, one, two); - p.add_instruction(sum_op{}, one, two); - auto count = std::distance(p.begin(), p.end()); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(sum_op{}, one, two); + mm->add_instruction(sum_op{}, one, two); + auto count = std::distance(mm->begin(), mm->end()); run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) == (count - 1)); - auto result = p.eval({}); + EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1)); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{3}); EXPECT(result != migraphx::literal{4}); } @@ -77,16 +78,16 @@ TEST_CASE(duplicate_test1) TEST_CASE(duplicate_test2) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(sum_op{}, one, two); - p.add_instruction(minus_op{}, one, two); - p.add_instruction(sum_op{}, one, two); - auto count = std::distance(p.begin(), p.end()); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(sum_op{}, one, two); + mm->add_instruction(minus_op{}, one, two); + mm->add_instruction(sum_op{}, one, two); + auto count = std::distance(mm->begin(), mm->end()); run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) == (count - 2)); - auto result = p.eval({}); + EXPECT(std::distance(mm->begin(), mm->end()) == (count - 2)); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{3}); EXPECT(result != migraphx::literal{4}); } @@ -94,18 +95,18 @@ TEST_CASE(duplicate_test2) TEST_CASE(depth_test) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto x1 = p.add_instruction(sum_op{}, one, two); - auto x2 = p.add_instruction(sum_op{}, one, two); - p.add_instruction(minus_op{}, x1, x2); - p.add_instruction(minus_op{}, x1, x2); - p.add_instruction(sum_op{}, one, two); - auto count = std::distance(p.begin(), p.end()); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto x1 = mm->add_instruction(sum_op{}, one, two); + auto x2 = mm->add_instruction(sum_op{}, one, two); + mm->add_instruction(minus_op{}, x1, x2); + mm->add_instruction(minus_op{}, x1, x2); + mm->add_instruction(sum_op{}, one, two); + auto count = std::distance(mm->begin(), mm->end()); run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) == (count - 4)); - auto result = p.eval({}); + EXPECT(std::distance(mm->begin(), mm->end()) == (count - 4)); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{3}); EXPECT(result != migraphx::literal{4}); } @@ -113,16 +114,17 @@ TEST_CASE(depth_test) TEST_CASE(undefined_test) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto undef = p.add_instruction(migraphx::op::undefined{}); - p.add_instruction(sum_op{}, one, two); - auto count = std::distance(p.begin(), p.end()); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(migraphx::make_op("undefined")); + mm->add_instruction(sum_op{}, one, two); + auto count = std::distance(mm->begin(), mm->end()); run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) == count - 1); - EXPECT(not p.has_instruction(undef)); - auto result = p.eval({}); + EXPECT(std::distance(mm->begin(), mm->end()) == count - 1); + EXPECT( + std::none_of(mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "undefined"; })); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{3}); EXPECT(result != migraphx::literal{4}); } @@ -130,52 +132,123 @@ TEST_CASE(undefined_test) TEST_CASE(duplicate_args1) { migraphx::program p; - - auto l0 = p.add_literal(0); - auto l3 = p.add_literal(3); - p.add_instruction(migraphx::op::add{}, l3, l3); - p.add_instruction(migraphx::op::identity{}, l0); - auto count = std::distance(p.begin(), p.end()); + auto* mm = p.get_main_module(); + auto l0 = mm->add_literal(0); + auto l3 = mm->add_literal(3); + mm->add_instruction(migraphx::make_op("add"), l3, l3); + mm->add_instruction(migraphx::make_op("identity"), l0); + auto count = std::distance(mm->begin(), mm->end()); run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) != count); - EXPECT(std::distance(p.begin(), p.end()) == 2); - auto result = p.eval({}); + EXPECT(std::distance(mm->begin(), mm->end()) != count); + EXPECT(std::distance(mm->begin(), mm->end()) == 2); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{0}); } TEST_CASE(duplicate_args2) { migraphx::program p; - - auto l0 = p.add_literal(0); - auto l3 = p.add_literal(3); - auto sum1 = p.add_instruction(migraphx::op::add{}, l0, l3); - p.add_instruction(migraphx::op::add{}, sum1, l3); - p.add_instruction(migraphx::op::identity{}, l0); - auto count = std::distance(p.begin(), p.end()); + auto* mm = p.get_main_module(); + auto l0 = mm->add_literal(0); + auto l3 = mm->add_literal(3); + auto sum1 = mm->add_instruction(migraphx::make_op("add"), l0, l3); + mm->add_instruction(migraphx::make_op("add"), sum1, l3); + mm->add_instruction(migraphx::make_op("identity"), l0); + auto count = std::distance(mm->begin(), mm->end()); run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) != count); - EXPECT(std::distance(p.begin(), p.end()) == 2); - auto result = p.eval({}); + EXPECT(std::distance(mm->begin(), mm->end()) != count); + EXPECT(std::distance(mm->begin(), mm->end()) == 2); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{0}); } TEST_CASE(duplicate_args3) { migraphx::program p; - - auto l0 = p.add_literal(0); - auto l3 = p.add_literal(3); - auto sum1 = p.add_instruction(migraphx::op::add{}, l0, l3); - auto sum2 = p.add_instruction(migraphx::op::add{}, l0, sum1); - p.add_instruction(migraphx::op::add{}, sum2, l3); - p.add_instruction(migraphx::op::identity{}, l0); - auto count = std::distance(p.begin(), p.end()); + auto* mm = p.get_main_module(); + auto l0 = mm->add_literal(0); + auto l3 = mm->add_literal(3); + auto sum1 = mm->add_instruction(migraphx::make_op("add"), l0, l3); + auto sum2 = mm->add_instruction(migraphx::make_op("add"), l0, sum1); + mm->add_instruction(migraphx::make_op("add"), sum2, l3); + mm->add_instruction(migraphx::make_op("identity"), l0); + auto count = std::distance(mm->begin(), mm->end()); run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) != count); - EXPECT(std::distance(p.begin(), p.end()) == 2); - auto result = p.eval({}); + EXPECT(std::distance(mm->begin(), mm->end()) != count); + EXPECT(std::distance(mm->begin(), mm->end()) == 2); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{0}); } +TEST_CASE(reused_twice) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector dims = {1, 2, 2}; + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, dims}); + auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, dims}); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); + auto epsilon = mm->add_literal(1e-12f); + auto exponent = mm->add_literal(2.0f); + + auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), add2); + auto mean_mbcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean); + auto sub = mm->add_instruction(migraphx::make_op("sub"), add2, mean_mbcast); + auto exponent_mbcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent); + auto pow = mm->add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast); + auto var = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow); + auto epsilon_mbcast = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon); + auto add_epsilon = mm->add_instruction(migraphx::make_op("add"), var, epsilon_mbcast); + mm->add_instruction(migraphx::make_op("sqrt"), add_epsilon); + mm->add_instruction(migraphx::make_op("add"), x, y); + + auto count = std::distance(mm->begin(), mm->end()); + run_pass(p); + p.debug_print(); + EXPECT(std::distance(mm->begin(), mm->end()) != count); + EXPECT(std::distance(mm->begin(), mm->end()) == 4); +} + +TEST_CASE(unused_module) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto* m1 = p.create_module("unused"); + auto* m2 = p.create_module("used"); + auto l0 = mm->add_literal(0); + m1->add_literal(0); + m2->add_literal(0); + mm->add_instruction(mod_pass_op{}, {l0}, {m2}); + EXPECT(migraphx::contains(p.get_modules(), m1)); + EXPECT(migraphx::contains(p.get_modules(), m2)); + run_pass(p); + EXPECT(migraphx::contains(p.get_modules(), m2)); + EXPECT(not migraphx::contains(p.get_modules(), m1)); +} + +TEST_CASE(param_not_eliminated) +{ + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int32_type, {2, 2}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + mm->add_parameter("z", s); + auto sum = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_return({sum}); + + return p; + }; + + auto p = create_program(); + run_pass(p); + EXPECT(p == create_program()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/dom.cpp b/test/dom.cpp new file mode 100644 index 0000000000000000000000000000000000000000..765df2bf51c79f032c124fac9064f580c244bc7e --- /dev/null +++ b/test/dom.cpp @@ -0,0 +1,29 @@ +#include +#include +#include +#include + +TEST_CASE(dom1) +{ + migraphx::module mm; + auto ins1 = mm.add_parameter("entry", {migraphx::shape::float_type}); + auto ins2 = mm.add_instruction(pass_op{}, ins1); + auto ins3 = mm.add_instruction(pass_op{}, ins2); + auto ins4 = mm.add_instruction(pass_op{}, ins2); + auto ins5 = mm.add_instruction(pass_op{}, ins3, ins4); + auto ins6 = mm.add_instruction(pass_op{}, ins2); + + auto dom = migraphx::compute_dominator(mm); + EXPECT(dom.strictly_dominate(ins1, ins2)); + EXPECT(dom.strictly_dominate(ins2, ins3)); + EXPECT(dom.strictly_dominate(ins2, ins4)); + EXPECT(dom.strictly_dominate(ins2, ins5)); + EXPECT(dom.strictly_dominate(ins2, ins6)); + + EXPECT(not dom.strictly_dominate(ins3, ins6)); + EXPECT(not dom.strictly_dominate(ins4, ins6)); + EXPECT(not dom.strictly_dominate(ins3, ins5)); + EXPECT(not dom.strictly_dominate(ins4, ins5)); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/dot_apply_alpha_beta_test.cpp b/test/dot_apply_alpha_beta_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..63f4287f3c030d41bfc9e33bfc3ed2a30ba3288e --- /dev/null +++ b/test/dot_apply_alpha_beta_test.cpp @@ -0,0 +1,134 @@ +#include +#include +#include +#include +#include +#include + +TEST_CASE(dot_apply_alpha_beta_half) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}}); + auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}}); + auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}}); + auto dot_res = migraphx::insert_apply_alpha_beta( + m1, m1.end(), {x, y, z}, migraphx::make_op("dot"), 3.0f, 2.0f); + m1.add_instruction(migraphx::make_op("identity"), dot_res); + } + migraphx::module m2; + { + + auto ht = migraphx::shape::half_type; + auto ft = migraphx::shape::float_type; + auto x = m2.add_parameter("x", migraphx::shape{ht, {2, 2}}); + auto y = m2.add_parameter("y", migraphx::shape{ht, {2, 2}}); + auto z = m2.add_parameter("z", migraphx::shape{ht, {2, 2}}); + auto alpha_literal = m2.add_literal(3.0f); + auto alpha_broadcast = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), + alpha_literal); + auto x_float = m2.add_instruction(migraphx::make_op("convert", {{"target_type", ft}}), x); + auto x_alpha_float = m2.add_instruction(migraphx::make_op("mul"), alpha_broadcast, x_float); + auto x_half = + m2.add_instruction(migraphx::make_op("convert", {{"target_type", ht}}), x_alpha_float); + auto dot_res = m2.add_instruction(migraphx::make_op("dot"), x_half, y); + auto beta_literal = m2.add_literal(2.0f); + auto z_float = m2.add_instruction(migraphx::make_op("convert", {{"target_type", ft}}), z); + auto beta_broadcast = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", z->get_shape().lens()}}), + beta_literal); + auto z_beta_float = m2.add_instruction(migraphx::make_op("mul"), z_float, beta_broadcast); + auto z_beta_half = + m2.add_instruction(migraphx::make_op("convert", {{"target_type", ht}}), z_beta_float); + auto z_add = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_half); + m2.add_instruction(migraphx::make_op("identity"), z_add); + } + EXPECT(m1 == m2); +} + +TEST_CASE(dot_apply_alpha_beta_double) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}}); + auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}}); + auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 1}}); + auto dot_res = + migraphx::add_apply_alpha_beta(m1, {x, y, z}, migraphx::make_op("dot"), 3.0f, 2.0f); + m1.add_instruction(migraphx::make_op("identity"), dot_res); + } + migraphx::module m2; + { + + auto dt = migraphx::shape::double_type; + auto x = m2.add_parameter("x", migraphx::shape{dt, {2, 2}}); + auto y = m2.add_parameter("y", migraphx::shape{dt, {2, 2}}); + auto z = m2.add_parameter("z", migraphx::shape{dt, {2, 1}}); + auto alpha_literal = m2.add_literal(3.0f); + auto alpha_broadcast = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), + alpha_literal); + auto alpha_double = m2.add_instruction(migraphx::make_op("convert", {{"target_type", dt}}), + alpha_broadcast); + auto x_alpha_double = m2.add_instruction(migraphx::make_op("mul"), alpha_double, x); + auto dot_res = m2.add_instruction(migraphx::make_op("dot"), x_alpha_double, y); + auto z_broadcast = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), z); + auto beta_literal = m2.add_literal(2.0f); + auto beta_broadcast = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", z_broadcast->get_shape().lens()}}), + beta_literal); + auto beta_double = + m2.add_instruction(migraphx::make_op("convert", {{"target_type", dt}}), beta_broadcast); + auto z_beta_double = m2.add_instruction(migraphx::make_op("mul"), z_broadcast, beta_double); + auto z_add = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_double); + m2.add_instruction(migraphx::make_op("identity"), z_add); + } + EXPECT(m1 == m2); +} + +TEST_CASE(quant_dot_apply_alpha_beta) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}}); + auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}}); + auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); + auto dot_res = migraphx::insert_apply_alpha_beta(m1, + m1.end(), + {x, y, z}, + migraphx::make_op("quant_dot"), + migraphx::literal{int32_t{3}}, + migraphx::literal{int32_t{2}}); + m1.add_instruction(migraphx::make_op("identity"), dot_res); + } + migraphx::module m2; + { + + auto i8 = migraphx::shape::int8_type; + auto i32 = migraphx::shape::int32_type; + auto x = m2.add_parameter("x", migraphx::shape{i8, {2, 2}}); + auto y = m2.add_parameter("y", migraphx::shape{i8, {2, 2}}); + auto z = m2.add_parameter("z", migraphx::shape{i32, {2, 2}}); + auto alpha_literal = m2.add_literal(int32_t(3)); + auto alpha_broadcast = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), + alpha_literal); + auto x_i32 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", i32}}), x); + auto x_alpha_i32 = m2.add_instruction(migraphx::make_op("mul"), alpha_broadcast, x_i32); + auto x_i8 = + m2.add_instruction(migraphx::make_op("convert", {{"target_type", i8}}), x_alpha_i32); + auto dot_res = m2.add_instruction(migraphx::make_op("quant_dot"), x_i8, y); + auto beta_literal = m2.add_literal(int32_t(2)); + auto beta_broadcast = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", z->get_shape().lens()}}), + beta_literal); + auto z_beta_i32 = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast); + auto z_add = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_i32); + m2.add_instruction(migraphx::make_op("identity"), z_add); + } + EXPECT(m1 == m2); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/eliminate_allocation_test.cpp b/test/eliminate_allocation_test.cpp index ab2064790d0468f73c9aabfb9c697c5bdb625de1..da40ac0b8772dddda8b951076f9abc5eb18b4824 100644 --- a/test/eliminate_allocation_test.cpp +++ b/test/eliminate_allocation_test.cpp @@ -6,10 +6,10 @@ #include #include -void run_pass(migraphx::program& p, std::size_t align = 32) +void run_pass(migraphx::module& m, std::size_t align = 32) { migraphx::run_passes( - p, {migraphx::eliminate_allocation{"allocate", align}, migraphx::dead_code_elimination{}}); + m, {migraphx::eliminate_allocation{"allocate", align}, migraphx::dead_code_elimination{}}); } struct allocate @@ -25,7 +25,7 @@ struct allocate std::string name() const { return "allocate"; } migraphx::shape compute_shape(const std::vector& inputs) const { - migraphx::check_shapes{inputs}.has(0); + migraphx::check_shapes{inputs, *this}.has(0); return s; } migraphx::argument compute(migraphx::context&, @@ -38,70 +38,74 @@ struct allocate TEST_CASE(basic) { - migraphx::program p; - auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {8}}}); - auto p1 = p.add_instruction(pass_op{}, a1); + migraphx::module m; - auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {40}}}); - auto p2 = p.add_instruction(pass_op{}, a2, p1); + auto a1 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {8}}}); + auto m1 = m.add_instruction(pass_op{}, a1); - auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); - p.add_instruction(pass_op{}, a3, p2); + auto a2 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {40}}}); + auto m2 = m.add_instruction(pass_op{}, a2, m1); - run_pass(p); - EXPECT(p.get_shape() == migraphx::shape{migraphx::shape::float_type, {200}}); - EXPECT(p.get_parameter_shape("memory").bytes() == (8 * 4 + 40 * 4 + 200 * 4)); + auto a3 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); + m.add_instruction(pass_op{}, a3, m2); + + run_pass(m); + EXPECT(m.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}}); + EXPECT(m.get_parameter_shape("memory").bytes() == (8 * 4 + 40 * 4 + 200 * 4)); } TEST_CASE(aligned) { - migraphx::program p; - auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}}); - auto p1 = p.add_instruction(pass_op{}, a1); + migraphx::module m; + + auto a1 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}}); + auto m1 = m.add_instruction(pass_op{}, a1); - auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}}); - auto p2 = p.add_instruction(pass_op{}, a2, p1); + auto a2 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}}); + auto m2 = m.add_instruction(pass_op{}, a2, m1); - auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); - p.add_instruction(pass_op{}, a3, p2); + auto a3 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); + m.add_instruction(pass_op{}, a3, m2); - run_pass(p); - EXPECT(p.get_shape() == migraphx::shape{migraphx::shape::float_type, {200}}); - EXPECT(p.get_parameter_shape("memory").bytes() == (32 + 32 + 200 * 4)); + run_pass(m); + EXPECT(m.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}}); + EXPECT(m.get_parameter_shape("memory").bytes() == (32 + 32 + 200 * 4)); } TEST_CASE(unaligned) { - migraphx::program p; - auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}}); - auto p1 = p.add_instruction(pass_op{}, a1); + migraphx::module m; - auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}}); - auto p2 = p.add_instruction(pass_op{}, a2, p1); + auto a1 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}}); + auto m1 = m.add_instruction(pass_op{}, a1); - auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); - p.add_instruction(pass_op{}, a3, p2); + auto a2 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}}); + auto m2 = m.add_instruction(pass_op{}, a2, m1); - run_pass(p, 1); - EXPECT(p.get_shape() == migraphx::shape{migraphx::shape::float_type, {200}}); - EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4)); + auto a3 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); + m.add_instruction(pass_op{}, a3, m2); + + run_pass(m, 1); + EXPECT(m.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}}); + EXPECT(m.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4)); } TEST_CASE(float_aligned) { - migraphx::program p; - auto a1 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}}); - auto p1 = p.add_instruction(pass_op{}, a1); + migraphx::module m; + + auto a1 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1}}}); + auto m1 = m.add_instruction(pass_op{}, a1); - auto a2 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}}); - auto p2 = p.add_instruction(pass_op{}, a2, p1); + auto a2 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2}}}); + auto m2 = m.add_instruction(pass_op{}, a2, m1); - auto a3 = p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); - p.add_instruction(pass_op{}, a3, p2); + auto a3 = m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {200}}}); + m.add_instruction(pass_op{}, a3, m2); - run_pass(p, 4); - EXPECT(p.get_shape() == migraphx::shape{migraphx::shape::float_type, {200}}); - EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4)); + run_pass(m, 4); + EXPECT(m.get_output_shapes().back() == migraphx::shape{migraphx::shape::float_type, {200}}); + EXPECT(m.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4)); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/eliminate_common_subexpression_test.cpp b/test/eliminate_common_subexpression_test.cpp index 2f454382a70d98f2a0ab8bd5d60e95fc9f597654..64f8797cbd35e5cc6cb777c3fe1470760956dc85 100644 --- a/test/eliminate_common_subexpression_test.cpp +++ b/test/eliminate_common_subexpression_test.cpp @@ -1,8 +1,9 @@ #include #include #include -#include #include +#include + #include void run_pass(migraphx::program& p) @@ -11,102 +12,195 @@ void run_pass(migraphx::program& p) p, {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}}); } +void run_pass(migraphx::module& m) +{ + migraphx::run_passes( + m, {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}}); +} + TEST_CASE(cse_test1) { - migraphx::program p1; + migraphx::module m1; { - auto one = p1.add_literal(1); - auto two = p1.add_literal(2); - auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two); - auto sum2 = p1.add_instruction(migraphx::op::add{}, one, two); - auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); - p1.add_instruction(pass_op{}, sum3); + auto one = m1.add_literal(1); + auto two = m1.add_literal(2); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, two); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), one, two); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); + m1.add_instruction(pass_op{}, sum3); } - run_pass(p1); + run_pass(m1); - migraphx::program p2; + migraphx::module m2; { - auto one = p2.add_literal(1); - auto two = p2.add_literal(2); - auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two); - auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum1); - p2.add_instruction(pass_op{}, sum3); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, two); + auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum1, sum1); + m2.add_instruction(pass_op{}, sum3); } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } TEST_CASE(cse_test2) { - migraphx::program p1; + migraphx::module m1; { - auto one = p1.add_literal(1); - auto two = p1.add_literal(2); - auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two); - auto sum2 = p1.add_instruction(migraphx::op::add{}, two, one); - auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); - p1.add_instruction(pass_op{}, sum3); + auto one = m1.add_literal(1); + auto two = m1.add_literal(2); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, two); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), two, one); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); + m1.add_instruction(pass_op{}, sum3); } - run_pass(p1); + run_pass(m1); - migraphx::program p2; + migraphx::module m2; { - auto one = p2.add_literal(1); - auto two = p2.add_literal(2); - auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two); - auto sum2 = p2.add_instruction(migraphx::op::add{}, two, one); - auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum2); - p2.add_instruction(pass_op{}, sum3); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, two); + auto sum2 = m2.add_instruction(migraphx::make_op("add"), two, one); + auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum1, sum2); + m2.add_instruction(pass_op{}, sum3); } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } TEST_CASE(cse_test3) { - migraphx::program p1; + migraphx::module m1; { - auto one = p1.add_literal(1); - auto two = p1.add_literal(1); - auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two); - auto sum2 = p1.add_instruction(migraphx::op::add{}, two, one); - auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); - p1.add_instruction(pass_op{}, sum3); + auto one = m1.add_literal(1); + auto two = m1.add_literal(1); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, two); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), two, one); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); + m1.add_instruction(pass_op{}, sum3); } - run_pass(p1); + run_pass(m1); - migraphx::program p2; + migraphx::module m2; { - auto one = p2.add_literal(1); - auto sum1 = p2.add_instruction(migraphx::op::add{}, one, one); - auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum1); - p2.add_instruction(pass_op{}, sum3); + auto one = m2.add_literal(1); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, one); + auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum1, sum1); + m2.add_instruction(pass_op{}, sum3); } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } TEST_CASE(cse_test4) { - migraphx::program p1; + migraphx::module m1; + { + auto one = m1.add_literal(1); + auto two = m1.add_literal(1); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, two); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), two, one); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, one); + auto sum4 = m1.add_instruction(migraphx::make_op("add"), sum2, two); + auto sum5 = m1.add_instruction(migraphx::make_op("add"), sum4, sum3); + m1.add_instruction(pass_op{}, sum5); + } + run_pass(m1); + + migraphx::module m2; + { + auto one = m2.add_literal(1); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, one); + auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum1, one); + auto sum5 = m2.add_instruction(migraphx::make_op("add"), sum3, sum3); + m2.add_instruction(pass_op{}, sum5); + } + EXPECT(m1 == m2); +} + +TEST_CASE(cse_test_literal) +{ + migraphx::module m1; { - auto one = p1.add_literal(1); - auto two = p1.add_literal(1); - auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two); - auto sum2 = p1.add_instruction(migraphx::op::add{}, two, one); - auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, one); - auto sum4 = p1.add_instruction(migraphx::op::add{}, sum2, two); - auto sum5 = p1.add_instruction(migraphx::op::add{}, sum4, sum3); - p1.add_instruction(pass_op{}, sum5); + auto six1 = m1.add_literal(6); + auto zero1 = m1.add_literal(0); + auto six2 = m1.add_literal(6); + auto zero2 = m1.add_literal(0); + auto six3 = m1.add_literal(6); + auto zero3 = m1.add_literal(0); + + auto sum1 = m1.add_instruction(migraphx::make_op("add"), six1, zero1); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), six2, zero2); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), six3, zero3); + auto sum4 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); + auto sum5 = m1.add_instruction(migraphx::make_op("add"), sum3, sum4); + m1.add_instruction(pass_op{}, sum5); } - run_pass(p1); + run_pass(m1); - migraphx::program p2; + migraphx::module m2; { - auto one = p2.add_literal(1); - auto sum1 = p2.add_instruction(migraphx::op::add{}, one, one); - auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, one); - auto sum5 = p2.add_instruction(migraphx::op::add{}, sum3, sum3); - p2.add_instruction(pass_op{}, sum5); + auto six = m2.add_literal(6); + auto zero = m2.add_literal(0); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), six, zero); + auto sum2 = m2.add_instruction(migraphx::make_op("add"), sum1, sum1); + auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum1, sum2); + m2.add_instruction(pass_op{}, sum3); } - EXPECT(p1 == p2); + EXPECT(m1 == m2); +} + +TEST_CASE(cse_test_submodule) +{ + migraphx::shape si{migraphx::shape::int64_type}; + migraphx::shape s{migraphx::shape::int64_type, {1}}; + migraphx::shape sc{migraphx::shape::bool_type}; + + auto create_program = [&](bool remove_literal = false) { + migraphx::program p; + std::vector vc = {true}; + std::vector vd = {3}; + auto* mm = p.get_main_module(); + + auto in_cond = mm->add_parameter("ccond", sc); + auto in_val = mm->add_parameter("val", s); + auto b0 = mm->add_literal(migraphx::literal(sc, vc)); + auto b1 = b0; + if(not(remove_literal)) + b1 = mm->add_literal(migraphx::literal(sc, vc)); + + auto* body1 = p.create_module("loop_module1"); + body1->add_parameter("#loop_module_in_1", sc); + auto in_v1 = body1->add_parameter("#loop_module_in_2", s); + auto l1 = body1->add_literal(migraphx::literal(si, vd)); + auto ad1 = body1->add_instruction(migraphx::make_op("add"), l1, l1); + auto val1 = body1->add_instruction(migraphx::make_op("add"), in_v1, ad1); + auto cond1 = body1->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), b0); + auto cond2 = body1->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), b1); + body1->add_return({cond1, cond2, val1, val1}); + + auto* body2 = p.create_module("loop_module2"); + body2->add_parameter("#loop_module_in_1", sc); + auto in_v2 = body2->add_parameter("#loop_module_in_2", s); + auto l2 = body2->add_literal(migraphx::literal(si, vd)); + auto ad2 = body2->add_instruction(migraphx::make_op("add"), l2, l2); + auto val2 = body2->add_instruction(migraphx::make_op("add"), in_v2, ad2); + auto cond3 = body2->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), b1); + body2->add_return({cond3, val2, val2}); + + auto loop1 = mm->add_instruction( + migraphx::make_op("loop", {{"max_iterations", 1}}), {in_cond, in_val}, {body1}); + auto loop2 = mm->add_instruction( + migraphx::make_op("loop", {{"max_iterations", 1}}), {in_cond, in_val}, {body2}); + + mm->add_return({loop1, loop2}); + + return p; + }; + auto p = create_program(); + run_pass(p); + EXPECT(p == create_program(true)); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/eliminate_concat_test.cpp b/test/eliminate_concat_test.cpp index f0a0d445725869b13bb3903beec7beb341bfd307..d6531123788de6c740cd5cea08eb99cdfa06dde7 100644 --- a/test/eliminate_concat_test.cpp +++ b/test/eliminate_concat_test.cpp @@ -4,6 +4,8 @@ #include #include #include +#include +#include #include #include @@ -18,10 +20,18 @@ struct concat return migraphx::reflect(self.op, f); } + migraphx::value attributes() const + { + migraphx::value normalize; + normalize["axis"] = migraphx::value::array{migraphx::op::normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + std::string name() const { return "eliminate_concat::concat"; } - migraphx::shape compute_shape(std::vector inputs) const + migraphx::shape normalize_compute_shape(std::vector inputs) const { - return op.compute_shape(std::move(inputs)); + inputs.pop_back(); + return op.normalize_compute_shape(std::move(inputs)); } migraphx::argument compute(migraphx::context&, const migraphx::shape& output_shape, @@ -44,9 +54,9 @@ struct concat_test_optimization } }; -void run_pass(migraphx::program& p) +void run_pass(migraphx::module& m) { - migraphx::run_passes(p, + migraphx::run_passes(m, {migraphx::eliminate_concat{concat_test_optimization{}}, migraphx::dead_code_elimination{}}); } @@ -64,7 +74,7 @@ struct allocate std::string name() const { return "allocate"; } migraphx::shape compute_shape(const std::vector& inputs) const { - migraphx::check_shapes{inputs}.has(0); + migraphx::check_shapes{inputs, *this}.has(0); return s; } migraphx::argument compute(migraphx::context&, @@ -80,7 +90,7 @@ struct simple_op std::string name() const { return "simple_op"; } migraphx::shape compute_shape(const std::vector& inputs) const { - migraphx::check_shapes{inputs}.has(1); + migraphx::check_shapes{inputs, *this}.has(1); return inputs.at(0); } migraphx::argument compute(migraphx::context&, @@ -104,164 +114,288 @@ using identity = migraphx::op::identity; TEST_CASE(simple) { auto create_test_program = [] { - migraphx::program p; - auto a1 = p.add_instruction(allocate{create_shape(1)}); - auto p1 = p.add_instruction(simple_op{}, a1); - auto a2 = p.add_instruction(allocate{create_shape(1)}); - auto p2 = p.add_instruction(simple_op{}, a2); + migraphx::module m; + + auto a1 = m.add_instruction(allocate{create_shape(1)}); + auto m1 = m.add_instruction(simple_op{}, a1); + auto a2 = m.add_instruction(allocate{create_shape(1)}); + auto m2 = m.add_instruction(simple_op{}, a2); std::size_t axis = 0; - auto a3 = p.add_instruction(allocate{create_shape(2)}); - p.add_instruction(concat(axis), p1, p2, a3); - return p; + auto a3 = m.add_instruction(allocate{create_shape(2)}); + m.add_instruction(concat(axis), m1, m2, a3); + return m; + }; + auto create_control_program = [] { + migraphx::module m; + + auto a1 = m.add_instruction(allocate{create_shape(2)}); + auto l1 = m.add_instruction(load{create_shape(1), 0}, a1); + auto m1 = m.add_instruction(simple_op{}, l1); + auto l2 = m.add_instruction(load{create_shape(1), 4}, a1); + auto m2 = m.add_instruction(simple_op{}, l2); + m.add_instruction(identity{}, a1, m1, m2); + return m; + }; + + auto m1 = create_test_program(); + auto m2 = create_control_program(); + run_pass(m1); + + EXPECT(m1 == m2); +} + +TEST_CASE(negative_axis1) +{ + auto create_test_program = [] { + migraphx::module m; + + auto a1 = m.add_instruction(allocate{create_shape(2, 2)}); + auto m1 = m.add_instruction(simple_op{}, a1); + auto a2 = m.add_instruction(allocate{create_shape(2, 2)}); + auto m2 = m.add_instruction(simple_op{}, a2); + std::size_t axis = -1; + auto a3 = m.add_instruction(allocate{create_shape(4, 2)}); + m.add_instruction(concat(axis), m1, m2, a3); + return m; + }; + auto create_control_program = create_test_program; + + auto m1 = create_test_program(); + auto m2 = create_control_program(); + run_pass(m1); + + EXPECT(m1 == m2); +} + +TEST_CASE(negative_axis2) +{ + auto create_test_program = [] { + migraphx::module m; + + auto a1 = m.add_instruction(allocate{create_shape(2, 2)}); + auto m1 = m.add_instruction(simple_op{}, a1); + auto a2 = m.add_instruction(allocate{create_shape(2, 2)}); + auto m2 = m.add_instruction(simple_op{}, a2); + std::size_t axis = -2; + auto a3 = m.add_instruction(allocate{create_shape(4, 2)}); + m.add_instruction(concat(axis), m1, m2, a3); + return m; }; auto create_control_program = [] { - migraphx::program p; - auto a1 = p.add_instruction(allocate{create_shape(2)}); - auto l1 = p.add_instruction(load{create_shape(1), 0}, a1); - auto p1 = p.add_instruction(simple_op{}, l1); - auto l2 = p.add_instruction(load{create_shape(1), 4}, a1); - auto p2 = p.add_instruction(simple_op{}, l2); - p.add_instruction(identity{}, a1, p1, p2); - return p; + migraphx::module m; + + auto a1 = m.add_instruction(allocate{create_shape(4, 2)}); + auto l1 = m.add_instruction(load{create_shape(2, 2), 0}, a1); + auto m1 = m.add_instruction(simple_op{}, l1); + auto l2 = m.add_instruction(load{create_shape(2, 2), 16}, a1); + auto m2 = m.add_instruction(simple_op{}, l2); + m.add_instruction(identity{}, a1, m1, m2); + return m; + }; + + auto m1 = create_test_program(); + auto m2 = create_control_program(); + run_pass(m1); + + EXPECT(m1 == m2); +} + +TEST_CASE(negative_axis3) +{ + auto create_test_program = [] { + migraphx::module m; + + auto a1 = m.add_instruction(allocate{create_shape(1, 2, 2)}); + auto m1 = m.add_instruction(simple_op{}, a1); + auto a2 = m.add_instruction(allocate{create_shape(1, 2, 2)}); + auto m2 = m.add_instruction(simple_op{}, a2); + std::size_t axis = -2; + auto a3 = m.add_instruction(allocate{create_shape(1, 4, 2)}); + m.add_instruction(concat(axis), m1, m2, a3); + return m; + }; + auto create_control_program = [] { + migraphx::module m; + + auto a1 = m.add_instruction(allocate{create_shape(1, 4, 2)}); + auto l1 = m.add_instruction(load{create_shape(1, 2, 2), 0}, a1); + auto m1 = m.add_instruction(simple_op{}, l1); + auto l2 = m.add_instruction(load{create_shape(1, 2, 2), 16}, a1); + auto m2 = m.add_instruction(simple_op{}, l2); + m.add_instruction(identity{}, a1, m1, m2); + return m; + }; + + auto m1 = create_test_program(); + auto m2 = create_control_program(); + run_pass(m1); + + EXPECT(m1 == m2); +} + +TEST_CASE(reversed) +{ + auto create_test_program = [] { + migraphx::module m; + + auto a1 = m.add_instruction(allocate{create_shape(1)}); + auto m1 = m.add_instruction(simple_op{}, a1); + auto a2 = m.add_instruction(allocate{create_shape(1)}); + auto m2 = m.add_instruction(simple_op{}, a2); + std::size_t axis = 0; + auto a3 = m.add_instruction(allocate{create_shape(2)}); + m.add_instruction(concat(axis), m2, m1, a3); + return m; + }; + auto create_control_program = [] { + migraphx::module m; + + auto a1 = m.add_instruction(allocate{create_shape(2)}); + auto l1 = m.add_instruction(load{create_shape(1), 4}, a1); + auto m1 = m.add_instruction(simple_op{}, l1); + auto l2 = m.add_instruction(load{create_shape(1), 0}, a1); + auto m2 = m.add_instruction(simple_op{}, l2); + m.add_instruction(identity{}, a1, m2, m1); + return m; }; - auto p1 = create_test_program(); - auto p2 = create_control_program(); - run_pass(p1); + auto m1 = create_test_program(); + auto m2 = create_control_program(); + run_pass(m1); - EXPECT(p1 == p2); + EXPECT(m1 == m2); } TEST_CASE(nested) { - auto concat_test_program = [](auto& p) { - auto a1 = p.add_instruction(allocate{create_shape(1)}); - auto p1 = p.add_instruction(simple_op{}, a1); - auto a2 = p.add_instruction(allocate{create_shape(1)}); - auto p2 = p.add_instruction(simple_op{}, a2); + auto concat_test_program = [](auto& m) { + auto a1 = m.add_instruction(allocate{create_shape(1)}); + auto m1 = m.add_instruction(simple_op{}, a1); + auto a2 = m.add_instruction(allocate{create_shape(1)}); + auto m2 = m.add_instruction(simple_op{}, a2); std::size_t axis = 0; - auto a3 = p.add_instruction(allocate{create_shape(2)}); - return p.add_instruction(concat(axis), p1, p2, a3); + auto a3 = m.add_instruction(allocate{create_shape(2)}); + return m.add_instruction(concat(axis), m1, m2, a3); }; auto create_test_program = [&] { - migraphx::program p; - auto concat1 = concat_test_program(p); - auto concat2 = concat_test_program(p); + migraphx::module m; + auto concat1 = concat_test_program(m); + auto concat2 = concat_test_program(m); std::size_t axis = 0; - auto a1 = p.add_instruction(allocate{create_shape(4)}); - p.add_instruction(concat(axis), concat1, concat2, a1); - return p; + auto a1 = m.add_instruction(allocate{create_shape(4)}); + m.add_instruction(concat(axis), concat1, concat2, a1); + return m; }; - auto concat_control_program = [](auto& p, auto a1) { - auto l1 = p.add_instruction(load{create_shape(1), 0}, a1); - auto p1 = p.add_instruction(simple_op{}, l1); - auto l2 = p.add_instruction(load{create_shape(1), 4}, a1); - auto p2 = p.add_instruction(simple_op{}, l2); - return p.add_instruction(identity{}, a1, p1, p2); + auto concat_control_program = [](auto& m, auto a1) { + auto l1 = m.add_instruction(load{create_shape(1), 0}, a1); + auto m1 = m.add_instruction(simple_op{}, l1); + auto l2 = m.add_instruction(load{create_shape(1), 4}, a1); + auto m2 = m.add_instruction(simple_op{}, l2); + return m.add_instruction(identity{}, a1, m1, m2); }; auto create_control_program = [&] { - migraphx::program p; - auto a1 = p.add_instruction(allocate{create_shape(4)}); - auto l1 = p.add_instruction(load{create_shape(2), 0}, a1); - auto concat1 = concat_control_program(p, l1); - auto l2 = p.add_instruction(load{create_shape(2), 8}, a1); - auto concat2 = concat_control_program(p, l2); - p.add_instruction(identity{}, a1, concat1, concat2); - return p; + migraphx::module m; + auto a1 = m.add_instruction(allocate{create_shape(4)}); + auto l1 = m.add_instruction(load{create_shape(2), 0}, a1); + auto concat1 = concat_control_program(m, l1); + auto l2 = m.add_instruction(load{create_shape(2), 8}, a1); + auto concat2 = concat_control_program(m, l2); + m.add_instruction(identity{}, a1, concat1, concat2); + return m; }; - auto p1 = create_test_program(); - auto p2 = create_control_program(); - run_pass(p1); + auto m1 = create_test_program(); + auto m2 = create_control_program(); + run_pass(m1); - EXPECT(p1 == p2); + EXPECT(m1 == m2); } TEST_CASE(basic) { auto create_test_program = [] { - migraphx::program p; + migraphx::module m; auto a1 = - p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}}); - auto p1 = p.add_instruction(simple_op{}, a1); + m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}}); + auto m1 = m.add_instruction(simple_op{}, a1); auto a2 = - p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}}); - auto p2 = p.add_instruction(simple_op{}, a2); + m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}}); + auto m2 = m.add_instruction(simple_op{}, a2); auto a3 = - p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}}); - auto p3 = p.add_instruction(simple_op{}, a3); + m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}}); + auto p3 = m.add_instruction(simple_op{}, a3); std::size_t axis = 1; - auto a4 = p.add_instruction( + auto a4 = m.add_instruction( allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}}); - p.add_instruction(concat(axis), p1, p2, p3, a4); - return p; + m.add_instruction(concat(axis), m1, m2, p3, a4); + return m; }; auto create_control_program = [] { - migraphx::program p; - auto a1 = p.add_instruction( + migraphx::module m; + auto a1 = m.add_instruction( allocate{migraphx::shape{migraphx::shape::float_type, {1, 10, 8, 8}}}); - auto l1 = p.add_instruction( + auto l1 = m.add_instruction( load{migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 8}}, 0}, {a1}); - auto p1 = p.add_instruction(simple_op{}, l1); - auto l2 = p.add_instruction( + auto m1 = m.add_instruction(simple_op{}, l1); + auto l2 = m.add_instruction( load{migraphx::shape{migraphx::shape::float_type, {1, 3, 8, 8}}, 512}, {a1}); - auto p2 = p.add_instruction(simple_op{}, l2); - auto l3 = p.add_instruction( + auto m2 = m.add_instruction(simple_op{}, l2); + auto l3 = m.add_instruction( load{migraphx::shape{migraphx::shape::float_type, {1, 5, 8, 8}}, 1280}, {a1}); - auto p3 = p.add_instruction(simple_op{}, l3); - p.add_instruction(identity{}, {a1, p1, p2, p3}); - return p; + auto p3 = m.add_instruction(simple_op{}, l3); + m.add_instruction(identity{}, {a1, m1, m2, p3}); + return m; }; - auto p1 = create_test_program(); - auto p2 = create_control_program(); - run_pass(p1); + auto m1 = create_test_program(); + auto m2 = create_control_program(); + run_pass(m1); - EXPECT(p1 == p2); + EXPECT(m1 == m2); } TEST_CASE(wont_work) { auto create_test_program = [] { - migraphx::program p; + migraphx::module m; auto a1 = - p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}}); - auto p1 = p.add_instruction(simple_op{}, a1); + m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}}); + auto m1 = m.add_instruction(simple_op{}, a1); auto a2 = - p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}}); - auto p2 = p.add_instruction(simple_op{}, a2); + m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}}); + auto m2 = m.add_instruction(simple_op{}, a2); auto a3 = - p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}}); - auto p3 = p.add_instruction(simple_op{}, a3); + m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}}); + auto p3 = m.add_instruction(simple_op{}, a3); std::size_t axis = 1; - auto a4 = p.add_instruction( + auto a4 = m.add_instruction( allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}}); - p.add_instruction(concat(axis), p1, p2, p3, a4); - return p; + m.add_instruction(concat(axis), m1, m2, p3, a4); + return m; }; auto create_control_program = [] { - migraphx::program p; + migraphx::module m; auto a1 = - p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}}); - auto p1 = p.add_instruction(simple_op{}, a1); + m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 2, 8, 8}}}); + auto m1 = m.add_instruction(simple_op{}, a1); auto a2 = - p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}}); - auto p2 = p.add_instruction(simple_op{}, a2); + m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 3, 8, 8}}}); + auto m2 = m.add_instruction(simple_op{}, a2); auto a3 = - p.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}}); - auto p3 = p.add_instruction(simple_op{}, a3); + m.add_instruction(allocate{migraphx::shape{migraphx::shape::float_type, {2, 5, 8, 8}}}); + auto p3 = m.add_instruction(simple_op{}, a3); std::size_t axis = 1; - auto a4 = p.add_instruction( + auto a4 = m.add_instruction( allocate{migraphx::shape{migraphx::shape::float_type, {2, 10, 8, 8}}}); - p.add_instruction(concat(axis), p1, p2, p3, a4); - return p; + m.add_instruction(concat(axis), m1, m2, p3, a4); + return m; }; - auto p1 = create_test_program(); - auto p2 = create_control_program(); - run_pass(p1); + auto m1 = create_test_program(); + auto m2 = create_control_program(); + run_pass(m1); - EXPECT(p1 == p2); + EXPECT(m1 == m2); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/eliminate_contiguous_test.cpp b/test/eliminate_contiguous_test.cpp old mode 100644 new mode 100755 index 771f54e486633006fd11ae2575b9a65febba8631..e2b96171e24f0221cdef75735a0e62ff099a69c3 --- a/test/eliminate_contiguous_test.cpp +++ b/test/eliminate_contiguous_test.cpp @@ -1,115 +1,185 @@ #include #include #include -#include -#include -#include -#include -#include -#include +#include #include +#include + +#include #include -void run_pass(migraphx::program& p) +void run_pass(migraphx::module& m) { - migraphx::run_passes(p, {migraphx::eliminate_contiguous{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes( + m, {migraphx::eliminate_contiguous{"contiguous"}, migraphx::dead_code_elimination{}}); } TEST_CASE(standard_op) { - migraphx::program p; - auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}}); - auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); - auto c = p.add_instruction(migraphx::op::contiguous{}, t); - p.add_instruction(pass_standard_op{}, c); - auto count = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) == count); + migraphx::module m; + + auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}}); + auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + auto c = m.add_instruction(migraphx::make_op("contiguous"), t); + m.add_instruction(pass_standard_op{}, c); + auto count = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(std::distance(m.begin(), m.end()) == count); } TEST_CASE(standard_op_const) { - migraphx::program p; - auto l = p.add_literal(get_2x2()); - auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); - auto c = p.add_instruction(migraphx::op::contiguous{}, t); - p.add_instruction(pass_standard_op{}, c); - run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) == 2); + migraphx::module m; + + auto l = m.add_literal(get_2x2()); + auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + auto c = m.add_instruction(migraphx::make_op("contiguous"), t); + m.add_instruction(pass_standard_op{}, c); + run_pass(m); + EXPECT(std::distance(m.begin(), m.end()) == 2); } TEST_CASE(non_standard_op) { - migraphx::program p; - auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}}); - auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); - auto c = p.add_instruction(migraphx::op::contiguous{}, t); - p.add_instruction(pass_op{}, c); - auto count = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) == count); + migraphx::module m; + + auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}}); + auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + auto c = m.add_instruction(migraphx::make_op("contiguous"), t); + m.add_instruction(pass_op{}, c); + auto count = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(std::distance(m.begin(), m.end()) == count); } TEST_CASE(non_standard_op_const) { - migraphx::program p; - auto l = p.add_literal(get_2x2()); - auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); - auto c = p.add_instruction(migraphx::op::contiguous{}, t); - p.add_instruction(pass_op{}, c); - run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) == 2); + migraphx::module m; + + auto l = m.add_literal(get_2x2()); + auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + auto c = m.add_instruction(migraphx::make_op("contiguous"), t); + m.add_instruction(pass_op{}, c); + run_pass(m); + EXPECT(std::distance(m.begin(), m.end()) == 2); } -TEST_CASE(transpose_gemm) +TEST_CASE(transpose_gem) { - migraphx::program p; - auto l = p.add_literal(get_2x2()); - auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); - auto c = p.add_instruction(migraphx::op::contiguous{}, t); - auto ic = p.add_instruction(migraphx::op::identity{}, c); - p.add_instruction(migraphx::op::dot{}, ic, l); - auto count = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) == (count - 1)); + migraphx::module m; + + auto l = m.add_literal(get_2x2()); + auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + auto c = m.add_instruction(migraphx::make_op("contiguous"), t); + auto ic = m.add_instruction(migraphx::make_op("identity"), c); + m.add_instruction(migraphx::make_op("dot"), ic, l); + auto count = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(std::distance(m.begin(), m.end()) == (count - 1)); } TEST_CASE(transpose_standard_op) { - migraphx::program p; - auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}}); - auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); - auto c = p.add_instruction(migraphx::op::contiguous{}, t); - auto sn = p.add_instruction(migraphx::op::sin{}, c); - p.add_instruction(pass_standard_op{}, sn); - auto count = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) == count); + migraphx::module m; + + auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}}); + auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + auto c = m.add_instruction(migraphx::make_op("contiguous"), t); + auto sn = m.add_instruction(migraphx::make_op("sin"), c); + m.add_instruction(pass_standard_op{}, sn); + auto count = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(std::distance(m.begin(), m.end()) == count); } TEST_CASE(transpose_standard_op_const) { - migraphx::program p; - auto l = p.add_literal(get_2x2()); - auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); - auto c = p.add_instruction(migraphx::op::contiguous{}, t); - auto sn = p.add_instruction(migraphx::op::sin{}, c); - p.add_instruction(pass_standard_op{}, sn); - run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) == 3); + migraphx::module m; + + auto l = m.add_literal(get_2x2()); + auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + auto c = m.add_instruction(migraphx::make_op("contiguous"), t); + auto sn = m.add_instruction(migraphx::make_op("sin"), c); + m.add_instruction(pass_standard_op{}, sn); + run_pass(m); + EXPECT(std::distance(m.begin(), m.end()) == 3); } TEST_CASE(no_packed_unary_op) { + migraphx::module m; + + auto l = m.add_literal(get_2x2()); + auto t = m.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), l); + auto c = m.add_instruction(migraphx::make_op("contiguous"), t); + auto sn = m.add_instruction(migraphx::make_op("sin"), c); + m.add_instruction(pass_standard_op{}, sn); + auto count = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(std::distance(m.begin(), m.end()) == count - 1); +} + +TEST_CASE(non_standard_return_input) +{ + migraphx::module m; + + auto l = m.add_literal(get_2x2()); + auto tl = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + auto c = m.add_instruction(migraphx::make_op("contiguous"), tl); + m.add_return({c}); + auto count = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(std::distance(m.begin(), m.end()) == count); +} + +TEST_CASE(non_standard_flatten_op) +{ + migraphx::module m; + + auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 6, 6, 6}}); + auto t = m.add_instruction( + migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {6, 6}}}), l); + auto c = m.add_instruction(migraphx::make_op("contiguous"), t); + m.add_instruction(migraphx::make_op("flatten"), c); + auto count = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(std::distance(m.begin(), m.end()) == count); +} + +TEST_CASE(standard_flatten_op) +{ + migraphx::module m; + + auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 6, 6, 6}}); + auto t = m.add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 1}}, {"ends", {6, 6}}}), l); + auto c = m.add_instruction(migraphx::make_op("contiguous"), t); + m.add_instruction(migraphx::make_op("flatten"), c); + auto count = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(std::distance(m.begin(), m.end()) == (count - 1)); +} + +TEST_CASE(contiguous_pointwise) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 8, 8}}; migraphx::program p; - auto l = p.add_literal(get_2x2()); - auto t = p.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l); - auto c = p.add_instruction(migraphx::op::contiguous{}, t); - auto sn = p.add_instruction(migraphx::op::sin{}, c); - p.add_instruction(pass_standard_op{}, sn); - auto count = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(std::distance(p.begin(), p.end()) == count - 1); + auto* mm = p.get_main_module(); + { + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {3}}); + auto yb = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 3, 8, 8}}}), y); + auto yc = mm->add_instruction(migraphx::make_op("contiguous"), yb); + auto add = add_pointwise(p, "main:pointwise0", {x, yc}, single_pointwise("add")); + mm->add_instruction(pass_op{}, add); + } + auto count = std::distance(mm->begin(), mm->end()); + run_pass(*mm); + EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1)); + EXPECT(std::none_of( + mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "contiguous"; })); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/eliminate_data_type_test.cpp b/test/eliminate_data_type_test.cpp new file mode 100755 index 0000000000000000000000000000000000000000..884128e9b42fc1691d04c73a78ba3a56df0c4272 --- /dev/null +++ b/test/eliminate_data_type_test.cpp @@ -0,0 +1,71 @@ +#include +#include +#include +#include +#include +#include + +#include + +void run_pass(migraphx::module& m, std::set types) +{ + migraphx::run_passes( + m, + {migraphx::eliminate_data_type{std::move(types), migraphx::shape::float_type}, + migraphx::eliminate_identity{}, + migraphx::dead_code_elimination{}}); +} + +TEST_CASE(simple) +{ + migraphx::shape s{migraphx::shape::int8_type, {2, 2}}; + migraphx::module mm1; + { + auto x = mm1.add_parameter("x", s); + auto y = mm1.add_parameter("y", s); + mm1.add_instruction(migraphx::make_op("add"), x, y); + } + run_pass(mm1, {migraphx::shape::int8_type}); + + migraphx::module mm2; + { + auto x = mm2.add_parameter("x", s); + auto y = mm2.add_parameter("y", s); + auto floatx = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), x); + auto floaty = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), y); + auto add = mm2.add_instruction(migraphx::make_op("add"), floatx, floaty); + mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), add); + } + EXPECT(mm1 == mm2); +} + +TEST_CASE(quant) +{ + migraphx::shape s{migraphx::shape::int8_type, {2, 2}}; + migraphx::module mm1; + { + auto x = mm1.add_parameter("x", s); + auto y = mm1.add_parameter("y", s); + mm1.add_instruction(migraphx::make_op("quant_dot"), x, y); + } + run_pass(mm1, {migraphx::shape::int8_type}); + + migraphx::module mm2; + { + auto x = mm2.add_parameter("x", s); + auto y = mm2.add_parameter("y", s); + auto floatx = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), x); + auto floaty = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), y); + auto add = mm2.add_instruction(migraphx::make_op("dot"), floatx, floaty); + mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), add); + } + EXPECT(mm1 == mm2); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/eliminate_identity_test.cpp b/test/eliminate_identity_test.cpp index 34b89e8935beb199ed6b36d5e9cf7cad7082a1af..b389077b27fbfbf869d52e2024dedc673ea389ff 100644 --- a/test/eliminate_identity_test.cpp +++ b/test/eliminate_identity_test.cpp @@ -3,25 +3,31 @@ #include #include #include -#include +#include + #include -void run_pass(migraphx::program& p) { migraphx::run_passes(p, {migraphx::eliminate_identity{}}); } +void run_pass(migraphx::program& p) +{ + migraphx::run_passes(*p.get_main_module(), {migraphx::eliminate_identity{}}); +} TEST_CASE(simple_test) { migraphx::program p; - auto one = p.add_literal(1); - auto one_identity = p.add_instruction(migraphx::op::identity{}, one); - auto two = p.add_literal(2); - auto two_identity = p.add_instruction(migraphx::op::identity{}, two); - p.add_instruction(sum_op{}, one_identity, two_identity); + auto* mm = p.get_main_module(); + + auto one = mm->add_literal(1); + auto one_identity = mm->add_instruction(migraphx::make_op("identity"), one); + auto two = mm->add_literal(2); + auto two_identity = mm->add_instruction(migraphx::make_op("identity"), two); + mm->add_instruction(sum_op{}, one_identity, two_identity); run_pass(p); - EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { + EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { return ins.name() == "identity"; })); - auto result = p.eval({}); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{3}); } @@ -29,15 +35,17 @@ TEST_CASE(simple_test_end) { migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto ans = p.add_instruction(sum_op{}, one, two); - p.add_instruction(migraphx::op::identity{}, ans); + auto* mm = p.get_main_module(); + + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto ans = mm->add_instruction(sum_op{}, one, two); + mm->add_instruction(migraphx::make_op("identity"), ans); run_pass(p); - EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { + EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { return ins.name() == "identity"; })); - auto result = p.eval({}); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{3}); } @@ -45,17 +53,19 @@ TEST_CASE(simple_test_end_dependency) { migraphx::program p; - auto one = p.add_literal(1.0); - auto two = p.add_literal(2.0); - auto three = p.add_literal(3.0); - auto ans = p.add_instruction(sum_op{}, one, two); - p.add_instruction(sum_op{}, ans, three); - p.add_instruction(migraphx::op::identity{}, ans); + auto* mm = p.get_main_module(); + + auto one = mm->add_literal(1.0); + auto two = mm->add_literal(2.0); + auto three = mm->add_literal(3.0); + auto ans = mm->add_instruction(sum_op{}, one, two); + mm->add_instruction(sum_op{}, ans, three); + mm->add_instruction(migraphx::make_op("identity"), ans); run_pass(p); - EXPECT(std::any_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { + EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { return ins.name() == "identity"; })); - auto result = p.eval({}); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{3.0}); } diff --git a/test/eliminate_pad_test.cpp b/test/eliminate_pad_test.cpp index 39907c0b552951ceb79e102b4ba96ca4b19275fc..80dc90821ab4eb6d201c01e53c62f30067457d3d 100644 --- a/test/eliminate_pad_test.cpp +++ b/test/eliminate_pad_test.cpp @@ -1,82 +1,109 @@ #include +#include #include #include #include #include #include +#include + #include -void run_pass(migraphx::program& p) +void run_pass(migraphx::module& m) { - migraphx::run_passes(p, {migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes( + m, + {migraphx::normalize_ops{}, migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}}); } migraphx::instruction_ref -create_im2col(migraphx::instruction_ref& l_img, size_t channels, migraphx::program& p) +create_im2col(migraphx::instruction_ref& l_img, size_t channels, migraphx::module& m) { size_t f[2] = {1, 1}; std::vector weights(channels * f[0] * f[1]); - migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; - auto l_weights = p.add_literal(migraphx::literal{s_weights, weights}); - return p.add_instruction(migraphx::op::im2col{}, l_img, l_weights); + auto l_weights = m.add_literal(migraphx::literal{s_weights, weights}); + return m.add_instruction(migraphx::make_op("im2col"), l_img, l_weights); } migraphx::instruction_ref create_conv(migraphx::instruction_ref& l_img, size_t channels, - migraphx::program& p, + migraphx::module& m, migraphx::op::padding_mode_t padding_mode = migraphx::op::padding_mode_t::default_) { migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}}; std::vector weights(4 * channels * 3 * 3); - - auto l_weights = p.add_literal(migraphx::literal{s_weights, weights}); + auto l_weights = m.add_literal(migraphx::literal{s_weights, weights}); migraphx::op::convolution op; op.padding_mode = padding_mode; - return p.add_instruction(op, l_img, l_weights); + return m.add_instruction(op, l_img, l_weights); } -TEST_CASE(rewrite_test) +TEST_CASE(rewrite_pad) { - migraphx::program p; - + migraphx::module m; size_t img_dim[2] = {2, 2}; size_t channels = 1; std::vector input(channels * img_dim[0] * img_dim[1]); std::iota(input.begin(), input.end(), 0); migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}}; - auto l_img = p.add_literal(migraphx::literal{s_img, input}); - auto padded_img = p.add_instruction(migraphx::op::pad{{0, 0, 1, 1, 0, 0, 1, 1}}, l_img); + auto l_img = m.add_literal(migraphx::literal{s_img, input}); + auto padded_img = + m.add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 1, 1, 0, 0, 1, 1}}}), l_img); + + auto l0 = create_im2col(padded_img, channels, m); + auto l1 = create_conv(padded_img, channels, m); + auto l2 = m.add_instruction( + migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), padded_img); + m.add_instruction(migraphx::make_op("identity"), l0, l1, l2); - auto l0 = create_im2col(padded_img, channels, p); - auto l1 = create_conv(padded_img, channels, p); - auto l2 = p.add_instruction(migraphx::op::pooling{}, padded_img); - p.add_instruction(migraphx::op::identity{}, l0, l1, l2); + auto s0 = l0->get_shape(); + auto s1 = l1->get_shape(); + auto s2 = l2->get_shape(); + run_pass(m); + EXPECT(l0->get_shape() == s0); + EXPECT(l1->get_shape() == s1); + EXPECT(l2->get_shape() == s2); + auto op0 = l0->get_operator().to_value(); + auto om1 = l1->get_operator().to_value(); + auto om2 = l2->get_operator().to_value(); + + EXPECT(op0["padding"].to_vector() == std::vector{1, 1, 1, 1}); + EXPECT(om1["padding"].to_vector() == std::vector{1, 1, 1, 1}); + EXPECT(om2["padding"].to_vector() == std::vector{1, 1, 1, 1}); - run_pass(p); EXPECT(std::none_of( - p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; })); + m.begin(), m.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; })); } -TEST_CASE(rewrite_test_asymmetric) +TEST_CASE(rewrite_pad_im2col_asymmetric) { - migraphx::program p; + migraphx::module m; + size_t img_dim[2] = {2, 2}; size_t channels = 1; std::vector input(channels * img_dim[0] * img_dim[1]); std::iota(input.begin(), input.end(), 0); migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}}; - auto l_img = p.add_literal(migraphx::literal{s_img, input}); - auto padded_img = p.add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 2, 2}}, l_img); + auto l_img = m.add_literal(migraphx::literal{s_img, input}); + auto padded_img = + m.add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 2, 2}}}), l_img); + + auto l0 = create_im2col(padded_img, channels, m); - create_im2col(padded_img, channels, p); + auto s0 = l0->get_shape(); + run_pass(m); + EXPECT(l0->get_shape() == s0); + auto op0 = l0->get_operator().to_value(); - run_pass(p); - EXPECT(std::any_of( - p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; })); + EXPECT(op0["padding"].to_vector() == std::vector{0, 0, 2, 2}); + + run_pass(m); + EXPECT(std::none_of( + m.begin(), m.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; })); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/eval_test.cpp b/test/eval_test.cpp index 9315be6a986e15bed990d9be624cc3d4096c4f4a..f98ca281a96339c92d1db4db76d367499c96beb7 100644 --- a/test/eval_test.cpp +++ b/test/eval_test.cpp @@ -71,7 +71,7 @@ struct reverse_pass { std::string name() const { return "reverse_pass"; } - void apply(migraphx::program& p) const { std::reverse(p.begin(), p.end()); } + void apply(migraphx::module& m) const { std::reverse(m.begin(), m.end()); } }; struct reverse_target @@ -89,17 +89,17 @@ struct invert_pass { std::string name() const { return "invert_pass"; } - void apply(migraphx::program& p) const + void apply(migraphx::module& m) const { - for(auto ins : migraphx::iterator_for(p)) + for(auto ins : migraphx::iterator_for(m)) { if(ins->name() == "sum") { - p.replace_instruction(ins, minus_op{}, ins->inputs()); + m.replace_instruction(ins, minus_op{}, ins->inputs()); } else if(ins->name() == "minus") { - p.replace_instruction(ins, sum_op{}, ins->inputs()); + m.replace_instruction(ins, sum_op{}, ins->inputs()); } } } @@ -130,11 +130,11 @@ struct double_invert_target TEST_CASE(literal_test1) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(sum_op{}, one, two); - auto result = p.eval({}); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(sum_op{}, one, two); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{3}); EXPECT(result != migraphx::literal{4}); } @@ -142,13 +142,13 @@ TEST_CASE(literal_test1) TEST_CASE(literal_test2) { migraphx::program p; + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto sum1 = mm->add_instruction(sum_op{}, one, two); + mm->add_instruction(sum_op{}, sum1, two); - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum1 = p.add_instruction(sum_op{}, one, two); - p.add_instruction(sum_op{}, sum1, two); - - auto result = p.eval({}); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{5}); EXPECT(result != migraphx::literal{3}); } @@ -156,10 +156,10 @@ TEST_CASE(literal_test2) TEST_CASE(print_test) { migraphx::program p; - - auto x = p.add_parameter("x", {migraphx::shape::int32_type}); - auto two = p.add_literal(2); - p.add_instruction(sum_op{}, x, two); + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", {migraphx::shape::int32_type}); + auto two = mm->add_literal(2); + mm->add_instruction(sum_op{}, x, two); std::stringstream ss; ss << p; @@ -170,13 +170,14 @@ TEST_CASE(print_test) TEST_CASE(param_test) { migraphx::program p; - - auto x = p.add_parameter("x", {migraphx::shape::int32_type}); - auto y = p.add_parameter("y", {migraphx::shape::int32_type}); - - p.add_instruction(sum_op{}, x, y); - auto result = p.eval( - {{"x", migraphx::literal{1}.get_argument()}, {"y", migraphx::literal{2}.get_argument()}}); + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", {migraphx::shape::int32_type}); + auto y = mm->add_parameter("y", {migraphx::shape::int32_type}); + + mm->add_instruction(sum_op{}, x, y); + auto result = p.eval({{"x", migraphx::literal{1}.get_argument()}, + {"y", migraphx::literal{2}.get_argument()}}) + .back(); EXPECT(result == migraphx::literal{3}); EXPECT(result != migraphx::literal{4}); } @@ -184,11 +185,11 @@ TEST_CASE(param_test) TEST_CASE(param_error_test) { migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", {migraphx::shape::int32_type}); + auto y = mm->add_parameter("y", {migraphx::shape::int32_type}); - auto x = p.add_parameter("x", {migraphx::shape::int32_type}); - auto y = p.add_parameter("y", {migraphx::shape::int32_type}); - - p.add_instruction(sum_op{}, x, y); + mm->add_instruction(sum_op{}, x, y); EXPECT(test::throws( [&] { p.eval({{"x", migraphx::literal{1}.get_argument()}}); @@ -199,11 +200,11 @@ TEST_CASE(param_error_test) TEST_CASE(param_error_shape_test) { migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 1}}); + auto y = mm->add_parameter("y", {migraphx::shape::int32_type, {1, 1}}); - auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 1}}); - auto y = p.add_parameter("y", {migraphx::shape::int32_type, {1, 1}}); - - p.add_instruction(sum_op{}, x, y); + mm->add_instruction(sum_op{}, x, y); EXPECT(test::throws( [&] { p.eval({ @@ -217,31 +218,34 @@ TEST_CASE(param_error_shape_test) TEST_CASE(get_param1) { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::int32_type, {1, 2}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - p.add_instruction(sum_op{}, x, y); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + mm->add_instruction(sum_op{}, x, y); EXPECT(bool{p.get_parameter("x") == x}); EXPECT(bool{p.get_parameter("y") == y}); - EXPECT(bool{p.get_parameter("nonexistent") == p.end()}); + EXPECT(bool{p.get_parameter("nonexistent") == mm->end()}); } TEST_CASE(get_param2) { migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(sum_op{}, one, two); - EXPECT(bool{p.get_parameter("nonexistent") == p.end()}); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(sum_op{}, one, two); + EXPECT(bool{p.get_parameter("nonexistent") == mm->end()}); } TEST_CASE(get_param_shapes) { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::int32_type, {1, 2}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - p.add_instruction(sum_op{}, x, y); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + mm->add_instruction(sum_op{}, x, y); auto m = p.get_parameter_shapes(); EXPECT(m.count("nonexistent") == 0); EXPECT(m.at("x") == s); @@ -251,14 +255,14 @@ TEST_CASE(get_param_shapes) TEST_CASE(replace_test) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.replace_instruction(sum, minus_op{}, two, one); - EXPECT(bool{p.validate() == p.end()}); - - auto result = p.eval({}); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto sum = mm->add_instruction(sum_op{}, one, two); + mm->replace_instruction(sum, minus_op{}, two, one); + EXPECT(bool{p.validate() == mm->end()}); + + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{1}); EXPECT(result != migraphx::literal{3}); } @@ -266,15 +270,15 @@ TEST_CASE(replace_test) TEST_CASE(replace_ins_test) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - auto minus = p.add_instruction(minus_op{}, two, one); - p.replace_instruction(sum, minus); - EXPECT(bool{p.validate() == p.end()}); - - auto result = p.eval({}); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto sum = mm->add_instruction(sum_op{}, one, two); + auto minus = mm->add_instruction(minus_op{}, two, one); + mm->replace_instruction(sum, minus); + EXPECT(bool{p.validate() == mm->end()}); + + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{1}); EXPECT(result != migraphx::literal{3}); } @@ -282,16 +286,16 @@ TEST_CASE(replace_ins_test) TEST_CASE(replace_ins_test2) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - auto minus = p.add_instruction(minus_op{}, two, one); - p.add_instruction(pass_op{}, minus); - p.replace_instruction(two, sum); - EXPECT(bool{p.validate() == p.end()}); - - auto result = p.eval({}); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto sum = mm->add_instruction(sum_op{}, one, two); + auto minus = mm->add_instruction(minus_op{}, two, one); + mm->add_instruction(pass_op{}, minus); + mm->replace_instruction(two, sum); + EXPECT(bool{p.validate() == mm->end()}); + + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{2}); EXPECT(result != migraphx::literal{3}); } @@ -299,14 +303,14 @@ TEST_CASE(replace_ins_test2) TEST_CASE(replace_op_test) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, two, one); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto sum = mm->add_instruction(sum_op{}, two, one); sum->replace(minus_op{}); - EXPECT(bool{p.validate() == p.end()}); + EXPECT(bool{p.validate() == mm->end()}); - auto result = p.eval({}); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{1}); EXPECT(result != migraphx::literal{3}); } @@ -314,27 +318,27 @@ TEST_CASE(replace_op_test) TEST_CASE(replace_op_recompute_shape_throw) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto sum = mm->add_instruction(sum_op{}, one, two); EXPECT(test::throws([&] { sum->replace(unary_pass_op{}); })); } TEST_CASE(insert_replace_test) { migraphx::program p; + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto sum1 = mm->add_instruction(sum_op{}, one, two); + mm->add_instruction(sum_op{}, sum1, two); - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum1 = p.add_instruction(sum_op{}, one, two); - p.add_instruction(sum_op{}, sum1, two); + auto sum0 = mm->insert_instruction(sum1, sum_op{}, two, two); + mm->replace_instruction(sum1, minus_op{}, sum0, two); + EXPECT(bool{p.validate() == mm->end()}); - auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two); - p.replace_instruction(sum1, minus_op{}, sum0, two); - EXPECT(bool{p.validate() == p.end()}); - - auto result = p.eval({}); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{4}); EXPECT(result != migraphx::literal{5}); } @@ -342,15 +346,15 @@ TEST_CASE(insert_replace_test) TEST_CASE(remove_test1) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - auto removed = p.add_instruction(minus_op{}, sum, one); - p.remove_instruction(removed); - EXPECT(bool{p.validate() == p.end()}); - - auto result = p.eval({}); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto sum = mm->add_instruction(sum_op{}, one, two); + auto removed = mm->add_instruction(minus_op{}, sum, one); + mm->remove_instruction(removed); + EXPECT(bool{p.validate() == mm->end()}); + + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{3}); EXPECT(result != migraphx::literal{1}); } @@ -358,15 +362,15 @@ TEST_CASE(remove_test1) TEST_CASE(remove_test2) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto removed = p.add_instruction(minus_op{}, two, one); - p.add_instruction(sum_op{}, one, two); - p.remove_instruction(removed); - EXPECT(bool{p.validate() == p.end()}); - - auto result = p.eval({}); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto removed = mm->add_instruction(minus_op{}, two, one); + mm->add_instruction(sum_op{}, one, two); + mm->remove_instruction(removed); + EXPECT(bool{p.validate() == mm->end()}); + + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{3}); EXPECT(result != migraphx::literal{1}); } @@ -374,12 +378,12 @@ TEST_CASE(remove_test2) TEST_CASE(target_test) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(sum_op{}, one, two); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(sum_op{}, one, two); p.compile(id_target{}); - auto result = p.eval({}); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{3}); EXPECT(result != migraphx::literal{4}); } @@ -387,12 +391,12 @@ TEST_CASE(target_test) TEST_CASE(invert_target_test) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(sum_op{}, two, one); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(sum_op{}, two, one); p.compile(invert_target{}); - auto result = p.eval({}); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{1}); EXPECT(result != migraphx::literal{4}); } @@ -400,12 +404,12 @@ TEST_CASE(invert_target_test) TEST_CASE(double_invert_target_test) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(sum_op{}, two, one); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(sum_op{}, two, one); p.compile(double_invert_target{}); - auto result = p.eval({}); + auto result = p.eval({}).back(); EXPECT(result == migraphx::literal{3}); EXPECT(result != migraphx::literal{4}); } @@ -413,10 +417,10 @@ TEST_CASE(double_invert_target_test) TEST_CASE(reverse_target_test) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(sum_op{}, one, two); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(sum_op{}, one, two); EXPECT(test::throws([&] { p.compile(reverse_target{}); })); } @@ -425,28 +429,30 @@ TEST_CASE(reverse_target_test) TEST_CASE(eval_context1) { migraphx::program p; + auto* mm = p.get_main_module(); id_target t{}; EXPECT(is_shared(t.ctx, t.get_context())); - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(sum_op{}, one, two); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(sum_op{}, one, two); p.compile(t); EXPECT(is_shared(t.ctx, p.get_context())); - p.eval({}); + p.eval({}).back(); EXPECT(is_shared(t.ctx, p.get_context())); } TEST_CASE(eval_context2) { migraphx::program p; + auto* mm = p.get_main_module(); id_target t{}; EXPECT(is_shared(t.ctx, t.get_context())); - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(id_ctx_op{}, one, two); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(id_ctx_op{}, one, two); p.compile(t); EXPECT(is_shared(t.ctx, p.get_context())); - p.eval({}); + p.eval({}).back(); // id_ctx_op will modify the context EXPECT(not is_shared(t.ctx, p.get_context())); } @@ -454,16 +460,17 @@ TEST_CASE(eval_context2) TEST_CASE(eval_context3) { migraphx::program p; + auto* mm = p.get_main_module(); id_target t{}; EXPECT(is_shared(t.ctx, t.get_context())); - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(id_ctx_final_op{}, one, two); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(id_ctx_final_op{}, one, two); p.compile(t); // Finalizer will modify the context EXPECT(not is_shared(t.ctx, p.get_context())); auto ctx = p.get_context(); - p.eval({}); + p.eval({}).back(); EXPECT(is_shared(ctx, p.get_context())); EXPECT(not is_shared(t.ctx, p.get_context())); } @@ -494,22 +501,24 @@ std::string capture_output(F f) TEST_CASE(debug_print_test) { migraphx::program p; - auto one = p.add_literal(1); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); std::vector onev = {one}; migraphx::program p2; - auto one2 = p2.add_literal(1); + auto* mm2 = p2.get_main_module(); + auto one2 = mm2->add_literal(1); - auto program_out = migraphx::trim(capture_output([&] { p.debug_print(); })); - auto ins_out = migraphx::trim(capture_output([&] { p.debug_print(one); })); - auto inss_out = migraphx::trim(capture_output([&] { p.debug_print(onev); })); - auto end_out = migraphx::trim(capture_output([&] { p.debug_print(p.end()); })); - auto p2_ins_out = migraphx::trim(capture_output([&] { p.debug_print(one2); })); + auto program_out = migraphx::trim(capture_output([&] { mm->debug_print(); })); + auto ins_out = migraphx::trim(capture_output([&] { mm->debug_print(one); })); + auto inss_out = migraphx::trim(capture_output([&] { mm->debug_print(onev); })); + auto end_out = migraphx::trim(capture_output([&] { mm->debug_print(mm->end()); })); + auto p2_ins_out = migraphx::trim(capture_output([&] { mm->debug_print(one2); })); EXPECT(program_out == ins_out); EXPECT(inss_out == ins_out); EXPECT(end_out == "End instruction"); - EXPECT(p2_ins_out == "Instruction not part of program"); + EXPECT(p2_ins_out == "Instruction not part of module"); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/float_equal.cpp b/test/float_equal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..075d371b0c56e59daf5bfed80129457c66fcc3af --- /dev/null +++ b/test/float_equal.cpp @@ -0,0 +1,95 @@ +#include +#include +#include "test.hpp" + +#include + +template +struct float_equal_expression +{ + T lhs; + U rhs; + + operator bool() const { return migraphx::float_equal(lhs, rhs); } + + bool operator not() const { return not bool(*this); } + + friend std::ostream& operator<<(std::ostream& s, const float_equal_expression& self) + { + s << "migraphx::float_equal(" << self.lhs << ", " << self.rhs << ")"; + return s; + } +}; + +template +auto test_float_equal(T x, U y) +{ + return test::make_lhs_expression(float_equal_expression{x, y}); +} + +template +void test_equality() +{ + auto x1 = T(0.1); + auto x2 = U(0.0); + auto x3 = U(1.0); + EXPECT(test_float_equal(x1, x1)); + EXPECT(test_float_equal(x2, x2)); + EXPECT(test_float_equal(x3, x3)); + + EXPECT(not test_float_equal(x1, x2)); + EXPECT(not test_float_equal(x2, x1)); + EXPECT(not test_float_equal(x1, x3)); + EXPECT(not test_float_equal(x3, x1)); + EXPECT(not test_float_equal(x2, x3)); + EXPECT(not test_float_equal(x3, x2)); +} + +TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); + +template +void test_limits() +{ + auto max1 = std::numeric_limits::max(); + auto max2 = std::numeric_limits::max(); + + auto min1 = std::numeric_limits::lowest(); + auto min2 = std::numeric_limits::lowest(); + + EXPECT(test_float_equal(max1, max1)); + EXPECT(test_float_equal(max2, max2)); + + EXPECT(not test_float_equal(max1, max2)); + EXPECT(not test_float_equal(max2, max1)); + + EXPECT(test_float_equal(min1, min1)); + EXPECT(test_float_equal(min2, min2)); + + EXPECT(not test_float_equal(min1, min2)); + EXPECT(not test_float_equal(min2, min1)); + + EXPECT(not test_float_equal(max1, min1)); + EXPECT(not test_float_equal(min1, max1)); + EXPECT(not test_float_equal(max2, min2)); + EXPECT(not test_float_equal(min2, max2)); + + EXPECT(not test_float_equal(max1, min2)); + EXPECT(not test_float_equal(min2, max1)); + + EXPECT(not test_float_equal(max2, min1)); + EXPECT(not test_float_equal(min1, max2)); +} + +TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp new file mode 100755 index 0000000000000000000000000000000000000000..ce9f27aeed8c8b983c3688540daca4c599daa766 --- /dev/null +++ b/test/fuse_pointwise.cpp @@ -0,0 +1,280 @@ +#include "migraphx/dead_code_elimination.hpp" +#include +#include +#include +#include +#include +#include + +#include +#include + +void run_pass(migraphx::program& p) +{ + migraphx::run_passes(p, {migraphx::fuse_pointwise{}, migraphx::dead_code_elimination{}}); +} + +TEST_CASE(single) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + auto pass = mm->add_instruction(pass_op{}, add1); + auto add2 = mm->add_instruction(migraphx::make_op("add"), pass, z); + mm->add_return({add2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add")); + auto pass = mm->add_instruction(pass_op{}, add1); + auto add2 = add_pointwise(p2, "main:pointwise1", {pass, z}, single_pointwise("add")); + mm->add_return({add2}); + } + EXPECT(p1 == p2); +} + +TEST_CASE(double_add) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); + mm->add_return({add2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto fadd = + add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) { + auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); + }); + mm->add_return({fadd}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(double_add_without_return) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_instruction(migraphx::make_op("add"), add1, z); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto fadd = + add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) { + auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); + }); + mm->add_instruction(migraphx::make_op("identity"), fadd); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(used_twice_not_fused) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + auto pass = mm->add_instruction(pass_op{}, add1); + auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, y); + auto add3 = mm->add_instruction(migraphx::make_op("add"), pass, add2); + mm->add_return({add3}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add")); + auto pass = mm->add_instruction(pass_op{}, add1); + auto fadd = add_pointwise( + p2, "main:pointwise1", {add1, y, pass}, [=](auto* pm, const auto& inputs) { + auto add2 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + return pm->add_instruction(migraphx::make_op("add"), inputs[2], add2); + }); + mm->add_return({fadd}); + } + EXPECT(p1 == p2); +} + +TEST_CASE(used_twice_fused) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, x); + auto add3 = mm->add_instruction(migraphx::make_op("add"), add1, y); + auto add4 = mm->add_instruction(migraphx::make_op("add"), add2, add3); + mm->add_return({add4}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto fadd = add_pointwise(p2, "main:pointwise0", {x, y}, [=](auto* pm, const auto& inputs) { + auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + auto add2 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[0]); + auto add3 = pm->add_instruction(migraphx::make_op("add"), add1, inputs[1]); + return pm->add_instruction(migraphx::make_op("add"), add2, add3); + }); + mm->add_return({fadd}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(duplicate_inputs) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, x); + auto pass = mm->add_instruction(pass_op{}, add1); + auto add2 = mm->add_instruction(migraphx::make_op("add"), pass, y); + mm->add_return({add2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) { + return pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[0]); + }); + auto pass = mm->add_instruction(pass_op{}, add1); + auto add2 = add_pointwise(p2, "main:pointwise1", {pass, y}, single_pointwise("add")); + mm->add_return({add2}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(scalar_input) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto one = mm->add_literal(1.0f); + auto y = + mm->add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), one); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_return({add1}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) { + auto y = pm->add_literal(1.0f); + return pm->add_instruction(migraphx::make_op("add"), inputs[0], y); + }); + mm->add_return({add1}); + } + EXPECT(p1 == p2); +} + +TEST_CASE(contiguous_input) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto one = mm->add_literal(1.0f); + auto yb = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), one); + auto y = mm->add_instruction(migraphx::make_op("contiguous"), yb); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_return({add1}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) { + auto y = pm->add_literal(1.0f); + return pm->add_instruction(migraphx::make_op("add"), inputs[0], y); + }); + mm->add_return({add1}); + } + EXPECT(p1 == p2); +} + +TEST_CASE(all_scalar_input) +{ + migraphx::shape s{migraphx::shape::float_type}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_return({add1}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, [=](auto* pm, const auto& inputs) { + return pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + }); + mm->add_return({add1}); + } + EXPECT(p1.get_output_shapes().size() == 1); + EXPECT(p1.get_output_shapes().front().scalar()); + EXPECT(p1.get_output_shapes() == p2.get_output_shapes()); + EXPECT(p1 == p2); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/generate.cpp b/test/generate.cpp index a526e88a51f10c5c5387006c7e80497d69ef67bf..fb651cf067ad28432524261fc4f77968c9b4f5b8 100644 --- a/test/generate.cpp +++ b/test/generate.cpp @@ -8,4 +8,34 @@ TEST_CASE(generate) EXPECT(migraphx::generate_literal(s, 1) != migraphx::generate_argument(s, 0)); } +TEST_CASE(fill_tuple) +{ + migraphx::shape s0{migraphx::shape::float_type, {4, 4, 1, 1}}; + migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::bool_type, {3, 2}}; + migraphx::shape s({s0, s1, s2}); + auto arg = migraphx::fill_argument(s, 1); + const auto& args = arg.get_sub_objects(); + EXPECT(args.at(0) == migraphx::fill_argument(s0, 1)); + EXPECT(args.at(1) == migraphx::fill_argument(s1, 1)); + EXPECT(args.at(2) == migraphx::fill_argument(s2, 1)); +} + +TEST_CASE(generate_tuple) +{ + migraphx::shape s0{migraphx::shape::float_type, {4, 4, 1, 1}}; + migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::bool_type, {3, 2}}; + migraphx::shape s({s0, s1, s2}); + auto arg = migraphx::generate_argument(s, 1); + const auto& args = arg.get_sub_objects(); + EXPECT(args.at(0) == migraphx::generate_argument(s0, 1)); + EXPECT(args.at(1) == migraphx::generate_argument(s1, 1)); + EXPECT(args.at(2) == migraphx::generate_argument(s2, 1)); + + EXPECT(args.at(0) != migraphx::generate_argument(s0, 0)); + EXPECT(args.at(1) != migraphx::generate_argument(s1, 2)); + EXPECT(args.at(2) != migraphx::generate_argument(s2, 0)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/gpu/adjust_allocation.cpp b/test/gpu/adjust_allocation.cpp index 5e0a8a0ae90534cf19f2450ff37722066d0c1abb..46d27933bb72adba0cfcac7ba89d2962e2702176 100644 --- a/test/gpu/adjust_allocation.cpp +++ b/test/gpu/adjust_allocation.cpp @@ -1,41 +1,48 @@ -#include -#include -#include +#include #include -#include +#include +#include +#include #include +#include #include +#include +#include #include #include -#include #include -#include -#include #include +#include +#include +#include #include #include -void run_lowering(migraphx::program& p) +void run_lowering(migraphx::program& p, bool offload_copy = false) { auto ctx = migraphx::gpu::context{}; - migraphx::run_passes(p, - {migraphx::auto_contiguous{}, - migraphx::gpu::lowering{&ctx, false}, - migraphx::dead_code_elimination{}, - migraphx::eliminate_contiguous{}, - migraphx::dead_code_elimination{}}); + migraphx::run_passes( + *p.get_main_module(), + {migraphx::auto_contiguous{}, + migraphx::gpu::lowering{&ctx, offload_copy}, + migraphx::dead_code_elimination{}, + migraphx::eliminate_contiguous{"gpu::contiguous"}, + migraphx::dead_code_elimination{}, + migraphx::replace_allocate{migraphx::gpu::gpu_allocation_model{}, offload_copy}, + migraphx::dead_code_elimination{}}); } TEST_CASE(tanh_shape) { auto create_program = [] { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - auto x = p.add_parameter("x", s); - auto tx = p.add_instruction(migraphx::op::transpose{{1, 0}}, x); - auto txh = p.add_instruction(migraphx::op::tanh{}, tx); - auto sum = p.add_instruction(migraphx::op::add{}, txh, txh); - p.add_instruction(migraphx::op::contiguous{}, sum); + auto x = mm->add_parameter("x", s); + auto tx = mm->add_instruction(migraphx::op::transpose{{1, 0}}, x); + auto txh = mm->add_instruction(migraphx::op::tanh{}, tx); + auto sum = mm->add_instruction(migraphx::op::add{}, txh, txh); + mm->add_instruction(migraphx::op::contiguous{}, sum); return p; }; @@ -49,7 +56,7 @@ TEST_CASE(tanh_shape) EXPECT(p1 == p2); - for(auto ins : iterator_for(p1)) + for(auto ins : iterator_for(*p1.get_main_module())) { if(ins->name() == "hip::allocate") { @@ -59,8 +66,46 @@ TEST_CASE(tanh_shape) } EXPECT(p1 != p2); - migraphx::run_passes(p2, - {migraphx::gpu::adjust_allocation{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(*p2.get_main_module(), + {migraphx::adjust_allocation{migraphx::gpu::gpu_allocation_model{}}, + migraphx::dead_code_elimination{}}); + EXPECT(p1 == p2); +} + +TEST_CASE(no_copy_dead_param) +{ + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + auto x = mm->add_parameter("x", s); + mm->add_parameter("y", s); + auto sum = mm->add_instruction(migraphx::make_op("add"), x, x); + mm->add_return({sum}); + + return p; + }; + + auto create_gpu_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + auto x = mm->add_parameter("x", s); + mm->add_parameter("y", s); + auto xb = mm->add_instruction(migraphx::make_op("hip::allocate", {{"shape", to_value(s)}})); + auto gx = mm->add_instruction(migraphx::make_op("hip::copy_to_gpu"), x, xb); + auto ab = mm->add_instruction(migraphx::make_op("hip::allocate", {{"shape", to_value(s)}})); + auto sum = mm->add_instruction(migraphx::make_op("gpu::add"), gx, gx, ab); + auto r = mm->add_instruction(migraphx::make_op("hip::copy_from_gpu"), sum); + mm->add_return({r}); + + return p; + }; + + auto p1 = create_program(); + auto p2 = create_gpu_program(); + + run_lowering(p1, true); EXPECT(p1 == p2); } diff --git a/test/gpu/context_serialize.cpp b/test/gpu/context_serialize.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e38b7504b24a3c27402c8e3f86b6b0198fa5e654 --- /dev/null +++ b/test/gpu/context_serialize.cpp @@ -0,0 +1,34 @@ +#include +#include +#include +#include +#include +#include "test.hpp" + +TEST_CASE(gpu_context_serialize) +{ + migraphx::context ctx = migraphx::gpu::context{0, 3}; + + auto v = ctx.to_value(); + EXPECT(v.size() == 2); + + EXPECT(v.contains("events")); + EXPECT(v.at("events").without_key().to() == 0); + + EXPECT(v.contains("streams")); + EXPECT(v.at("streams").without_key().to() == 3); + + migraphx::gpu::context g_ctx; + g_ctx.from_value(v); + + auto v1 = g_ctx.to_value(); + EXPECT(v == v1); +} + +TEST_CASE(context_queue) +{ + migraphx::context ctx = migraphx::gpu::context{0, 3}; + EXPECT(ctx.get_queue().get() != nullptr); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/gpu/jit.cpp b/test/gpu/jit.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a1192363f9196dfe7937fccd6803627d9fd47714 --- /dev/null +++ b/test/gpu/jit.cpp @@ -0,0 +1,332 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// NOLINTNEXTLINE +const std::string write_2s = R"__migraphx__( +#include + +extern "C" { +__global__ void write(int8_t* data) +{ + int num = threadIdx.x + blockDim.x * blockIdx.x; + data[num] = 2; +} + +} + +int main() {} + +)__migraphx__"; + +// NOLINTNEXTLINE +const std::string add_2s_binary = R"__migraphx__( +#include + +extern "C" { +__global__ void add_2(std::int8_t* x, std::int8_t* y) +{ + int num = threadIdx.x + blockDim.x * blockIdx.x; + y[num] = x[num] + 2; +} + +} + +int main() {} + +)__migraphx__"; + +// NOLINTNEXTLINE +const std::string simple_pointwise_increment = R"__migraphx__( +#include +#include + +using namespace migraphx; + +extern "C" { +__global__ void kernel(void* x, void* y) +{ + make_tensors()(x, y)([](auto xt, auto yt) __device__ { + auto idx = make_index(); + const auto stride = idx.nglobal(); + for(index_int i = idx.global; i < xt.get_shape().elements(); i += stride) + { + yt[i] = xt[i] + 1; + } + }); +} + +} + +int main() {} + +)__migraphx__"; + +// NOLINTNEXTLINE +const std::string check_define = R"__migraphx__( + +#ifndef __DEFINE__ +#error __DEFINE__ was not defined +#endif + +int main() {} + +)__migraphx__"; + +// NOLINTNEXTLINE +const std::string unused_param = R"__migraphx__( + +extern "C" { +__global__ void kernel(void* x, void* y) +{} +} + +int main() {} + +)__migraphx__"; + +// NOLINTNEXTLINE +const std::string incorrect_program = R"__migraphx__( + +extern "C" { +__global__ void kernel(void* x) +{ + x += y; +} +} + +int main() {} + +)__migraphx__"; + +// NOLINTNEXTLINE +const std::string math_template = R"__migraphx__( +#include +#include + +extern "C" { +__global__ void kernel(${type}* p) +{ + auto x = *p; + *p = migraphx::implicit_conversion(migraphx::${invoke}); + +} +} + +int main() {} + +)__migraphx__"; + +migraphx::src_file make_src_file(const std::string& name, const std::string& content) +{ + return {name, std::make_pair(content.data(), content.data() + content.size())}; +} + +TEST_CASE(simple_compile_hip) +{ + auto binaries = migraphx::gpu::compile_hip_src( + {make_src_file("main.cpp", write_2s)}, "", migraphx::gpu::get_device_name()); + EXPECT(binaries.size() == 1); + + migraphx::argument input{{migraphx::shape::int8_type, {5}}}; + auto ginput = migraphx::gpu::to_gpu(input); + migraphx::gpu::kernel k{binaries.front(), "write"}; + k.launch(nullptr, input.get_shape().elements(), 1024)(ginput.cast()); + auto output = migraphx::gpu::from_gpu(ginput); + + EXPECT(output != input); + auto data = output.get(); + EXPECT(migraphx::all_of(data, [](auto x) { return x == 2; })); +} + +auto check_target(const std::string& arch) +{ + auto define = "__" + arch + "__"; + auto content = migraphx::replace_string(check_define, "__DEFINE__", define); + return migraphx::gpu::compile_hip_src({make_src_file("main.cpp", content)}, "", arch); +} + +TEST_CASE(compile_target) +{ + EXPECT(not check_target("gfx900").empty()); + EXPECT(not check_target("gfx906").empty()); +} + +TEST_CASE(compile_errors) +{ + EXPECT(test::throws([&] { + migraphx::gpu::compile_hip_src( + {make_src_file("main.cpp", incorrect_program)}, "", migraphx::gpu::get_device_name()); + })); +} + +TEST_CASE(compile_warnings) +{ + auto compile = [](const std::string& params) { + return migraphx::gpu::compile_hip_src( + {make_src_file("main.cpp", unused_param)}, params, migraphx::gpu::get_device_name()); + }; + + EXPECT(not compile("").empty()); + EXPECT(not compile("-Wunused-parameter -Wno-error").empty()); + EXPECT(not compile("-Wno-unused-parameter -Werror").empty()); + EXPECT(test::throws([&] { compile("-Werror=unused-parameter"); })); + EXPECT(test::throws([&] { compile("-Wunused-parameter -Werror"); })); +} + +TEST_CASE(code_object_hip) +{ + auto binaries = migraphx::gpu::compile_hip_src( + {make_src_file("main.cpp", add_2s_binary)}, "", migraphx::gpu::get_device_name()); + EXPECT(binaries.size() == 1); + + migraphx::shape input{migraphx::shape::int8_type, {5}}; + + std::vector expected_inputs = {input, input}; + auto co = migraphx::make_op("gpu::code_object", + {{"code_object", migraphx::value::binary{binaries.front()}}, + {"symbol_name", "add_2"}, + {"global", input.elements()}, + {"local", 1024}, + {"expected_inputs", migraphx::to_value(expected_inputs)}, + {"output", migraphx::to_value(input)}}); + + migraphx::program p; + auto* mm = p.get_main_module(); + auto input_literal = migraphx::generate_literal(input); + auto output_literal = migraphx::transform(input_literal, [](auto x) { return x + 2; }); + auto x = mm->add_literal(input_literal); + auto y = mm->add_parameter("output", input); + mm->add_instruction(co, x, y); + migraphx::compile_options options; + p.compile(migraphx::gpu::target{}, options); + + auto result = + migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front()); + + EXPECT(result == output_literal.get_argument()); +} + +TEST_CASE(compile_code_object_hip) +{ + migraphx::shape input{migraphx::shape::float_type, {5, 2}}; + migraphx::gpu::hip_compile_options options; + options.global = 256 * 1024; + options.local = 1024; + options.inputs = {input, input}; + options.output = input; + + auto co = migraphx::gpu::compile_hip_code_object(simple_pointwise_increment, options); + + migraphx::program p; + auto* mm = p.get_main_module(); + auto input_literal = migraphx::generate_literal(input); + auto output_literal = migraphx::transform(input_literal, [](auto x) { return x + 1; }); + auto x = mm->add_literal(input_literal); + auto y = mm->add_parameter("output", input); + mm->add_instruction(co, x, y); + p.compile(migraphx::gpu::target{}, migraphx::compile_options{}); + + auto result = + migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front()); + + EXPECT(result == output_literal.get_argument()); +} + +TEST_CASE(compile_pointwise) +{ + migraphx::shape input{migraphx::shape::float_type, {5, 2}}; + + migraphx::gpu::context ctx; + auto co = migraphx::gpu::compile_op( + "pointwise", ctx, {input, input}, {{"lambda", "[](auto x) { return x + 1; }"}}); + + migraphx::program p; + auto* mm = p.get_main_module(); + auto input_literal = migraphx::generate_literal(input); + auto output_literal = migraphx::transform(input_literal, [](auto x) { return x + 1; }); + auto x = mm->add_literal(input_literal); + auto y = mm->add_parameter("output", input); + mm->add_instruction(co, x, y); + p.compile(migraphx::gpu::target{}, migraphx::compile_options{}); + + auto result = + migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front()); + + EXPECT(result == output_literal.get_argument()); +} + +TEST_CASE(compile_math) +{ + std::vector math_invoke = { + // clang-format off + "abs(x)", + "acos(x)", + "acosh(x)", + "asin(x)", + "asinh(x)", + "atan(x)", + "atanh(x)", + "ceil(x)", + "cos(x)", + "cosh(x)", + "erf(x)", + "exp(x)", + "floor(x)", + "isnan(x)", + "log(x)", + "max(x, x)", + "min(x, x)", + "pow(x, 0)", + "pow(x, x)", + "round(x)", + "rsqrt(x)", + "sin(x)", + "sinh(x)", + "sqrt(x)", + "tan(x)", + "tanh(x)", + "where(true, x, x)", + // clang-format on + }; + std::vector data_types; + auto vec_sizes = {2, 4, 6}; + for(auto&& t : migraphx::shape::types()) + { + if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) + continue; + auto name = migraphx::shape::cpp_type(t); + if(t == migraphx::shape::half_type) + name.insert(0, "migraphx::"); + data_types.push_back(name); + migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) { + return "migraphx::vec<" + name + ", " + std::to_string(i) + ">"; + }); + } + migraphx::shape input{migraphx::shape::float_type, {5, 2}}; + migraphx::gpu::hip_compile_options options; + options.global = 1024; + options.local = 1024; + options.inputs = {input}; + options.output = input; + migraphx::par_for(math_invoke.size() * data_types.size(), 1, [&](auto i) { + const auto& t = data_types[i % data_types.size()]; + const auto& invoke = math_invoke[i / data_types.size()]; + auto src = migraphx::interpolate_string(math_template, {{"type", t}, {"invoke", invoke}}); + auto co = migraphx::gpu::compile_hip_code_object(src, options); + (void)co; + }); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/gpu/literal.cpp b/test/gpu/literal.cpp index 3490ef6ffc72d683fcce6391aa1c1da9c3f8e8d0..cb0b31910af8d2fb635493f9572e81de1faad71c 100644 --- a/test/gpu/literal.cpp +++ b/test/gpu/literal.cpp @@ -9,13 +9,14 @@ void gpu_literal_test() { migraphx::program p; + auto* mm = p.get_main_module(); auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - p.add_literal(lit); + mm->add_literal(lit); p.compile(migraphx::gpu::target{}); auto scratch = p.get_parameter("scratch"); - if(scratch == p.end()) + if(scratch == mm->end()) { - auto result = p.eval({}); + auto result = p.eval({}).back(); EXPECT(lit == migraphx::gpu::from_gpu(result)); } else diff --git a/test/gpu/ops_test.cpp b/test/gpu/ops_test.cpp deleted file mode 100644 index 3d3d9d817fbcb8feafd17aa2d9d40c1386754df0..0000000000000000000000000000000000000000 --- a/test/gpu/ops_test.cpp +++ /dev/null @@ -1,4190 +0,0 @@ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include - -#include - -#ifdef __clang__ -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wglobal-constructors" -#endif - -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_GPU_COMPILE) - -// An improved async, that doesn't block -template -std::future::type> detach_async(Function&& f, - bool parallel = true) -{ - if(parallel) - { - using result_type = typename std::result_of::type; - std::packaged_task task(std::forward(f)); - auto fut = task.get_future(); - std::thread(std::move(task)).detach(); - return std::move(fut); - } - return std::async(std::launch::deferred, std::forward(f)); -} - -struct auto_print -{ - static void set_terminate_handler(const std::string& name) - { - static std::string pname; - pname = name; - std::set_terminate(+[] { - std::cout << "FAILED: " << pname << std::endl; - try - { - std::rethrow_exception(std::current_exception()); - } - catch(const std::exception& e) - { - std::cout << " what(): " << e.what() << std::endl; - } - std::cout << std::endl; - for(auto&& handle : auto_print::handlers) - handle(); - }); - } - static std::array, 2> handlers; - int index; - template - auto_print(T& x, int i) : index(i) - { - handlers[index] = [&x] { std::cout << x << std::endl; }; - } - - ~auto_print() - { - handlers[index] = [] {}; - } -}; -std::array, 2> auto_print::handlers = {}; - -template -auto get_hash(const T& x) -{ - return std::hash{}(x); -} - -void compile_check(migraphx::program& p, const migraphx::target& t, bool show_trace = false) -{ - auto name = t.name(); - auto s = p.get_shape(); - std::stringstream ss; - migraphx::compile_options options; - options.trace = migraphx::tracer{ss}; - p.compile(t, options); - if(p.get_shape() != s) - { - std::cout << ss.str() << std::endl; - throw std::runtime_error("Compiling program with " + name + " alters its shape"); - } - if(show_trace) - { - std::cout << ss.str() << std::endl; - } -} - -template -migraphx::argument run_cpu(migraphx::program& p) -{ - V v; - p = v.create_program(); - auto_print pp{p, 0}; - compile_check(p, migraphx::cpu::target{}); - migraphx::program::parameter_map m; - for(auto&& x : p.get_parameter_shapes()) - { - m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first)); - } - return p.eval(m); -} - -template -migraphx::argument run_gpu(migraphx::program& p) -{ - V v; - p = v.create_program(); - auto_print pp{p, 1}; - compile_check(p, migraphx::gpu::target{}, migraphx::enabled(MIGRAPHX_TRACE_GPU_COMPILE{})); - migraphx::program::parameter_map m; - for(auto&& x : p.get_parameter_shapes()) - { - m[x.first] = - migraphx::gpu::to_gpu(migraphx::generate_argument(x.second, get_hash(x.first))); - } - // Program should have an output parameter - EXPECT(bool{m.find("output") != m.end()}); - // Ensure the program doesn't modify the context in a dry run - auto ctx = p.get_context(); - assert(&ctx != &p.get_context()); - EXPECT(is_shared(ctx, p.get_context())); - p.dry_run(m); - EXPECT(is_shared(ctx, p.get_context())); - p.eval(m); - return migraphx::gpu::from_gpu(p.eval(m)); -} - -template -void run_verify_program() -{ - auto_print::set_terminate_handler(migraphx::get_type_name()); - // std::cout << migraphx::get_type_name() << std::endl; - migraphx::program cpu_prog; - migraphx::program gpu_prog; - auto cpu_arg_f = detach_async([&] { return run_cpu(cpu_prog); }); - auto gpu_arg = run_gpu(gpu_prog); - auto cpu_arg = cpu_arg_f.get(); - bool passed = verify_args(migraphx::get_type_name(), cpu_arg, gpu_arg); - if(not passed) - { - V v; - auto p = v.create_program(); - std::cout << p << std::endl; - std::cout << "cpu:\n" << cpu_prog << std::endl; - std::cout << "gpu:\n" << gpu_prog << std::endl; - std::cout << std::endl; - } - std::set_terminate(nullptr); -} - -template -int auto_register_verify_program() -{ - test::add_test_case(migraphx::get_type_name(), [] { run_verify_program(); }); - return 0; -} - -template -struct verify_program -{ - static int static_register; - // This typedef ensures that the static member will be instantiated if - // the class itself is instantiated - using static_register_type = - std::integral_constant; -}; - -template -int verify_program::static_register = auto_register_verify_program(); // NOLINT - -struct test_literals : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto input = p.add_literal( - generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}})); - auto weights = p.add_literal( - generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}})); - auto conv = p.add_instruction(migraphx::op::convolution{}, input, weights); - p.add_instruction(migraphx::op::relu{}, conv); - return p; - } -}; - -struct test_add : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - p.add_instruction(migraphx::op::add{}, x, y); - return p; - } -}; - -struct test_add_half : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::half_type, {3}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - p.add_instruction(migraphx::op::add{}, x, y); - return p; - } -}; - -struct test_mul : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - p.add_instruction(migraphx::op::mul{}, x, y); - return p; - } -}; - -struct test_exp : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {6}}; - auto x = p.add_instruction(migraphx::op::abs{}, p.add_parameter("x", s)); - p.add_instruction(migraphx::op::exp{}, x); - return p; - } -}; - -struct test_erf : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; - auto param = p.add_parameter("x", s); - p.add_instruction(migraphx::op::erf{}, param); - return p; - } -}; - -struct test_sqrt : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; - auto param = p.add_parameter("x", s); - auto param_abs = p.add_instruction(migraphx::op::abs{}, param); - p.add_instruction(migraphx::op::sqrt{}, param_abs); - return p; - } -}; - -struct test_sign : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; - auto param = p.add_parameter("x", s); - p.add_instruction(migraphx::op::sign{}, param); - return p; - } -}; - -struct test_log : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {6}}; - auto x = p.add_instruction(migraphx::op::abs{}, p.add_parameter("x", s)); - p.add_instruction(migraphx::op::log{}, x); - return p; - } -}; - -struct test_pow : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {6}}; - std::vector vec_e(s.elements(), 2.0f); - auto b = p.add_parameter("x", s); - auto e = p.add_literal(migraphx::literal(s, vec_e)); - p.add_instruction(migraphx::op::pow{}, b, e); - return p; - } -}; - -struct test_sin : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {10}}; - auto x = p.add_parameter("x", s); - p.add_instruction(migraphx::op::sin{}, x); - return p; - } -}; - -struct test_cos : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::double_type, {8}}; - auto x = p.add_parameter("x", s); - p.add_instruction(migraphx::op::cos{}, x); - return p; - } -}; - -struct test_tan : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {16}}; - auto x = p.add_parameter("x", s); - p.add_instruction(migraphx::op::tan{}, x); - return p; - } -}; - -struct test_sinh : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::double_type, {16}}; - auto x = p.add_parameter("x", s); - p.add_instruction(migraphx::op::sinh{}, x); - return p; - } -}; - -struct test_cosh : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::double_type, {16}}; - auto x = p.add_parameter("x", s); - p.add_instruction(migraphx::op::cosh{}, x); - return p; - } -}; - -struct test_tanh : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - p.add_instruction(migraphx::op::tanh{}, x); - return p; - } -}; - -struct test_trans_tanh : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x); - auto tanhx = p.add_instruction(migraphx::op::tanh{}, tx); - auto r = p.add_instruction(migraphx::op::add{}, tanhx, tanhx); - p.add_instruction(migraphx::op::contiguous{}, r); - - return p; - } -}; - -struct test_slice_sin : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); - auto t = p.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, l); - p.add_instruction(migraphx::op::sin{}, t); - - return p; - } -}; - -struct test_asin : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::double_type, {16}}; - auto x = p.add_parameter("x", s); - p.add_instruction(migraphx::op::asin{}, x); - return p; - } -}; - -struct test_acos : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::double_type, {16}}; - auto x = p.add_parameter("x", s); - p.add_instruction(migraphx::op::acos{}, x); - return p; - } -}; - -struct test_atan : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::double_type, {16}}; - auto x = p.add_parameter("x", s); - p.add_instruction(migraphx::op::atan{}, x); - return p; - } -}; - -struct test_scale : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", migraphx::shape::float_type); - auto scale = p.add_instruction(migraphx::op::scalar{s.lens()}, y); - p.add_instruction(migraphx::op::mul{}, x, scale); - return p; - } -}; - -struct test_slice : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::int32_type, {2, 2, 4}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", {migraphx::shape::int32_type, {2, 2, 2}}); - auto slice0 = p.add_instruction(migraphx::op::slice{{2}, {0}, {2}}, x); - p.add_instruction(migraphx::op::add{}, y, slice0); - - return p; - } -}; - -struct test_triadd : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - auto z = p.add_parameter("z", s); - auto sum = p.add_instruction(migraphx::op::add{}, x, y); - p.add_instruction(migraphx::op::add{}, sum, z); - return p; - } -}; - -struct test_triadd2 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - migraphx::shape b{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - auto z = p.add_parameter("z", b); - auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z); - auto sum = p.add_instruction(migraphx::op::add{}, x, y); - p.add_instruction(migraphx::op::add{}, sum, zb); - return p; - } -}; - -struct test_mul_add : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - migraphx::shape bs{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", s); - auto a = p.add_parameter("a", bs); - auto b = p.add_parameter("b", bs); - auto ab = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, a); - auto bb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, b); - auto mul = p.add_instruction(migraphx::op::mul{}, x, ab); - p.add_instruction(migraphx::op::add{}, mul, bb); - return p; - } -}; - -struct test_add_broadcast : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}}); - auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}}); - auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape().lens()}, y); - p.add_instruction(migraphx::op::add{}, x, by); - return p; - } -}; - -struct test_add_broadcast2 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 4}}); - auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}}); - auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y); - p.add_instruction(migraphx::op::add{}, x, by); - return p; - } -}; - -struct test_add_broadcast3 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 5}}); - auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}}); - auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y); - p.add_instruction(migraphx::op::add{}, x, by); - return p; - } -}; - -struct test_add_broadcast4 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 3, 5}}); - auto y = p.add_parameter("y", {migraphx::shape::float_type, {3}}); - auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y); - p.add_instruction(migraphx::op::add{}, x, by); - return p; - } -}; - -struct test_add_broadcast5 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 4, 8}}); - auto y = p.add_parameter("y", {migraphx::shape::float_type, {4}}); - auto by = p.add_instruction(migraphx::op::broadcast{1, x->get_shape().lens()}, y); - p.add_instruction(migraphx::op::add{}, x, by); - return p; - } -}; - -struct test_triadd_broadcast : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}}); - auto y = p.add_parameter("y", {migraphx::shape::float_type, {2, 2}}); - auto z = p.add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}}); - auto by = p.add_instruction(migraphx::op::broadcast{0, x->get_shape().lens()}, y); - auto sum = p.add_instruction(migraphx::op::add{}, x, by); - p.add_instruction(migraphx::op::add{}, sum, z); - return p; - } -}; - -struct test_sub : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - auto z = p.add_parameter("z", s); - auto diff = p.add_instruction(migraphx::op::sub{}, x, y); - p.add_instruction(migraphx::op::sub{}, diff, z); - return p; - } -}; - -struct test_sub2 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - migraphx::shape b{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - auto z = p.add_parameter("z", b); - auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z); - auto diff = p.add_instruction(migraphx::op::sub{}, x, y); - p.add_instruction(migraphx::op::sub{}, diff, zb); - return p; - } -}; - -struct test_div : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - auto z = p.add_parameter("z", s); - auto diff = p.add_instruction(migraphx::op::div{}, x, y); - p.add_instruction(migraphx::op::div{}, diff, z); - return p; - } -}; - -struct test_div2 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - migraphx::shape b{migraphx::shape::float_type, {3}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - auto z = p.add_parameter("z", b); - auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z); - auto diff = p.add_instruction(migraphx::op::div{}, x, y); - p.add_instruction(migraphx::op::div{}, diff, zb); - return p; - } -}; - -struct test_softmax1 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 3, 4}}); - p.add_instruction(migraphx::op::softmax{0}, x); - return p; - } -}; - -struct test_softmax2 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = - p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1000, 1, 1}}); - p.add_instruction(migraphx::op::softmax{}, x); - return p; - } -}; - -template -struct test_softmax : verify_program> -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{T, {512, 4, 1067, 6}}; - auto param = p.add_parameter("0", s); - p.add_instruction(migraphx::op::softmax{Axis}, param); - - return p; - } -}; - -template struct test_softmax<0, migraphx::shape::float_type>; -template struct test_softmax<2, migraphx::shape::float_type>; -template struct test_softmax<1, migraphx::shape::double_type>; -template struct test_softmax<3, migraphx::shape::double_type>; -template struct test_softmax<0, migraphx::shape::half_type>; -template struct test_softmax<1, migraphx::shape::half_type>; -template struct test_softmax<2, migraphx::shape::half_type>; -template struct test_softmax<3, migraphx::shape::half_type>; - -template -struct test_arg_ops : verify_program> -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 1025}}; - auto param = p.add_parameter("data", s); - p.add_instruction(T{Axis}, param); - - return p; - } -}; - -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; - -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; - -struct test_conv : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto input = - p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto weights = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - p.add_instruction(migraphx::op::convolution{}, input, weights); - return p; - } -}; - -struct test_conv2 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto input = - p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 512, 28, 28}}); - auto weights = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {256, 512, 1, 1}}); - p.add_instruction(migraphx::op::convolution{{0, 0}, {1, 1}, {1, 1}}, input, weights); - return p; - } -}; - -struct test_group_conv : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto input = - p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 4, 16, 16}}); - auto weights = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 1, 3, 3}}); - migraphx::op::convolution op; - op.group = 4; - p.add_instruction(op, input, weights); - return p; - } -}; - -struct test_conv_relu : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto input = - p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto weights = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto conv = p.add_instruction(migraphx::op::convolution{}, input, weights); - p.add_instruction(migraphx::op::relu{}, conv); - return p; - } -}; - -struct test_conv_relu_half : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto input = - p.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); - auto weights = - p.add_parameter("w", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); - auto conv = p.add_instruction(migraphx::op::convolution{}, input, weights); - p.add_instruction(migraphx::op::relu{}, conv); - return p; - } -}; - -struct test_conv_bias_clipped_relu : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto input = - p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto weights = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto l0 = migraphx::literal{migraphx::shape{migraphx::shape::float_type, {4}}, - {2.0f, 2.0f, 2.0f, 2.0f}}; - auto bias = p.add_literal(l0); - auto conv = p.add_instruction(migraphx::op::convolution{}, input, weights); - auto bcast_add = - p.add_instruction(migraphx::op::broadcast{1, conv->get_shape().lens()}, bias); - auto bias_add = p.add_instruction(migraphx::op::add{}, conv, bcast_add); - p.add_instruction(migraphx::op::clip{6.0f, 0.0f}, bias_add); - return p; - } -}; - -struct test_conv_add : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}}); - auto w = p.add_literal( - migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 1)); - auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); - auto v = p.add_literal( - migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 2)); - auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w); - auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v); - auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); - p.add_instruction(migraphx::op::exp{}, sum); - return p; - } -}; - -struct test_conv_add_1x1_diff_strides : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 8, 2, 2}}); - auto w = p.add_literal( - migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 1)); - auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); - auto v = p.add_literal( - migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 2)); - auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w); - auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v); - auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); - p.add_instruction(migraphx::op::exp{}, sum); - return p; - } -}; - -struct test_conv_bn_add : verify_program -{ - static migraphx::instruction_ref add_bn(migraphx::program& p, - migraphx::instruction_ref x, - std::size_t channels, - std::size_t seed = 1) - { - migraphx::shape vars{migraphx::shape::float_type, {channels}}; - auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + seed))); - auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + seed))); - auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + seed))); - auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + seed))); - return p.add_instruction( - migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance); - } - - migraphx::program create_program() const - { - migraphx::program p; - std::size_t ichannels = 64; - std::size_t ochannels = 256; - auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, 56}}); - auto w = p.add_literal(migraphx::generate_literal( - {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 1)); - auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, ichannels, 56, 56}}); - auto v = p.add_literal(migraphx::generate_literal( - {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 2)); - auto relu1 = p.add_instruction(migraphx::op::relu{}, x); - auto conv1 = p.add_instruction(migraphx::op::convolution{}, relu1, w); - auto bn1 = add_bn(p, conv1, ochannels, 1); - auto relu2 = p.add_instruction(migraphx::op::relu{}, y); - auto conv2 = p.add_instruction(migraphx::op::convolution{}, relu2, v); - auto bn2 = add_bn(p, conv2, ochannels, 1); - auto sum = p.add_instruction(migraphx::op::add{}, bn1, bn2); - p.add_instruction(migraphx::op::relu{}, sum); - return p; - } -}; - -struct test_add_relu : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto add = p.add_instruction(migraphx::op::add{}, x, y); - p.add_instruction(migraphx::op::relu{}, add); - return p; - } -}; - -struct test_add_sigmoid : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto add = p.add_instruction(migraphx::op::add{}, x, y); - p.add_instruction(migraphx::op::sigmoid{}, add); - return p; - } -}; - -struct test_add_tanh : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto add = p.add_instruction(migraphx::op::add{}, x, y); - p.add_instruction(migraphx::op::tanh{}, add); - return p; - } -}; - -struct test_triadd_relu : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto z = p.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto sum = p.add_instruction(migraphx::op::add{}, x, y); - auto triadd = p.add_instruction(migraphx::op::add{}, sum, z); - p.add_instruction(migraphx::op::relu{}, triadd); - return p; - } -}; - -struct test_triadd_sigmoid : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto z = p.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto sum = p.add_instruction(migraphx::op::add{}, x, y); - auto triadd = p.add_instruction(migraphx::op::add{}, sum, z); - p.add_instruction(migraphx::op::sigmoid{}, triadd); - return p; - } -}; - -struct test_triadd_tanh : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto z = p.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto sum = p.add_instruction(migraphx::op::add{}, x, y); - auto triadd = p.add_instruction(migraphx::op::add{}, sum, z); - p.add_instruction(migraphx::op::tanh{}, triadd); - return p; - } -}; - -struct test_sigmoid : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - p.add_instruction(migraphx::op::sigmoid{}, x); - return p; - } -}; - -struct test_abs : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - p.add_instruction(migraphx::op::abs{}, x); - return p; - } -}; - -struct test_trans_abs : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto tx = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x); - auto absx = p.add_instruction(migraphx::op::abs{}, tx); - auto r = p.add_instruction(migraphx::op::add{}, absx, absx); - p.add_instruction(migraphx::op::contiguous{}, r); - - return p; - } -}; - -struct test_leaky_relu : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - p.add_instruction(migraphx::op::leaky_relu{0.01}, x); - return p; - } -}; - -struct test_elu : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - p.add_instruction(migraphx::op::leaky_relu{1.0}, x); - return p; - } -}; - -struct test_relu_lrn : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 2}}); - auto y = p.add_instruction(migraphx::op::relu{}, x); - p.add_instruction(migraphx::op::lrn{0.0001, 0.75, 1.0, 5}, y); - return p; - } -}; - -struct test_conv_pooling : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto input = - p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 32, 32}}); - auto weights = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto conv = p.add_instruction(migraphx::op::convolution{}, input, weights); - auto pooling = p.add_instruction(migraphx::op::pooling{"max"}, conv); - p.add_instruction(migraphx::op::relu{}, pooling); - return p; - } -}; - -struct test_concat_pooling : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto input = - p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 256, 8, 8}}); - auto transpose = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, input); - auto concat = p.add_instruction(migraphx::op::concat{3}, transpose); - auto concat_t = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, concat); - - auto pooling = - p.add_instruction(migraphx::op::pooling{"average", {0, 0}, {1, 1}, {8, 8}}, concat_t); - p.add_instruction(migraphx::op::relu{}, pooling); - return p; - } -}; - -struct test_global_avg_pooling : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto input = - p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - auto op = migraphx::op::pooling{"average"}; - auto lens = input->get_shape().lens(); - op.lengths = {lens[2], lens[3]}; - p.add_instruction(op, input); - return p; - } -}; - -struct test_global_max_pooling : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto input = - p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - auto op = migraphx::op::pooling{"max"}; - auto lens = input->get_shape().lens(); - op.lengths = {lens[2], lens[3]}; - p.add_instruction(op, input); - return p; - } -}; - -struct test_gemm : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}}); - auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}}); - p.add_instruction(migraphx::op::dot{}, a, b); - return p; - } -}; - -struct test_gemm_copy : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; - migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; - migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); - auto pc = p.add_parameter("c", sc); - auto dr = p.add_instruction(migraphx::op::dot{}, pa, pb, pc); - p.add_instruction(migraphx::op::add{}, dr, dr); - - return p; - } -}; - -struct test_gemm_ex : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 4, 5}}); - auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}}); - p.add_instruction(migraphx::op::dot{}, a, b); - return p; - } -}; - -struct test_gemm_half : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::half_type, {4, 5}}); - auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::half_type, {5, 3}}); - p.add_instruction(migraphx::op::dot{}, a, b); - return p; - } -}; - -struct test_gemm_ld //: verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto a = - p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}, {10, 1}}); - auto b = - p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}, {20, 1}}); - p.add_instruction(migraphx::op::dot{}, a, b); - return p; - } -}; - -struct test_gemm_transposeb : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}}); - auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}}); - auto bt = p.add_instruction(migraphx::op::transpose{{1, 0}}, b); - p.add_instruction(migraphx::op::dot{}, a, bt); - return p; - } -}; - -struct test_gemm_transposeb_ex : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 4, 5}}); - auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}}); - auto bt = p.add_instruction(migraphx::op::transpose{{0, 2, 1}}, b); - p.add_instruction(migraphx::op::dot{}, a, bt); - return p; - } -}; - -struct test_gemm_transposea : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}}); - auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}}); - auto at = p.add_instruction(migraphx::op::transpose{{1, 0}}, a); - p.add_instruction(migraphx::op::dot{}, at, b); - return p; - } -}; - -struct test_gemm_transposea_ex : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); - auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}}); - auto at = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, a); - p.add_instruction(migraphx::op::dot{}, at, b); - return p; - } -}; - -struct test_gemm_transposeab : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}}); - auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}}); - auto at = p.add_instruction(migraphx::op::transpose{{1, 0}}, a); - auto bt = p.add_instruction(migraphx::op::transpose{{1, 0}}, b); - p.add_instruction(migraphx::op::dot{}, at, bt); - return p; - } -}; - -struct gemm_multi_dim_2 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}}; - auto l1 = p.add_parameter("1", m1_shape); - auto l2 = p.add_parameter("2", m2_shape); - - p.add_instruction(migraphx::op::dot{}, l1, l2); - - return p; - } -}; - -struct gemm_2args_mm_1 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; - auto l1 = p.add_parameter("1", m1_shape); - auto l2 = p.add_parameter("2", m2_shape); - auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2); - - p.add_instruction(migraphx::op::dot{}, l1, bl2); - - return p; - } -}; - -struct gemm_2args_mm_2 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}}; - auto l1 = p.add_parameter("1", m1_shape); - auto l2 = p.add_parameter("2", m2_shape); - auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2); - - p.add_instruction(migraphx::op::dot{}, l1, bl2); - - return p; - } -}; - -struct gemm_2args_mm_3 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}}; - auto l1 = p.add_parameter("1", m1_shape); - auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1); - auto l2 = p.add_parameter("2", m2_shape); - - p.add_instruction(migraphx::op::dot{}, bl1, l2); - - return p; - } -}; - -struct gemm_2args_mm_4 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}}; - auto l1 = p.add_parameter("1", m1_shape); - auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1); - auto l2 = p.add_parameter("2", m2_shape); - - p.add_instruction(migraphx::op::dot{}, bl1, l2); - - return p; - } -}; - -struct gemm_2args_mm_5 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}}; - auto l1 = p.add_parameter("1", m1_shape); - auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1); - auto l2 = p.add_parameter("2", m2_shape); - - p.add_instruction(migraphx::op::dot{}, bl1, l2); - - return p; - } -}; - -struct gemm_2args_mm_6 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}}; - auto l1 = p.add_parameter("1", m1_shape); - auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1); - auto l2 = p.add_parameter("2", m2_shape); - auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, l2); - - p.add_instruction(migraphx::op::dot{}, bl1, bl2); - - return p; - } -}; - -struct gemm_2args_mm_7 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}}; - auto l1 = p.add_parameter("1", m1_shape); - auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1); - auto l2 = p.add_parameter("2", m2_shape); - - p.add_instruction(migraphx::op::dot{}, bl1, l2); - - return p; - } -}; - -struct gemm_multi_dim_2_3 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}}; - auto l1 = p.add_parameter("1", m1_shape); - auto l2 = p.add_parameter("2", m2_shape); - - p.add_instruction(migraphx::op::dot{}, l1, l2); - - return p; - } -}; - -struct gemm_2args_vv : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {8}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {8}}; - auto l1 = p.add_parameter("1", m1_shape); - auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1); - auto l2 = p.add_parameter("2", m2_shape); - auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2); - float alpha = 0.23f; - - auto res = p.add_instruction(migraphx::op::dot{alpha}, ul1, ul2); - auto sres = p.add_instruction(migraphx::op::squeeze{{0}}, res); - p.add_instruction(migraphx::op::squeeze{{0}}, sres); - - return p; - } -}; - -struct gemm_2args_mv : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {3, 5}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {5}}; - auto l1 = p.add_parameter("1", m1_shape); - auto l2 = p.add_parameter("2", m2_shape); - auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2); - - p.add_instruction(migraphx::op::dot{}, l1, ul2); - - return p; - } -}; - -struct gemm_2args_bmv : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 3, 5}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {5}}; - auto l1 = p.add_parameter("1", m1_shape); - auto l2 = p.add_parameter("2", m2_shape); - auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2); - auto bul2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 5, 1}}, ul2); - - p.add_instruction(migraphx::op::dot{}, l1, bul2); - - return p; - } -}; - -struct gemm_2args_vm : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {5}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {5, 4}}; - auto l1 = p.add_parameter("1", m1_shape); - auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1); - auto l2 = p.add_parameter("2", m2_shape); - - auto res = p.add_instruction(migraphx::op::dot{}, ul1, l2); - p.add_instruction(migraphx::op::squeeze{{0}}, res); - - return p; - } -}; - -struct gemm_2args_vbm : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {5}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {2, 2, 5, 4}}; - auto l1 = p.add_parameter("1", m1_shape); - auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1); - auto bul1 = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 1, 5}}, ul1); - - auto l2 = p.add_parameter("2", m2_shape); - - auto res = p.add_instruction(migraphx::op::dot{}, bul1, l2); - p.add_instruction(migraphx::op::squeeze{{2}}, res); - - return p; - } -}; - -struct gemm_multi_3args : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}}; - migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}}; - - auto l1 = p.add_parameter("1", m1_shape); - auto l2 = p.add_parameter("2", m2_shape); - auto l3 = p.add_parameter("3", m3_shape); - float alpha = 0.35; - float beta = 0.41; - p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3); - - return p; - } -}; - -struct gemm_multi_3args_c25 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {3, 5}}; - migraphx::shape m3_shape{migraphx::shape::float_type, {2, 5}}; - - auto l1 = p.add_parameter("1", m1_shape); - auto l2 = p.add_parameter("2", m2_shape); - auto l3 = p.add_parameter("3", m3_shape); - float alpha = 0.35; - float beta = 0.41; - p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3); - - return p; - } -}; - -struct gemm_multi_3args_beta0 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; - migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}}; - auto l1 = p.add_parameter("1", m1_shape); - auto l2 = p.add_parameter("2", m2_shape); - auto l3 = p.add_parameter("3", m3_shape); - - float alpha = 1.0f; - float beta = 0.0f; - p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3); - - return p; - } -}; - -struct gemm_multi_3args_alpha0 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; - migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}}; - auto l1 = p.add_parameter("1", m1_shape); - auto l2 = p.add_parameter("2", m2_shape); - auto l3 = p.add_parameter("3", m3_shape); - - float alpha = 0.0f; - float beta = 1.0f; - p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3); - - return p; - } -}; - -struct quant_dot_3args_1 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; - migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; - - auto l1 = p.add_parameter("a", m1_shape); - auto l2 = p.add_parameter("b", m2_shape); - auto l3 = p.add_parameter("c", m3_shape); - p.add_instruction(migraphx::op::quant_dot{}, l1, l2, l3); - return p; - } -}; - -struct quant_dot_3args_2 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; - migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; - - auto l1 = p.add_parameter("a", m1_shape); - auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); - auto l2 = p.add_parameter("b", m2_shape); - auto l3 = p.add_parameter("c", m3_shape); - p.add_instruction(migraphx::op::quant_dot{1, 3}, tl1, l2, l3); - return p; - } -}; - -struct quant_dot_3args_3 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; - migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; - - auto l1 = p.add_parameter("a", m1_shape); - auto l2 = p.add_parameter("b", m2_shape); - auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2); - auto l3 = p.add_parameter("c", m3_shape); - p.add_instruction(migraphx::op::quant_dot{2, 3}, l1, tl2, l3); - return p; - } -}; - -struct quant_dot_3args_4 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; - migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; - - auto l1 = p.add_parameter("a", m1_shape); - auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); - auto l2 = p.add_parameter("b", m2_shape); - auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2); - auto l3 = p.add_parameter("c", m3_shape); - p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3); - return p; - } -}; - -struct batch_quant_dot_1 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}}; - migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; - - auto l1 = p.add_parameter("a", m1_shape); - auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1); - auto l2 = p.add_parameter("b", m2_shape); - auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2); - auto l3 = p.add_parameter("c", m3_shape); - p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3); - return p; - } -}; - -struct batch_quant_dot_2 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}}; - migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; - - auto l1 = p.add_parameter("a", m1_shape); - auto l2 = p.add_parameter("b", m2_shape); - auto l3 = p.add_parameter("c", m3_shape); - p.add_instruction(migraphx::op::quant_dot{1, 3}, l1, l2, l3); - return p; - } -}; - -struct test_contiguous : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}}; - auto x = p.add_parameter("x", s); - p.add_instruction(migraphx::op::contiguous{}, x); - EXPECT(p.get_shape().standard()); - return p; - } -}; - -struct test_contiguous_broadcast : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {1, 2}, {0, 1}}; - auto x = p.add_parameter("x", s); - p.add_instruction(migraphx::op::contiguous{}, x); - EXPECT(p.get_shape().standard()); - return p; - } -}; - -struct test_contiguous_broadcast_transpose : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {1, 3072, 768}, {0, 1, 3072}}; - auto x = p.add_parameter("x", s); - p.add_instruction(migraphx::op::contiguous{}, x); - EXPECT(p.get_shape().standard()); - return p; - } -}; - -struct test_transpose : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {4, 3, 4, 4}}; - auto x = p.add_parameter("x", s); - std::vector perm = {0, 2, 3, 1}; - auto l = p.add_instruction(migraphx::op::transpose{perm}, x); - p.add_instruction(migraphx::op::contiguous{}, l); - return p; - } -}; - -struct test_batchnorm_inference_2 : verify_program -{ - const size_t width = 14; - const size_t height = 14; - const size_t channels = 256; - const size_t batches = 1; - - migraphx::program create_program() const - { - migraphx::program p; - - migraphx::shape s{migraphx::shape::float_type, {batches, channels, height, width}}; - migraphx::shape vars{migraphx::shape::float_type, {channels}}; - auto x = p.add_parameter("x", s); - auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); - auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); - auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); - auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); - p.add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance); - return p; - } -}; - -struct test_batchnorm_inference : verify_program -{ - const size_t width = 3; - const size_t height = 3; - const size_t channels = 3; - const size_t batches = 4; - - migraphx::program create_program() const - { - migraphx::program p; - - migraphx::shape s{migraphx::shape::float_type, {batches, channels, height, width}}; - migraphx::shape vars{migraphx::shape::float_type, {channels}}; - auto x = p.add_parameter("x", s); - auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); - auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); - auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); - auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); - p.add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance); - return p; - } -}; - -struct test_clip : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3}}); - p.add_instruction(migraphx::op::clip{6.0, 0.0}, x); - return p; - } -}; - -struct test_conv_bn : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - - migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}}; - migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}}; - migraphx::shape vars{migraphx::shape::float_type, {64}}; - auto x = p.add_parameter("x", xs); - auto w = p.add_parameter("w", ws); - auto conv = p.add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w); - auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); - auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); - auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); - auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); - p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance); - return p; - } -}; - -struct test_conv_bn_relu_pooling : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - - migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}}; - migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}}; - migraphx::shape vars{migraphx::shape::float_type, {64}}; - auto x = p.add_parameter("x", xs); - auto w = p.add_parameter("w", ws); - auto conv = p.add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w); - auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); - auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); - auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); - auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); - auto bn = p.add_instruction( - migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance); - auto relu = p.add_instruction(migraphx::op::relu{}, bn); - p.add_instruction(migraphx::op::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu); - return p; - } -}; - -struct quant_conv : verify_program -{ - migraphx::program create_program() - { - migraphx::program p; - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; - auto pa = p.add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; - auto pc = p.add_parameter("c", c_shape); - p.add_instruction(migraphx::op::quant_convolution{}, pa, pc); - return p; - } -}; - -struct quant_conv_default_mode : verify_program -{ - migraphx::program create_program() - { - migraphx::program p; - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; - auto pa = p.add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; - auto pc = p.add_parameter("c", c_shape); - p.add_instruction( - migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same}, - pa, - pc); - return p; - } -}; - -struct quant_conv_valid_mode : verify_program -{ - migraphx::program create_program() - { - migraphx::program p; - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; - auto pa = p.add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; - auto pc = p.add_parameter("c", c_shape); - p.add_instruction( - migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid}, - pa, - pc); - return p; - } -}; - -struct quant_conv_padding : verify_program -{ - migraphx::program create_program() - { - migraphx::program p; - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; - auto pa = p.add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; - auto pc = p.add_parameter("c", c_shape); - p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{1, 1}}}, pa, pc); - return p; - } -}; - -struct quant_conv_padding_stride : verify_program -{ - migraphx::program create_program() - { - migraphx::program p; - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; - auto pa = p.add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; - auto pc = p.add_parameter("c", c_shape); - p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{2, 2}}}, pa, pc); - - return p; - } -}; - -struct test_concat_axis_1 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - int axis = 1; - migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; - migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; - migraphx::shape s2{migraphx::shape::int32_type, {2, 1}}; - auto l0 = p.add_parameter("x", s0); - auto l1 = p.add_parameter("y", s1); - auto l2 = p.add_parameter("z", s2); - p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); - return p; - } -}; - -struct test_concat_axis_neg_1 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - int axis = -1; - migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; - migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; - migraphx::shape s2{migraphx::shape::int32_type, {2, 1}}; - auto l0 = p.add_parameter("x", s0); - auto l1 = p.add_parameter("y", s1); - auto l2 = p.add_parameter("z", s2); - p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); - return p; - } -}; - -struct test_concat_axis_0 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - int axis = 0; - migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; - migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; - migraphx::shape s2{migraphx::shape::int32_type, {1, 2}}; - auto l0 = p.add_parameter("x", s0); - auto l1 = p.add_parameter("y", s1); - auto l2 = p.add_parameter("z", s2); - p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); - return p; - } -}; - -struct test_concat_transpose : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - int axis = 1; - migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; - migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; - migraphx::shape s2{migraphx::shape::int32_type, {2, 4}}; - auto l0 = p.add_parameter("x", s0); - auto lp1 = p.add_parameter("y", s1); - auto l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp1); - auto l2 = p.add_parameter("z", s2); - p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); - return p; - } -}; - -struct test_concat_transpose2 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - int axis = 1; - migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; - migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; - migraphx::shape s2{migraphx::shape::int32_type, {5, 2}}; - auto l0 = p.add_parameter("x", s0); - auto l1 = p.add_parameter("y", s1); - auto lp2 = p.add_parameter("z", s2); - auto l2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp2); - p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); - return p; - } -}; - -struct test_concat_transpose3 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - int axis = 1; - migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; - migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; - migraphx::shape s2{migraphx::shape::int32_type, {5, 2}}; - auto l0 = p.add_parameter("x", s0); - auto lp1 = p.add_parameter("y", s1); - auto l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp1); - auto lp2 = p.add_parameter("z", s2); - auto l2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp2); - p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); - return p; - } -}; - -struct test_concat_relu : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - int axis = 0; - migraphx::shape s0{migraphx::shape::float_type, {2, 2}}; - migraphx::shape s1{migraphx::shape::float_type, {3, 2}}; - migraphx::shape s2{migraphx::shape::float_type, {1, 2}}; - auto l0 = p.add_parameter("x", s0); - auto l1 = p.add_parameter("y", s1); - auto l2 = p.add_parameter("z", s2); - auto r0 = p.add_instruction(migraphx::op::relu{}, l0); - auto r1 = p.add_instruction(migraphx::op::relu{}, l1); - auto r2 = p.add_instruction(migraphx::op::relu{}, l2); - auto c0 = p.add_instruction(migraphx::op::concat{axis}, r0, r1, r2); - p.add_instruction(migraphx::op::relu{}, c0); - return p; - } -}; - -struct test_pad : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s0{migraphx::shape::int32_type, {1, 96, 165, 165}}; - std::vector pads0 = {0, 0, 0, 0, 0, 0, 1, 1}; - std::vector pads1 = {0, 0, 0, 0, 1, 1, 1, 1}; - std::vector pads2 = {1, 1, 1, 1, 0, 0, 0, 0}; - std::vector pads3 = {1, 0, 1, 0, 1, 0, 2, 0}; - auto l0 = p.add_parameter("x", s0); - p.add_instruction(migraphx::op::pad{pads0}, l0); - p.add_instruction(migraphx::op::pad{pads1}, l0); - p.add_instruction(migraphx::op::pad{pads2}, l0); - p.add_instruction(migraphx::op::pad{pads3}, l0); - return p; - } -}; - -struct test_pad_int8 : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - std::vector data0 = {0, 1, 2, 3}; - migraphx::shape s0{migraphx::shape::float_type, {2, 2}}; - auto l0 = p.add_literal(migraphx::literal{s0, data0}); - migraphx::op::pad op{}; - op.value = std::numeric_limits::lowest(); - op.pads = {0, 0, 1, 1}; - p.add_instruction(op, l0); - return p; - } -}; - -struct test_pooling_autopad : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s0{migraphx::shape::float_type, {1, 3, 63, 63}}; - auto l0 = p.add_parameter("x", s0); - migraphx::op::pooling op{"max"}; - op.padding_mode = migraphx::op::padding_mode_t::same; - op.lengths = {2, 2}; - op.stride = {2, 2}; - p.add_instruction(op, l0); - return p; - } -}; - -struct test_gather : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 3}}; - migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}}; - std::vector indices{1, 2, 2, 1}; - auto a0 = p.add_parameter("data", s); - auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); - int axis = 0; - p.add_instruction(migraphx::op::gather{axis}, a0, a1); - return p; - } -}; - -struct test_gather_neg_axis : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 3}}; - migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}}; - std::vector indices{1, 2, 2, 1}; - auto a0 = p.add_parameter("data", s); - auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); - int axis = -1; - p.add_instruction(migraphx::op::gather{axis}, a0, a1); - return p; - } -}; - -struct test_gather_neg_indices : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 3}}; - migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}}; - std::vector indices{-2, -1, -1, -2}; - auto a0 = p.add_parameter("data", s); - auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); - int axis = -1; - p.add_instruction(migraphx::op::gather{axis}, a0, a1); - return p; - } -}; - -struct test_gather_scalar_output : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3}}; - migraphx::shape s_indices{migraphx::shape::int32_type}; - std::vector indices{1}; - auto a0 = p.add_parameter("data", s); - auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); - int axis = 0; - p.add_instruction(migraphx::op::gather{axis}, a0, a1); - return p; - } -}; - -struct test_gather_scalar_index : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 3}}; - migraphx::shape s_indices{migraphx::shape::int32_type}; - std::vector indices{1}; - auto a0 = p.add_parameter("data", s); - auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); - int axis = -1; - p.add_instruction(migraphx::op::gather{axis}, a0, a1); - return p; - } -}; - -struct test_gather_1d_index : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {3, 3}}; - migraphx::shape s_indices{migraphx::shape::int32_type, {1}}; - std::vector indices{1}; - auto a0 = p.add_parameter("data", s); - auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); - int axis = -1; - p.add_instruction(migraphx::op::gather{axis}, a0, a1); - return p; - } -}; - -void manual_identity() -{ - migraphx::program p; - std::vector data0 = {0, 1, 2, 3}; - migraphx::shape s0{migraphx::shape::float_type, {2, 2}}; - auto l0 = p.add_literal(migraphx::literal{s0, data0}); - p.add_instruction(migraphx::op::identity{}, l0); - p.compile(migraphx::gpu::target{}); - migraphx::program::parameter_map m; - for(auto&& x : p.get_parameter_shapes()) - { - m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second)); - } - auto result = migraphx::gpu::from_gpu(p.eval(m)); - std::cout << result << std::endl; -} - -void manual_test_concat_relu() -{ - migraphx::program p; - int axis = 0; - std::vector data0 = {0, 1, 2, 3}; - std::vector data1 = {4, 5, 6, 7, 8, 9}; - std::vector data2 = {10, 11}; - migraphx::shape s0{migraphx::shape::float_type, {2, 2}}; - migraphx::shape s1{migraphx::shape::float_type, {3, 2}}; - migraphx::shape s2{migraphx::shape::float_type, {1, 2}}; - auto l0 = p.add_literal(migraphx::literal{s0, data0}); - auto l1 = p.add_literal(migraphx::literal{s1, data1}); - auto l2 = p.add_literal(migraphx::literal{s2, data2}); - auto r0 = p.add_instruction(migraphx::op::relu{}, l0); - auto r1 = p.add_instruction(migraphx::op::relu{}, l1); - auto r2 = p.add_instruction(migraphx::op::relu{}, l2); - auto c0 = p.add_instruction(migraphx::op::concat{axis}, r0, r1, r2); - p.add_instruction(migraphx::op::relu{}, c0); - - p.compile(migraphx::gpu::target{}); - migraphx::program::parameter_map m; - for(auto&& x : p.get_parameter_shapes()) - { - m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second)); - } - auto result = migraphx::gpu::from_gpu(p.eval(m)); - std::cout << result << std::endl; -} - -struct test_conv_bn_relu_pooling2 : verify_program -{ - static migraphx::instruction_ref - add_bn(migraphx::program& p, migraphx::instruction_ref x, std::size_t channels) - { - migraphx::shape vars{migraphx::shape::float_type, {channels}}; - auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + channels))); - auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + channels))); - auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + channels))); - auto variance = - p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + channels))); - return p.add_instruction( - migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance); - } - migraphx::program create_program() const - { - migraphx::program p; - - migraphx::shape xs1{migraphx::shape::float_type, {1, 512, 7, 7}}; - migraphx::shape xs2{migraphx::shape::float_type, {1, 1024, 14, 14}}; - migraphx::shape ws1{migraphx::shape::float_type, {2048, 512, 1, 1}}; - migraphx::shape ws2{migraphx::shape::float_type, {2048, 1024, 1, 1}}; - auto x1 = p.add_parameter("x1", xs1); - auto w1 = p.add_parameter("w1", ws1); - auto conv1 = p.add_instruction(migraphx::op::convolution{{0, 0}, {1, 1}, {1, 1}}, x1, w1); - auto bn1 = add_bn(p, conv1, 2048); - auto x2 = p.add_parameter("x2", xs2); - auto w2 = p.add_parameter("w2", ws2); - auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}, {1, 1}}, x2, w2); - auto bn2 = add_bn(p, conv2, 2048); - auto add = p.add_instruction(migraphx::op::add{}, bn1, bn2); - auto relu = p.add_instruction(migraphx::op::relu{}, add); - p.add_instruction(migraphx::op::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu); - return p; - } -}; - -struct test_rnn_forward : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 1; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto output = - p.add_instruction(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, output); - - return p; - } -}; - -struct test_rnn_forward10 : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 10; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto output = - p.add_instruction(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, output); - - return p; - } -}; - -struct test_rnn_reverse : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 1; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip}, - seq, - w, - r, - bias, - und, - ih); - - return p; - } -}; - -struct test_rnn_reverse2 : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip}, - seq, - w, - r, - bias, - und, - ih); - - return p; - } -}; - -struct test_rnn_3args : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 1; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - - p.add_instruction(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip}, - seq, - w, - r); - - return p; - } -}; - -struct test_rnn_4args : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 5; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - - p.add_instruction(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip}, - seq, - w, - r, - bias); - - return p; - } -}; - -struct test_rnn_5args : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 10; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto output = - p.add_instruction(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r, - bias, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, output); - - return p; - } -}; - -struct test_rnn_bidirectional : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 1; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto output = - p.add_instruction(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, output); - - return p; - } -}; - -struct test_rnn_bidirectional10 : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 10; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto output = - p.add_instruction(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, output); - - return p; - } -}; - -struct test_rnn_bi_3args : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 10; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto output = - p.add_instruction(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r); - p.add_instruction(migraphx::op::rnn_last_output{}, output); - - return p; - } -}; - -struct test_gru_forward_last : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto output = - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, output); - - return p; - } -}; - -struct test_gru_forward_hs : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r, - bias, - und, - ih); - - return p; - } -}; - -struct test_gru_forward_3args_und : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r, - und, - und, - und); - - return p; - } -}; - -struct test_gru_forward_3args : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r); - - return p; - } -}; - -struct test_gru_forward_seq1 : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 1; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r); - - return p; - } -}; - -struct test_gru_forward_default_actv : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 1; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - p.add_instruction( - migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, - seq, - w, - r); - - return p; - } -}; - -struct test_gru_forward_default_actv1 : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction( - migraphx::op::gru{ - hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip}, - seq, - w, - r, - bias, - und, - ih); - - return p; - } -}; - -struct test_gru_reverse_last : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto output = - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, output); - - return p; - } -}; - -struct test_gru_reverse_3args : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip}, - seq, - w, - r); - - return p; - } -}; - -struct test_gru_bidirct_last : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto output = - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - und, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, output); - - return p; - } -}; - -struct test_gru_bidirct_hs : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - und, - ih); - - return p; - } -}; - -struct test_gru_bidirct_3args_und : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - und, - und, - und); - - return p; - } -}; - -struct test_gru_bidirct_3args : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r); - - return p; - } -}; - -struct test_gru_bidirct_seq1 : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 1; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r); - - return p; - } -}; - -struct test_gru_bidirct_default_actv : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 1; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - p.add_instruction( - migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip}, - seq, - w, - r); - - return p; - } -}; - -struct test_gru_bidirct_default_actv1 : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction(migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - und, - ih); - - return p; - } -}; - -struct test_lstm_forward_last : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto ic = p.add_parameter("ic", ic_shape); - auto pph = p.add_parameter("pph", pph_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto output = p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r, - bias, - und, - ih, - ic, - pph); - p.add_instruction(migraphx::op::rnn_last_output{}, output); - - return p; - } -}; - -struct test_lstm_forward_hs : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto ic = p.add_parameter("ic", ic_shape); - auto pph = p.add_parameter("pph", pph_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r, - bias, - und, - ih, - ic, - pph); - - return p; - } -}; - -struct test_lstm_forward_3args_und : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r, - und, - und, - und, - und, - und); - - return p; - } -}; - -struct test_lstm_forward_3args : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r); - - return p; - } -}; - -struct test_lstm_forward_seq1 : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 1; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r); - - return p; - } -}; - -struct test_lstm_forward_default_actv : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 1; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - p.add_instruction( - migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, - seq, - w, - r); - - return p; - } -}; - -struct test_lstm_forward_default_actv1 : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction( - migraphx::op::lstm{ - hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip}, - seq, - w, - r, - bias, - und, - ih); - - return p; - } -}; - -struct test_lstm_reverse_last : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto ic = p.add_parameter("ic", ic_shape); - auto pph = p.add_parameter("pph", pph_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto output = p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip}, - seq, - w, - r, - bias, - und, - ih, - ic, - pph); - p.add_instruction(migraphx::op::rnn_last_output{}, output); - - return p; - } -}; - -struct test_lstm_reverse_3args : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip}, - seq, - w, - r); - - return p; - } -}; - -struct test_lstm_reverse_3args_cell_output : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto hs = p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip}, - seq, - w, - r); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs); - - return p; - } -}; - -struct test_lstm_bidirct_last : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto ic = p.add_parameter("ic", ic_shape); - auto pph = p.add_parameter("pph", pph_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto output = p.add_instruction( - migraphx::op::lstm{ - hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - und, - ih, - ic, - pph); - p.add_instruction(migraphx::op::rnn_last_output{}, output); - - return p; - } -}; - -struct test_lstm_bidirct_hs : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction(migraphx::op::lstm{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - und, - ih); - - return p; - } -}; - -struct test_lstm_bidirct_3args_und : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - p.add_instruction( - migraphx::op::gru{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - und, - und, - und, - und, - und); - - return p; - } -}; - -struct test_lstm_bidirct_3args : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - p.add_instruction(migraphx::op::lstm{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r); - - return p; - } -}; - -struct test_lstm_bidirct_seq1 : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 1; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - p.add_instruction(migraphx::op::lstm{hidden_size, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r); - - return p; - } -}; - -struct test_lstm_bidirct_default_actv : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 1; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - p.add_instruction( - migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip}, - seq, - w, - r); - - return p; - } -}; - -struct test_lstm_bidirct_default_actv1 : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction(migraphx::op::lstm{hidden_size, - {migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - und, - ih); - - return p; - } -}; - -struct test_lstm_bidirct_default_actv2 : verify_program -{ - migraphx::program create_program() const - { - std::size_t batch_size = 2; - std::size_t seq_len = 3; - std::size_t hidden_size = 5; - std::size_t input_size = 8; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::program p; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 4 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - auto seq = p.add_parameter("seq", in_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", b_shape); - auto ih = p.add_parameter("ih", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - p.add_instruction(migraphx::op::lstm{hidden_size, - {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - und, - ih); - - return p; - } -}; - -template -struct test_logsoftmax : verify_program> -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{T, {10, 4, 2080, 6}}; - auto param = p.add_parameter("0", s); - p.add_instruction(migraphx::op::logsoftmax{Axis}, param); - - return p; - } -}; - -template struct test_logsoftmax<0, migraphx::shape::float_type>; -template struct test_logsoftmax<1, migraphx::shape::float_type>; -template struct test_logsoftmax<2, migraphx::shape::float_type>; -template struct test_logsoftmax<3, migraphx::shape::float_type>; -template struct test_logsoftmax<1, migraphx::shape::double_type>; -template struct test_logsoftmax<3, migraphx::shape::double_type>; -template struct test_logsoftmax<1, migraphx::shape::half_type>; -template struct test_logsoftmax<0, migraphx::shape::half_type>; -template struct test_logsoftmax<2, migraphx::shape::half_type>; -template struct test_logsoftmax<3, migraphx::shape::half_type>; - -struct test_fp32_fp16_lall : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - std::vector data(2 * 3); - std::iota(data.begin(), data.end(), 1.0f); - auto l1 = p.add_literal(migraphx::literal(s, data)); - auto l2 = p.add_parameter("p2", s); - p.add_instruction(migraphx::op::add{}, l1, l2); - migraphx::quantize_fp16(p, {"all"}); - return p; - }; -}; - -struct test_fp32_fp16_ladd : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - std::vector data(2 * 3); - std::iota(data.begin(), data.end(), 1.0f); - auto l1 = p.add_literal(migraphx::literal(s, data)); - auto l2 = p.add_parameter("p2", s); - p.add_instruction(migraphx::op::add{}, l1, l2); - migraphx::quantize_fp16(p, {"add"}); - return p; - }; -}; - -struct test_fp32_fp16_add : verify_program -{ - migraphx::program create_program() - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - auto p1 = p.add_parameter("x", s); - auto p2 = p.add_parameter("y", s); - auto sum = p.add_instruction(migraphx::op::add{}, p1, p2); - auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2); - p.add_instruction(migraphx::op::add{}, diff, p1); - migraphx::quantize_fp16(p, {"add"}); - - return p; - }; -}; - -struct test_fp32_fp16_sub : verify_program -{ - migraphx::program create_program() - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - auto p1 = p.add_parameter("x", s); - auto p2 = p.add_parameter("y", s); - auto sum = p.add_instruction(migraphx::op::add{}, p1, p2); - auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2); - p.add_instruction(migraphx::op::add{}, diff, p1); - migraphx::quantize_fp16(p, {"sub"}); - - return p; - }; -}; - -template -struct test_reduce_op_large : verify_program> -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{T, {3, 1026, 4, 3}}; - auto x = p.add_parameter("x", s); - p.add_instruction(Op{{1}}, x); - return p; - }; -}; - -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; - -template -struct test_reduce_op_small : verify_program> -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{T, {3, 4, 8, 8}}; - auto x = p.add_parameter("x", s); - p.add_instruction(Op{{1}}, x); - return p; - }; -}; -template struct test_reduce_op_small; -template struct test_reduce_op_small; -template struct test_reduce_op_small; -template struct test_reduce_op_small; - -template struct test_reduce_op_small; -template struct test_reduce_op_small; -template struct test_reduce_op_small; -template struct test_reduce_op_small; - -struct test_rsqrt : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}}; - auto x = p.add_parameter("x", s); - auto l0 = p.add_instruction(migraphx::op::clip{std::numeric_limits::max(), 1.0}, x); - p.add_instruction(migraphx::op::rsqrt{}, l0); - return p; - }; -}; - -struct test_round : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - - migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; - auto param = p.add_parameter("x", s); - p.add_instruction(migraphx::op::round{}, param); - return p; - }; -}; - -struct test_ceil : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - - migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; - auto param = p.add_parameter("x", s); - p.add_instruction(migraphx::op::ceil{}, param); - return p; - }; -}; - -struct test_floor : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - - migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; - auto param = p.add_parameter("x", s); - p.add_instruction(migraphx::op::floor{}, param); - return p; - }; -}; - -struct test_convert : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - migraphx::shape sa{migraphx::shape::float_type, {8, 24}}; - migraphx::shape sb{migraphx::shape::float_type, {24, 6}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); - auto ia = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pa); - auto ib = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pb); - p.add_instruction(migraphx::op::quant_dot{}, ia, ib); - - return p; - }; -}; - -int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/gpu/pack_args.cpp b/test/gpu/pack_args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b32723aa289f3d9777698580e35e20ad573d66a5 --- /dev/null +++ b/test/gpu/pack_args.cpp @@ -0,0 +1,43 @@ +#include +#include + +template +std::size_t packed_sizes() +{ + return sizeof(T); +} + +template +std::size_t packed_sizes() +{ + return sizeof(T) + packed_sizes(); +} + +template +std::size_t sizes() +{ + return migraphx::gpu::pack_args({Ts{}...}).size(); +} + +template +std::size_t padding() +{ + EXPECT(sizes() >= packed_sizes()); + return sizes() - packed_sizes(); +} + +struct float_struct +{ + float x, y; +}; + +TEST_CASE(alignment_padding) +{ + EXPECT(padding() == 0); + EXPECT(padding() == 0); + EXPECT(padding() == 2); + EXPECT(padding() == 2); + EXPECT(padding() == 1); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/gpu/pack_int8_args.cpp b/test/gpu/pack_int8_args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..585d6ca3bb774ced032dba2fb2d1bc1deaf1a29e --- /dev/null +++ b/test/gpu/pack_int8_args.cpp @@ -0,0 +1,433 @@ +#include "migraphx/instruction_ref.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void run_passes(migraphx::module& m) +{ + auto ctx = migraphx::gpu::context{}; + migraphx::run_passes(m, + {migraphx::auto_contiguous{}, + migraphx::gpu::lowering{&ctx, false}, + migraphx::dead_code_elimination{}, + migraphx::replace_allocate{migraphx::gpu::gpu_allocation_model{}}, + migraphx::dead_code_elimination{}, + migraphx::gpu::pack_int8_args{}, + migraphx::dead_code_elimination{}}); +} + +bool get_int8_x4_format() +{ + bool int8_x4_format = true; +#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 + auto ctx = migraphx::gpu::context{}; + rocblas_gemm_flags flag; + rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag); + int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4); +#endif + return int8_x4_format; +} + +TEST_CASE(quant_dot) +{ + auto create_module = [] { + migraphx::module m("test"); + migraphx::shape m1_shape{migraphx::shape::int8_type, {5, 8}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; + migraphx::shape m3_shape{migraphx::shape::int32_type, {5, 7}}; + + auto l1 = m.add_parameter("a", m1_shape); + auto l2 = m.add_parameter("b", m2_shape); + auto l3 = m.add_parameter("c", m3_shape); + auto r = + migraphx::add_apply_alpha_beta(m, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1); + m.add_return({r}); + return m; + }; + + auto create_optimized_int8_x4 = [](bool int8_x4) { + migraphx::module m("test"); + migraphx::shape m1_shape{migraphx::shape::int8_type, {5, 8}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; + migraphx::shape m3_shape{migraphx::shape::int32_type, {5, 7}}; + + auto l1 = m.add_parameter("a", m1_shape); + auto l2 = m.add_parameter("b", m2_shape); + auto l3 = m.add_parameter("c", m3_shape); + auto beta = m.add_literal(1); + auto output = m.add_parameter("test:#output_0", m3_shape); + auto gemm_alloc = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}})); + + auto packa = l2; + if(int8_x4) + { + auto alloc = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m2_shape)}})); + packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), l2, alloc); + } + auto gemm = + m.add_instruction(migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}), + l1, + packa, + gemm_alloc); + + auto beta_broadcast = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", m3_shape.lens()}}), beta); + auto beta_alloc = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}})); + auto beta_contiguous = + m.add_instruction(migraphx::make_op("gpu::contiguous"), beta_broadcast, beta_alloc); + auto mul_alloc = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}})); + auto m3_beta = + m.add_instruction(migraphx::make_op("gpu::mul"), l3, beta_contiguous, mul_alloc); + auto gemm_add = m.add_instruction(migraphx::make_op("gpu::add"), gemm, m3_beta, output); + m.add_return({gemm_add}); + + return m; + }; + + auto m1 = create_module(); + run_passes(m1); + + bool flag = get_int8_x4_format(); + auto m2 = create_optimized_int8_x4(flag); + EXPECT(m1 == m2); +} + +TEST_CASE(quant_dot_trans) +{ + auto create_module = [] { + migraphx::module m("test"); + migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 8, 5}}; + migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 8}}; + + auto l1 = m.add_parameter("a", s1); + auto tl1 = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1); + auto l2 = m.add_parameter("b", s2); + auto tl2 = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); + auto r = migraphx::add_apply_alpha_beta(m, {tl1, tl2}, migraphx::make_op("quant_dot"), 3); + m.add_return({r}); + return m; + }; + + auto create_optimized_int8_x4 = [](bool int8_x4) { + migraphx::module m("test"); + migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 8, 5}}; + migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 8}}; + migraphx::shape s3{migraphx::shape::int32_type, {3, 2, 5, 7}}; + + auto l1 = m.add_parameter("a", s1); + auto l2 = m.add_parameter("b", s2); + auto alpha = m.add_literal(3); + auto output = m.add_parameter("test:#output_0", s3); + + auto tl1 = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1); + migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 8}}; + auto alloca = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}})); + auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, alloca); + + auto tl2 = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); + migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 8, 7}}; + auto allocb = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}})); + auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, allocb); + + auto alpha_broadcast = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", conta->get_shape().lens()}}), alpha); + auto alpha_alloc = m.add_instruction(migraphx::make_op( + "hip::allocate", + {{"shape", + migraphx::to_value(migraphx::shape(migraphx::shape::int32_type, {3, 2, 5, 8}))}})); + auto alpha_contiguous = + m.add_instruction(migraphx::make_op("gpu::contiguous"), alpha_broadcast, alpha_alloc); + // alpha = int32 and tl1 = int8, convert tl1 to int32 for multiplication and then convert + // back result to int8 + auto tl1_convert_alloc = m.add_instruction(migraphx::make_op( + "hip::allocate", {{"shape", migraphx::to_value(alpha_contiguous->get_shape())}})); + auto tl1_convert = m.add_instruction( + migraphx::make_op("gpu::convert", {{"target_type", alpha->get_shape().type()}}), + conta, + tl1_convert_alloc); + auto mul_alloc = m.add_instruction(migraphx::make_op( + "hip::allocate", {{"shape", migraphx::to_value(tl1_convert->get_shape())}})); + auto tl1_alpha_int32 = m.add_instruction( + migraphx::make_op("gpu::mul"), alpha_contiguous, tl1_convert, mul_alloc); + // convert mul_res to int8 + auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op( + "hip::allocate", {{"shape", migraphx::to_value(conta->get_shape())}})); + auto tl1_alpha_int8 = m.add_instruction( + migraphx::make_op("gpu::convert", {{"target_type", conta->get_shape().type()}}), + tl1_alpha_int32, + tl1_alpha_int8_alloc); + + auto packb = contb; + if(int8_x4) + { + auto allocpb = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}})); + packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), contb, allocpb); + } + + auto gemm = + m.add_instruction(migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}), + tl1_alpha_int8, + packb, + output); + m.add_return({gemm}); + + return m; + }; + + auto m1 = create_module(); + bool flag = get_int8_x4_format(); + auto m2 = create_optimized_int8_x4(flag); + + run_passes(m1); + + EXPECT(m1 == m2); +} + +TEST_CASE(quant_dot_pad) +{ + auto create_module = [] { + migraphx::module m("test"); + migraphx::shape s1{migraphx::shape::int8_type, {5, 6}}; + migraphx::shape s2{migraphx::shape::int8_type, {6, 7}}; + migraphx::shape s3{migraphx::shape::int32_type, {5, 7}}; + + auto l1 = m.add_parameter("a", s1); + auto l2 = m.add_parameter("b", s2); + auto l3 = m.add_parameter("c", s3); + auto r = + migraphx::add_apply_alpha_beta(m, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1); + m.add_return({r}); + return m; + }; + + auto create_optimized_int8_x4 = [](bool int8_x4) { + migraphx::module m("test"); + migraphx::shape s1{migraphx::shape::int8_type, {5, 6}}; + migraphx::shape ps1{migraphx::shape::int8_type, {5, 8}}; + migraphx::shape s2{migraphx::shape::int8_type, {6, 7}}; + migraphx::shape ps2{migraphx::shape::int8_type, {8, 7}}; + migraphx::shape s3{migraphx::shape::int32_type, {5, 7}}; + + auto l1 = m.add_parameter("a", s1); + auto l2 = m.add_parameter("b", s2); + auto l3 = m.add_parameter("c", s3); + auto beta = m.add_literal(1); + auto output = m.add_parameter("test:#output_0", s3); + + auto pl1 = l1; + auto packa = l2; + migraphx::instruction_ref pl2{}; + if(int8_x4) + { + auto po1 = m.insert_instruction( + l1, migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}})); + pl1 = m.add_instruction( + migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 2, 0, 0}}, {"value", 0}}), + l1, + po1); + + auto po2 = m.insert_instruction( + l2, migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}})); + pl2 = m.insert_instruction( + std::next(l2), + migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {2, 0, 0, 0}}, {"value", 0}}), + l2, + po2); + } + + auto gemm_alloc = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}})); + + if(int8_x4) + { + auto alloc = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}})); + packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), pl2, alloc); + } + + auto gemm = + m.add_instruction(migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}), + pl1, + packa, + gemm_alloc); + + auto beta_broadcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), beta); + auto beta_alloc = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}})); + auto beta_contiguous = + m.add_instruction(migraphx::make_op("gpu::contiguous"), beta_broadcast, beta_alloc); + auto mul_alloc = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}})); + auto m3_beta = + m.add_instruction(migraphx::make_op("gpu::mul"), l3, beta_contiguous, mul_alloc); + auto gemm_add = m.add_instruction(migraphx::make_op("gpu::add"), gemm, m3_beta, output); + m.add_return({gemm_add}); + return m; + }; + + auto m1 = create_module(); + bool flag = get_int8_x4_format(); + auto m2 = create_optimized_int8_x4(flag); + + run_passes(m1); + + EXPECT(m1 == m2); +} + +TEST_CASE(quant_dot_trans_pad) +{ + auto create_module = [] { + migraphx::module m("test"); + migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 9, 5}}; + migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 9}}; + + auto l1 = m.add_parameter("a", s1); + auto tl1 = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1); + auto l2 = m.add_parameter("b", s2); + auto tl2 = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); + auto r = migraphx::add_apply_alpha_beta(m, {tl1, tl2}, migraphx::make_op("quant_dot"), 3); + m.add_return({r}); + return m; + }; + + auto create_optimized_int8_x4 = [](bool int8_x4) { + migraphx::module m("test"); + migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 9, 5}}; + migraphx::shape ps1{migraphx::shape::int8_type, {3, 2, 5, 12}}; + migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 9}}; + migraphx::shape ps2{migraphx::shape::int8_type, {3, 2, 12, 7}}; + migraphx::shape s3{migraphx::shape::int32_type, {3, 2, 5, 7}}; + + auto l1 = m.add_parameter("a", s1); + auto l2 = m.add_parameter("b", s2); + auto alpha = m.add_literal(3); + auto output = m.add_parameter("test:#output_0", s3); + + auto tl1 = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1); + migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 9}}; + auto ta = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}})); + auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, ta); + + auto tl2 = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); + migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}}; + auto tb = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}})); + + migraphx::instruction_ref ptb{}; + if(int8_x4) + { + ptb = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}})); + } + auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, tb); + auto pb = contb; + if(int8_x4) + { + pb = m.add_instruction( + migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 0, 3, 0, 0, 0, 0, 0}}}), + contb, + ptb); + } + + auto alpha_broadcast = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", conta->get_shape().lens()}}), alpha); + auto alpha_alloc = m.add_instruction( + migraphx::make_op("hip::allocate", + {{"shape", + migraphx::to_value(migraphx::shape(migraphx::shape::int32_type, + conta->get_shape().lens()))}})); + auto alpha_contiguous = + m.add_instruction(migraphx::make_op("gpu::contiguous"), alpha_broadcast, alpha_alloc); + + // alpha = int32 and tl1 = int8, convert tl1 to int32 for multiplication and then convert + // back result to int8 + auto tl1_convert_alloc = m.add_instruction(migraphx::make_op( + "hip::allocate", {{"shape", migraphx::to_value(alpha_contiguous->get_shape())}})); + auto tl1_convert = m.add_instruction( + migraphx::make_op("gpu::convert", {{"target_type", alpha->get_shape().type()}}), + conta, + tl1_convert_alloc); + auto mul_alloc = m.add_instruction(migraphx::make_op( + "hip::allocate", {{"shape", migraphx::to_value(tl1_convert->get_shape())}})); + auto tl1_alpha_int32 = m.add_instruction( + migraphx::make_op("gpu::mul"), alpha_contiguous, tl1_convert, mul_alloc); + // convert mul_res to int8 + auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op( + "hip::allocate", {{"shape", migraphx::to_value(conta->get_shape())}})); + + migraphx::instruction_ref pta{}; + if(int8_x4) + { + pta = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}})); + } + + auto tl1_alpha_int8 = m.add_instruction( + migraphx::make_op("gpu::convert", {{"target_type", conta->get_shape().type()}}), + tl1_alpha_int32, + tl1_alpha_int8_alloc); + + auto pa = tl1_alpha_int8; + if(int8_x4) + { + pa = m.add_instruction( + migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 0, 0, 3, 0, 0, 0, 0}}}), + tl1_alpha_int8, + pta); + } + + auto packb = pb; + if(int8_x4) + { + auto allocpb = m.add_instruction( + migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}})); + packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), pb, allocpb); + } + + auto gemm = m.add_instruction( + migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}), pa, packb, output); + m.add_return({gemm}); + + return m; + }; + + auto m1 = create_module(); + bool flag = get_int8_x4_format(); + auto m2 = create_optimized_int8_x4(flag); + + run_passes(m1); + + EXPECT(m1 == m2); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/gpu/quantization.cpp b/test/gpu/quantization.cpp index 09fa547d3178979dfaccc8b391554eeef8c39c2e..e9f39b126cfefa18fa72fbe7f297567b770dae4f 100644 --- a/test/gpu/quantization.cpp +++ b/test/gpu/quantization.cpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include #include @@ -12,23 +12,23 @@ #include #include #include -#include "test.hpp" +#include #include TEST_CASE(gpu_target_copy) { migraphx::target gpu_t = migraphx::gpu::target{}; - migraphx::target cpu_t = migraphx::cpu::target{}; + migraphx::target ref_t = migraphx::ref::target{}; migraphx::shape s{migraphx::shape::int8_type, {2, 3, 4, 5}}; - auto cpu_arg_orig = migraphx::generate_argument(s, 0x123456L); - auto gpu_arg = gpu_t.copy_to(cpu_arg_orig); - auto cpu_arg_final = gpu_t.copy_from(gpu_arg); + auto ref_arg_orig = migraphx::generate_argument(s, 0x123456L); + auto gpu_arg = gpu_t.copy_to(ref_arg_orig); + auto ref_arg_final = gpu_t.copy_from(gpu_arg); std::vector val_orig; - cpu_arg_orig.visit([&](auto v) { val_orig.assign(v.begin(), v.end()); }); + ref_arg_orig.visit([&](auto v) { val_orig.assign(v.begin(), v.end()); }); std::vector val_final; - cpu_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); }); + ref_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); }); EXPECT(migraphx::verify_range(val_orig, val_final)); } @@ -37,13 +37,13 @@ TEST_CASE(int8_quantization) { auto run_prog = [](migraphx::program p, const migraphx::target& t, - migraphx::program::parameter_map& m_in, + migraphx::parameter_map& m_in, std::vector& res) { - std::vector cali_data; + std::vector cali_data; cali_data.push_back(m_in); migraphx::quantize_int8(p, t, cali_data); p.compile(t); - migraphx::program::parameter_map m; + migraphx::parameter_map m; for(auto&& x : p.get_parameter_shapes()) { if(m_in.count(x.first) > 0) @@ -56,41 +56,40 @@ TEST_CASE(int8_quantization) } } - auto result = t.copy_from(p.eval(m)); + auto result = t.copy_from(p.eval(m).back()); result.visit([&](auto v) { res.assign(v.begin(), v.end()); }); }; auto create_program = [] { migraphx::program p; - migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; + auto* mm = p.get_main_module(); + migraphx::shape sa{migraphx::shape::float_type, {5, 16}}; migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; - migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); - auto pc = p.add_parameter("c", sc); - p.add_instruction(migraphx::op::dot{}, pa, pb, pc); + migraphx::shape sc{migraphx::shape::float_type, {5, 8}}; + auto pa = mm->add_parameter("a", sa); + auto pb = mm->add_parameter("b", sb); + mm->add_instruction(migraphx::op::dot{}, pa, pb); return p; }; { auto p = create_program(); - migraphx::program::parameter_map m; - migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; + migraphx::parameter_map m; + migraphx::shape sa{migraphx::shape::float_type, {5, 16}}; migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; - migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; + migraphx::shape sc{migraphx::shape::float_type, {5, 8}}; m["a"] = migraphx::generate_argument(sa); m["b"] = migraphx::generate_argument(sb); - m["c"] = migraphx::generate_argument(sc); - std::vector cpu_result; - migraphx::target cpu_t = migraphx::cpu::target{}; - run_prog(p, cpu_t, m, cpu_result); + std::vector ref_result; + migraphx::target ref_t = migraphx::ref::target{}; + run_prog(p, ref_t, m, ref_result); std::vector gpu_result; migraphx::target gpu_t = migraphx::gpu::target{}; run_prog(p, gpu_t, m, gpu_result); - EXPECT(migraphx::verify_range(cpu_result, gpu_result)); + EXPECT(migraphx::verify_range(ref_result, gpu_result)); } } diff --git a/test/include/basic_ops.hpp b/test/include/basic_ops.hpp old mode 100644 new mode 100755 index 11a889d66826936be48bdf29276813bd77fedb5c..e48e1acede0f1b25b20729c69b9efca6613bdc6f --- a/test/include/basic_ops.hpp +++ b/test/include/basic_ops.hpp @@ -79,6 +79,29 @@ struct pass_op return {}; return inputs.front(); } + int output_alias(const std::vector& s) const { return s.empty() ? -1 : 0; } +}; + +struct mod_pass_op +{ + std::string name() const { return "mod_pass"; } + + migraphx::shape compute_shape(std::vector inputs, + std::vector mods) const + { + if(!mods.empty()) + { + auto out_shapes = mods[0]->get_output_shapes(); + return out_shapes[0]; + } + if(!inputs.empty()) + { + return inputs.front(); + } + + return {}; + } + int output_alias(const std::vector&) const { return 0; } }; diff --git a/test/include/pointwise.hpp b/test/include/pointwise.hpp new file mode 100755 index 0000000000000000000000000000000000000000..14b358295be962206da2a7767da478e6a34153b6 --- /dev/null +++ b/test/include/pointwise.hpp @@ -0,0 +1,34 @@ +#ifndef MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP +#define MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP + +#include +#include +#include + +template +migraphx::instruction_ref add_pointwise(migraphx::program& p, + const std::string& name, + std::vector inputs, + F f) +{ + auto* pm = p.create_module(name); + auto* mm = p.get_main_module(); + pm->set_bypass(); + std::vector params; + std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) { + return pm->add_parameter("x" + std::to_string(params.size()), + migraphx::shape{input->get_shape().type()}); + }); + auto r = f(pm, params); + pm->add_return({r}); + return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm}); +} + +inline auto single_pointwise(const std::string& name) +{ + return [=](auto* pm, const auto& inputs) { + return pm->add_instruction(migraphx::make_op(name), inputs); + }; +} + +#endif // MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP diff --git a/test/include/rob.hpp b/test/include/rob.hpp index 44f95ef94679c7c10ba5447d7b62cd352f09893b..c6cd443c1b6933ec99d779de393501591851eb04 100644 --- a/test/include/rob.hpp +++ b/test/include/rob.hpp @@ -10,18 +10,22 @@ template struct stowed { + // NOLINTNEXTLINE static typename Tag::type value; }; template +// NOLINTNEXTLINE typename Tag::type stowed::value; template struct stow_private { stow_private() noexcept { stowed::value = X; } + // NOLINTNEXTLINE static stow_private instance; }; template +// NOLINTNEXTLINE stow_private stow_private::instance; template diff --git a/test/include/test.hpp b/test/include/test.hpp index 2e451f2017e4cc2150882e7e3668bdaaf38ab661..15ce4776ad60210ca4668a563e919eba1d261580 100644 --- a/test/include/test.hpp +++ b/test/include/test.hpp @@ -1,23 +1,44 @@ +#include #include #include #include #include #include +#include +#include #include #include +#ifdef __linux__ +#include +#endif + #ifndef MIGRAPHX_GUARD_TEST_TEST_HPP #define MIGRAPHX_GUARD_TEST_TEST_HPP namespace test { +// clang-format off +// NOLINTNEXTLINE +#define TEST_FOREACH_BINARY_OPERATORS(m) \ + m(==, equal) \ + m(!=, not_equal) \ + m(<=, less_than_equal) \ + m(>=, greater_than_equal) \ + m(<, less_than) \ + m(>, greater_than) \ + m(and, and_op) \ + m(or, or_op) +// clang-format on + +// clang-format off // NOLINTNEXTLINE -#define TEST_FOREACH_OPERATOR(m) \ - m(==, equal) m(!=, not_equal) m(<=, less_than_equal) m(>=, greater_than_equal) m(<, less_than) \ - m(>, greater_than) +#define TEST_FOREACH_UNARY_OPERATORS(m) \ + m(not, not_op) +// clang-format on // NOLINTNEXTLINE -#define TEST_EACH_OPERATOR_OBJECT(op, name) \ +#define TEST_EACH_BINARY_OPERATOR_OBJECT(op, name) \ struct name \ { \ static std::string as_string() { return #op; } \ @@ -28,26 +49,97 @@ namespace test { } \ }; -TEST_FOREACH_OPERATOR(TEST_EACH_OPERATOR_OBJECT) +// NOLINTNEXTLINE +#define TEST_EACH_UNARY_OPERATOR_OBJECT(op, name) \ + struct name \ + { \ + static std::string as_string() { return #op; } \ + template \ + static decltype(auto) call(T&& x) \ + { \ + return op x; \ + } \ + }; + +TEST_FOREACH_BINARY_OPERATORS(TEST_EACH_BINARY_OPERATOR_OBJECT) +TEST_FOREACH_UNARY_OPERATORS(TEST_EACH_UNARY_OPERATOR_OBJECT) + +struct nop +{ + static std::string as_string() { return ""; } + template + static auto call(T&& x) + { + return static_cast(x); + } +}; + +struct function +{ + static std::string as_string() { return ""; } + template + static decltype(auto) call(T&& x) + { + return x(); + } +}; + +template +inline Stream& stream_range(Stream& s, Iterator start, Iterator last) +{ + if(start != last) + { + s << *start; + std::for_each(std::next(start), last, [&](auto&& x) { s << ", " << x; }); + } + return s; +} -inline std::ostream& operator<<(std::ostream& s, std::nullptr_t) +template +inline Stream& operator<<(Stream& s, std::nullptr_t) { s << "nullptr"; return s; } -template -inline std::ostream& operator<<(std::ostream& s, const std::vector& v) +template {}>::type> +inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.begin(), v.end())) { s << "{ "; - for(auto&& x : v) - { - s << x << ", "; - } + stream_range(s, v.begin(), v.end()); s << "}"; return s; } +template +const T& get_value(const T& x) +{ + return x; +} + +template +struct lhs_expression; + +template +lhs_expression make_lhs_expression(T&& lhs); + +template +lhs_expression make_lhs_expression(T&& lhs, Operator); + +// NOLINTNEXTLINE +#define TEST_EXPR_BINARY_OPERATOR(op, name) \ + template \ + auto operator op(const V& rhs2) const \ + { \ + return make_expression(*this, rhs2, name{}); /* NOLINT */ \ + } + +// NOLINTNEXTLINE +#define TEST_EXPR_UNARY_OPERATOR(op, name) \ + auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ } + template struct expression { @@ -56,11 +148,16 @@ struct expression friend std::ostream& operator<<(std::ostream& s, const expression& self) { - s << " [ " << self.lhs << " " << Operator::as_string() << " " << self.rhs << " ]"; + s << self.lhs << " " << Operator::as_string() << " " << self.rhs; return s; } - decltype(auto) value() const { return Operator::call(lhs, rhs); }; + friend decltype(auto) get_value(const expression& e) { return e.value(); } + + decltype(auto) value() const { return Operator::call(get_value(lhs), get_value(rhs)); }; + + TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR) + TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR) }; // TODO: Remove rvalue references @@ -70,9 +167,6 @@ expression make_expression(T&& rhs, U&& lhs, Operator) return {std::forward(rhs), std::forward(lhs)}; } -template -struct lhs_expression; - // TODO: Remove rvalue reference template lhs_expression make_lhs_expression(T&& lhs) @@ -80,7 +174,13 @@ lhs_expression make_lhs_expression(T&& lhs) return lhs_expression{std::forward(lhs)}; } -template +template +lhs_expression make_lhs_expression(T&& lhs, Operator) +{ + return lhs_expression{std::forward(lhs)}; +} + +template struct lhs_expression { T lhs; @@ -88,20 +188,20 @@ struct lhs_expression friend std::ostream& operator<<(std::ostream& s, const lhs_expression& self) { + std::string op = Operator::as_string(); + if(not op.empty()) + s << Operator::as_string() << " "; s << self.lhs; return s; } - T value() const { return lhs; } -// NOLINTNEXTLINE -#define TEST_LHS_OPERATOR(op, name) \ - template \ - auto operator op(const U& rhs) const \ - { \ - return make_expression(lhs, rhs, name{}); /* NOLINT */ \ - } + friend decltype(auto) get_value(const lhs_expression& e) { return e.value(); } + + decltype(auto) value() const { return Operator::call(get_value(lhs)); } + + TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR) + TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR) - TEST_FOREACH_OPERATOR(TEST_LHS_OPERATOR) // NOLINTNEXTLINE #define TEST_LHS_REOPERATOR(op) \ template \ @@ -116,11 +216,65 @@ struct lhs_expression TEST_LHS_REOPERATOR(%) TEST_LHS_REOPERATOR(&) TEST_LHS_REOPERATOR(|) - TEST_LHS_REOPERATOR (^) - TEST_LHS_REOPERATOR(&&) - TEST_LHS_REOPERATOR(||) + TEST_LHS_REOPERATOR(^) +}; + +template +struct predicate +{ + std::string msg; + F f; + + friend std::ostream& operator<<(std::ostream& s, const predicate& self) + { + s << self.msg; + return s; + } + + decltype(auto) operator()() const { return f(); } + + operator decltype(auto)() const { return f(); } }; +template +auto make_predicate(const std::string& msg, F f) +{ + return make_lhs_expression(predicate{msg, f}, function{}); +} + +inline std::string as_string(bool x) +{ + if(x) + return "true"; + return "false"; +} + +template +std::string as_string(const T& x) +{ + std::stringstream ss; + ss << x; + return ss.str(); +} + +template +std::string as_string(Iterator start, Iterator last) +{ + std::stringstream ss; + stream_range(ss, start, last); + return ss.str(); +} + +template +auto make_function(const std::string& name, F f) +{ + return [=](auto&&... xs) { + std::vector args = {as_string(xs)...}; + return make_predicate(name + "(" + as_string(args.begin(), args.end()) + ")", + [=] { return f(xs...); }); + }; +} + struct capture { template @@ -128,16 +282,49 @@ struct capture { return make_lhs_expression(x); } + + template + auto operator->*(const lhs_expression& x) const + { + return x; + } }; +enum class color +{ + reset = 0, + bold = 1, + underlined = 4, + fg_red = 31, + fg_green = 32, + fg_yellow = 33, + fg_blue = 34, + fg_default = 39, + bg_red = 41, + bg_green = 42, + bg_yellow = 43, + bg_blue = 44, + bg_default = 49 +}; +inline std::ostream& operator<<(std::ostream& os, const color& c) +{ +#ifndef _WIN32 + static const bool use_color = isatty(STDOUT_FILENO) != 0; + if(use_color) + return os << "\033[" << static_cast(c) << "m"; +#endif + return os; +} + template void failed(T x, const char* msg, const char* func, const char* file, int line, F f) { - if(!x.value()) + if(!bool(x.value())) { std::cout << func << std::endl; std::cout << file << ":" << line << ":" << std::endl; - std::cout << " FAILED: " << msg << " " << x << std::endl; + std::cout << color::bold << color::fg_red << " FAILED: " << color::reset << msg << " " + << "[ " << x << " ]" << std::endl; f(); } } @@ -170,10 +357,17 @@ bool throws(F f, const std::string& msg = "") } } +template +auto near(T px, U py, double ptol = 1e-6f) +{ + return make_function("near", [](auto x, auto y, auto tol) { return std::abs(x - y) < tol; })( + px, py, ptol); +} + using string_map = std::unordered_map>; template -string_map parse(std::vector as, Keyword keyword) +string_map generic_parse(std::vector as, Keyword keyword) { string_map result; @@ -189,18 +383,22 @@ string_map parse(std::vector as, Keyword keyword) { flag = f.front(); result[flag]; // Ensure the flag exists + flag = f.back(); } } return result; } +using test_case = std::function; + inline auto& get_test_cases() { - static std::vector>> cases; + // NOLINTNEXTLINE + static std::vector> cases; return cases; } -inline void add_test_case(std::string name, std::function f) +inline void add_test_case(std::string name, test_case f) { get_test_cases().emplace_back(std::move(name), std::move(f)); } @@ -214,54 +412,263 @@ struct auto_register_test_case } }; -inline void run_test_case(const std::string& name, const std::function& f) +struct failure_error { - std::cout << "[ RUN ] " << name << std::endl; - f(); - std::cout << "[ COMPLETE ] " << name << std::endl; -} +}; -inline void run(int argc, const char* argv[]) +[[noreturn]] inline void fail() { throw failure_error{}; } + +struct driver { - std::vector as(argv + 1, argv + argc); + driver() + { + add_flag({"--help", "-h"}, "Show help"); + add_flag({"--list", "-l"}, "List all test cases"); + add_flag({"--continue", "-c"}, "Continue after failure"); + add_flag({"--quiet", "-q"}, "Don't print out extra output"); + } + struct argument + { + std::vector flags = {}; + std::string help = ""; + int nargs = 1; + }; - auto args = parse(as, [](auto &&) -> std::vector { return {}; }); - auto cases = args[""]; - if(cases.empty()) + void add_arg(const std::vector& flags, const std::string& help = "") { - for(auto&& tc : get_test_cases()) - run_test_case(tc.first, tc.second); + arguments.push_back(argument{flags, help, 1}); } - else + + void add_flag(const std::vector& flags, const std::string& help = "") { - std::unordered_map> m(get_test_cases().begin(), - get_test_cases().end()); - for(auto&& name : cases) + arguments.push_back(argument{flags, help, 0}); + } + + void show_help(const std::string& exe) const + { + std::cout << std::endl; + std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl; + std::cout << " "; + std::cout << exe << " ... " << std::endl; + std::cout << std::endl; + + std::cout << color::fg_yellow << "ARGS:" << color::reset << std::endl; + std::cout << " "; + std::cout << color::fg_green << "..." << color::reset; + std::cout << std::endl; + std::cout << " " + << "Test case name to run" << std::endl; + std::cout << std::endl; + std::cout << color::fg_yellow << "OPTIONS:" << color::reset << std::endl; + for(auto&& arg : arguments) { - auto f = m.find(name); - if(f == m.end()) - std::cout << "[ ERROR ] Test case '" << name << "' not found." << std::endl; + std::string prefix = " "; + std::cout << color::fg_green; + for(const std::string& a : arg.flags) + { + std::cout << prefix; + std::cout << a; + prefix = ", "; + } + std::cout << color::reset << std::endl; + std::cout << " " << arg.help << std::endl; + } + } + + std::ostream& out() const + { + struct null_buffer : std::streambuf + { + virtual int overflow(int c) override { return c; } + }; + static null_buffer buffer; + static std::ostream null_stream(&buffer); + if(quiet) + return null_stream; + return std::cout; + } + + string_map parse(int argc, const char* argv[]) const + { + std::vector args(argv + 1, argv + argc); + string_map keys; + for(auto&& arg : arguments) + { + for(auto&& flag : arg.flags) + { + keys[flag] = {arg.flags.front()}; + if(arg.nargs == 0) + keys[flag].push_back(""); + } + } + auto result = generic_parse(args, [&](auto&& s) -> std::vector { + if(keys.count(s) > 0) + return keys[s]; else - run_test_case(name, f->second); + return {}; + }); + result["__exe__"].push_back(argv[0]); + return result; + } + + static std::string create_command(const string_map& args) + { + std::stringstream ss; + ss << args.at("__exe__").front(); + if(args.count("") > 0) + { + for(auto&& arg : args.at("")) + ss << " \"" << arg << "\""; + } + for(auto&& p : args) + { + if(p.first == "__exe__") + continue; + if(p.first.empty()) + continue; + ss << " " << p.first; + for(auto&& arg : p.second) + ss << " \"" << arg << "\""; + } + return ss.str(); + } + + static std::string fork(const std::string& name, string_map args) + { + std::string msg; + args[""] = {name}; + args.erase("--continue"); + args["--quiet"]; + auto cmd = create_command(args); + auto r = std::system(cmd.c_str()); // NOLINT + if(r != 0) + msg = "Exited with " + std::to_string(r); + return msg; + } + + void run_test_case(const std::string& name, const test_case& f, const string_map& args) + { + ran++; + out() << color::fg_green << "[ RUN ] " << color::reset << color::bold << name + << color::reset << std::endl; + std::string msg; + if(args.count("--continue") > 0) + { + msg = fork(name, args); + } + else + { + try + { + f(); + } + catch(const failure_error&) + { + msg = "Test failure"; + } + } + if(msg.empty()) + { + out() << color::fg_green << "[ COMPLETE ] " << color::reset << color::bold << name + << color::reset << std::endl; + } + else + { + failed.push_back(name); + out() << color::fg_red << "[ FAILED ] " << color::reset << color::bold << name + << color::reset << ": " << color::fg_yellow << msg << color::reset << std::endl; + } + } + + void run(int argc, const char* argv[]) + { + auto args = parse(argc, argv); + if(args.count("--help") > 0) + { + show_help(args.at("__exe__").front()); + return; + } + if(args.count("--list") > 0) + { + for(auto&& tc : get_test_cases()) + out() << tc.first << std::endl; + return; + } + + if(args.count("--quiet") > 0) + quiet = true; + + auto cases = args[""]; + if(cases.empty()) + { + for(auto&& tc : get_test_cases()) + run_test_case(tc.first, tc.second, args); + } + else + { + std::unordered_map m(get_test_cases().begin(), + get_test_cases().end()); + for(auto&& iname : cases) + { + for(auto&& name : get_case_names(iname)) + { + auto f = m.find(name); + if(f == m.end()) + { + out() << color::fg_red << "[ ERROR ] Test case '" << name + << "' not found." << color::reset << std::endl; + failed.push_back(name); + } + else + run_test_case(name, f->second, args); + } + } + } + out() << color::fg_green << "[==========] " << color::fg_yellow << ran << " tests ran" + << color::reset << std::endl; + if(not failed.empty()) + { + out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << failed.size() + << " tests failed" << color::reset << std::endl; + for(auto&& name : failed) + out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << name + << color::reset << std::endl; + std::exit(1); } } + + std::function(const std::string&)> get_case_names = + [](const std::string& name) -> std::vector { return {name}; }; + std::vector arguments = {}; + std::vector failed = {}; + std::size_t ran = 0; + bool quiet = false; +}; + +inline void run(int argc, const char* argv[]) +{ + driver d{}; + d.run(argc, argv); } } // namespace test +// NOLINTNEXTLINE +#define TEST_CAPTURE(...) test::capture{}->*__VA_ARGS__ + // NOLINTNEXTLINE #define CHECK(...) \ test::failed( \ test::capture{}->*__VA_ARGS__, #__VA_ARGS__, __PRETTY_FUNCTION__, __FILE__, __LINE__, [] { \ }) // NOLINTNEXTLINE -#define EXPECT(...) \ - test::failed(test::capture{}->*__VA_ARGS__, \ - #__VA_ARGS__, \ - __PRETTY_FUNCTION__, \ - __FILE__, \ - __LINE__, \ - &std::abort) +#define EXPECT(...) \ + test::failed(TEST_CAPTURE(__VA_ARGS__), \ + #__VA_ARGS__, \ + __PRETTY_FUNCTION__, \ + __FILE__, \ + __LINE__, \ + &test::fail) // NOLINTNEXTLINE #define STATUS(...) EXPECT((__VA_ARGS__) == 0) diff --git a/test/inline_module_test.cpp b/test/inline_module_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f964bb0e9bbc1e1a4c5b0d66c397a71f136c24ce --- /dev/null +++ b/test/inline_module_test.cpp @@ -0,0 +1,492 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +void run_pass(migraphx::program& p) +{ + migraphx::run_passes(p, {migraphx::inline_module{}, migraphx::dead_code_elimination{}}); +} + +TEST_CASE(cannot_inline_both) +{ + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sd{migraphx::shape::float_type, {2, 3}}; + auto x = mm->add_parameter("x", sd); + + std::vector one(sd.elements(), 1); + std::vector two(sd.elements(), 2); + + auto* then_smod = p.create_module("then_smod"); + auto l1 = then_smod->add_literal(migraphx::literal{sd, one}); + auto r1 = then_smod->add_instruction(migraphx::make_op("add"), x, l1); + then_smod->add_return({r1}); + + auto* else_smod = p.create_module("else_smod"); + auto l2 = else_smod->add_literal(migraphx::literal{sd, two}); + auto r2 = else_smod->add_instruction(migraphx::make_op("mul"), x, l2); + else_smod->add_return({r2}); + + migraphx::shape s_cond{migraphx::shape::bool_type, {1}}; + auto cond = mm->add_parameter("cond", s_cond); + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_smod, else_smod}); + mm->add_return({ret}); + + return p; + }; + + auto p = create_program(); + run_pass(p); + + EXPECT(p == create_program()); +} + +TEST_CASE(cannot_inline_one) +{ + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + migraphx::shape s{migraphx::shape::float_type, {5}}; + auto cond = mm->add_parameter("cond", cond_s); + auto x = mm->add_parameter("x", s); + + auto* then_mod = p.create_module("If_0_if"); + std::vector data1 = {1, 2, 3, 4, 5}; + auto l1 = then_mod->add_literal(migraphx::literal(s, data1)); + then_mod->add_return({l1, x}); + + auto* else_mod = p.create_module("If_0_else"); + std::vector data2 = {5, 4, 3, 2, 1}; + auto l2 = else_mod->add_literal(migraphx::literal(s, data2)); + auto s2 = else_mod->add_instruction(migraphx::make_op("add"), x, l2); + else_mod->add_return({s2, l2}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + mm->add_return({ret}); + + return p; + }; + + auto p = create_program(); + run_pass(p); + + EXPECT(p == create_program()); +} + +TEST_CASE(inline_if_test) +{ + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sc{migraphx::shape::bool_type, {1}}; + auto cond = mm->add_literal(migraphx::literal(sc, {1})); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + std::vector ones(s.elements(), 1.0f); + auto l1 = mm->add_literal(s, ones); + std::vector rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946}; + auto l2 = mm->add_literal(s, rand); + auto x = mm->add_parameter("x", s); + auto sm = mm->add_instruction(migraphx::make_op("add"), l1, x); + auto y = mm->add_parameter("y", s); + + auto* then_mod = p.create_module("If_5_if"); + auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, sm); + then_mod->add_outline(s); + then_mod->add_return({rt}); + + auto* else_mod = p.create_module("If_5_else"); + auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2); + else_mod->add_return({re}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + mm->add_return({r}); + return p; + }; + + auto create_inline = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + std::vector ones(s.elements(), 1.0f); + auto l1 = mm->add_literal(s, ones); + std::vector rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946}; + mm->add_literal(s, rand); + auto x = mm->add_parameter("x", s); + auto sm = mm->add_instruction(migraphx::make_op("add"), l1, x); + mm->add_parameter("y", s); + auto r = mm->add_instruction(migraphx::make_op("add"), x, sm); + mm->add_return({r}); + + return p; + }; + + auto p = create_program(); + run_pass(p); + EXPECT(p == create_inline()); +} + +TEST_CASE(inline_else_test) +{ + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sc{migraphx::shape::bool_type, {1}}; + auto cond = mm->add_literal(migraphx::literal(sc, {0})); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + std::vector ones(s.elements(), 1.0f); + auto l1 = mm->add_literal(s, ones); + std::vector rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946}; + auto l2 = mm->add_literal(s, rand); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + + auto* then_mod = p.create_module("If_5_if"); + auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1); + then_mod->add_return({rt}); + + auto* else_mod = p.create_module("If_5_else"); + else_mod->add_parameter("e", s); + else_mod->add_literal(migraphx::literal(s, ones)); + auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2); + else_mod->add_return({re}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + mm->add_return({r}); + return p; + }; + + auto create_inline = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + std::vector ones(s.elements(), 1.0f); + mm->add_literal(s, ones); + std::vector rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946}; + auto l2 = mm->add_literal(s, rand); + mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + mm->add_parameter("e", s); + auto r = mm->add_instruction(migraphx::make_op("mul"), y, l2); + mm->add_return({r}); + + return p; + }; + + auto p = create_program(); + run_pass(p); + EXPECT(p == create_inline()); +} + +TEST_CASE(if_recursive_test) +{ + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + migraphx::shape xs{migraphx::shape::float_type, {2, 3}}; + migraphx::shape ys{migraphx::shape::float_type, {3, 3}}; + std::vector datax = {1, 2, 3, 4, 5, 6}; + std::vector datay = {8, 7, 6, 5, 4, 3, 2, 1, 0}; + + auto lx = mm->add_literal(migraphx::literal(xs, datax)); + auto ly = mm->add_literal(migraphx::literal(ys, datay)); + auto cond = mm->add_literal(migraphx::literal(cond_s, {0})); + auto x1 = mm->add_parameter("x1", xs); + auto x2 = mm->add_parameter("x2", xs); + auto y2 = mm->add_parameter("y2", ys); + auto cond1 = mm->add_parameter("cond", cond_s); + + auto* then_mod = p.create_module("If_5_if"); + auto l1 = then_mod->add_literal(migraphx::literal(ys, datay)); + auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx); + then_mod->add_return({a1, l1}); + + auto* then_mod1 = p.create_module("If_6_if"); + auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay)); + auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx); + then_mod1->add_return({a11, l11}); + + auto* else_mod1 = p.create_module("If_6_else"); + auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax)); + auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly); + else_mod1->add_return({l21, a21}); + + auto* else_mod = p.create_module("If_5_else"); + auto l2 = else_mod->add_literal(migraphx::literal(xs, datax)); + auto a2 = + else_mod->add_instruction(migraphx::make_op("if"), {cond1}, {then_mod1, else_mod1}); + auto a3 = + else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), a2); + else_mod->add_return({l2, a3}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); + mm->add_return({r}); + + return p; + }; + + auto create_inline = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + migraphx::shape xs{migraphx::shape::float_type, {2, 3}}; + migraphx::shape ys{migraphx::shape::float_type, {3, 3}}; + std::vector datax = {1, 2, 3, 4, 5, 6}; + std::vector datay = {8, 7, 6, 5, 4, 3, 2, 1, 0}; + + auto lx = mm->add_literal(migraphx::literal(xs, datax)); + auto ly = mm->add_literal(migraphx::literal(ys, datay)); + mm->add_parameter("x1", xs); + auto x2 = mm->add_parameter("x2", xs); + auto y2 = mm->add_parameter("y2", ys); + auto cond1 = mm->add_parameter("cond", cond_s); + + auto* then_mod1 = p.create_module("If_6_if"); + auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay)); + auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx); + then_mod1->add_return({a11, l11}); + + auto* else_mod1 = p.create_module("If_6_else"); + auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax)); + auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly); + else_mod1->add_return({l21, a21}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond1}, {then_mod1, else_mod1}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); + mm->add_return({r}); + + return p; + }; + + auto p = create_program(); + run_pass(p); + EXPECT(p == create_inline()); +} + +TEST_CASE(if_recursive_cond0_test) +{ + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + migraphx::shape xs{migraphx::shape::float_type, {2, 3}}; + migraphx::shape ys{migraphx::shape::float_type, {3, 3}}; + std::vector datax = {1, 2, 3, 4, 5, 6}; + std::vector datay = {8, 7, 6, 5, 4, 3, 2, 1, 0}; + + auto lx = mm->add_literal(migraphx::literal(xs, datax)); + auto ly = mm->add_literal(migraphx::literal(ys, datay)); + auto cond = mm->add_literal(migraphx::literal(cond_s, {0})); + auto x1 = mm->add_parameter("x1", xs); + auto x2 = mm->add_parameter("x2", xs); + auto y2 = mm->add_parameter("y2", ys); + + auto* then_mod = p.create_module("If_5_if"); + auto l1 = then_mod->add_literal(migraphx::literal(ys, datay)); + auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx); + then_mod->add_return({a1, l1}); + + auto* then_mod1 = p.create_module("If_6_if"); + auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay)); + auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx); + then_mod1->add_return({a11, l11}); + + auto* else_mod1 = p.create_module("If_6_else"); + auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax)); + auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly); + else_mod1->add_return({l21, a21}); + + auto* else_mod = p.create_module("If_5_else"); + auto l2 = else_mod->add_literal(migraphx::literal(xs, datax)); + auto a2 = + else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1}); + auto a3 = + else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), a2); + else_mod->add_return({l2, a3}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); + mm->add_return({r}); + + return p; + }; + + auto create_inline = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + migraphx::shape xs{migraphx::shape::float_type, {2, 3}}; + migraphx::shape ys{migraphx::shape::float_type, {3, 3}}; + std::vector datax = {1, 2, 3, 4, 5, 6}; + std::vector datay = {8, 7, 6, 5, 4, 3, 2, 1, 0}; + + mm->add_literal(migraphx::literal(xs, datax)); + auto ly = mm->add_literal(migraphx::literal(ys, datay)); + mm->add_parameter("x1", xs); + mm->add_parameter("x2", xs); + auto y2 = mm->add_parameter("y2", ys); + auto m = mm->add_instruction(migraphx::make_op("mul"), y2, ly); + mm->add_return({m}); + + return p; + }; + + auto p = create_program(); + run_pass(p); + EXPECT(p == create_inline()); +} + +TEST_CASE(inline_tuple_true_test) +{ + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sc{migraphx::shape::bool_type, {1}}; + auto cond = mm->add_literal(migraphx::literal(sc, {1})); + migraphx::shape sd{migraphx::shape::float_type, {1}}; + auto l1 = mm->add_literal(migraphx::literal(sd, {1})); + auto l2 = mm->add_literal(migraphx::literal(sd, {2})); + auto l3 = mm->add_literal(migraphx::literal(sd, {3})); + migraphx::shape sx{migraphx::shape::float_type, {1, 4}}; + migraphx::shape sy{migraphx::shape::float_type, {3, 4}}; + auto x = mm->add_parameter("x", sx); + auto y = mm->add_parameter("y", sy); + + auto* then_mod = p.create_module("If_6_if"); + auto m1 = then_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1); + auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1); + auto m2 = then_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2); + auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2); + then_mod->add_return({add0, mul0}); + + auto* else_mod = p.create_module("If_6_else"); + auto me1 = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3); + auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1); + auto me2 = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3); + auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2); + else_mod->add_return({mul1, add1}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); + mm->add_return({r0, r1}); + + return p; + }; + auto create_inline = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape sd{migraphx::shape::float_type, {1}}; + auto l1 = mm->add_literal(migraphx::literal(sd, {1})); + auto l2 = mm->add_literal(migraphx::literal(sd, {2})); + mm->add_literal(migraphx::literal(sd, {3})); + migraphx::shape sx{migraphx::shape::float_type, {1, 4}}; + migraphx::shape sy{migraphx::shape::float_type, {3, 4}}; + auto x = mm->add_parameter("x", sx); + auto y = mm->add_parameter("y", sy); + + auto m1 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1); + auto add = mm->add_instruction(migraphx::make_op("add"), x, m1); + auto m2 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2); + auto mul = mm->add_instruction(migraphx::make_op("mul"), y, m2); + mm->add_return({add, mul}); + + return p; + }; + + auto p = create_program(); + run_pass(p); + EXPECT(p == create_inline()); +} + +TEST_CASE(inline_tuple_false_test) +{ + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sc{migraphx::shape::bool_type, {1}}; + auto cond = mm->add_literal(migraphx::literal(sc, {0})); + migraphx::shape sd{migraphx::shape::float_type, {1}}; + auto l1 = mm->add_literal(migraphx::literal(sd, {1})); + auto l2 = mm->add_literal(migraphx::literal(sd, {2})); + auto l3 = mm->add_literal(migraphx::literal(sd, {3})); + migraphx::shape sx{migraphx::shape::float_type, {1, 4}}; + migraphx::shape sy{migraphx::shape::float_type, {3, 4}}; + auto x = mm->add_parameter("x", sx); + auto y = mm->add_parameter("y", sy); + + auto* then_mod = p.create_module("If_6_if"); + auto m1 = then_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1); + auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1); + auto m2 = then_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2); + auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2); + then_mod->add_return({add0, mul0}); + + auto* else_mod = p.create_module("If_6_else"); + auto me1 = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3); + auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1); + auto me2 = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3); + auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2); + else_mod->add_return({mul1, add1}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); + mm->add_return({r0, r1}); + + return p; + }; + + auto create_inline = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape sc{migraphx::shape::bool_type, {1}}; + migraphx::shape sd{migraphx::shape::float_type, {1}}; + mm->add_literal(migraphx::literal(sd, {1})); + mm->add_literal(migraphx::literal(sd, {2})); + auto l3 = mm->add_literal(migraphx::literal(sd, {3})); + migraphx::shape sx{migraphx::shape::float_type, {1, 4}}; + migraphx::shape sy{migraphx::shape::float_type, {3, 4}}; + auto x = mm->add_parameter("x", sx); + auto y = mm->add_parameter("y", sy); + + auto m1 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3); + auto mul = mm->add_instruction(migraphx::make_op("mul"), x, m1); + auto m2 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3); + auto add = mm->add_instruction(migraphx::make_op("add"), y, m2); + mm->add_return({mul, add}); + + return p; + }; + + auto p = create_program(); + run_pass(p); + EXPECT(p == create_inline()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/insert_pad_test.cpp b/test/insert_pad_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e07ee803b65ffdfa2cb3ee1be1b9471e6f68e449 --- /dev/null +++ b/test/insert_pad_test.cpp @@ -0,0 +1,91 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +void run_pass(migraphx::module& m) +{ + migraphx::run_passes( + m, {migraphx::normalize_ops{}, migraphx::insert_pad{}, migraphx::dead_code_elimination{}}); +} + +migraphx::instruction_ref +create_im2col(migraphx::instruction_ref& l_img, size_t channels, migraphx::module& m) +{ + size_t f[2] = {1, 1}; + std::vector weights(channels * f[0] * f[1]); + migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; + auto l_weights = m.add_literal(migraphx::literal{s_weights, weights}); + return m.add_instruction( + migraphx::make_op("im2col", {{"padding", {0, 0, 1, 1}}}), l_img, l_weights); +} + +migraphx::instruction_ref +create_conv(migraphx::instruction_ref& l_img, + size_t channels, + migraphx::module& m, + migraphx::op::padding_mode_t padding_mode = migraphx::op::padding_mode_t::default_) +{ + migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}}; + std::vector weights(4 * channels * 3 * 3); + auto l_weights = m.add_literal(migraphx::literal{s_weights, weights}); + migraphx::op::convolution op; + op.padding_mode = padding_mode; + op.padding = {0, 0, 1, 1}; + return m.add_instruction(op, l_img, l_weights); +} + +TEST_CASE(rewrite_pad) +{ + migraphx::module m; + size_t img_dim[2] = {2, 2}; + size_t channels = 1; + std::vector input(channels * img_dim[0] * img_dim[1]); + std::iota(input.begin(), input.end(), 0); + + migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}}; + auto l_img = m.add_literal(migraphx::literal{s_img, input}); + + auto l0 = create_im2col(l_img, channels, m); + auto l1 = create_conv(l_img, channels, m); + auto l2 = m.add_instruction( + migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::max}, {"padding", {0, 0, 1, 1}}}), + l_img); + m.add_instruction(migraphx::make_op("identity"), l0, l1, l2); + + run_pass(m); + + EXPECT(std::any_of( + m.begin(), m.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; })); +} + +TEST_CASE(rewrite_pad_symmetric) +{ + migraphx::module m; + + size_t img_dim[2] = {2, 2}; + size_t channels = 1; + std::vector input(channels * img_dim[0] * img_dim[1]); + std::iota(input.begin(), input.end(), 0); + + migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}}; + auto l_img = m.add_literal(migraphx::literal{s_img, input}); + + m.add_instruction( + migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::max}, {"padding", {1, 1, 1, 1}}}), + l_img); + + run_pass(m); + EXPECT(std::none_of( + m.begin(), m.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; })); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/jit.cpp b/test/jit.cpp new file mode 100755 index 0000000000000000000000000000000000000000..ee2f17ac978f381d9d83ae62a293d6b0067f1831 --- /dev/null +++ b/test/jit.cpp @@ -0,0 +1,84 @@ +#include +#include +#include +#include +#include +#include + +// NOLINTNEXTLINE +const std::string add_42_src = R"migraphx( +extern "C" int add(int x) +{ + return x+42; +} +)migraphx"; + +// NOLINTNEXTLINE +const std::string preamble = R"migraphx( +#include +)migraphx"; + +template +std::function +compile_function(const std::string& src, const std::string& flags, const std::string& fname) +{ + migraphx::src_compiler compiler; + compiler.flags = flags + "-std=c++14 -fPIC -shared"; + compiler.output = "libsimple.so"; + migraphx::src_file f; + f.path = "main.cpp"; + f.content = std::make_pair(src.data(), src.data() + src.size()); + auto image = compiler.compile({f}); + return migraphx::dynamic_loader{image}.get_function(fname); +} + +template +std::function compile_module(const migraphx::module& m, const std::string& flags = "") +{ + migraphx::cpp_generator g; + g.fmap([](auto&& name) { return "std::" + name; }); + g.create_function(g.generate_module(m).set_attributes({"extern \"C\""})); + + return compile_function(preamble + g.str(), flags, m.name()); +} + +TEST_CASE(simple_run) +{ + auto f = compile_function(add_42_src, "", "add"); + EXPECT(f(8) == 50); + EXPECT(f(10) == 52); +} + +TEST_CASE(generate_module) +{ + migraphx::module m("foo"); + auto x = m.add_parameter("x", migraphx::shape::float_type); + auto y = m.add_parameter("y", migraphx::shape::float_type); + auto sum = m.add_instruction(migraphx::make_op("add"), x, y); + m.add_instruction(migraphx::make_op("sqrt"), sum); + + auto f = compile_module(m); + + EXPECT(test::near(f(2, 2), 2)); + EXPECT(test::near(f(10, 6), 4)); + EXPECT(test::near(f(1, 2), std::sqrt(3))); +} + +TEST_CASE(generate_module_with_literals) +{ + migraphx::module m("foo"); + auto x = m.add_parameter("x", migraphx::shape::float_type); + auto y = m.add_parameter("y", migraphx::shape::float_type); + auto z = m.add_literal(1.f); + auto sum1 = m.add_instruction(migraphx::make_op("add"), x, z); + auto sum2 = m.add_instruction(migraphx::make_op("add"), sum1, y); + m.add_instruction(migraphx::make_op("sqrt"), sum2); + + auto f = compile_module(m); + + EXPECT(test::near(f(1, 2), 2)); + EXPECT(test::near(f(9, 6), 4)); + EXPECT(test::near(f(0, 2), std::sqrt(3))); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/json_test.cpp b/test/json_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fa0c4212a8a065d80cd1f36c24b957b2c98dcb3a --- /dev/null +++ b/test/json_test.cpp @@ -0,0 +1,248 @@ +#include +#include +#include +#include +#include +#include +#include + +TEST_CASE(null_value) +{ + migraphx::value v; + auto json_str = migraphx::to_json_string(v); + EXPECT(json_str == "null"); +} + +TEST_CASE(null_value_rev) +{ + std::string json_str = "null"; + migraphx::value v = migraphx::from_json_string(json_str); + migraphx::value ev; + EXPECT(v == ev); +} + +TEST_CASE(null_array) +{ + migraphx::value v; + migraphx::value arr = {v, v}; + auto json_str = migraphx::to_json_string(arr); + EXPECT(json_str == "[null,null]"); +} + +TEST_CASE(null_array_rev) +{ + std::string json_str = "[null,null]"; + migraphx::value v = migraphx::from_json_string(json_str); + migraphx::value e; + migraphx::value ev = {e, e}; + EXPECT(ev == v); +} + +TEST_CASE(empty_object1) +{ + migraphx::value val = migraphx::from_json_string("{}"); + EXPECT(val == migraphx::value::object{}); + EXPECT(migraphx::to_json_string(migraphx::value::object{}) == "{}"); +} + +TEST_CASE(empty_array1) +{ + migraphx::value val = migraphx::from_json_string("[]"); + EXPECT(val == migraphx::value::array{}); + EXPECT(migraphx::to_json_string(migraphx::value::array{}) == "[]"); +} + +TEST_CASE(int_value) +{ + migraphx::value v = -1; + std::string json_str = migraphx::to_json_string(v); + EXPECT(json_str == "-1"); +} + +TEST_CASE(int_value_rev) +{ + std::string json_str = "-1"; + migraphx::value v = migraphx::from_json_string(json_str); + migraphx::value ev = -1; + EXPECT(v == ev); +} + +TEST_CASE(unsigned_value) +{ + migraphx::value v = 1; + std::string json_str = migraphx::to_json_string(v); + EXPECT(json_str == "1"); +} + +TEST_CASE(unsigned_value_rev) +{ + std::string json_str = "1"; + migraphx::value v = migraphx::from_json_string(json_str); + EXPECT(v.is_uint64()); + EXPECT(v.get_uint64() == 1); +} + +TEST_CASE(float_value) +{ + migraphx::value v = 1.5; + std::string json_str = migraphx::to_json_string(v); + EXPECT(json_str == "1.5"); +} + +TEST_CASE(float_value_rev) +{ + std::string json_str = "1.5"; + migraphx::value v = migraphx::from_json_string(json_str); + migraphx::value ev = 1.5; + EXPECT(v == ev); +} + +TEST_CASE(array_value) +{ + migraphx::value v = {1, 2}; + std::string json_str = migraphx::to_json_string(v); + EXPECT(json_str == "[1,2]"); +} + +TEST_CASE(array_value_rev) +{ + std::string json_str = "[1,2]"; + migraphx::value v = migraphx::from_json_string(json_str); + EXPECT(v.is_array()); + EXPECT(v.size() == 2); + EXPECT(v[0].get_uint64() == 1); + EXPECT(v[1].get_uint64() == 2); +} + +TEST_CASE(object_value) +{ + migraphx::value v = {{"a", 1.2}, {"b", true}}; + std::string json_str = migraphx::to_json_string(v); + EXPECT(json_str == "{\"a\":1.2,\"b\":true}"); +} + +TEST_CASE(object_value_rev) +{ + std::string json_str = R"({"a":1.2,"b":true})"; + migraphx::value v = migraphx::from_json_string(json_str); + migraphx::value ev = {{"a", 1.2}, {"b", true}}; + EXPECT(v == ev); +} + +TEST_CASE(null_object) +{ + migraphx::value v; + migraphx::value v1 = {{"a", v}}; + std::string json_str = migraphx::to_json_string(v1); + EXPECT(json_str == "{\"a\":null}"); +} + +TEST_CASE(null_object_rev) +{ + std::string json_str = R"({"a":null})"; + migraphx::value eo = migraphx::from_json_string(json_str); + migraphx::value v; + migraphx::value ev = {{"a", v}}; + EXPECT(eo == ev); +} + +TEST_CASE(string_value) +{ + migraphx::value v = "string_test"; + std::string json_str = migraphx::to_json_string(v); + EXPECT(json_str == "\"string_test\""); +} + +TEST_CASE(string_value_rev) +{ + std::string json_str = "\"string_test\""; + migraphx::value v = migraphx::from_json_string(json_str); + migraphx::value ev = "string_test"; + EXPECT(v == ev); +} + +TEST_CASE(array_of_objects) +{ + migraphx::value obj1 = {"key1", uint64_t{1}}; + migraphx::value obj2 = {"key2", uint64_t{2}}; + migraphx::value arr = {obj1, obj2}; + std::string json_str = migraphx::to_json_string(arr); + EXPECT(json_str == "{\"key1\":1,\"key2\":2}"); +} + +TEST_CASE(array_of_objects_rev) +{ + std::string json_str = R"({"key1":1,"key2":2})"; + migraphx::value v = migraphx::from_json_string(json_str); + migraphx::value obj1 = {"key1", uint64_t{1}}; + migraphx::value obj2 = {"key2", uint64_t{2}}; + migraphx::value arr = {obj1, obj2}; + EXPECT(arr == v); +} + +TEST_CASE(object_of_array) +{ + migraphx::value obj1 = {"key1", 1}; + migraphx::value obj2 = {"key2", 2}; + migraphx::value obj; + obj["key"] = {obj1, obj2}; + std::string json_str = migraphx::to_json_string(obj); + EXPECT(json_str == "{\"key\":{\"key1\":1,\"key2\":2}}"); +} + +TEST_CASE(object_of_array_rev) +{ + std::string json_str = R"({"key":{"key1":1,"key2":2}})"; + migraphx::value v = migraphx::from_json_string(json_str); + migraphx::value obj1 = {"key1", uint64_t{1}}; + migraphx::value obj2 = {"key2", uint64_t{2}}; + migraphx::value obj; + obj["key"] = {obj1, obj2}; + EXPECT(v == obj); +} + +TEST_CASE(shape_value) +{ + migraphx::shape s{migraphx::shape::int32_type, {2, 3, 4, 5}}; + migraphx::value val = migraphx::to_value(s); + std::string json_str = migraphx::to_json_string(val); + migraphx::value val_rev = migraphx::from_json_string(json_str); + migraphx::shape s_rev; + migraphx::from_value(val_rev, s_rev); + + EXPECT(s == s_rev); +} + +TEST_CASE(argument_value) +{ + migraphx::shape s{migraphx::shape::int32_type, {2, 3, 4, 5}}; + std::vector data(s.elements()); + std::iota(data.begin(), data.end(), 1); + migraphx::argument argu = migraphx::argument(s, data.data()); + + migraphx::value val = migraphx::to_value(argu); + std::string json_str = migraphx::to_json_string(val); + migraphx::value val_rev = migraphx::from_json_string(json_str); + migraphx::argument argu_rev; + migraphx::from_value(val_rev, argu_rev); + + EXPECT(argu == argu_rev); +} + +TEST_CASE(literal_value) +{ + migraphx::shape s{migraphx::shape::int32_type, {2, 3, 4, 5}}; + std::vector data(s.elements()); + std::iota(data.begin(), data.end(), 1); + migraphx::literal l = migraphx::literal(s, data); + + migraphx::value val = migraphx::to_value(l); + std::string json_str = migraphx::to_json_string(val); + migraphx::value val_rev = migraphx::from_json_string(json_str); + migraphx::literal l_rev; + migraphx::from_value(val_rev, l_rev); + + EXPECT(l == l_rev); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/literal_test.cpp b/test/literal_test.cpp old mode 100644 new mode 100755 index a45e43dce552788c2b696f068f8a48386804efb0..af10118679f4c76d79b7682f9100912c3a267863 --- a/test/literal_test.cpp +++ b/test/literal_test.cpp @@ -1,5 +1,6 @@ #include +#include #include #include #include "test.hpp" @@ -119,4 +120,19 @@ TEST_CASE(literal_visit_empty) EXPECT(test::throws([&] { x.visit_at([](auto) {}); })); } +TEST_CASE(value_literal) +{ + migraphx::shape s{migraphx::shape::int64_type, {3}}; + migraphx::literal l1{s, {1, 2, 3}}; + auto v1 = migraphx::to_value(l1); + migraphx::literal l2{1}; + auto v2 = migraphx::to_value(l2); + EXPECT(v1 != v2); + + auto l3 = migraphx::from_value(v1); + EXPECT(l3 == l1); + auto l4 = migraphx::from_value(v2); + EXPECT(l4 == l2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/main.cpp b/test/main.cpp old mode 100644 new mode 100755 index e51dccb508cc071b6373d79e767dc8988a6794e7..64a546934a0f88a076ac4286628135c91b3fdf2c --- a/test/main.cpp +++ b/test/main.cpp @@ -1,2 +1,3 @@ +#include "test.hpp" -int main() {} +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/marker.cpp b/test/marker.cpp new file mode 100755 index 0000000000000000000000000000000000000000..60f4def169d1f9349540376f7e7f45745530525b --- /dev/null +++ b/test/marker.cpp @@ -0,0 +1,57 @@ +#include +#include +#include +#include +#include +#include + +#include "test.hpp" + +struct mock_marker +{ + std::shared_ptr ss = std::make_shared(); + + void mark_start(migraphx::instruction_ref ins_ref) + { + std::string text = "Mock marker instruction start:" + ins_ref->name(); + (*ss) << text; + } + void mark_stop(migraphx::instruction_ref) + { + std::string text = "Mock marker instruction stop."; + (*ss) << text; + } + void mark_start(const migraphx::program&) + { + std::string text = "Mock marker program start."; + (*ss) << text; + } + void mark_stop(const migraphx::program&) + { + std::string text = "Mock marker program stop."; + (*ss) << text; + } +}; + +TEST_CASE(marker) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(migraphx::make_op("add"), one, two); + p.compile(migraphx::ref::target{}); + + mock_marker temp_marker; + p.mark({}, temp_marker); + + std::string output = temp_marker.ss->str(); + EXPECT(migraphx::contains(output, "Mock marker instruction start:@literal")); + EXPECT(migraphx::contains(output, "Mock marker instruction start:ref::op")); + EXPECT(migraphx::contains(output, "Mock marker instruction stop.")); + EXPECT(migraphx::contains(output, "Mock marker program start.")); + EXPECT(migraphx::contains(output, "Mock marker program stop.")); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/matcher.cpp b/test/matcher.cpp index 6e5a240d286a4c92095aa54833c3551112c535c1..330cbea260a1186dfedc2d8f3c795995941e52e4 100644 --- a/test/matcher.cpp +++ b/test/matcher.cpp @@ -7,513 +7,500 @@ namespace match = migraphx::match; MIGRAPHX_PRED_MATCHER(throws, migraphx::instruction_ref) { MIGRAPHX_THROW("Matcher throws"); } -template -migraphx::match::matcher_result find_match(migraphx::program& p, M&& m) -{ - migraphx::match::matcher_result result; - for(auto ins : migraphx::iterator_for(p)) - { - result = migraphx::match::match_instruction(p, ins, m); - if(result.result != p.end()) - return result; - } - return result; -} - void match1() { - migraphx::program p; - auto l = p.add_literal(1); + migraphx::module mm; + auto l = mm.add_literal(1); auto m = match::standard_shape(); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == l}); } TEST_CASE(match_name1) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum"); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); } TEST_CASE(match_name2) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("min"); - auto r = find_match(p, m); - EXPECT(bool{r.result == p.end()}); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); } TEST_CASE(match_name3) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::standard_shape()); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); } TEST_CASE(match_arg1) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::arg(0)(match::name("@literal")), match::standard_shape()); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); } TEST_CASE(match_arg2) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::arg(0)(match::name("sum")), match::standard_shape()); - auto r = find_match(p, m); - EXPECT(bool{r.result == p.end()}); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); } TEST_CASE(match_arg3) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::arg(1)(match::name("@literal")), match::standard_shape()); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); } TEST_CASE(match_arg4) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - auto pass = p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + auto pass = mm.add_instruction(pass_op{}, sum); auto m = match::name("pass")(match::arg(0)(match::name("sum")), match::standard_shape()); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == pass}); } TEST_CASE(match_arg5) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("pass")(match::arg(1)(match::name("sum")), match::standard_shape()); - auto r = find_match(p, m); - EXPECT(bool{r.result == p.end()}); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); } TEST_CASE(match_arg6) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::arg(0)(match::name("@literal"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); } TEST_CASE(match_arg7) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); } TEST_CASE(match_arg8) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal"))), match::standard_shape()); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); } TEST_CASE(match_nargs1) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::nargs(2)); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); } TEST_CASE(match_nargs2) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::nargs(2), match::standard_shape()); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); } TEST_CASE(match_nargs3) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::all_of(match::nargs(2))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); } TEST_CASE(match_args1) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::args(match::name("@literal"), match::name("@literal")), match::standard_shape()); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); } TEST_CASE(match_args2) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::args(match::name("@literal"), match::name("sum")), match::standard_shape()); - auto r = find_match(p, m); - EXPECT(bool{r.result == p.end()}); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); } TEST_CASE(match_args3) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::args(match::name("@literal")), match::standard_shape()); - auto r = find_match(p, m); - EXPECT(bool{r.result == p.end()}); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); } TEST_CASE(match_args4) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum1 = p.add_instruction(sum_op{}, one, two); - auto sum2 = p.add_instruction(sum_op{}, sum1, two); - p.add_instruction(pass_op{}, sum2); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, two); + mm.add_instruction(pass_op{}, sum2); auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")), match::standard_shape()); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum2}); } TEST_CASE(match_args5) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")), match::standard_shape()); - auto r = find_match(p, m); - EXPECT(bool{r.result == p.end()}); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); } TEST_CASE(match_args6) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - auto pass = p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + auto pass = mm.add_instruction(pass_op{}, sum); auto m = match::name("pass")(match::args(match::name("sum")), match::standard_shape()); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == pass}); } TEST_CASE(match_args7) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - auto pass = p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + auto pass = mm.add_instruction(pass_op{}, sum); auto m = match::name("pass")(match::args(match::name("sum")(match::args( match::name("@literal"), match::name("@literal")))), match::standard_shape()); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == pass}); } TEST_CASE(match_either_args1) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum1 = p.add_instruction(sum_op{}, one, two); - auto sum2 = p.add_instruction(sum_op{}, sum1, two); - p.add_instruction(pass_op{}, sum2); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, two); + mm.add_instruction(pass_op{}, sum2); auto m = match::name("sum")(match::either_arg(0, 1)(match::name("sum"), match::name("@literal"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum2}); } TEST_CASE(match_either_args2) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum1 = p.add_instruction(sum_op{}, one, two); - auto sum2 = p.add_instruction(sum_op{}, sum1, two); - p.add_instruction(pass_op{}, sum2); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, two); + mm.add_instruction(pass_op{}, sum2); auto m = match::name("sum")(match::either_arg(0, 1)(match::name("@literal"), match::name("sum"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum2}); } TEST_CASE(match_either_args3) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum1 = p.add_instruction(sum_op{}, one, two); - auto sum2 = p.add_instruction(sum_op{}, sum1, two); - p.add_instruction(pass_op{}, sum2); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, two); + mm.add_instruction(pass_op{}, sum2); auto m = match::name("sum")(match::either_arg(0, 1)(match::name("pass"), match::name("@literal"))); - auto r = find_match(p, m); - EXPECT(bool{r.result == p.end()}); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); } TEST_CASE(match_either_args_any1) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum1 = p.add_instruction(sum_op{}, one, two); - auto sum2 = p.add_instruction(sum_op{}, sum1, two); - p.add_instruction(pass_op{}, sum2); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, two); + mm.add_instruction(pass_op{}, sum2); auto m = match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum1}); - EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); + EXPECT(bool{r.instructions["x"] != r.instructions["y"]}); } TEST_CASE(match_either_args_any2) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum1 = p.add_instruction(sum_op{}, one, two); - auto sum2 = p.add_instruction(sum_op{}, sum1, two); - p.add_instruction(pass_op{}, sum2); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, two); + mm.add_instruction(pass_op{}, sum2); auto m = match::name("sum")( match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum1}); - EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); + EXPECT(bool{r.instructions["x"] != r.instructions["y"]}); } TEST_CASE(match_either_args_any3) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum1 = p.add_instruction(sum_op{}, one, two); - auto sum2 = p.add_instruction(sum_op{}, sum1, two); - p.add_instruction(pass_op{}, sum2); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, two); + mm.add_instruction(pass_op{}, sum2); auto m = match::name("sum")( match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum1}); - EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); + EXPECT(bool{r.instructions["x"] != r.instructions["y"]}); } TEST_CASE(match_either_args_any4) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum1 = p.add_instruction(sum_op{}, one, two); - auto sum2 = p.add_instruction(sum_op{}, sum1, two); - p.add_instruction(pass_op{}, sum2); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, two); + mm.add_instruction(pass_op{}, sum2); auto m = match::name("sum")( match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum2}); - EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); + EXPECT(bool{r.instructions["x"] != r.instructions["y"]}); } TEST_CASE(match_either_args_any5) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum1 = p.add_instruction(sum_op{}, one, two); - auto sum2 = p.add_instruction(sum_op{}, sum1, two); - p.add_instruction(pass_op{}, sum2); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, two); + mm.add_instruction(pass_op{}, sum2); auto m = match::name("sum")( match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum2}); - EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); + EXPECT(bool{r.instructions["x"] != r.instructions["y"]}); } TEST_CASE(match_all_of1) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal")))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); } TEST_CASE(match_all_of2) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")( match::all_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal")))); - auto r = find_match(p, m); - EXPECT(bool{r.result == p.end()}); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); } TEST_CASE(match_all_of3) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::all_of(match::all_of( match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal"))))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); } TEST_CASE(match_lazy_any_of) { - migraphx::program p; - auto one = p.add_literal(1); - p.add_instruction(pass_op{}, one); + migraphx::module mm; + auto one = mm.add_literal(1); + mm.add_instruction(pass_op{}, one); auto m = match::any_of(match::any(), throws()); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == one}); } TEST_CASE(match_lazy_all_of) { - migraphx::program p; - auto one = p.add_literal(1); - p.add_instruction(pass_op{}, one); + migraphx::module mm; + auto one = mm.add_literal(1); + mm.add_instruction(pass_op{}, one); auto m = match::all_of(match::none(), throws()); - auto r = find_match(p, m); - EXPECT(bool{r.result == p.end()}); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); } TEST_CASE(match_lazy_none_of) { - migraphx::program p; - auto one = p.add_literal(1); - p.add_instruction(pass_op{}, one); + migraphx::module mm; + auto one = mm.add_literal(1); + mm.add_instruction(pass_op{}, one); auto m = match::none_of(match::any(), throws()); - auto r = find_match(p, m); - EXPECT(bool{r.result == p.end()}); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); } TEST_CASE(match_any_of1) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")( match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal")))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); } TEST_CASE(match_any_of2) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")( match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum")))); - auto r = find_match(p, m); - EXPECT(bool{r.result == p.end()}); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); } TEST_CASE(match_any_of_lazy1) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")( match::any_of(match::args(match::any(), match::any()).bind("x"), match::args(match::name("sum"), match::name("sum")).bind("y"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); EXPECT(migraphx::contains(r.instructions, "x")); EXPECT(bool{r.instructions["x"] == sum}); @@ -522,15 +509,15 @@ TEST_CASE(match_any_of_lazy1) TEST_CASE(match_any_of_lazy2) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")( match::any_of(match::args(match::name("@literal"), match::name("@literal")).bind("x"), match::args(match::any(), match::any()).bind("y"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); EXPECT(migraphx::contains(r.instructions, "x")); EXPECT(bool{r.instructions["x"] == sum}); @@ -539,15 +526,15 @@ TEST_CASE(match_any_of_lazy2) TEST_CASE(match_any_of_lazy3) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")( match::any_of(match::args(match::any(), match::any()).bind("x"), match::args(match::name("@literal"), match::name("@literal")).bind("y"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); EXPECT(migraphx::contains(r.instructions, "x")); EXPECT(bool{r.instructions["x"] == sum}); @@ -556,15 +543,15 @@ TEST_CASE(match_any_of_lazy3) TEST_CASE(match_any_of_lazy4) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::any_of( match::args(match::name("@literal").bind("x1"), match::name("@literal").bind("y1")), match::args(match::any().bind("x2"), match::any().bind("y2")))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); EXPECT(migraphx::contains(r.instructions, "x1")); EXPECT(migraphx::contains(r.instructions, "y1")); @@ -576,15 +563,15 @@ TEST_CASE(match_any_of_lazy4) TEST_CASE(match_any_of_lazy5) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::any_of( match::args(match::any().bind("x1"), match::any().bind("y1")), match::args(match::name("@literal").bind("x2"), match::name("@literal").bind("y2")))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); EXPECT(migraphx::contains(r.instructions, "x1")); EXPECT(migraphx::contains(r.instructions, "y1")); @@ -596,183 +583,459 @@ TEST_CASE(match_any_of_lazy5) TEST_CASE(match_none_of1) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")( match::none_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum")))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == sum}); } TEST_CASE(match_none_of2) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("sum")(match::none_of(match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal")))); - auto r = find_match(p, m); - EXPECT(bool{r.result == p.end()}); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); } TEST_CASE(match_output1) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto minus = p.add_instruction(minus_op{}, two, one); - auto sum = p.add_instruction(sum_op{}, minus, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto minus = mm.add_instruction(minus_op{}, two, one); + auto sum = mm.add_instruction(sum_op{}, minus, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("minus")(match::output(match::name("sum"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == minus}); } TEST_CASE(match_output2) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto minus = p.add_instruction(minus_op{}, two, one); - auto sum = p.add_instruction(sum_op{}, minus, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto minus = mm.add_instruction(minus_op{}, two, one); + auto sum = mm.add_instruction(sum_op{}, minus, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("@literal")(match::output(match::name("sum"))); - auto r = find_match(p, m); - EXPECT(bool{r.result == p.end()}); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); } TEST_CASE(match_skip_output1) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto minus = p.add_instruction(minus_op{}, two, one); - auto sum = p.add_instruction(sum_op{}, minus, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto minus = mm.add_instruction(minus_op{}, two, one); + auto sum = mm.add_instruction(sum_op{}, minus, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == minus}); } TEST_CASE(match_skip_output2) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto minus = p.add_instruction(minus_op{}, two, one); - auto minus_pass = p.add_instruction(pass_op{}, minus); - auto sum = p.add_instruction(sum_op{}, minus_pass, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto minus = mm.add_instruction(minus_op{}, two, one); + auto minus_pass = mm.add_instruction(pass_op{}, minus); + auto sum = mm.add_instruction(sum_op{}, minus_pass, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == minus}); } TEST_CASE(match_skip_output3) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto minus = p.add_instruction(minus_op{}, two, one); - auto minus_pass1 = p.add_instruction(pass_op{}, minus); - auto minus_pass2 = p.add_instruction(pass_op{}, minus_pass1); - auto minus_pass3 = p.add_instruction(pass_op{}, minus_pass2); - auto sum = p.add_instruction(sum_op{}, minus_pass3, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto minus = mm.add_instruction(minus_op{}, two, one); + auto minus_pass1 = mm.add_instruction(pass_op{}, minus); + auto minus_pass2 = mm.add_instruction(pass_op{}, minus_pass1); + auto minus_pass3 = mm.add_instruction(pass_op{}, minus_pass2); + auto sum = mm.add_instruction(sum_op{}, minus_pass3, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == minus}); } TEST_CASE(match_skip_output4) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto pass = p.add_instruction(pass_op{}, one); - auto sum = p.add_instruction(sum_op{}, pass, two); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto pass = mm.add_instruction(pass_op{}, one); + auto sum = mm.add_instruction(sum_op{}, pass, two); + mm.add_instruction(pass_op{}, sum); auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == two}); } TEST_CASE(match_skip_output5) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto pass = p.add_instruction(pass_op{}, one); - auto sum1 = p.add_instruction(sum_op{}, pass, two); - auto sum2 = p.add_instruction(sum_op{}, sum1, one); - auto sum3 = p.add_instruction(sum_op{}, sum2, two); - p.add_instruction(pass_op{}, sum3); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto pass = mm.add_instruction(pass_op{}, one); + auto sum1 = mm.add_instruction(sum_op{}, pass, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, one); + auto sum3 = mm.add_instruction(sum_op{}, sum2, two); + mm.add_instruction(pass_op{}, sum3); auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum"))); - auto r = find_match(p, m); - EXPECT(bool{r.result == p.end()}); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); } TEST_CASE(match_skip_output6) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto minus = p.add_instruction(minus_op{}, two, one); - auto sum1 = p.add_instruction(sum_op{}, minus, two); - auto sum2 = p.add_instruction(sum_op{}, sum1, one); - auto sum3 = p.add_instruction(sum_op{}, sum2, two); - p.add_instruction(pass_op{}, sum3); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto minus = mm.add_instruction(minus_op{}, two, one); + auto sum1 = mm.add_instruction(sum_op{}, minus, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, one); + auto sum3 = mm.add_instruction(sum_op{}, sum2, two); + mm.add_instruction(pass_op{}, sum3); auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == minus}); } TEST_CASE(match_skip_output7) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto minus1 = p.add_instruction(minus_op{}, two, one); - auto minus2 = p.add_instruction(minus_op{}, two, minus1); - auto sum = p.add_instruction(sum_op{}, one, minus2); - p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto minus1 = mm.add_instruction(minus_op{}, two, one); + auto minus2 = mm.add_instruction(minus_op{}, two, minus1); + auto sum = mm.add_instruction(sum_op{}, one, minus2); + mm.add_instruction(pass_op{}, sum); auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus"))); - auto r = find_match(p, m); + auto r = find_match(mm, m); EXPECT(bool{r.result == minus1}); } TEST_CASE(match_bind1) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - auto pass = p.add_instruction(pass_op{}, sum); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + auto pass = mm.add_instruction(pass_op{}, sum); auto m = match::name("pass")( match::args(match::name("sum")(match::args(match::name("@literal").bind("one"), match::name("@literal").bind("two"))) .bind("sum")), match::standard_shape()) .bind("pass"); - auto r = find_match(p, m); - EXPECT(bool{r.instructions.at("one") == one}); - EXPECT(bool{r.instructions.at("two") == two}); - EXPECT(bool{r.instructions.at("sum") == sum}); - EXPECT(bool{r.instructions.at("pass") == pass}); + auto r = find_match(mm, m); + EXPECT(bool{r.instructions["one"] == one}); + EXPECT(bool{r.instructions["two"] == two}); + EXPECT(bool{r.instructions["sum"] == sum}); + EXPECT(bool{r.instructions["pass"] == pass}); EXPECT(bool{r.result == pass}); } +TEST_CASE(match_bind_modules1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto* child = p.create_module("child"); + auto two = child->add_literal(2); + auto sum = child->add_instruction(sum_op{}, one, two); + child->add_instruction(pass_op{}, sum); + mm->add_instruction(mod_pass_op{}, {one}, {child}); + auto m = match::name("pass")( + match::args(match::name("sum")(match::args(match::name("@literal").bind("one"), + match::name("@literal").bind("two"))) + .bind("sum")), + match::standard_shape()) + .bind("pass"); + auto r = find_match(*child, m); + EXPECT(not migraphx::contains(r.instructions, "one")); + EXPECT(not migraphx::contains(r.instructions, "two")); + EXPECT(not migraphx::contains(r.instructions, "sum")); + EXPECT(not migraphx::contains(r.instructions, "pass")); + EXPECT(bool{r.result == child->end()}); +} + +TEST_CASE(match_bind_modules2) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto* child = p.create_module("child"); + auto two = child->add_literal(2); + auto sum = child->add_instruction(sum_op{}, one, two); + auto pass = child->add_instruction(pass_op{}, sum); + mm->add_instruction(mod_pass_op{}, {one}, {child}); + auto m = match::name("pass")( + match::args(match::name("sum")(match::args(match::name("@literal"), + match::name("@literal").bind("two"))) + .bind("sum")), + match::standard_shape()) + .bind("pass"); + auto r = find_match(*child, m); + EXPECT(bool{r.instructions["two"] == two}); + EXPECT(bool{r.instructions["sum"] == sum}); + EXPECT(bool{r.instructions["pass"] == pass}); + EXPECT(bool{r.result == pass}); +} + +TEST_CASE(match_has_value1) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, two); + mm.add_instruction(pass_op{}, sum2); + auto m = match::has_value(1); + auto r = find_match(mm, m); + EXPECT(bool{r.result == one}); +} + +TEST_CASE(match_has_value2) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, two); + mm.add_instruction(pass_op{}, sum2); + auto m = match::has_value(2); + auto r = find_match(mm, m); + EXPECT(bool{r.result == two}); +} + +TEST_CASE(match_has_value3) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, two); + mm.add_instruction(pass_op{}, sum2); + auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(2))); + auto r = find_match(mm, m); + EXPECT(bool{r.result == sum1}); +} + +TEST_CASE(match_has_value4) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, two); + mm.add_instruction(pass_op{}, sum2); + auto m = match::has_value(3); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); +} + +TEST_CASE(match_has_value5) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, two); + mm.add_instruction(pass_op{}, sum2); + auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(3))); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); +} + +TEST_CASE(match_has_value6) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, two); + mm.add_instruction(pass_op{}, sum2); + auto m = match::name("sum")(match::args(match::has_value(2), match::has_value(1))); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); +} + +TEST_CASE(match_tree1) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto three = mm.add_literal(3); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, three); + mm.add_instruction(pass_op{}, sum2); + auto m = match::tree( + match::name("sum"), match::has_value(1), match::has_value(2), match::has_value(3)); + auto r = find_match(mm, m); + EXPECT(bool{r.result == sum2}); +} + +TEST_CASE(match_tree2) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto three = mm.add_literal(3); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, three); + mm.add_instruction(pass_op{}, sum2); + auto m = match::tree( + match::name("sum"), match::has_value(2), match::has_value(1), match::has_value(3)); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); +} + +TEST_CASE(match_tree3) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto three = mm.add_literal(3); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, three, sum1); + mm.add_instruction(pass_op{}, sum2); + auto m = match::tree( + match::name("sum"), match::has_value(3), match::has_value(1), match::has_value(2)); + auto r = find_match(mm, m); + EXPECT(bool{r.result == sum2}); +} + +TEST_CASE(match_tree4) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto three = mm.add_literal(3); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, three); + mm.add_instruction(pass_op{}, sum2); + auto m = match::tree(match::name("sum"), + match::has_value(1), + match::has_value(2), + match::has_value(3), + match::has_value(4)); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); +} + +TEST_CASE(match_tree5) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto three = mm.add_literal(3); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, three); + mm.add_instruction(pass_op{}, sum2); + auto m = match::tree(match::name("sum"), match::has_value(2), match::has_value(3)); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); +} + +TEST_CASE(match_tree6) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto three = mm.add_literal(3); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, three); + mm.add_instruction(pass_op{}, sum2); + auto m = match::tree(match::name("sum"), match::has_value(1), match::has_value(3)); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); +} + +TEST_CASE(match_unordered_tree1) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto three = mm.add_literal(3); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, three); + mm.add_instruction(pass_op{}, sum2); + auto m = match::unordered_tree( + match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1)); + auto r = find_match(mm, m); + EXPECT(bool{r.result == sum2}); +} + +TEST_CASE(match_unordered_tree2) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto three = mm.add_literal(3); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, three, sum1); + mm.add_instruction(pass_op{}, sum2); + auto m = match::unordered_tree( + match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1)); + auto r = find_match(mm, m); + EXPECT(bool{r.result == sum2}); +} + +TEST_CASE(match_unordered_tree3) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto three = mm.add_literal(3); + auto sum1 = mm.add_instruction(sum_op{}, two, one); + auto sum2 = mm.add_instruction(sum_op{}, sum1, three); + mm.add_instruction(pass_op{}, sum2); + auto m = match::unordered_tree( + match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1)); + auto r = find_match(mm, m); + EXPECT(bool{r.result == sum2}); +} + +TEST_CASE(match_unordered_tree4) +{ + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto three = mm.add_literal(3); + auto sum1 = mm.add_instruction(sum_op{}, one, two); + auto sum2 = mm.add_instruction(sum_op{}, sum1, three); + mm.add_instruction(pass_op{}, sum2); + auto m = match::unordered_tree( + match::name("sum"), match::has_value(4), match::has_value(2), match::has_value(1)); + auto r = find_match(mm, m); + EXPECT(bool{r.result == mm.end()}); +} + struct match_find_sum { migraphx::instruction_ref ins; auto matcher() const { return match::name("sum"); } - void apply(migraphx::program&, const match::matcher_result& r) const + void apply(migraphx::module&, const match::matcher_result& r) const { EXPECT(bool{r.result == ins}); } @@ -783,7 +1046,7 @@ struct match_find_literal migraphx::instruction_ref ins; auto matcher() const { return match::name("@literal"); } - void apply(migraphx::program&, const match::matcher_result& r) const + void apply(migraphx::module&, const match::matcher_result& r) const { EXPECT(bool{r.result != ins}); EXPECT(r.result->name() == "@literal"); @@ -792,12 +1055,12 @@ struct match_find_literal TEST_CASE(match_finder) { - migraphx::program p; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, one, two); - p.add_instruction(pass_op{}, sum); - match::find_matches(p, match_find_sum{sum}, match_find_literal{sum}); + migraphx::module mm; + auto one = mm.add_literal(1); + auto two = mm.add_literal(2); + auto sum = mm.add_instruction(sum_op{}, one, two); + mm.add_instruction(pass_op{}, sum); + match::find_matches(mm, match_find_sum{sum}, match_find_literal{sum}); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/memory_coloring_test.cpp b/test/memory_coloring_test.cpp old mode 100644 new mode 100755 index 04350b59a57fa433d9eb93eef626949ddd32d124..b7be41ca8f4d6c01d1add52b9432a88dc75db6ad --- a/test/memory_coloring_test.cpp +++ b/test/memory_coloring_test.cpp @@ -3,12 +3,13 @@ #include #include #include +#include #include #include -void run_pass(migraphx::program& p) +void run_pass(migraphx::module& m) { - migraphx::run_passes(p, {migraphx::memory_coloring{"allocate", true}}); + migraphx::run_passes(m, {migraphx::memory_coloring{"allocate", true}}); } struct allocate @@ -35,578 +36,3734 @@ struct allocate } }; -migraphx::instruction_ref add_alloc(migraphx::program& p, const migraphx::shape& s) +migraphx::instruction_ref add_alloc(migraphx::module& m, const migraphx::shape& s) { - return p.add_instruction(allocate{s}); + return m.add_instruction(allocate{s}); } -bool no_allocate(const migraphx::program& p) +bool no_allocate(const migraphx::module& m) { - return std::none_of(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "allocate"; }); + return std::none_of(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "allocate"; }); +} + +bool is_overlap(std::pair x, std::pair y) +{ + return std::max(x.first, y.first) < std::min(x.second, y.second); +} + +std::pair get_load_interval(migraphx::instruction_ref a) +{ + auto v = a->get_operator().to_value(); + auto offset = v.at("offset").to(); + auto s = migraphx::from_value(v.at("shape")); + return {offset, offset + s.bytes()}; +} + +bool is_overlap_load(migraphx::instruction_ref a, migraphx::instruction_ref b) +{ + return is_overlap(get_load_interval(a), get_load_interval(b)); +} + +bool is_disjoint(const std::vector& inss) +{ + for(auto ins1 : inss) + { + for(auto ins2 : inss) + { + if(ins1 == ins2) + continue; + if(is_overlap_load(ins1, ins2)) + return false; + } + } + return true; } TEST_CASE(test1) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - p.add_instruction(pass_op{}, a2, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 192); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + m.add_instruction(pass_op{}, a2, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 192); + CHECK(no_allocate(m)); + CHECK(is_disjoint({a1, a2})); } TEST_CASE(test2) { - migraphx::program p; - auto input = p.add_parameter("input", migraphx::shape{migraphx::shape::float_type, {16}}); + migraphx::module m; + + auto input = m.add_parameter("input", migraphx::shape{migraphx::shape::float_type, {16}}); - auto a1 = add_alloc(p, {migraphx::shape::float_type, {128}}); - auto p1 = p.add_instruction(pass_op{}, a1, input); - auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - p.add_instruction(pass_op{}, p2, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 672); - CHECK(no_allocate(p)); + auto a1 = add_alloc(m, {migraphx::shape::float_type, {128}}); + auto m1 = m.add_instruction(pass_op{}, a1, input); + auto m2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + m.add_instruction(pass_op{}, m2, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 672); + CHECK(no_allocate(m)); } TEST_CASE(test3) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p2 = add_alloc(p, {migraphx::shape::float_type, {128}}); - auto p1 = p.add_instruction(pass_op{}, p2, a1); - auto p3 = add_alloc(p, {migraphx::shape::float_type, {40}}); - p.add_instruction(pass_op{}, p3, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 672); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m2 = add_alloc(m, {migraphx::shape::float_type, {128}}); + auto m1 = m.add_instruction(pass_op{}, m2, a1); + auto p3 = add_alloc(m, {migraphx::shape::float_type, {40}}); + m.add_instruction(pass_op{}, p3, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 672); + CHECK(no_allocate(m)); } TEST_CASE(test4) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {0}}); - auto p2 = add_alloc(p, {migraphx::shape::float_type, {128}}); - auto p1 = p.add_instruction(pass_op{}, p2, a1); - auto p3 = add_alloc(p, {migraphx::shape::float_type, {40}}); - p.add_instruction(pass_op{}, p3, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 672); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {0}}); + auto m2 = add_alloc(m, {migraphx::shape::float_type, {128}}); + auto m1 = m.add_instruction(pass_op{}, m2, a1); + auto p3 = add_alloc(m, {migraphx::shape::float_type, {40}}); + m.add_instruction(pass_op{}, p3, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 672); + CHECK(no_allocate(m)); } TEST_CASE(test5) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto p2 = add_alloc(p, {migraphx::shape::float_type, {8}}); - p.add_instruction(pass_op{}, p2, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 192); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto m2 = add_alloc(m, {migraphx::shape::float_type, {8}}); + m.add_instruction(pass_op{}, m2, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 192); + CHECK(no_allocate(m)); } TEST_CASE(test6) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p3 = add_alloc(p, {migraphx::shape::float_type, {40}}); - p.add_instruction(pass_op{}, p3, p2, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 352); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto m2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto p3 = add_alloc(m, {migraphx::shape::float_type, {40}}); + m.add_instruction(pass_op{}, p3, m2, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 352); + CHECK(no_allocate(m)); } TEST_CASE(test7) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p3 = add_alloc(p, {migraphx::shape::float_type, {8}}); - p.add_instruction(pass_op{}, p3, p2, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 224); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto m2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto p3 = add_alloc(m, {migraphx::shape::float_type, {8}}); + m.add_instruction(pass_op{}, p3, m2, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 224); + CHECK(no_allocate(m)); } TEST_CASE(test8) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto p2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p3 = add_alloc(p, {migraphx::shape::float_type, {192}}); - p.add_instruction(pass_op{}, p3, p2, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 960); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto m2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto p3 = add_alloc(m, {migraphx::shape::float_type, {192}}); + m.add_instruction(pass_op{}, p3, m2, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 960); + CHECK(no_allocate(m)); } TEST_CASE(test9) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto p2 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p3 = add_alloc(p, {migraphx::shape::float_type, {8}}); - p.add_instruction(pass_op{}, p3, p2, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 96); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto m2 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto p3 = add_alloc(m, {migraphx::shape::float_type, {8}}); + m.add_instruction(pass_op{}, p3, m2, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 96); + CHECK(no_allocate(m)); } TEST_CASE(test10) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - p.add_instruction(pass_op{}, a1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 32); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + m.add_instruction(pass_op{}, a1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 32); + CHECK(no_allocate(m)); } TEST_CASE(test11) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p2 = p.add_instruction(pass_op{}, a2, p1); - p.add_instruction(pass_op{}, a3, p2); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 224); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m2 = m.add_instruction(pass_op{}, a2, m1); + m.add_instruction(pass_op{}, a3, m2); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 224); + CHECK(no_allocate(m)); } TEST_CASE(test12) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p2 = p.add_instruction(pass_op{}, a2, p1); - p.add_instruction(pass_op{}, a3, p2); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 352); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto m2 = m.add_instruction(pass_op{}, a2, m1); + m.add_instruction(pass_op{}, a3, m2); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 352); + CHECK(no_allocate(m)); } TEST_CASE(test13) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p2 = p.add_instruction(pass_op{}, a2, p1); - p.add_instruction(pass_op{}, a3, p2); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 224); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto m2 = m.add_instruction(pass_op{}, a2, m1); + m.add_instruction(pass_op{}, a3, m2); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 224); + CHECK(no_allocate(m)); } TEST_CASE(test14) { - migraphx::program p; - auto a3 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto p2 = p.add_instruction(pass_op{}, a2, p1); - p.add_instruction(pass_op{}, a3, p2); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 224); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a3 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto m2 = m.add_instruction(pass_op{}, a2, m1); + m.add_instruction(pass_op{}, a3, m2); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 224); + CHECK(no_allocate(m)); } TEST_CASE(test15) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p2 = p.add_instruction(pass_op{}, a2); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}}); - p.add_instruction(pass_op{}, a3, p1, p2); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 352); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto m2 = m.add_instruction(pass_op{}, a2); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {40}}); + m.add_instruction(pass_op{}, a3, m1, m2); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 352); + CHECK(no_allocate(m)); } TEST_CASE(test16) { - migraphx::program p; - auto a1 = p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {8}})); - auto p1 = p.add_instruction(pass_op{}, a1); - auto a2 = p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {40}})); - auto p2 = p.add_instruction(pass_op{}, a2); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}}); - p.add_instruction(pass_op{}, a3, p1, p2); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 160); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {8}})); + auto m1 = m.add_instruction(pass_op{}, a1); + auto a2 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {40}})); + auto m2 = m.add_instruction(pass_op{}, a2); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {40}}); + m.add_instruction(pass_op{}, a3, m1, m2); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 160); + CHECK(no_allocate(m)); } TEST_CASE(test17) { - migraphx::program p; - auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto a1 = p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {8}})); - auto p1 = p.add_instruction(pass_op{}, a1); - auto a2 = p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {40}})); - auto p2 = p.add_instruction(pass_op{}, a2); - p.add_instruction(pass_op{}, a3, p1, p2); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 160); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a3 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto a1 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {8}})); + auto m1 = m.add_instruction(pass_op{}, a1); + auto a2 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {40}})); + auto m2 = m.add_instruction(pass_op{}, a2); + m.add_instruction(pass_op{}, a3, m1, m2); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 160); + CHECK(no_allocate(m)); } TEST_CASE(test18) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto p2 = p.add_instruction(pass_op{}, a1, p1); - auto p3 = p.add_instruction(pass_op{}, p2, p1); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - p.add_instruction(pass_op{}, a2, p1, p2, p3); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 192); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto m2 = m.add_instruction(pass_op{}, a1, m1); + auto p3 = m.add_instruction(pass_op{}, m2, m1); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + m.add_instruction(pass_op{}, a2, m1, m2, p3); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 192); + CHECK(no_allocate(m)); } TEST_CASE(test19) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p2 = p.add_instruction(pass_op{}, a2, p1); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}}); - p.add_instruction(pass_op{}, a3, p2, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 352); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto m2 = m.add_instruction(pass_op{}, a2, m1); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {40}}); + m.add_instruction(pass_op{}, a3, m2, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 352); + CHECK(no_allocate(m)); } TEST_CASE(test20) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {32}}); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {32}}); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {32}}); - auto p1 = p.add_instruction(pass_op{}, a1, a2, a3); - auto a4 = add_alloc(p, {migraphx::shape::float_type, {32}}); - p.add_instruction(pass_op{}, a4, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 384); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {32}}); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {32}}); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {32}}); + auto m1 = m.add_instruction(pass_op{}, a1, a2, a3); + auto a4 = add_alloc(m, {migraphx::shape::float_type, {32}}); + m.add_instruction(pass_op{}, a4, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 384); + CHECK(no_allocate(m)); } TEST_CASE(test21) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {32}}); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {32}}); - auto p1 = p.add_instruction(pass_op{}, a1, a2, a3); - auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}}); - p.add_instruction(pass_op{}, a4, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 288); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {32}}); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {32}}); + auto m1 = m.add_instruction(pass_op{}, a1, a2, a3); + auto a4 = add_alloc(m, {migraphx::shape::float_type, {8}}); + m.add_instruction(pass_op{}, a4, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 288); + CHECK(no_allocate(m)); } TEST_CASE(test22) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {32}}); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {32}}); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1, a2, a3); - auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}}); - p.add_instruction(pass_op{}, a4, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 288); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {32}}); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {32}}); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1, a2, a3); + auto a4 = add_alloc(m, {migraphx::shape::float_type, {8}}); + m.add_instruction(pass_op{}, a4, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 288); + CHECK(no_allocate(m)); } TEST_CASE(test23) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {32}}); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {32}}); - auto p1 = p.add_instruction(pass_op{}, a1, a2, a3); - auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}}); - p.add_instruction(pass_op{}, a4, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 288); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {32}}); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {32}}); + auto m1 = m.add_instruction(pass_op{}, a1, a2, a3); + auto a4 = add_alloc(m, {migraphx::shape::float_type, {8}}); + m.add_instruction(pass_op{}, a4, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 288); + CHECK(no_allocate(m)); } TEST_CASE(test24) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {32}}); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {32}}); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {32}}); - auto p1 = p.add_instruction(pass_op{}, a1, a2, a3); - auto a4 = add_alloc(p, {migraphx::shape::float_type, {8}}); - p.add_instruction(pass_op{}, a4, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 384); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {32}}); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {32}}); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {32}}); + auto m1 = m.add_instruction(pass_op{}, a1, a2, a3); + auto a4 = add_alloc(m, {migraphx::shape::float_type, {8}}); + m.add_instruction(pass_op{}, a4, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 384); + CHECK(no_allocate(m)); } TEST_CASE(test25) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - p.add_instruction(nop{}); - auto p1 = p.add_instruction(pass_op{}, a1); - p.add_instruction(nop{}); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - p.add_instruction(pass_op{}, a2, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 192); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + m.add_instruction(nop{}); + auto m1 = m.add_instruction(pass_op{}, a1); + m.add_instruction(nop{}); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + m.add_instruction(pass_op{}, a2, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 192); + CHECK(no_allocate(m)); } TEST_CASE(test26) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - p.add_instruction(nop{}, a1); - auto p1 = p.add_instruction(pass_op{}, a1); - p.add_instruction(nop{}, a1, p1); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - p.add_instruction(pass_op{}, a2, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 192); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + m.add_instruction(nop{}, a1); + auto m1 = m.add_instruction(pass_op{}, a1); + m.add_instruction(nop{}, a1, m1); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + m.add_instruction(pass_op{}, a2, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 192); + CHECK(no_allocate(m)); } TEST_CASE(test27) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - p.add_instruction(nop{}, a2, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 192); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + m.add_instruction(nop{}, a2, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 192); + CHECK(no_allocate(m)); } TEST_CASE(test28) { - migraphx::program p; - auto output = p.add_parameter("output", {migraphx::shape::float_type, {8}}); - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p2 = p.add_instruction(pass_op{}, a2, p1); - p.add_instruction(pass_op{}, p2, output); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 192); - CHECK(no_allocate(p)); + migraphx::module m; + + auto output = m.add_parameter("output", {migraphx::shape::float_type, {8}}); + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto m2 = m.add_instruction(pass_op{}, a2, m1); + m.add_instruction(pass_op{}, m2, output); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 192); + CHECK(no_allocate(m)); } TEST_CASE(test29) { - migraphx::program p; - auto output = p.add_parameter("output", {migraphx::shape::float_type, {8}}); - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p2 = p.add_instruction(pass_op{}, a2, p1); - p.move_instruction(output, p2); - p.add_instruction(pass_op{}, p2, output); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 192); - CHECK(no_allocate(p)); + migraphx::module m; + auto output = m.add_parameter("output", {migraphx::shape::float_type, {8}}); + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto m2 = m.add_instruction(pass_op{}, a2, m1); + m.move_instruction(output, m2); + m.add_instruction(pass_op{}, m2, output); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 192); + CHECK(no_allocate(m)); } TEST_CASE(test30) { - migraphx::program p; - auto output = p.add_parameter("x", {migraphx::shape::float_type, {8}}); - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p2 = p.add_instruction(pass_op{}, a2, p1); - p.move_instruction(output, p2); - p.add_instruction(pass_op{}, p2, output); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 192); - CHECK(no_allocate(p)); + migraphx::module m; + + auto output = m.add_parameter("x", {migraphx::shape::float_type, {8}}); + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto m2 = m.add_instruction(pass_op{}, a2, m1); + m.move_instruction(output, m2); + m.add_instruction(pass_op{}, m2, output); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 192); + CHECK(no_allocate(m)); } TEST_CASE(test31) { - migraphx::program p; - auto output = p.add_parameter("output", {migraphx::shape::float_type, {8}}); - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a1); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - p.move_instruction(output, a2); - p.add_instruction(pass_op{}, a2, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 192); - CHECK(no_allocate(p)); + migraphx::module m; + + auto output = m.add_parameter("output", {migraphx::shape::float_type, {8}}); + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a1); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + m.move_instruction(output, a2); + m.add_instruction(pass_op{}, a2, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 192); + CHECK(no_allocate(m)); } TEST_CASE(test32) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p1 = p.add_instruction(pass_op{}, a2, a1, a3); - auto a5 = add_alloc(p, {migraphx::shape::float_type, {40}}); - p.add_instruction(pass_op{}, a5, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 352); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto m1 = m.add_instruction(pass_op{}, a2, a1, a3); + auto a5 = add_alloc(m, {migraphx::shape::float_type, {40}}); + m.add_instruction(pass_op{}, a5, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 352); + CHECK(no_allocate(m)); } TEST_CASE(test33) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a2, a1, a3); - auto a5 = add_alloc(p, {migraphx::shape::float_type, {40}}); - p.add_instruction(pass_op{}, a5, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 192); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a2, a1, a3); + auto a5 = add_alloc(m, {migraphx::shape::float_type, {40}}); + m.add_instruction(pass_op{}, a5, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 192); + CHECK(no_allocate(m)); } TEST_CASE(test34) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p1 = p.add_instruction(pass_op{}, a2, a1, a3); - auto a5 = add_alloc(p, {migraphx::shape::float_type, {8}}); - p.add_instruction(pass_op{}, a5, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 480); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto m1 = m.add_instruction(pass_op{}, a2, a1, a3); + auto a5 = add_alloc(m, {migraphx::shape::float_type, {8}}); + m.add_instruction(pass_op{}, a5, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 480); + CHECK(no_allocate(m)); } TEST_CASE(test35) { - migraphx::program p; - auto a1 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {8}}); - auto p1 = p.add_instruction(pass_op{}, a2, a1, a3); - auto a5 = add_alloc(p, {migraphx::shape::float_type, {8}}); - p.add_instruction(pass_op{}, a5, p1); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 224); - CHECK(no_allocate(p)); + migraphx::module m; + + auto a1 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {8}}); + auto m1 = m.add_instruction(pass_op{}, a2, a1, a3); + auto a5 = add_alloc(m, {migraphx::shape::float_type, {8}}); + m.add_instruction(pass_op{}, a5, m1); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 224); + CHECK(no_allocate(m)); } TEST_CASE(test36) { - migraphx::program p; - auto output = p.add_parameter("output", {migraphx::shape::float_type, {20}}); - auto a1 = add_alloc(p, {migraphx::shape::float_type, {0}}); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p1 = p.add_instruction(pass_op{}, a2, a1); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p2 = p.add_instruction(pass_op{}, a3, p1); - auto a4 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p3 = p.add_instruction(pass_op{}, a4, p2); - p.add_instruction(pass_op{}, output, p3); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 320); - CHECK(no_allocate(p)); + migraphx::module m; + + auto output = m.add_parameter("output", {migraphx::shape::float_type, {20}}); + auto a1 = add_alloc(m, {migraphx::shape::float_type, {0}}); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto m1 = m.add_instruction(pass_op{}, a2, a1); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto m2 = m.add_instruction(pass_op{}, a3, m1); + auto a4 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto p3 = m.add_instruction(pass_op{}, a4, m2); + m.add_instruction(pass_op{}, output, p3); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 320); + CHECK(no_allocate(m)); } TEST_CASE(test37) { - migraphx::program p; - auto output = p.add_parameter("output", {migraphx::shape::float_type, {20}}); - auto a1 = add_alloc(p, {migraphx::shape::float_type, {4}}); - auto a2 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p1 = p.add_instruction(pass_op{}, a2, a1); - auto a3 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p2 = p.add_instruction(pass_op{}, a3, p1); - auto a4 = add_alloc(p, {migraphx::shape::float_type, {40}}); - auto p3 = p.add_instruction(pass_op{}, a4, p2); - p.add_instruction(pass_op{}, output, p3); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 320); - CHECK(no_allocate(p)); + migraphx::module m; + + auto output = m.add_parameter("output", {migraphx::shape::float_type, {20}}); + auto a1 = add_alloc(m, {migraphx::shape::float_type, {4}}); + auto a2 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto m1 = m.add_instruction(pass_op{}, a2, a1); + auto a3 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto m2 = m.add_instruction(pass_op{}, a3, m1); + auto a4 = add_alloc(m, {migraphx::shape::float_type, {40}}); + auto p3 = m.add_instruction(pass_op{}, a4, m2); + m.add_instruction(pass_op{}, output, p3); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 320); + CHECK(no_allocate(m)); } TEST_CASE(test38) +{ + migraphx::module m; + + auto output = m.add_parameter("output", {migraphx::shape::float_type, {1, 64, 56, 56}}); + auto m29 = add_alloc(m, {migraphx::shape::float_type, {0}}); + auto p30 = add_alloc(m, {migraphx::shape::float_type, {1, 64, 112, 112}}); + auto p31 = m.add_instruction(pass_op{}, p30, m29); + auto p32 = add_alloc(m, {migraphx::shape::float_type, {1, 64, 112, 112}}); + auto p37 = m.add_instruction(pass_op{}, p32, p31); + auto p38 = add_alloc(m, {migraphx::shape::float_type, {1, 64, 112, 112}}); + auto p39 = m.add_instruction(pass_op{}, p38, p37); + auto p40 = add_alloc(m, {migraphx::shape::float_type, {1, 64, 56, 56}}); + auto p41 = m.add_instruction(pass_op{}, p40, p39); + auto p42 = add_alloc(m, {migraphx::shape::float_type, {0}}); + auto p43 = add_alloc(m, {migraphx::shape::float_type, {1, 64, 56, 56}}); + auto p44 = m.add_instruction(pass_op{}, p43, p41, p42); + auto p45 = add_alloc(m, {migraphx::shape::float_type, {1, 64, 56, 56}}); + auto p50 = m.add_instruction(pass_op{}, p45, p44); + auto p51 = add_alloc(m, {migraphx::shape::float_type, {1, 64, 56, 56}}); + auto p52 = m.add_instruction(pass_op{}, p51, p50); + auto p53 = add_alloc(m, {migraphx::shape::float_type, {0}}); + auto p54 = add_alloc(m, {migraphx::shape::float_type, {1, 64, 56, 56}}); + auto p55 = m.add_instruction(pass_op{}, p54, p52, p53); + auto p56 = add_alloc(m, {migraphx::shape::float_type, {1, 64, 56, 56}}); + auto p61 = m.add_instruction(pass_op{}, p56, p55); + auto p62 = add_alloc(m, {migraphx::shape::float_type, {1, 64, 56, 56}}); + auto p63 = m.add_instruction(pass_op{}, p62, p61, p41); + auto p64 = add_alloc(m, {migraphx::shape::float_type, {0}}); + auto p65 = add_alloc(m, {migraphx::shape::float_type, {1, 64, 56, 56}}); + auto p66 = m.add_instruction(pass_op{}, p65, p63, p64); + auto p67 = add_alloc(m, {migraphx::shape::float_type, {1, 64, 56, 56}}); + auto p72 = m.add_instruction(pass_op{}, p67, p66); + auto p73 = add_alloc(m, {migraphx::shape::float_type, {1, 64, 56, 56}}); + auto p74 = m.add_instruction(pass_op{}, p73, p72); + auto p75 = add_alloc(m, {migraphx::shape::float_type, {0}}); + auto p76 = add_alloc(m, {migraphx::shape::float_type, {1, 64, 56, 56}}); + auto p77 = m.add_instruction(pass_op{}, p76, p74, p75); + auto p78 = add_alloc(m, {migraphx::shape::float_type, {1, 64, 56, 56}}); + auto p83 = m.add_instruction(pass_op{}, p78, p77); + m.add_instruction(pass_op{}, output, p83, p63); + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 7225344); // Optimal solution is 6422528 + CHECK(no_allocate(m)); +} + +TEST_CASE(test39) { migraphx::program p; - auto output = p.add_parameter("output", {migraphx::shape::float_type, {1, 64, 56, 56}}); - auto p29 = add_alloc(p, {migraphx::shape::float_type, {0}}); - auto p30 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 112, 112}}); - auto p31 = p.add_instruction(pass_op{}, p30, p29); - auto p32 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 112, 112}}); - auto p37 = p.add_instruction(pass_op{}, p32, p31); - auto p38 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 112, 112}}); - auto p39 = p.add_instruction(pass_op{}, p38, p37); - auto p40 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}}); - auto p41 = p.add_instruction(pass_op{}, p40, p39); - auto p42 = add_alloc(p, {migraphx::shape::float_type, {0}}); - auto p43 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}}); - auto p44 = p.add_instruction(pass_op{}, p43, p41, p42); - auto p45 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}}); - auto p50 = p.add_instruction(pass_op{}, p45, p44); - auto p51 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}}); - auto p52 = p.add_instruction(pass_op{}, p51, p50); - auto p53 = add_alloc(p, {migraphx::shape::float_type, {0}}); - auto p54 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}}); - auto p55 = p.add_instruction(pass_op{}, p54, p52, p53); - auto p56 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}}); - auto p61 = p.add_instruction(pass_op{}, p56, p55); - auto p62 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}}); - auto p63 = p.add_instruction(pass_op{}, p62, p61, p41); - auto p64 = add_alloc(p, {migraphx::shape::float_type, {0}}); - auto p65 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}}); - auto p66 = p.add_instruction(pass_op{}, p65, p63, p64); - auto p67 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}}); - auto p72 = p.add_instruction(pass_op{}, p67, p66); - auto p73 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}}); - auto p74 = p.add_instruction(pass_op{}, p73, p72); - auto p75 = add_alloc(p, {migraphx::shape::float_type, {0}}); - auto p76 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}}); - auto p77 = p.add_instruction(pass_op{}, p76, p74, p75); - auto p78 = add_alloc(p, {migraphx::shape::float_type, {1, 64, 56, 56}}); - auto p83 = p.add_instruction(pass_op{}, p78, p77); - p.add_instruction(pass_op{}, output, p83, p63); - run_pass(p); - CHECK(p.get_parameter_shape("scratch").bytes() == 7225344); // Optimal solution is 6422528 - CHECK(no_allocate(p)); + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + auto cond = add_alloc(*mm, cond_s); + auto output = mm->add_parameter("output", {migraphx::shape::float_type, {20}}); + + migraphx::shape ds{migraphx::shape::float_type, {2, 3}}; + std::vector data1 = {0.384804, -1.77948, -0.453775, 0.477438, -1.06333, -1.12893}; + auto l1 = mm->add_literal(migraphx::literal(ds, data1)); + std::vector data2 = {-0.258047, 0.360394, 0.536804, -0.577762, 1.0217, 1.02442}; + auto l2 = mm->add_literal(migraphx::literal(ds, data2)); + + auto* then_mod = p.create_module("If_0_if"); + auto i1 = add_alloc(*then_mod, ds); + auto a1 = then_mod->add_instruction(pass_op{}, i1, l1); + then_mod->add_return({a1, output}); + + auto* else_mod = p.create_module("If_0_else"); + auto i2 = add_alloc(*else_mod, ds); + auto a2 = else_mod->add_instruction(pass_op{}, i2, l2); + else_mod->add_return({a2, output}); + + auto ret = mm->add_instruction(mod_pass_op{}, {cond}, {then_mod, else_mod}); + mm->add_return({ret, output}); + + auto sub_modules = p.get_modules(); + std::reverse(sub_modules.begin(), sub_modules.end()); + for(auto& smod : sub_modules) + { + run_pass(*smod); + } + + CHECK(mm->get_parameter_shape("scratch").bytes() == 4); + CHECK(then_mod->get_parameter_shape("scratch").bytes() == 24); + CHECK(else_mod->get_parameter_shape("scratch").bytes() == 24); + CHECK(no_allocate(*mm)); + CHECK(no_allocate(*then_mod)); + CHECK(no_allocate(*else_mod)); +} + +// NOLINTNEXTLINE +TEST_CASE(rnn_dom) +{ + migraphx::module m; + + auto mx0 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 10}}); + auto mx1 = m.add_instruction(pass_op{}); + auto mr = m.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {1, 15, 5}}); + auto mx2 = m.add_instruction(pass_op{}, mr); + auto mx3 = m.add_instruction(pass_op{}, mx2); + auto mx4 = m.add_instruction(pass_op{}, mx3); + m.add_instruction(pass_op{}); + auto mx6 = m.add_instruction(pass_op{}, mx0, mx1, mx4); + m.add_instruction(pass_op{}); + auto mx8 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 15}}); + m.add_instruction(pass_op{}, mx8, mx1, mx0, mx6); + auto mseq = m.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {3, 2, 8}}); + auto mx10 = m.add_instruction(pass_op{}, mseq); + auto mx11 = m.add_instruction(pass_op{}, mx10); + auto mw = m.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 15, 8}}); + auto mx12 = m.add_instruction(pass_op{}, mw); + auto mx13 = m.add_instruction(pass_op{}, mx12); + m.add_instruction(pass_op{}); + auto mx15 = m.add_instruction(pass_op{}, mx8, mx11, mx13); + m.add_instruction(pass_op{}, mx15, mx1, mx0, mx6); + m.add_instruction(pass_op{}); + auto mx18 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, mx18, mx6, mx15, mx0, mx1, mx8); + auto mx20 = m.add_instruction(pass_op{}, mx6); + m.add_instruction(pass_op{}, mx20, mx8, mx15, mx18); + auto mx22 = m.add_instruction(pass_op{}, mx15); + m.add_instruction(pass_op{}, mx22, mx1, mx0, mx20, mx6, mx18); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx27 = m.add_instruction(pass_op{}, mx18, mx22, mx20); + m.add_instruction(pass_op{}, mx27, mx15, mx8, mx6, mx20, mx1, mx22, mx0); + m.add_instruction(pass_op{}); + auto mx30 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, mx30, mx20, mx22, mx1, mx15, mx8, mx6, mx27, mx0, mx18); + auto mx32 = m.add_instruction(pass_op{}, mx15); + m.add_instruction(pass_op{}, mx32, mx20, mx30, mx0, mx18, mx1, mx27, mx6); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx36 = m.add_instruction(pass_op{}, mx30, mx32); + m.add_instruction(pass_op{}, mx36, mx32, mx0, mx27, mx8, mx1, mx15, mx6, mx20, mx22, mx18); + auto mx38 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, mx38, mx32, mx0, mx27, mx8, mx1, mx15, mx6, mx20, mx22, mx18); + auto mx40 = m.add_instruction(pass_op{}, mx38, mx36); + m.add_instruction(pass_op{}, mx40, mx32, mx0, mx27, mx8, mx1, mx15, mx6, mx20, mx22, mx18); + m.add_instruction(pass_op{}); + auto mx43 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, mx43, mx15, mx32, mx27, mx30, mx18, mx8, mx40, mx36, mx22, mx38); + auto mx45 = m.add_instruction(pass_op{}, mx6); + m.add_instruction(pass_op{}, mx45, mx32, mx27, mx30, mx18, mx40, mx36, mx22, mx8, mx15, mx38); + auto mx47 = m.add_instruction(pass_op{}, mx15); + m.add_instruction( + pass_op{}, mx47, mx30, mx18, mx43, mx6, mx1, mx45, mx0, mx27, mx36, mx20, mx40, mx38); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx51 = m.add_instruction(pass_op{}, mx43, mx47, mx45); + m.add_instruction( + pass_op{}, mx51, mx15, mx47, mx32, mx27, mx30, mx18, mx8, mx36, mx22, mx40, mx38); + auto mx53 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction( + pass_op{}, mx53, mx15, mx47, mx32, mx27, mx30, mx18, mx8, mx36, mx22, mx40, mx38); + auto mx55 = m.add_instruction(pass_op{}, mx53, mx51, mx1); + m.add_instruction( + pass_op{}, mx55, mx15, mx47, mx32, mx27, mx30, mx18, mx8, mx36, mx22, mx40, mx38); + auto mx57 = m.add_instruction(pass_op{}, mx3); + m.add_instruction(pass_op{}); + auto mx59 = m.add_instruction(pass_op{}, mx40, mx55, mx57, mx40); + m.add_instruction( + pass_op{}, mx59, mx15, mx8, mx38, mx18, mx30, mx27, mx47, mx32, mx40, mx36, mx22); + auto mx61 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx61, + mx30, + mx15, + mx1, + mx51, + mx20, + mx59, + mx32, + mx45, + mx22, + mx8, + mx47, + mx40, + mx53, + mx6, + mx55, + mx0, + mx43, + mx38, + mx36); + m.add_instruction(pass_op{}); + auto mx64 = m.add_instruction(pass_op{}, mx61, mx27, mx1); + m.add_instruction(pass_op{}, + mx64, + mx30, + mx15, + mx1, + mx51, + mx20, + mx59, + mx32, + mx45, + mx22, + mx8, + mx47, + mx40, + mx53, + mx6, + mx55, + mx0, + mx43, + mx38, + mx36); + m.add_instruction(pass_op{}); + auto mx67 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx67, + mx18, + mx6, + mx1, + mx51, + mx20, + mx59, + mx27, + mx55, + mx43, + mx38, + mx0, + mx61, + mx45, + mx36, + mx40, + mx53, + mx64, + mx30); + auto mx69 = m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}, + mx69, + mx18, + mx6, + mx1, + mx51, + mx20, + mx59, + mx27, + mx55, + mx43, + mx38, + mx0, + mx61, + mx45, + mx36, + mx40, + mx53, + mx64, + mx30); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx73 = m.add_instruction(pass_op{}, mx67, mx69, mx27); + m.add_instruction(pass_op{}, + mx73, + mx18, + mx6, + mx1, + mx51, + mx20, + mx59, + mx27, + mx55, + mx43, + mx38, + mx0, + mx61, + mx45, + mx36, + mx40, + mx53, + mx64, + mx30); + m.add_instruction(pass_op{}); + auto mx76 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx76, + mx64, + mx30, + mx18, + mx40, + mx8, + mx61, + mx38, + mx69, + mx67, + mx73, + mx27, + mx47, + mx32, + mx36, + mx15, + mx22); + m.add_instruction(pass_op{}); + auto mx79 = m.add_instruction(pass_op{}, mx76, mx59); + m.add_instruction(pass_op{}, + mx79, + mx64, + mx30, + mx18, + mx40, + mx8, + mx61, + mx38, + mx69, + mx67, + mx73, + mx27, + mx47, + mx32, + mx36, + mx15, + mx22); + auto mx81 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx81, + mx36, + mx32, + mx27, + mx47, + mx18, + mx30, + mx73, + mx67, + mx22, + mx15, + mx61, + mx8, + mx64, + mx40, + mx69, + mx38); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx85 = m.add_instruction(pass_op{}, mx81, mx73, mx79, mx64); + m.add_instruction(pass_op{}, + mx85, + mx36, + mx32, + mx27, + mx47, + mx18, + mx30, + mx73, + mx67, + mx22, + mx15, + mx61, + mx8, + mx64, + mx40, + mx69, + mx38); + m.add_instruction(pass_op{}); + auto mx88 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 10}}); + m.add_instruction(pass_op{}, + mx88, + mx36, + mx32, + mx27, + mx47, + mx18, + mx30, + mx73, + mx67, + mx22, + mx15, + mx61, + mx8, + mx64, + mx40, + mx69, + mx38); + auto mx90 = m.add_instruction(pass_op{}, mx88, mx85, mx4); + m.add_instruction(pass_op{}, + mx90, + mx36, + mx32, + mx27, + mx47, + mx18, + mx30, + mx73, + mx67, + mx22, + mx15, + mx61, + mx8, + mx64, + mx40, + mx69, + mx38); + m.add_instruction(pass_op{}); + auto mx93 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 15}}); + m.add_instruction(pass_op{}, + mx93, + mx51, + mx88, + mx20, + mx64, + mx43, + mx61, + mx53, + mx81, + mx47, + mx6, + mx45, + mx0, + mx55, + mx18, + mx76, + mx1, + mx79, + mx85, + mx90, + mx8, + mx69, + mx67, + mx73, + mx32, + mx59, + mx22, + mx15, + mx27); + auto mx95 = m.add_instruction(pass_op{}, mseq); + auto mx96 = m.add_instruction(pass_op{}, mx95); + m.add_instruction(pass_op{}); + auto mx98 = m.add_instruction(pass_op{}, mx93, mx96, mx13); + m.add_instruction(pass_op{}, + mx98, + mx51, + mx88, + mx20, + mx64, + mx43, + mx61, + mx53, + mx81, + mx47, + mx6, + mx45, + mx0, + mx55, + mx18, + mx76, + mx1, + mx79, + mx85, + mx90, + mx8, + mx69, + mx67, + mx73, + mx32, + mx59, + mx22, + mx15, + mx27); + m.add_instruction(pass_op{}); + auto mx101 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx101, + mx43, + mx40, + mx53, + mx59, + mx51, + mx6, + mx61, + mx81, + mx38, + mx45, + mx20, + mx0, + mx76, + mx55, + mx18, + mx85, + mx1, + mx93, + mx79, + mx90, + mx27, + mx88, + mx64, + mx30, + mx98, + mx36); + auto mx103 = m.add_instruction(pass_op{}, mx90); + m.add_instruction(pass_op{}, + mx103, + mx64, + mx101, + mx15, + mx67, + mx73, + mx18, + mx40, + mx8, + mx47, + mx98, + mx27, + mx32, + mx61, + mx22, + mx93, + mx69, + mx36, + mx38, + mx30); + auto mx105 = m.add_instruction(pass_op{}, mx98); + m.add_instruction(pass_op{}, + mx105, + mx43, + mx88, + mx53, + mx64, + mx59, + mx6, + mx76, + mx61, + mx81, + mx47, + mx103, + mx22, + mx45, + mx0, + mx55, + mx18, + mx85, + mx51, + mx20, + mx1, + mx79, + mx90, + mx8, + mx101, + mx15, + mx69, + mx67, + mx73, + mx32, + mx27); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx110 = m.add_instruction(pass_op{}, mx101, mx105, mx103); + m.add_instruction(pass_op{}, + mx110, + mx88, + mx40, + mx93, + mx59, + mx43, + mx61, + mx53, + mx81, + mx103, + mx6, + mx45, + mx0, + mx55, + mx18, + mx64, + mx20, + mx76, + mx1, + mx79, + mx38, + mx85, + mx90, + mx27, + mx30, + mx105, + mx98, + mx51, + mx36); + m.add_instruction(pass_op{}); + auto mx113 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx113, + mx59, + mx20, + mx51, + mx1, + mx79, + mx90, + mx55, + mx85, + mx76, + mx81, + mx47, + mx6, + mx38, + mx88, + mx43, + mx40, + mx0, + mx45, + mx53, + mx93, + mx8, + mx101, + mx15, + mx69, + mx67, + mx73, + mx32, + mx110, + mx22, + mx103, + mx30, + mx36, + mx98, + mx105); + auto mx115 = m.add_instruction(pass_op{}, mx98); + m.add_instruction(pass_op{}, + mx115, + mx59, + mx20, + mx51, + mx1, + mx79, + mx90, + mx55, + mx18, + mx85, + mx76, + mx61, + mx81, + mx47, + mx6, + mx88, + mx43, + mx0, + mx45, + mx53, + mx64, + mx8, + mx101, + mx15, + mx69, + mx67, + mx73, + mx113, + mx32, + mx110, + mx22, + mx103, + mx27); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx119 = m.add_instruction(pass_op{}, mx113, mx115); + m.add_instruction(pass_op{}, + mx119, + mx59, + mx20, + mx51, + mx1, + mx79, + mx90, + mx55, + mx85, + mx76, + mx81, + mx47, + mx6, + mx38, + mx88, + mx43, + mx40, + mx0, + mx45, + mx53, + mx93, + mx8, + mx101, + mx15, + mx69, + mx67, + mx73, + mx32, + mx110, + mx22, + mx103, + mx30, + mx36, + mx115, + mx98, + mx105); + auto mx121 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx121, + mx59, + mx20, + mx51, + mx1, + mx79, + mx90, + mx55, + mx85, + mx76, + mx81, + mx47, + mx6, + mx38, + mx88, + mx43, + mx40, + mx0, + mx45, + mx53, + mx93, + mx8, + mx101, + mx15, + mx69, + mx67, + mx73, + mx32, + mx110, + mx22, + mx103, + mx30, + mx36, + mx115, + mx98, + mx105); + auto mx123 = m.add_instruction(pass_op{}, mx121, mx119); + m.add_instruction(pass_op{}, + mx123, + mx59, + mx20, + mx51, + mx1, + mx79, + mx90, + mx55, + mx85, + mx76, + mx81, + mx47, + mx6, + mx38, + mx88, + mx43, + mx40, + mx0, + mx45, + mx53, + mx93, + mx8, + mx101, + mx15, + mx69, + mx67, + mx73, + mx32, + mx110, + mx22, + mx103, + mx30, + mx36, + mx115, + mx98, + mx105); + m.add_instruction(pass_op{}); + auto mx126 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx126, + mx115, + mx113, + mx8, + mx67, + mx61, + mx73, + mx18, + mx123, + mx119, + mx32, + mx15, + mx36, + mx110, + mx27, + mx101, + mx22, + mx98, + mx47, + mx40, + mx93, + mx38, + mx69, + mx121, + mx64, + mx30, + mx105); + auto mx128 = m.add_instruction(pass_op{}, mx90); + m.add_instruction(pass_op{}, + mx128, + mx93, + mx98, + mx8, + mx67, + mx73, + mx18, + mx123, + mx61, + mx40, + mx47, + mx27, + mx32, + mx101, + mx22, + mx15, + mx110, + mx36, + mx119, + mx38, + mx64, + mx30, + mx69, + mx121, + mx113, + mx115, + mx105); + auto mx130 = m.add_instruction(pass_op{}, mx98); + m.add_instruction(pass_op{}, + mx130, + mx119, + mx64, + mx22, + mx110, + mx126, + mx128, + mx121, + mx113, + mx67, + mx90, + mx69, + mx15, + mx20, + mx8, + mx27, + mx51, + mx85, + mx79, + mx123, + mx103, + mx18, + mx55, + mx32, + mx0, + mx45, + mx61, + mx53, + mx76, + mx6, + mx47, + mx59, + mx73, + mx81, + mx88, + mx1, + mx43, + mx101); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx134 = m.add_instruction(pass_op{}, mx126, mx130, mx128); + m.add_instruction(pass_op{}, + mx134, + mx130, + mx8, + mx67, + mx61, + mx73, + mx18, + mx123, + mx119, + mx32, + mx15, + mx36, + mx110, + mx27, + mx101, + mx22, + mx113, + mx115, + mx98, + mx47, + mx40, + mx93, + mx38, + mx69, + mx121, + mx64, + mx30, + mx105); + auto mx136 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx136, + mx130, + mx8, + mx67, + mx61, + mx73, + mx18, + mx123, + mx119, + mx32, + mx15, + mx36, + mx110, + mx27, + mx101, + mx22, + mx113, + mx115, + mx98, + mx47, + mx40, + mx93, + mx38, + mx69, + mx121, + mx64, + mx30, + mx105); + auto mx138 = m.add_instruction(pass_op{}, mx136, mx134, mx85); + m.add_instruction(pass_op{}, + mx138, + mx130, + mx8, + mx67, + mx61, + mx73, + mx18, + mx123, + mx119, + mx32, + mx15, + mx36, + mx110, + mx27, + mx101, + mx22, + mx113, + mx115, + mx98, + mx47, + mx40, + mx93, + mx38, + mx69, + mx121, + mx64, + mx30, + mx105); + m.add_instruction(pass_op{}); + auto mx141 = m.add_instruction(pass_op{}, mx123, mx138, mx57, mx123); + m.add_instruction(pass_op{}, + mx141, + mx113, + mx115, + mx130, + mx105, + mx38, + mx93, + mx61, + mx98, + mx27, + mx64, + mx30, + mx119, + mx121, + mx69, + mx8, + mx67, + mx40, + mx47, + mx32, + mx101, + mx22, + mx36, + mx110, + mx15, + mx73, + mx18, + mx123); + auto mx143 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx143, + mx8, + mx73, + mx121, + mx67, + mx101, + mx110, + mx69, + mx15, + mx138, + mx88, + mx43, + mx79, + mx53, + mx61, + mx45, + mx18, + mx0, + mx6, + mx27, + mx22, + mx134, + mx32, + mx1, + mx119, + mx59, + mx85, + mx103, + mx126, + mx64, + mx128, + mx55, + mx76, + mx47, + mx81, + mx90, + mx136, + mx51, + mx141, + mx20, + mx113, + mx123); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx147 = m.add_instruction(pass_op{}, mx143, mx69, mx110); + m.add_instruction(pass_op{}, + mx147, + mx8, + mx73, + mx121, + mx67, + mx101, + mx110, + mx69, + mx15, + mx138, + mx88, + mx43, + mx79, + mx53, + mx61, + mx45, + mx18, + mx0, + mx6, + mx27, + mx22, + mx134, + mx32, + mx1, + mx119, + mx59, + mx85, + mx103, + mx126, + mx64, + mx128, + mx55, + mx76, + mx47, + mx81, + mx90, + mx136, + mx51, + mx141, + mx20, + mx113, + mx123); + m.add_instruction(pass_op{}); + auto mx150 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx150, + mx30, + mx121, + mx115, + mx98, + mx130, + mx85, + mx88, + mx90, + mx79, + mx1, + mx93, + mx64, + mx18, + mx53, + mx61, + mx38, + mx27, + mx147, + mx0, + mx6, + mx51, + mx40, + mx134, + mx43, + mx119, + mx59, + mx45, + mx76, + mx128, + mx81, + mx136, + mx55, + mx138, + mx123, + mx126, + mx141, + mx103, + mx20, + mx105, + mx113, + mx143, + mx36); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx154 = m.add_instruction(pass_op{}, mx150, mx110, mx85); + m.add_instruction(pass_op{}, + mx154, + mx30, + mx121, + mx115, + mx98, + mx130, + mx85, + mx88, + mx90, + mx79, + mx1, + mx93, + mx64, + mx18, + mx53, + mx61, + mx38, + mx27, + mx147, + mx0, + mx6, + mx51, + mx40, + mx134, + mx43, + mx119, + mx59, + mx45, + mx76, + mx128, + mx81, + mx136, + mx55, + mx138, + mx123, + mx126, + mx141, + mx103, + mx20, + mx105, + mx113, + mx143, + mx36); + m.add_instruction(pass_op{}); + auto mx157 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx157, + mx101, + mx8, + mx115, + mx130, + mx105, + mx38, + mx147, + mx93, + mx64, + mx61, + mx98, + mx40, + mx27, + mx121, + mx30, + mx154, + mx113, + mx73, + mx119, + mx36, + mx150, + mx69, + mx67, + mx47, + mx110, + mx32, + mx22, + mx15, + mx18, + mx123, + mx143); + m.add_instruction(pass_op{}); + auto mx160 = m.add_instruction(pass_op{}, mx157, mx141); + m.add_instruction(pass_op{}, + mx160, + mx101, + mx8, + mx115, + mx130, + mx105, + mx38, + mx147, + mx93, + mx64, + mx61, + mx98, + mx40, + mx27, + mx121, + mx30, + mx154, + mx113, + mx73, + mx119, + mx36, + mx150, + mx69, + mx67, + mx47, + mx110, + mx32, + mx22, + mx15, + mx18, + mx123, + mx143); + auto mx162 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx162, + mx101, + mx8, + mx115, + mx130, + mx105, + mx38, + mx147, + mx93, + mx64, + mx61, + mx98, + mx40, + mx27, + mx121, + mx30, + mx154, + mx113, + mx73, + mx119, + mx36, + mx150, + mx69, + mx67, + mx47, + mx110, + mx32, + mx22, + mx15, + mx18, + mx123, + mx143); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx166 = m.add_instruction(pass_op{}, mx162, mx147, mx160, mx154); + m.add_instruction(pass_op{}, + mx166, + mx101, + mx8, + mx115, + mx130, + mx105, + mx38, + mx147, + mx93, + mx64, + mx61, + mx98, + mx40, + mx27, + mx121, + mx30, + mx154, + mx113, + mx73, + mx119, + mx36, + mx150, + mx69, + mx67, + mx47, + mx110, + mx32, + mx22, + mx15, + mx18, + mx123, + mx143); + m.add_instruction(pass_op{}); + auto mx169 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 15}}); + m.add_instruction(pass_op{}, + mx169, + mx154, + mx90, + mx88, + mx79, + mx126, + mx15, + mx103, + mx22, + mx134, + mx166, + mx30, + mx73, + mx20, + mx128, + mx160, + mx8, + mx45, + mx0, + mx6, + mx157, + mx53, + mx136, + mx93, + mx47, + mx81, + mx141, + mx85, + mx110, + mx59, + mx1, + mx162, + mx101, + mx36, + mx38, + mx76, + mx143, + mx67, + mx147, + mx150, + mx138, + mx115, + mx105, + mx51, + mx69, + mx40, + mx32, + mx43, + mx55, + mx130, + mx98); + auto mx171 = m.add_instruction(pass_op{}, mseq); + auto mx172 = m.add_instruction(pass_op{}, mx171); + m.add_instruction(pass_op{}); + auto mx174 = m.add_instruction(pass_op{}, mx169, mx172, mx13); + m.add_instruction(pass_op{}, + mx174, + mx154, + mx90, + mx88, + mx79, + mx126, + mx15, + mx103, + mx22, + mx134, + mx166, + mx30, + mx73, + mx20, + mx128, + mx160, + mx8, + mx45, + mx0, + mx6, + mx157, + mx53, + mx136, + mx93, + mx47, + mx81, + mx141, + mx85, + mx110, + mx59, + mx1, + mx162, + mx101, + mx36, + mx38, + mx76, + mx143, + mx67, + mx147, + mx150, + mx138, + mx115, + mx105, + mx51, + mx69, + mx40, + mx32, + mx43, + mx55, + mx130, + mx98); + m.add_instruction(pass_op{}); + auto mx177 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 10}}); + m.add_instruction(pass_op{}, + mx177, + mx101, + mx8, + mx115, + mx130, + mx105, + mx38, + mx147, + mx93, + mx64, + mx154, + mx61, + mx98, + mx40, + mx27, + mx174, + mx121, + mx30, + mx113, + mx73, + mx119, + mx36, + mx150, + mx69, + mx67, + mx47, + mx110, + mx32, + mx22, + mx169, + mx15, + mx18, + mx123, + mx143); + m.add_instruction(pass_op{}); + auto mx180 = m.add_instruction(pass_op{}, mx177, mx166, mx4); + m.add_instruction(pass_op{}, + mx180, + mx101, + mx8, + mx115, + mx130, + mx105, + mx38, + mx147, + mx93, + mx64, + mx154, + mx61, + mx98, + mx40, + mx27, + mx174, + mx121, + mx30, + mx113, + mx73, + mx119, + mx36, + mx150, + mx69, + mx67, + mx47, + mx110, + mx32, + mx22, + mx169, + mx15, + mx18, + mx123, + mx143); + m.add_instruction(pass_op{}); + auto mx183 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx183, + mx67, + mx90, + mx150, + mx138, + mx88, + mx79, + mx126, + mx15, + mx103, + mx22, + mx134, + mx180, + mx166, + mx174, + mx73, + mx20, + mx154, + mx32, + mx43, + mx55, + mx157, + mx18, + mx0, + mx113, + mx6, + mx76, + mx53, + mx61, + mx177, + mx136, + mx81, + mx141, + mx85, + mx110, + mx64, + mx45, + mx8, + mx169, + mx59, + mx1, + mx162, + mx101, + mx119, + mx51, + mx69, + mx128, + mx160, + mx27, + mx47, + mx123, + mx121); + auto mx185 = m.add_instruction(pass_op{}, mx180); + m.add_instruction(pass_op{}, + mx185, + mx101, + mx8, + mx115, + mx130, + mx105, + mx38, + mx147, + mx93, + mx64, + mx154, + mx61, + mx98, + mx40, + mx27, + mx183, + mx174, + mx121, + mx30, + mx113, + mx73, + mx119, + mx36, + mx150, + mx69, + mx67, + mx47, + mx110, + mx32, + mx22, + mx169, + mx15, + mx18, + mx123, + mx143); + auto mx187 = m.add_instruction(pass_op{}, mx174); + m.add_instruction(pass_op{}, + mx187, + mx150, + mx128, + mx67, + mx15, + mx88, + mx43, + mx79, + mx126, + mx103, + mx22, + mx90, + mx180, + mx183, + mx166, + mx141, + mx30, + mx20, + mx59, + mx55, + mx38, + mx160, + mx0, + mx32, + mx85, + mx6, + mx76, + mx157, + mx45, + mx162, + mx138, + mx154, + mx53, + mx177, + mx136, + mx51, + mx47, + mx81, + mx93, + mx73, + mx8, + mx110, + mx101, + mx69, + mx185, + mx36, + mx143, + mx147, + mx134, + mx1, + mx130, + mx115, + mx105, + mx40, + mx98); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx192 = m.add_instruction(pass_op{}, mx183, mx187, mx185); + m.add_instruction(pass_op{}, + mx192, + mx150, + mx128, + mx67, + mx187, + mx15, + mx88, + mx43, + mx79, + mx126, + mx103, + mx64, + mx22, + mx90, + mx180, + mx141, + mx20, + mx59, + mx134, + mx1, + mx55, + mx113, + mx160, + mx0, + mx32, + mx85, + mx6, + mx76, + mx157, + mx45, + mx162, + mx138, + mx154, + mx53, + mx61, + mx177, + mx174, + mx136, + mx119, + mx185, + mx51, + mx47, + mx81, + mx73, + mx8, + mx110, + mx18, + mx169, + mx101, + mx69, + mx27, + mx123, + mx166, + mx121); + m.add_instruction(pass_op{}); + auto mx195 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx195, + mx115, + mx105, + mx98, + mx123, + mx27, + mx126, + mx103, + mx64, + mx183, + mx174, + mx136, + mx177, + mx141, + mx51, + mx93, + mx113, + mx38, + mx160, + mx55, + mx30, + mx61, + mx138, + mx53, + mx76, + mx85, + mx6, + mx20, + mx59, + mx0, + mx40, + mx43, + mx88, + mx79, + mx180, + mx90, + mx187, + mx81, + mx128, + mx157, + mx45, + mx162, + mx134, + mx1, + mx130, + mx147, + mx166, + mx121, + mx18, + mx169, + mx143, + mx119, + mx36, + mx185, + mx192); + auto mx197 = m.add_instruction(pass_op{}, mx174); + m.add_instruction(pass_op{}, + mx197, + mx128, + mx150, + mx101, + mx69, + mx126, + mx103, + mx22, + mx166, + mx183, + mx136, + mx177, + mx141, + mx30, + mx73, + mx93, + mx38, + mx160, + mx55, + mx76, + mx32, + mx85, + mx6, + mx20, + mx59, + mx0, + mx43, + mx15, + mx88, + mx79, + mx180, + mx90, + mx67, + mx81, + mx138, + mx154, + mx53, + mx157, + mx45, + mx162, + mx51, + mx47, + mx195, + mx110, + mx8, + mx143, + mx147, + mx134, + mx1, + mx130, + mx115, + mx105, + mx40, + mx98, + mx36, + mx185, + mx192); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx201 = m.add_instruction(pass_op{}, mx195, mx197); + m.add_instruction(pass_op{}, + mx201, + mx115, + mx105, + mx98, + mx123, + mx27, + mx126, + mx103, + mx64, + mx183, + mx174, + mx136, + mx177, + mx141, + mx51, + mx93, + mx113, + mx38, + mx160, + mx55, + mx30, + mx61, + mx138, + mx53, + mx76, + mx85, + mx6, + mx20, + mx59, + mx0, + mx40, + mx43, + mx197, + mx88, + mx79, + mx180, + mx90, + mx187, + mx81, + mx128, + mx157, + mx45, + mx162, + mx134, + mx1, + mx130, + mx147, + mx166, + mx121, + mx18, + mx169, + mx143, + mx119, + mx36, + mx185, + mx192); + auto mx203 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx203, + mx115, + mx105, + mx98, + mx123, + mx27, + mx126, + mx103, + mx64, + mx183, + mx174, + mx136, + mx177, + mx141, + mx51, + mx93, + mx113, + mx38, + mx160, + mx55, + mx30, + mx61, + mx138, + mx53, + mx76, + mx85, + mx6, + mx20, + mx59, + mx0, + mx40, + mx43, + mx197, + mx88, + mx79, + mx180, + mx90, + mx187, + mx81, + mx128, + mx157, + mx45, + mx162, + mx134, + mx1, + mx130, + mx147, + mx166, + mx121, + mx18, + mx169, + mx143, + mx119, + mx36, + mx185, + mx192); + auto mx205 = m.add_instruction(pass_op{}, mx203, mx201); + m.add_instruction(pass_op{}, + mx205, + mx115, + mx105, + mx98, + mx123, + mx27, + mx126, + mx103, + mx64, + mx183, + mx174, + mx136, + mx177, + mx141, + mx51, + mx93, + mx113, + mx38, + mx160, + mx55, + mx30, + mx61, + mx138, + mx53, + mx76, + mx85, + mx6, + mx20, + mx59, + mx0, + mx40, + mx43, + mx197, + mx88, + mx79, + mx180, + mx90, + mx187, + mx81, + mx128, + mx157, + mx45, + mx162, + mx134, + mx1, + mx130, + mx147, + mx166, + mx121, + mx18, + mx169, + mx143, + mx119, + mx36, + mx185, + mx192); + m.add_instruction(pass_op{}); + auto mx208 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx208, + mx30, + mx40, + mx64, + mx93, + mx18, + mx98, + mx115, + mx143, + mx38, + mx147, + mx183, + mx197, + mx150, + mx119, + mx32, + mx8, + mx105, + mx101, + mx110, + mx195, + mx47, + mx27, + mx22, + mx205, + mx121, + mx67, + mx187, + mx113, + mx73, + mx201, + mx130, + mx203, + mx169, + mx69, + mx15, + mx154, + mx61, + mx174, + mx123, + mx36, + mx192); + auto mx210 = m.add_instruction(pass_op{}, mx180); + m.add_instruction(pass_op{}, + mx210, + mx143, + mx115, + mx18, + mx93, + mx150, + mx47, + mx187, + mx15, + mx169, + mx69, + mx205, + mx32, + mx119, + mx113, + mx73, + mx201, + mx30, + mx67, + mx121, + mx22, + mx27, + mx40, + mx98, + mx174, + mx61, + mx154, + mx64, + mx147, + mx38, + mx203, + mx130, + mx8, + mx110, + mx105, + mx101, + mx195, + mx183, + mx197, + mx123, + mx36, + mx192); + auto mx212 = m.add_instruction(pass_op{}, mx174); + m.add_instruction(pass_op{}, + mx212, + mx32, + mx67, + mx90, + mx15, + mx138, + mx126, + mx103, + mx38, + mx136, + mx180, + mx141, + mx51, + mx30, + mx22, + mx201, + mx59, + mx134, + mx154, + mx150, + mx1, + mx160, + mx45, + mx6, + mx76, + mx88, + mx53, + mx47, + mx183, + mx81, + mx157, + mx93, + mx79, + mx85, + mx0, + mx210, + mx73, + mx8, + mx110, + mx20, + mx69, + mx177, + mx36, + mx143, + mx162, + mx147, + mx130, + mx115, + mx55, + mx105, + mx40, + mx98, + mx208, + mx203, + mx128, + mx205, + mx195, + mx101, + mx185, + mx43, + mx166, + mx192); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx216 = m.add_instruction(pass_op{}, mx208, mx212, mx210); + m.add_instruction(pass_op{}, + mx216, + mx121, + mx30, + mx64, + mx93, + mx123, + mx143, + mx119, + mx36, + mx150, + mx8, + mx101, + mx169, + mx147, + mx110, + mx27, + mx61, + mx40, + mx205, + mx115, + mx32, + mx69, + mx67, + mx98, + mx187, + mx195, + mx73, + mx105, + mx183, + mx197, + mx22, + mx113, + mx201, + mx47, + mx130, + mx154, + mx15, + mx212, + mx18, + mx174, + mx38, + mx203, + mx192); + auto mx218 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx218, + mx121, + mx30, + mx64, + mx93, + mx123, + mx143, + mx119, + mx36, + mx150, + mx8, + mx101, + mx169, + mx147, + mx110, + mx27, + mx61, + mx40, + mx205, + mx115, + mx32, + mx69, + mx67, + mx98, + mx187, + mx195, + mx73, + mx105, + mx183, + mx197, + mx22, + mx113, + mx201, + mx47, + mx130, + mx154, + mx15, + mx212, + mx18, + mx174, + mx38, + mx203, + mx192); + auto mx220 = m.add_instruction(pass_op{}, mx218, mx216, mx166); + m.add_instruction(pass_op{}, + mx220, + mx121, + mx30, + mx64, + mx93, + mx123, + mx143, + mx119, + mx36, + mx150, + mx8, + mx101, + mx169, + mx147, + mx110, + mx27, + mx61, + mx40, + mx205, + mx115, + mx32, + mx69, + mx67, + mx98, + mx187, + mx195, + mx73, + mx105, + mx183, + mx197, + mx22, + mx113, + mx201, + mx47, + mx130, + mx154, + mx15, + mx212, + mx18, + mx174, + mx38, + mx203, + mx192); + m.add_instruction(pass_op{}); + auto mx223 = m.add_instruction(pass_op{}, mx205, mx220, mx57, mx205); + m.add_instruction(pass_op{}, + mx223, + mx38, + mx192, + mx203, + mx130, + mx47, + mx143, + mx123, + mx169, + mx121, + mx147, + mx110, + mx27, + mx36, + mx150, + mx119, + mx101, + mx8, + mx64, + mx61, + mx115, + mx32, + mx69, + mx67, + mx98, + mx187, + mx195, + mx73, + mx105, + mx183, + mx197, + mx22, + mx113, + mx201, + mx174, + mx18, + mx93, + mx205, + mx40, + mx30, + mx154, + mx15, + mx212); + auto mx225 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx225, + mx45, + mx59, + mx76, + mx90, + mx218, + mx67, + mx126, + mx103, + mx136, + mx138, + mx15, + mx32, + mx1, + mx160, + mx150, + mx110, + mx51, + mx30, + mx6, + mx157, + mx93, + mx79, + mx85, + mx88, + mx53, + mx154, + mx134, + mx141, + mx180, + mx38, + mx81, + mx223, + mx183, + mx220, + mx210, + mx0, + mx208, + mx20, + mx69, + mx73, + mx185, + mx101, + mx201, + mx22, + mx203, + mx47, + mx128, + mx205, + mx195, + mx8, + mx177, + mx36, + mx55, + mx216, + mx105, + mx115, + mx130, + mx40, + mx98, + mx43, + mx166, + mx192, + mx162, + mx147, + mx143); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx229 = m.add_instruction(pass_op{}, mx225, mx69, mx192); + m.add_instruction(pass_op{}, + mx229, + mx45, + mx59, + mx76, + mx90, + mx218, + mx67, + mx126, + mx103, + mx136, + mx138, + mx15, + mx32, + mx1, + mx160, + mx150, + mx110, + mx51, + mx30, + mx6, + mx157, + mx93, + mx79, + mx85, + mx88, + mx53, + mx154, + mx134, + mx141, + mx180, + mx38, + mx81, + mx223, + mx183, + mx220, + mx210, + mx0, + mx208, + mx20, + mx69, + mx73, + mx185, + mx101, + mx201, + mx22, + mx203, + mx47, + mx128, + mx205, + mx195, + mx8, + mx177, + mx36, + mx55, + mx216, + mx105, + mx115, + mx130, + mx40, + mx98, + mx43, + mx166, + mx192, + mx162, + mx147, + mx143); + m.add_instruction(pass_op{}); + auto mx232 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx232, + mx160, + mx154, + mx76, + mx43, + mx67, + mx55, + mx187, + mx88, + mx126, + mx197, + mx225, + mx136, + mx59, + mx64, + mx15, + mx212, + mx128, + mx32, + mx218, + mx150, + mx216, + mx110, + mx169, + mx103, + mx113, + mx141, + mx79, + mx223, + mx90, + mx6, + mx18, + mx138, + mx210, + mx85, + mx53, + mx61, + mx45, + mx134, + mx119, + mx180, + mx166, + mx20, + mx0, + mx177, + mx81, + mx208, + mx157, + mx185, + mx1, + mx69, + mx201, + mx174, + mx101, + mx51, + mx22, + mx162, + mx220, + mx203, + mx47, + mx195, + mx73, + mx27, + mx205, + mx229, + mx8, + mx123, + mx121); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx236 = m.add_instruction(pass_op{}, mx232, mx192, mx166); + m.add_instruction(pass_op{}, + mx236, + mx160, + mx154, + mx76, + mx43, + mx67, + mx55, + mx187, + mx88, + mx126, + mx197, + mx225, + mx136, + mx59, + mx64, + mx15, + mx212, + mx128, + mx32, + mx218, + mx150, + mx216, + mx110, + mx169, + mx103, + mx113, + mx141, + mx79, + mx223, + mx90, + mx6, + mx18, + mx138, + mx210, + mx85, + mx53, + mx61, + mx45, + mx134, + mx119, + mx180, + mx166, + mx20, + mx0, + mx177, + mx81, + mx208, + mx157, + mx185, + mx1, + mx69, + mx201, + mx174, + mx101, + mx51, + mx22, + mx162, + mx220, + mx203, + mx47, + mx195, + mx73, + mx27, + mx205, + mx229, + mx8, + mx123, + mx121); + m.add_instruction(pass_op{}); + auto mx239 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}, + mx239, + mx38, + mx192, + mx232, + mx203, + mx229, + mx183, + mx154, + mx201, + mx113, + mx174, + mx110, + mx197, + mx36, + mx115, + mx150, + mx98, + mx130, + mx32, + mx101, + mx169, + mx8, + mx64, + mx27, + mx225, + mx22, + mx147, + mx67, + mx205, + mx73, + mx61, + mx105, + mx18, + mx47, + mx123, + mx93, + mx195, + mx119, + mx69, + mx40, + mx187, + mx30, + mx15, + mx143, + mx236, + mx121, + mx212); + m.add_instruction(pass_op{}); + auto mx242 = m.add_instruction(pass_op{}, mx239, mx223); + m.add_instruction(pass_op{}, + mx242, + mx38, + mx192, + mx232, + mx203, + mx229, + mx183, + mx154, + mx201, + mx113, + mx174, + mx110, + mx197, + mx36, + mx115, + mx150, + mx98, + mx130, + mx32, + mx101, + mx169, + mx8, + mx64, + mx27, + mx225, + mx22, + mx147, + mx67, + mx205, + mx73, + mx61, + mx105, + mx18, + mx47, + mx123, + mx93, + mx195, + mx119, + mx69, + mx40, + mx187, + mx30, + mx15, + mx143, + mx236, + mx121, + mx212); + auto mx244 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}}); + m.add_instruction(pass_op{}); + m.add_instruction(pass_op{}); + auto mx247 = m.add_instruction(pass_op{}, mx244, mx229, mx242, mx236); + auto moutput = + m.add_parameter("output", migraphx::shape{migraphx::shape::float_type, {3, 1, 2, 5}}); + auto mx248 = m.add_instruction(pass_op{}, mx247); + auto mx249 = m.add_instruction(pass_op{}, mx166); + auto mx250 = m.add_instruction(pass_op{}, mx85); + m.add_instruction(pass_op{}, moutput, mx250, mx249, mx248); + + run_pass(m); + CHECK(m.get_parameter_shape("scratch").bytes() == 1600); + CHECK(no_allocate(m)); + CHECK(is_disjoint({mx0, mx8})); + CHECK(is_disjoint({mx0, mx8})); + CHECK(is_disjoint({mx0, mx18, mx8})); + CHECK(is_disjoint({mx0, mx18, mx8})); + CHECK(is_disjoint({mx0, mx18, mx8})); + CHECK(is_disjoint({mx0, mx18, mx8})); + CHECK(is_disjoint({mx0, mx18, mx8})); + CHECK(is_disjoint({mx0, mx18, mx30, mx8})); + CHECK(is_disjoint({mx0, mx18, mx30, mx8})); + CHECK(is_disjoint({mx30, mx8})); + CHECK(is_disjoint({mx0, mx18, mx30, mx8})); + CHECK(is_disjoint({mx0, mx18, mx38, mx8})); + CHECK(is_disjoint({mx30, mx38})); + CHECK(is_disjoint({mx0, mx18, mx38, mx8})); + CHECK(is_disjoint({mx18, mx30, mx38, mx43, mx8})); + CHECK(is_disjoint({mx0, mx18, mx30, mx38, mx8})); + CHECK(is_disjoint({mx0, mx18, mx30, mx38, mx43, mx8})); + CHECK(is_disjoint({mx0, mx43, mx8})); + CHECK(is_disjoint({mx18, mx30, mx38, mx43, mx8})); + CHECK(is_disjoint({mx18, mx30, mx38, mx53, mx8})); + CHECK(is_disjoint({mx43, mx53})); + CHECK(is_disjoint({mx18, mx30, mx38, mx53, mx8})); + CHECK(is_disjoint({mx38, mx53})); + CHECK(is_disjoint({mx18, mx30, mx38, mx8})); + CHECK(is_disjoint({mx0, mx30, mx38, mx43, mx53, mx61, mx8})); + CHECK(is_disjoint({mx18, mx61})); + CHECK(is_disjoint({mx0, mx30, mx38, mx43, mx53, mx61, mx8})); + CHECK(is_disjoint({mx0, mx18, mx30, mx38, mx43, mx53, mx61, mx67})); + CHECK(is_disjoint({mx0, mx18, mx30, mx38, mx43, mx53, mx61})); + CHECK(is_disjoint({mx18, mx67})); + CHECK(is_disjoint({mx0, mx18, mx30, mx38, mx43, mx53, mx61, mx67})); + CHECK(is_disjoint({mx18, mx30, mx38, mx61, mx67, mx76, mx8})); + CHECK(is_disjoint({mx38, mx76})); + CHECK(is_disjoint({mx18, mx30, mx38, mx61, mx67, mx76, mx8})); + CHECK(is_disjoint({mx18, mx30, mx38, mx61, mx67, mx8, mx81})); + CHECK(is_disjoint({mx61, mx67, mx76, mx81})); + CHECK(is_disjoint({mx18, mx30, mx38, mx61, mx67, mx8, mx81})); + CHECK(is_disjoint({mx18, mx30, mx38, mx61, mx67, mx8, mx88})); + CHECK(is_disjoint({mx81, mx88})); + CHECK(is_disjoint({mx18, mx30, mx38, mx61, mx67, mx8, mx88})); + CHECK(is_disjoint({mx0, mx18, mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88, mx93})); + CHECK(is_disjoint({mx0, mx18, mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88, mx93})); + CHECK(is_disjoint({mx0, mx101, mx18, mx30, mx38, mx43, mx53, mx61, mx76, mx81, mx88, mx93})); + CHECK(is_disjoint({mx101, mx18, mx30, mx38, mx61, mx67, mx8, mx88, mx93})); + CHECK( + is_disjoint({mx0, mx101, mx18, mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88, mx93})); + CHECK(is_disjoint({mx101, mx88, mx93})); + CHECK(is_disjoint({mx0, mx101, mx18, mx30, mx38, mx43, mx53, mx61, mx76, mx81, mx88, mx93})); + CHECK(is_disjoint( + {mx0, mx101, mx113, mx30, mx38, mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93})); + CHECK(is_disjoint( + {mx0, mx101, mx113, mx18, mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88, mx93})); + CHECK(is_disjoint({mx113, mx93})); + CHECK(is_disjoint( + {mx0, mx101, mx113, mx30, mx38, mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93})); + CHECK(is_disjoint( + {mx0, mx101, mx121, mx30, mx38, mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93})); + CHECK(is_disjoint({mx113, mx121})); + CHECK(is_disjoint( + {mx0, mx101, mx121, mx30, mx38, mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93})); + CHECK(is_disjoint({mx101, mx113, mx121, mx126, mx18, mx30, mx38, mx61, mx67, mx8, mx93})); + CHECK(is_disjoint({mx101, mx113, mx121, mx18, mx30, mx38, mx61, mx67, mx8, mx88, mx93})); + CHECK(is_disjoint({mx0, + mx101, + mx113, + mx121, + mx126, + mx18, + mx38, + mx43, + mx53, + mx61, + mx67, + mx76, + mx8, + mx81, + mx88, + mx93})); + CHECK(is_disjoint({mx126, mx88, mx93})); + CHECK(is_disjoint({mx101, mx113, mx121, mx126, mx18, mx30, mx38, mx61, mx67, mx8, mx93})); + CHECK(is_disjoint({mx101, mx113, mx121, mx136, mx18, mx30, mx38, mx61, mx67, mx8, mx93})); + CHECK(is_disjoint({mx126, mx136, mx81})); + CHECK(is_disjoint({mx101, mx113, mx121, mx136, mx18, mx30, mx38, mx61, mx67, mx8, mx93})); + CHECK(is_disjoint({mx121, mx136})); + CHECK(is_disjoint({mx101, mx113, mx121, mx18, mx30, mx38, mx61, mx67, mx8, mx93})); + CHECK(is_disjoint({mx0, + mx101, + mx113, + mx121, + mx126, + mx136, + mx143, + mx18, + mx38, + mx43, + mx53, + mx61, + mx67, + mx76, + mx8, + mx81, + mx88})); + CHECK(is_disjoint({mx101, mx143})); + CHECK(is_disjoint({mx0, + mx101, + mx113, + mx121, + mx126, + mx136, + mx143, + mx18, + mx38, + mx43, + mx53, + mx61, + mx67, + mx76, + mx8, + mx81, + mx88})); + CHECK(is_disjoint({mx0, + mx113, + mx121, + mx126, + mx136, + mx143, + mx150, + mx18, + mx30, + mx38, + mx43, + mx53, + mx61, + mx76, + mx81, + mx88, + mx93})); + CHECK(is_disjoint({mx101, mx150, mx81})); + CHECK(is_disjoint({mx0, + mx113, + mx121, + mx126, + mx136, + mx143, + mx150, + mx18, + mx30, + mx38, + mx43, + mx53, + mx61, + mx76, + mx81, + mx88, + mx93})); + CHECK(is_disjoint( + {mx101, mx113, mx121, mx143, mx150, mx157, mx18, mx30, mx38, mx61, mx67, mx8, mx93})); + CHECK(is_disjoint({mx121, mx157})); + CHECK(is_disjoint( + {mx101, mx113, mx121, mx143, mx150, mx157, mx18, mx30, mx38, mx61, mx67, mx8, mx93})); + CHECK(is_disjoint( + {mx101, mx113, mx121, mx143, mx150, mx162, mx18, mx30, mx38, mx61, mx67, mx8, mx93})); + CHECK(is_disjoint({mx143, mx150, mx157, mx162})); + CHECK(is_disjoint( + {mx101, mx113, mx121, mx143, mx150, mx162, mx18, mx30, mx38, mx61, mx67, mx8, mx93})); + CHECK(is_disjoint({mx0, mx101, mx121, mx126, mx136, mx143, mx150, mx157, mx162, mx169, + mx30, mx38, mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93})); + CHECK(is_disjoint({mx0, mx101, mx121, mx126, mx136, mx143, mx150, mx157, mx162, mx169, + mx30, mx38, mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93})); + CHECK(is_disjoint({mx101, + mx113, + mx121, + mx143, + mx150, + mx169, + mx177, + mx18, + mx30, + mx38, + mx61, + mx67, + mx8, + mx93})); + CHECK(is_disjoint({mx162, mx177})); + CHECK(is_disjoint({mx101, + mx113, + mx121, + mx143, + mx150, + mx169, + mx177, + mx18, + mx30, + mx38, + mx61, + mx67, + mx8, + mx93})); + CHECK(is_disjoint({mx0, mx101, mx113, mx121, mx126, mx136, mx150, mx157, mx162, mx169, mx177, + mx18, mx183, mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88})); + CHECK(is_disjoint({mx101, + mx113, + mx121, + mx143, + mx150, + mx169, + mx177, + mx18, + mx183, + mx30, + mx38, + mx61, + mx67, + mx8, + mx93})); + CHECK( + is_disjoint({mx0, mx101, mx121, mx126, mx136, mx143, mx150, mx157, mx162, mx169, mx177, + mx183, mx30, mx38, mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93})); + CHECK(is_disjoint({mx169, mx177, mx183})); + CHECK(is_disjoint({mx0, mx101, mx113, mx121, mx126, mx136, mx150, mx157, mx162, mx169, mx177, + mx18, mx183, mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88})); + CHECK( + is_disjoint({mx0, mx113, mx121, mx126, mx136, mx143, mx157, mx162, mx169, mx177, mx18, + mx183, mx195, mx30, mx38, mx43, mx53, mx61, mx76, mx81, mx88, mx93})); + CHECK(is_disjoint({mx0, mx101, mx121, mx126, mx136, mx143, mx150, mx157, + mx162, mx169, mx177, mx183, mx195, mx30, mx38, mx43, + mx53, mx67, mx76, mx8, mx81, mx88, mx93})); + CHECK(is_disjoint({mx169, mx195})); + CHECK( + is_disjoint({mx0, mx113, mx121, mx126, mx136, mx143, mx157, mx162, mx169, mx177, mx18, + mx183, mx195, mx30, mx38, mx43, mx53, mx61, mx76, mx81, mx88, mx93})); + CHECK( + is_disjoint({mx0, mx113, mx121, mx126, mx136, mx143, mx157, mx162, mx169, mx177, mx18, + mx183, mx203, mx30, mx38, mx43, mx53, mx61, mx76, mx81, mx88, mx93})); + CHECK(is_disjoint({mx195, mx203})); + CHECK( + is_disjoint({mx0, mx113, mx121, mx126, mx136, mx143, mx157, mx162, mx169, mx177, mx18, + mx183, mx203, mx30, mx38, mx43, mx53, mx61, mx76, mx81, mx88, mx93})); + CHECK(is_disjoint({mx101, + mx113, + mx121, + mx143, + mx150, + mx169, + mx18, + mx183, + mx195, + mx203, + mx208, + mx30, + mx38, + mx61, + mx67, + mx8, + mx93})); + CHECK(is_disjoint({mx101, + mx113, + mx121, + mx143, + mx150, + mx169, + mx177, + mx18, + mx183, + mx195, + mx203, + mx30, + mx38, + mx61, + mx67, + mx8, + mx93})); + CHECK(is_disjoint({mx0, mx101, mx121, mx126, mx136, mx143, mx150, mx157, mx162, + mx169, mx177, mx183, mx195, mx203, mx208, mx30, mx38, mx43, + mx53, mx67, mx76, mx8, mx81, mx88, mx93})); + CHECK(is_disjoint({mx169, mx177, mx208})); + CHECK(is_disjoint({mx101, + mx113, + mx121, + mx143, + mx150, + mx169, + mx18, + mx183, + mx195, + mx203, + mx208, + mx30, + mx38, + mx61, + mx67, + mx8, + mx93})); + CHECK(is_disjoint({mx101, + mx113, + mx121, + mx143, + mx150, + mx169, + mx18, + mx183, + mx195, + mx203, + mx218, + mx30, + mx38, + mx61, + mx67, + mx8, + mx93})); + CHECK(is_disjoint({mx162, mx208, mx218})); + CHECK(is_disjoint({mx101, + mx113, + mx121, + mx143, + mx150, + mx169, + mx18, + mx183, + mx195, + mx203, + mx218, + mx30, + mx38, + mx61, + mx67, + mx8, + mx93})); + CHECK(is_disjoint({mx203, mx218})); + CHECK(is_disjoint({mx101, + mx113, + mx121, + mx143, + mx150, + mx169, + mx18, + mx183, + mx195, + mx203, + mx30, + mx38, + mx61, + mx67, + mx8, + mx93})); + CHECK(is_disjoint({mx0, mx101, mx121, mx126, mx136, mx143, mx150, mx157, mx162, + mx177, mx183, mx195, mx203, mx208, mx218, mx225, mx30, mx38, + mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93})); + CHECK(is_disjoint({mx183, mx225})); + CHECK(is_disjoint({mx0, mx101, mx121, mx126, mx136, mx143, mx150, mx157, mx162, + mx177, mx183, mx195, mx203, mx208, mx218, mx225, mx30, mx38, + mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93})); + CHECK(is_disjoint({mx0, mx101, mx113, mx121, mx126, mx136, mx150, mx157, mx162, + mx169, mx177, mx18, mx195, mx203, mx208, mx218, mx225, mx232, + mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88})); + CHECK(is_disjoint({mx162, mx183, mx232})); + CHECK(is_disjoint({mx0, mx101, mx113, mx121, mx126, mx136, mx150, mx157, mx162, + mx169, mx177, mx18, mx195, mx203, mx208, mx218, mx225, mx232, + mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88})); + CHECK(is_disjoint({mx101, + mx113, + mx121, + mx143, + mx150, + mx169, + mx18, + mx183, + mx195, + mx203, + mx225, + mx232, + mx239, + mx30, + mx38, + mx61, + mx67, + mx8, + mx93})); + CHECK(is_disjoint({mx203, mx239})); + CHECK(is_disjoint({mx101, + mx113, + mx121, + mx143, + mx150, + mx169, + mx18, + mx183, + mx195, + mx203, + mx225, + mx232, + mx239, + mx30, + mx38, + mx61, + mx67, + mx8, + mx93})); + CHECK(is_disjoint({mx225, mx232, mx239, mx244})); + CHECK(is_disjoint({mx162, mx244, mx81})); } TEST_CASE(literal_test) { migraphx::program p; + auto* mm = p.get_main_module(); + auto lit = generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - p.add_literal(lit); - run_pass(p); - auto result = p.eval({}); + mm->add_literal(lit); + run_pass(*mm); + auto result = p.eval({}).back(); CHECK(lit == result); } diff --git a/test/module_test.cpp b/test/module_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a3f687522cb9645c100aa24d3f6c6d773947104b --- /dev/null +++ b/test/module_test.cpp @@ -0,0 +1,329 @@ +#include +#include +#include +#include +#include +#include +#include +#include "test.hpp" +#include + +#include + +migraphx::program create_program() +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", {migraphx::shape::int64_type}); + auto y = mm->add_parameter("y", {migraphx::shape::int64_type}); + + auto sum = mm->add_instruction(sum_op{}, x, y); + auto one = mm->add_literal(1); + mm->add_instruction(sum_op{}, sum, one); + + return p; +} + +TEST_CASE(calc_implict_deps) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + migraphx::shape xs{migraphx::shape::float_type, {2, 3}}; + migraphx::shape ys{migraphx::shape::float_type, {3, 3}}; + std::vector datax = {1, 2, 3, 4, 5, 6}; + std::vector datay = {8, 7, 6, 5, 4, 3, 2, 1, 0}; + + auto lx = mm->add_literal(migraphx::literal(xs, datax)); + auto ly = mm->add_literal(migraphx::literal(ys, datay)); + auto cond = mm->add_parameter("cond", cond_s); + auto x1 = mm->add_parameter("x1", xs); + auto x2 = mm->add_parameter("x2", xs); + auto y2 = mm->add_parameter("y2", ys); + + auto* then_mod = p.create_module("If_5_if"); + auto l1 = then_mod->add_literal(migraphx::literal(ys, datay)); + auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx); + then_mod->add_return({a1, l1}); + + auto* then_mod1 = p.create_module("If_6_if"); + auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay)); + auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx); + then_mod1->add_return({a11, l11}); + + auto* else_mod1 = p.create_module("If_6_else"); + auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax)); + auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly); + else_mod1->add_return({l21, a21}); + + auto* else_mod = p.create_module("If_5_else"); + auto l2 = else_mod->add_literal(migraphx::literal(ys, datay)); + auto a2 = else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1}); + auto a3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), a2); + else_mod->add_return({a3, l2}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + mm->add_return({r}); + + auto implicit_deps = mm->calc_implicit_deps(); + EXPECT(migraphx::contains(implicit_deps, ret)); + EXPECT(migraphx::contains(implicit_deps.at(ret), x1)); + EXPECT(migraphx::contains(implicit_deps.at(ret), x2)); + EXPECT(migraphx::contains(implicit_deps.at(ret), y2)); +} + +TEST_CASE(module_annotate) +{ + migraphx::program p1 = create_program(); + migraphx::program p2 = create_program(); + + auto* mm1 = p1.get_main_module(); + auto* mm2 = p2.get_main_module(); + EXPECT(*mm1 == *mm2); + + std::stringstream ss1; + mm1->annotate(ss1, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; }); + + std::stringstream ss2; + mm2->annotate(ss2, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; }); + + EXPECT(ss1.str() == ss2.str()); +} + +TEST_CASE(module_ins_clear) +{ + migraphx::program p1 = create_program(); + migraphx::program p2; + + p2 = p1; + + EXPECT(p1 == p2); +} + +TEST_CASE(module_name) +{ + migraphx::module m1("name"); + EXPECT(m1.name() == "name"); + + auto m2 = m1; // NOLINT + EXPECT(m2.name() == "name"); + migraphx::module m3; + m3 = m1; + EXPECT(m3.name() == "name"); +} + +TEST_CASE(module_name_main) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + EXPECT(mm->name() == "main"); +} + +TEST_CASE(module_print_cpp) +{ + migraphx::program p1 = create_program(); + migraphx::program p2 = create_program(); + + auto* mm1 = p1.get_main_module(); + auto* mm2 = p2.get_main_module(); + + std::stringstream ss1; + mm1->print_cpp(ss1); + + std::stringstream ss2; + mm2->print_cpp(ss2); + + EXPECT(ss1.str() == ss2.str()); +} + +TEST_CASE(module_print_graph) +{ + migraphx::program p1 = create_program(); + migraphx::program p2 = create_program(); + + auto* mm1 = p1.get_main_module(); + auto* mm2 = p2.get_main_module(); + + std::stringstream ss1; + mm1->print_graph(ss1, true); + + std::stringstream ss2; + mm2->print_graph(ss2, true); + + EXPECT(ss1.str() == ss2.str()); +} + +TEST_CASE(program_module_assign) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sd{migraphx::shape::float_type, {2, 3}}; + auto x = mm->add_parameter("x", sd); + + std::vector one(sd.elements(), 1); + std::vector two(sd.elements(), 2); + + auto* then_smod = p.create_module("then_smod"); + auto l1 = then_smod->add_literal(migraphx::literal{sd, one}); + auto r1 = then_smod->add_instruction(migraphx::make_op("add"), x, l1); + then_smod->add_return({r1}); + + auto* else_smod = p.create_module("else_smod"); + auto l2 = else_smod->add_literal(migraphx::literal{sd, two}); + auto r2 = else_smod->add_instruction(migraphx::make_op("mul"), x, l2); + else_smod->add_return({r2}); + + migraphx::shape s_cond{migraphx::shape::bool_type, {1}}; + auto cond = mm->add_parameter("cond", s_cond); + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_smod, else_smod}); + mm->add_return({ret}); + + migraphx::program p1 = p; + + EXPECT(p == p1); +} + +TEST_CASE(program_module_replace) +{ + auto create_program = [](bool use_if) { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sd{migraphx::shape::float_type, {2, 3}}; + auto x = mm->add_parameter("x", sd); + + std::vector one(sd.elements(), 1); + std::vector two(sd.elements(), 2); + + auto* then_smod = p.create_module("then_smod"); + auto l1 = then_smod->add_literal(migraphx::literal{sd, one}); + auto r1 = then_smod->add_instruction(migraphx::make_op("add"), x, l1); + then_smod->add_return({r1}); + + auto* else_smod = p.create_module("else_smod"); + auto l2 = else_smod->add_literal(migraphx::literal{sd, two}); + auto r2 = else_smod->add_instruction(migraphx::make_op("mul"), x, l2); + else_smod->add_return({r2}); + + migraphx::shape s_cond{migraphx::shape::bool_type, {1}}; + auto cond = mm->add_parameter("cond", s_cond); + + migraphx::instruction_ref ret{}; + + if(use_if) + { + ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_smod, else_smod}); + } + else + { + ret = mm->add_instruction(mod_pass_op{}, {cond}, {then_smod, else_smod}); + } + + mm->add_return({ret}); + + return p; + }; + + migraphx::program p1 = create_program(false); + migraphx::program p2 = create_program(true); + EXPECT(p1 != p2); + + auto* m1 = p1.get_main_module(); + auto ins_pass = std::prev(std::prev(m1->end())); + const auto& inputs = ins_pass->inputs(); + const auto& mod_inputs = ins_pass->module_inputs(); + m1->replace_instruction(ins_pass, migraphx::make_op("if"), inputs, mod_inputs); + + EXPECT(p1 == p2); +} + +TEST_CASE(submodule_copy) +{ + migraphx::module mm("main"); + auto x = mm.add_parameter("x", {migraphx::shape::int64_type}); + + migraphx::module sm("sub"); + sm.add_instruction(migraphx::make_op("sin"), x); + + mm.add_instruction(migraphx::make_op("if"), {x}, {&sm, &sm}); + + auto mm2 = mm; + + EXPECT(mm == mm2); + EXPECT(mm.get_sub_modules() == mm2.get_sub_modules()); +} + +TEST_CASE(parameter_name_order) +{ + migraphx::shape s{migraphx::shape::int32_type, {1}}; + migraphx::module mm("main"); + auto x1 = mm.add_parameter("x1", s); + auto x2 = mm.add_parameter("x2", s); + auto x3 = mm.add_parameter("x3", s); + auto x4 = mm.add_parameter("x4", s); + + std::vector param_names = {"x1", "x2", "x3", "x4"}; + auto sum1 = mm.add_instruction(migraphx::make_op("add"), x1, x2); + auto sum2 = mm.add_instruction(migraphx::make_op("add"), x3, x4); + auto r = mm.add_instruction(migraphx::make_op("mul"), sum1, sum2); + mm.add_return({r}); + + auto names = mm.get_parameter_names(); + EXPECT(param_names == names); + + auto m1 = mm; + auto names1 = m1.get_parameter_names(); + EXPECT(param_names == names1); +} + +struct check_for_pass_op +{ + bool* found = nullptr; + std::string name() const { return "check_for_pass_op"; } + void apply(migraphx::module& m) const + { + *found |= std::any_of(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "pass"; }); + } +}; + +TEST_CASE(module_bypass) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto* sub = p.create_module("sub"); + sub->set_bypass(); + sub->add_instruction(pass_op{}); + mm->add_instruction(mod_pass_op{}, {}, {sub}); + bool found = false; + migraphx::run_passes(p, {check_for_pass_op{&found}}); + EXPECT(not found); +} + +TEST_CASE(module_without_bypass) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto* sub = p.create_module("sub"); + sub->add_instruction(pass_op{}); + mm->add_instruction(mod_pass_op{}, {}, {sub}); + bool found = false; + migraphx::run_passes(p, {check_for_pass_op{&found}}); + EXPECT(found); +} + +TEST_CASE(multiple_module_dependency) +{ + // Test when an instruction from a submodule depends on previous module + migraphx::program p; + auto* mm = p.get_main_module(); + auto* sub = p.create_module("sub"); + auto l1 = mm->add_literal(migraphx::literal(3)); + // second same literal to make sure instruction_ref is being compared, rather than the + // instructions + sub->add_literal(migraphx::literal(3)); + sub->add_instruction(sum_op{}, l1, l1); + EXPECT((sub->validate() == sub->end())); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/msgpack.cpp b/test/msgpack.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1caf7fc70d1991884c2e11222806d848aa533909 --- /dev/null +++ b/test/msgpack.cpp @@ -0,0 +1,127 @@ +#include +#include +#include +#include +#include "test.hpp" + +template +std::vector msgpack_buffer(const T& src) +{ + std::stringstream buffer; + msgpack::pack(buffer, src); + buffer.seekg(0); + std::string str = buffer.str(); + return std::vector(str.data(), str.data() + str.size()); // NOLINT +} + +TEST_CASE(test_msgpack_empty_value) +{ + migraphx::value v; + auto buffer = migraphx::to_msgpack(v); + auto mp = migraphx::from_msgpack(buffer); + EXPECT(mp == v); + EXPECT(v.is_null()); + EXPECT(mp.is_null()); +} + +TEST_CASE(test_msgpack_int) +{ + migraphx::value v = 3; + auto buffer = migraphx::to_msgpack(v); + EXPECT(buffer == msgpack_buffer(3)); + EXPECT(migraphx::from_msgpack(buffer).to() == v.to()); +} + +TEST_CASE(test_msgpack_int_negative) +{ + migraphx::value v = -3; + auto buffer = migraphx::to_msgpack(v); + EXPECT(buffer == msgpack_buffer(-3)); + EXPECT(migraphx::from_msgpack(buffer).to() == v.to()); +} + +TEST_CASE(test_msgpack_bool) +{ + migraphx::value v = true; + auto buffer = migraphx::to_msgpack(v); + EXPECT(buffer == msgpack_buffer(true)); + EXPECT(migraphx::from_msgpack(buffer) == v); +} + +TEST_CASE(test_msgpack_float) +{ + migraphx::value v = 3.0; + auto buffer = migraphx::to_msgpack(v); + EXPECT(buffer == msgpack_buffer(3.0)); + EXPECT(migraphx::from_msgpack(buffer) == v); +} + +TEST_CASE(test_msgpack_string) +{ + migraphx::value v = "abc"; + auto buffer = migraphx::to_msgpack(v); + EXPECT(buffer == msgpack_buffer("abc")); + EXPECT(migraphx::from_msgpack(buffer) == v); +} + +TEST_CASE(test_msgpack_array) +{ + migraphx::value v = {1, 2, 3}; + auto buffer = migraphx::to_msgpack(v); + EXPECT(buffer == msgpack_buffer(std::vector{1, 2, 3})); + EXPECT(migraphx::from_msgpack(buffer).to_vector() == v.to_vector()); +} + +TEST_CASE(test_msgpack_empty_array) +{ + migraphx::value v = migraphx::value::array{}; + auto buffer = migraphx::to_msgpack(v); + EXPECT(buffer == msgpack_buffer(std::vector{})); + EXPECT(migraphx::from_msgpack(buffer) == v); +} + +TEST_CASE(test_msgpack_object) +{ + migraphx::value v = {{"one", 1.0}, {"three", 3.0}, {"two", 2.0}}; + auto buffer = migraphx::to_msgpack(v); + EXPECT(buffer == msgpack_buffer(std::map{ + {"one", 1.0}, {"three", 3.0}, {"two", 2.0}})); + EXPECT(migraphx::from_msgpack(buffer) == v); +} + +TEST_CASE(test_msgpack_empty_object) +{ + migraphx::value v = migraphx::value::object{}; + auto buffer = migraphx::to_msgpack(v); + EXPECT(buffer == msgpack_buffer(std::vector{})); + auto u = migraphx::from_msgpack(buffer); + // This is not equal since an empty object becomes an empty array + EXPECT(u != v); + EXPECT(u.is_array()); + EXPECT(u.size() == 0); +} + +struct foo +{ + double a; + std::string b; + MSGPACK_DEFINE_MAP(a, b); +}; + +TEST_CASE(test_msgpack_object_class) +{ + migraphx::value v = {{"a", 1.0}, {"b", "abc"}}; + auto buffer = migraphx::to_msgpack(v); + EXPECT(buffer == msgpack_buffer(foo{1.0, "abc"})); + EXPECT(migraphx::from_msgpack(buffer) == v); +} + +TEST_CASE(test_msgpack_array_class) +{ + migraphx::value v = {{{"a", 1.0}, {"b", "abc"}}, {{"a", 3.0}, {"b", "xyz"}}}; + auto buffer = migraphx::to_msgpack(v); + EXPECT(buffer == msgpack_buffer(std::vector{foo{1.0, "abc"}, foo{3.0, "xyz"}})); + EXPECT(migraphx::from_msgpack(buffer) == v); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/normalize_ops_test.cpp b/test/normalize_ops_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..00bc83ec3ddfeef21d2d481b0254b7222a1b1015 --- /dev/null +++ b/test/normalize_ops_test.cpp @@ -0,0 +1,182 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +struct normalize_test_op +{ + std::vector axes = {}; + + template + static auto reflect(Self& self, F f) + { + return migraphx::pack(f(self.axes, "axes")); + } + + migraphx::value attributes() const + { + migraphx::value normalize; + normalize["axes"] = migraphx::value::array{migraphx::op::normalize_attribute::clip_max, + migraphx::op::normalize_attribute::clip_min}; + return {{"normalize_axes", normalize}}; + } + + std::string name() const { return "normalize_ops_test::test_op"; } + migraphx::shape normalize_compute_shape(std::vector inputs) const + { + return inputs[0]; + } + migraphx::argument compute(migraphx::context&, + const migraphx::shape& output_shape, + const std::vector&) const + { + return {output_shape}; + } +}; + +void run_pass(migraphx::module& m) +{ + migraphx::run_passes(m, {migraphx::normalize_ops{}, migraphx::dead_code_elimination{}}); +} + +migraphx::module create_gather(int64_t axis) +{ + migraphx::module m; + migraphx::shape sd{migraphx::shape::float_type, {2, 3, 4}}; + migraphx::shape si{migraphx::shape::int64_type, {2, 3}}; + auto di = m.add_parameter("data", sd); + auto ii = m.add_parameter("ind", si); + auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", axis}}), di, ii); + m.add_return({r}); + + return m; +} + +TEST_CASE(gather_test) +{ + + auto m1 = create_gather(-3); + auto m2 = create_gather(0); + run_pass(m1); + + EXPECT(m1 == m2); +} + +TEST_CASE(gather_test_1) +{ + auto m1 = create_gather(1); + auto m2 = create_gather(1); + run_pass(m1); + + EXPECT(m1 == m2); +} + +migraphx::module create_padded_op(const std::vector& pad_vals) +{ + migraphx::module m; + migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}}; + auto si = m.add_parameter("data", s); + auto r = m.add_instruction(migraphx::make_op("pooling", {{"padding", pad_vals}}), si); + m.add_return({r}); + + return m; +} + +TEST_CASE(padding_attr_test) +{ + migraphx::module m1 = create_padded_op({0, 1}); + migraphx::module m2 = create_padded_op({0, 1, 0, 1}); + run_pass(m1); + + EXPECT(m1 == m2); +} + +migraphx::module create_reduce_mean(const std::vector& axes) +{ + migraphx::module m; + migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}}; + auto si = m.add_parameter("data", s); + auto r = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), si); + m.add_return({r}); + + return m; +} + +TEST_CASE(reduce_mean_test) +{ + migraphx::module m1 = create_reduce_mean({0, 1, -1}); + migraphx::module m2 = create_reduce_mean({0, 1, 3}); + run_pass(m1); + + EXPECT(m1 == m2); +} + +TEST_CASE(reduce_mean_test_1) +{ + migraphx::module m1 = create_reduce_mean({0, 1, 2}); + migraphx::module m2 = create_reduce_mean({0, 1, 2}); + run_pass(m1); + + EXPECT(m1 == m2); +} + +migraphx::module create_slice(const std::vector& axes, + const std::vector& starts, + const std::vector& ends) +{ + migraphx::module m; + migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}}; + auto si = m.add_parameter("data", s); + auto r = m.add_instruction( + migraphx::make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), si); + m.add_return({r}); + + return m; +} + +TEST_CASE(slice_test) +{ + migraphx::module m1 = create_slice({0, 1, -1}, {-5, 1, -3}, {2, 2, 8}); + migraphx::module m2 = create_slice({0, 1, 3}, {0, 1, 2}, {2, 2, 5}); + run_pass(m1); + + EXPECT(m1 == m2); +} + +TEST_CASE(slice_test_1) +{ + migraphx::module m1 = create_slice({0, 1, 3}, {0, 1, -3}, {1, 2, 5}); + migraphx::module m2 = create_slice({0, 1, 3}, {0, 1, 2}, {1, 2, 5}); + run_pass(m1); + + EXPECT(m1 == m2); +} + +migraphx::module create_test_op(const std::vector& axes) +{ + migraphx::module m; + migraphx::shape sd{migraphx::shape::float_type, {2, 3, 4}}; + auto di = m.add_parameter("data", sd); + auto r = m.add_instruction(normalize_test_op{axes}, di); + m.add_return({r}); + + return m; +} + +TEST_CASE(test_op) +{ + std::vector axes1 = {-4, 5}; + auto m1 = create_test_op(axes1); + + std::vector axes2 = {1, 2}; + auto m2 = create_test_op(axes2); + + run_pass(m1); + EXPECT(m1 == m2); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/onnx/acosh_test.onnx b/test/onnx/acosh_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8eec14e3044354ba3895960f542cddbd18340507 --- /dev/null +++ b/test/onnx/acosh_test.onnx @@ -0,0 +1,15 @@ + +acosh_test:= + +xy"Acosh +acosh_testZ +x + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/add_scalar_test.onnx b/test/onnx/add_scalar_test.onnx index 43cd0686474d3fb6c1aaded98e0f22ca43dbe4be..64d007147c274f21da5752b0e5e29072e3ab3f39 100644 Binary files a/test/onnx/add_scalar_test.onnx and b/test/onnx/add_scalar_test.onnx differ diff --git a/test/onnx/asinh_test.onnx b/test/onnx/asinh_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..008c0158be6d62923513e8a2e53946d30375a191 --- /dev/null +++ b/test/onnx/asinh_test.onnx @@ -0,0 +1,15 @@ + +asinh_test:= + +xy"Asinh +asinh_testZ +x + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/atanh_test.onnx b/test/onnx/atanh_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..952bba76fb3dab2aba260e90ea6518906347d44e --- /dev/null +++ b/test/onnx/atanh_test.onnx @@ -0,0 +1,15 @@ + +atanh_test:= + +xy"Atanh +atanh_testZ +x + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/averagepool_1d_test.onnx b/test/onnx/averagepool_1d_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa020e7e6aa96df8885b0b07a8454c8d3c2ff851 --- /dev/null +++ b/test/onnx/averagepool_1d_test.onnx @@ -0,0 +1,14 @@ +averagepool_1d_test:q +( +01" AveragePool* + kernel_shape@ averagepool_1d_testZ +0 + + + +b +1 + + + +B \ No newline at end of file diff --git a/test/onnx/averagepool_3d_test.onnx b/test/onnx/averagepool_3d_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..453c74f0ae4cc187d7598117caaddd2338a19c6a --- /dev/null +++ b/test/onnx/averagepool_3d_test.onnx @@ -0,0 +1,18 @@ +averagepool_3d_test:… +, +01" AveragePool* + kernel_shape@@@ averagepool_3d_testZ +0 + + + + + +b +1 + + + + + +B \ No newline at end of file diff --git a/test/onnx/averagepool_notset_test.onnx b/test/onnx/averagepool_notset_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..61cf146457b66fa9def351613e38ff51dccf10a1 Binary files /dev/null and b/test/onnx/averagepool_notset_test.onnx differ diff --git a/test/onnx/averagepool_nt_cip_test.onnx b/test/onnx/averagepool_nt_cip_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0a7d48c515de9afe0532c7a952f3ebafedd2bd67 Binary files /dev/null and b/test/onnx/averagepool_nt_cip_test.onnx differ diff --git a/test/onnx/averagepool_same_lower_test.onnx b/test/onnx/averagepool_same_lower_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f81cb278241ad990c0a4958bf9751bcf5c07137f --- /dev/null +++ b/test/onnx/averagepool_same_lower_test.onnx @@ -0,0 +1,18 @@ +averagepool_same_lower_test:ž +E +xy" AveragePool* +auto_pad" +SAME_LOWER * + kernel_shape@@ averagepool_same_lower_testZ +x + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/averagepool_same_upper_test.onnx b/test/onnx/averagepool_same_upper_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..063de98d9ee20ff09973d33874570fc3ecf00b69 --- /dev/null +++ b/test/onnx/averagepool_same_upper_test.onnx @@ -0,0 +1,18 @@ +averagepool_same_upper_test:ž +E +xy" AveragePool* +auto_pad" +SAME_UPPER * + kernel_shape@@ averagepool_same_upper_testZ +x + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/averagepool_sl_cip_test.onnx b/test/onnx/averagepool_sl_cip_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fad2d91ee412fd7407792ff4500f36fdea954257 --- /dev/null +++ b/test/onnx/averagepool_sl_cip_test.onnx @@ -0,0 +1,19 @@ +averagepool_sl_cip_test:´ +_ +xy" AveragePool* +auto_pad" +SAME_LOWER * +count_include_pad * + kernel_shape@@ averagepool_sl_cip_testZ +x + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/batchnorm_1d_test.onnx b/test/onnx/batchnorm_1d_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a6a3e8271f2a33647598fbcf2028a6a2691ec25f --- /dev/null +++ b/test/onnx/batchnorm_1d_test.onnx @@ -0,0 +1,35 @@ +batchnorm_1d_test:Ø +M +0 +1 +2 +3 +45"BatchNormalization* +epsilon½7†5 * +momentumfff? batchnorm_1d_testZ +0 + + + +Z +1 + + +Z +2 + + +Z +3 + + +Z +4 + + +b +5 + + + +B \ No newline at end of file diff --git a/test/onnx/batchnorm_3d_test.onnx b/test/onnx/batchnorm_3d_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7ac13200f115c8ef1b786171abca462997d3376a --- /dev/null +++ b/test/onnx/batchnorm_3d_test.onnx @@ -0,0 +1,39 @@ +batchnorm_3d_test:è +M +0 +1 +2 +3 +45"BatchNormalization* +epsilon½7†5 * +momentumfff? batchnorm_3d_testZ +0 + + + + + +Z +1 + + +Z +2 + + +Z +3 + + +Z +4 + + +b +5 + + + + + +B \ No newline at end of file diff --git a/test/onnx/celu_alpha_test.onnx b/test/onnx/celu_alpha_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c660d43d8ea7f467f84443c441c7878fa0ea75a6 --- /dev/null +++ b/test/onnx/celu_alpha_test.onnx @@ -0,0 +1,12 @@ +celu_alpha_test:R + +xy"Celu* +alphaÍÌL? celu_alpha_testZ +x + + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/celu_default_test.onnx b/test/onnx/celu_default_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..36fe4fcd8c27eba8bbfcc8ea31cfc5648de499c9 --- /dev/null +++ b/test/onnx/celu_default_test.onnx @@ -0,0 +1,11 @@ +celu_default_test:K + +xy"Celucelu_default_testZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/celu_verify_test.onnx b/test/onnx/celu_verify_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..875068bb4429534912edd5e1fbd7363c8ea57c3d Binary files /dev/null and b/test/onnx/celu_verify_test.onnx differ diff --git a/test/onnx/celu_wrong_type_test.onnx b/test/onnx/celu_wrong_type_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8725353a644949099b0fa030b6b7e34feec561f1 --- /dev/null +++ b/test/onnx/celu_wrong_type_test.onnx @@ -0,0 +1,13 @@ +celu_wrong_type_test:N + +xy"Celucelu_wrong_type_testZ +x +  + + +b +y +  + + +B \ No newline at end of file diff --git a/test/onnx/celu_zero_alpha_test.onnx b/test/onnx/celu_zero_alpha_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ca0b8792dae1bf0c828aa86cfc738b94ca3fb5ec Binary files /dev/null and b/test/onnx/celu_zero_alpha_test.onnx differ diff --git a/test/onnx/clip_test.onnx b/test/onnx/clip_test.onnx index 22757c3f971c8e07445334f0313b91119dd922a4..e1fa51c4997b8cabd2737268ade6a87fb3cdfc39 100644 Binary files a/test/onnx/clip_test.onnx and b/test/onnx/clip_test.onnx differ diff --git a/test/onnx/clip_test_args_type_mismatch.onnx b/test/onnx/clip_test_args_type_mismatch.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b90332ac969b8a4b90fa6e895c9a119a83a072c3 Binary files /dev/null and b/test/onnx/clip_test_args_type_mismatch.onnx differ diff --git a/test/onnx/clip_test_op11.onnx b/test/onnx/clip_test_op11.onnx new file mode 100644 index 0000000000000000000000000000000000000000..34c3c0a8293fe2781be554a814f5b18dc5e7bf65 Binary files /dev/null and b/test/onnx/clip_test_op11.onnx differ diff --git a/test/onnx/clip_test_op11_max_only.onnx b/test/onnx/clip_test_op11_max_only.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4e104b77e7a872a90e4ecae8ee486fbd94bade45 Binary files /dev/null and b/test/onnx/clip_test_op11_max_only.onnx differ diff --git a/test/onnx/clip_test_op11_min_only.onnx b/test/onnx/clip_test_op11_min_only.onnx new file mode 100644 index 0000000000000000000000000000000000000000..43156fd87de0a760609bcf6883e5980ffcc8061c Binary files /dev/null and b/test/onnx/clip_test_op11_min_only.onnx differ diff --git a/test/onnx/clip_test_op11_no_args.onnx b/test/onnx/clip_test_op11_no_args.onnx new file mode 100644 index 0000000000000000000000000000000000000000..55fb1b21e59bf05f4e43bfefae30cb0daeb96330 --- /dev/null +++ b/test/onnx/clip_test_op11_no_args.onnx @@ -0,0 +1,11 @@ +clip_test_op11_no_args:H + +01"Clipclip_test_op11_no_argsZ +0 + + +b +1 + + +B \ No newline at end of file diff --git a/test/onnx/clip_test_op11_no_args1.onnx b/test/onnx/clip_test_op11_no_args1.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4248a4c945bd3912774fc92d0db00be3c5c143d0 Binary files /dev/null and b/test/onnx/clip_test_op11_no_args1.onnx differ diff --git a/test/onnx/const_of_shape_empty_input_test.onnx b/test/onnx/const_of_shape_empty_input_test.onnx index 18e1ec9bae8d528e0fdd2764e4aa80bbc8e8f3a3..a014f9603af1f109cc5c40bb78c8bad7692278bb 100644 Binary files a/test/onnx/const_of_shape_empty_input_test.onnx and b/test/onnx/const_of_shape_empty_input_test.onnx differ diff --git a/test/onnx/constant_fill_input_as_shape_test.onnx b/test/onnx/constant_fill_input_as_shape_test.onnx index 7f252ba2071090e898a5c6216faf9cce0a2dbf94..91f5e12060847ff7ed6fdb6c6dbfb8f2623fad6b 100644 Binary files a/test/onnx/constant_fill_input_as_shape_test.onnx and b/test/onnx/constant_fill_input_as_shape_test.onnx differ diff --git a/test/onnx/constant_scalar_test.onnx b/test/onnx/constant_scalar_test.onnx index 51bb0531d59f635fb7af52a69e8b86539180d77d..79d3fcfdf3d385db48ca7746667aa48e083ed42c 100644 --- a/test/onnx/constant_scalar_test.onnx +++ b/test/onnx/constant_scalar_test.onnx @@ -1,6 +1,6 @@ -constant-scalar-example:R +constant_scalar_test:Y 00"Constant*! -value**B const_tensor  test-constantb +value**B const_tensor constant_scalar_testb 0  diff --git a/test/onnx/conv.weight b/test/onnx/conv.weight new file mode 100644 index 0000000000000000000000000000000000000000..36d6bbd10939df8b9422873e83e859133d24007f Binary files /dev/null and b/test/onnx/conv.weight differ diff --git a/test/onnx/conv_1d_test.onnx b/test/onnx/conv_1d_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a5568bb9addc0365517c285865e837527385f34d --- /dev/null +++ b/test/onnx/conv_1d_test.onnx @@ -0,0 +1,19 @@ + conv_1d_test:j + +0 +12"Conv conv_1d_testZ +0 + + + +Z +1 + + + +b +2 + + + +B \ No newline at end of file diff --git a/test/onnx/conv_3d_test.onnx b/test/onnx/conv_3d_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f67a191355297b13e449eb38e08072ae3c452d08 --- /dev/null +++ b/test/onnx/conv_3d_test.onnx @@ -0,0 +1,25 @@ + conv_3d_test:‚ + +0 +12"Conv conv_3d_testZ +0 + + + + + +Z +1 + + + + + +b +2 + + + + + +B \ No newline at end of file diff --git a/test/onnx/conv_attr_fail_test.onnx b/test/onnx/conv_attr_fail_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3bc63d73ff3df0e2f6261c5a8f1fbc0462a8aeaf --- /dev/null +++ b/test/onnx/conv_attr_fail_test.onnx @@ -0,0 +1,20 @@ +conv_attr_fail_test:ƒ +! +0 +12"Conv* +strides@@ conv_attr_fail_testZ +0 + + + +Z +1 + + + +b +2 + + + +B \ No newline at end of file diff --git a/test/onnx/conv_autopad_same_test.onnx b/test/onnx/conv_autopad_same_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f87135334f44310cb9c5cc5828802d8e8bc3c0fa --- /dev/null +++ b/test/onnx/conv_autopad_same_test.onnx @@ -0,0 +1,25 @@ +conv_autopad_same_test:» +J +0 +12"Conv* +auto_pad"SAME * + dilations@@ * +strides@@ conv_autopad_same_testZ +0 + + + + + Z +1 + + + + +b +2 + + + + + B \ No newline at end of file diff --git a/test/onnx/conv_bias_test.onnx b/test/onnx/conv_bias_test.onnx index dea7dd93cb20f1de55fe8f9970dbcb47a3621c7b..7a0ce5333598bc8090f1961f833c5fb2b36d281c 100644 --- a/test/onnx/conv_bias_test.onnx +++ b/test/onnx/conv_bias_test.onnx @@ -1,10 +1,10 @@ - conv-example:­ +conv_bias_test:² 8 0 1 23"Conv* dilations@@ * -strides@@  test_convZ +strides@@ conv_bias_testZ 0   diff --git a/test/onnx/convinteger_bias_test.onnx b/test/onnx/convinteger_bias_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..518c98b9dc4bb82a381145f674052c0ba5f52cc0 --- /dev/null +++ b/test/onnx/convinteger_bias_test.onnx @@ -0,0 +1,29 @@ +convinteger_bias_test:À +? +0 +1 +23" ConvInteger* + dilations@@ * +strides@@ convinteger_bias_testZ +0 + + + + + Z +1 + + + + +Z +2 + + +b +3 + + + + +B \ No newline at end of file diff --git a/test/onnx/deconv_bias_test.onnx b/test/onnx/deconv_bias_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3cfa26213f1313f90a1a9ce2d0ae9f23c483416e --- /dev/null +++ b/test/onnx/deconv_bias_test.onnx @@ -0,0 +1,27 @@ +deconv_bias_test:ž +" +x +w +byconv1" ConvTransposedeconv_bias_testZ +x + + + + +Z +w + + + + +Z +b + + +b +y + + + + +B diff --git a/test/onnx/deconv_input_pads_asymm_1d_test.onnx b/test/onnx/deconv_input_pads_asymm_1d_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..44e8a8f7038d9538fc3f653002889f50795b9363 Binary files /dev/null and b/test/onnx/deconv_input_pads_asymm_1d_test.onnx differ diff --git a/test/onnx/deconv_input_pads_asymm_test.onnx b/test/onnx/deconv_input_pads_asymm_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b98af97e7fc76162380fcf91d8685f3db105ca77 Binary files /dev/null and b/test/onnx/deconv_input_pads_asymm_test.onnx differ diff --git a/test/onnx/deconv_input_pads_strides_test.onnx b/test/onnx/deconv_input_pads_strides_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d1ce6c16fded83b7dd4e358a55b296fd8cfdebf8 --- /dev/null +++ b/test/onnx/deconv_input_pads_strides_test.onnx @@ -0,0 +1,24 @@ +deconv_input_pads_strides_test:¶ += +x +wy" ConvTranspose* +pads@@@@ * +strides@@ deconv_input_pads_strides_testZ +x + + + + +Z +w + + + + +b +y + + + + +B diff --git a/test/onnx/deconv_input_pads_test.onnx b/test/onnx/deconv_input_pads_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..73cc010143c84524c636fdd5ed6474be29703d8a --- /dev/null +++ b/test/onnx/deconv_input_pads_test.onnx @@ -0,0 +1,24 @@ +deconv_input_pads_test:® += +x +wy" ConvTranspose* +pads@@@@ * +strides@@ deconv_input_pads_testZ +x + + + + +Z +w + + + + +b +y + + + + +B diff --git a/test/onnx/deconv_output_padding_3d_test.onnx b/test/onnx/deconv_output_padding_3d_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9130dc6dc1c18ffd00eabf25f5f309d6f77a72c9 --- /dev/null +++ b/test/onnx/deconv_output_padding_3d_test.onnx @@ -0,0 +1,28 @@ +deconv_output_padding_3d_test:Ë +G +x +wy" ConvTranspose* +output_padding@@@ * +strides@@@ deconv_output_padding_3d_testZ +x + + + + + +Z +w + + + + + +b +y + + + + + + +B \ No newline at end of file diff --git a/test/onnx/deconv_output_padding_test.onnx b/test/onnx/deconv_output_padding_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9379d5e82d56b307212d2b2c815ccfa4e8554989 --- /dev/null +++ b/test/onnx/deconv_output_padding_test.onnx @@ -0,0 +1,25 @@ +deconv_output_padding_test:¸ +C +x +wy" ConvTranspose* +output_padding@@ * +strides@@ deconv_output_padding_testZ +x + + + + +Z +w + + + + +b +y + + + + + +B diff --git a/test/onnx/deconv_output_shape_3d_test.onnx b/test/onnx/deconv_output_shape_3d_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..db6eb77fdea1ef5159bca785831148c0e6186acb --- /dev/null +++ b/test/onnx/deconv_output_shape_3d_test.onnx @@ -0,0 +1,29 @@ +deconv_output_shape_3d_test:Ç +E +x +wy" ConvTranspose* + output_shape@ +@@ * +strides@@@ deconv_output_shape_3d_testZ +x + + + + + +Z +w + + + + + +b +y + + + + + + +B \ No newline at end of file diff --git a/test/onnx/deconv_output_shape_test.onnx b/test/onnx/deconv_output_shape_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa59ee045f925877c62c942aceb4a53c932fefc5 --- /dev/null +++ b/test/onnx/deconv_output_shape_test.onnx @@ -0,0 +1,26 @@ +deconv_output_shape_test:´ +A +x +wy" ConvTranspose* + output_shape@ +@ * +strides@@ deconv_output_shape_testZ +x + + + + +Z +w + + + + +b +y + + + + + +B diff --git a/test/onnx/deconv_test.onnx b/test/onnx/deconv_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..01564f94108e8edb52e811caedc942feac37350a --- /dev/null +++ b/test/onnx/deconv_test.onnx @@ -0,0 +1,22 @@ + deconv_test:… + +x +wyconv1" ConvTranspose deconv_testZ +x + + + + +Z +w + + + + +b +y + + + + +B diff --git a/test/onnx/depthtospace_crd_test.onnx b/test/onnx/depthtospace_crd_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f36bd9a31e47dcbff2780f682c0123ad8d9d3154 --- /dev/null +++ b/test/onnx/depthtospace_crd_test.onnx @@ -0,0 +1,19 @@ +depthtospace_crd_test:‰ +6 +xy" DepthToSpace* + blocksize * +mode"CRD depthtospace_crd_testZ +x + + + + +b +y + + + + + + +B \ No newline at end of file diff --git a/test/onnx/depthtospace_simple_test.onnx b/test/onnx/depthtospace_simple_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f2e1cc96528c84d2e6dda475a593a1c60107ad25 --- /dev/null +++ b/test/onnx/depthtospace_simple_test.onnx @@ -0,0 +1,17 @@ +depthtospace_simple_test:Œ +6 +xy" DepthToSpace* + blocksize * +mode"DCR depthtospace_simple_testZ +x + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/depthtospace_test.onnx b/test/onnx/depthtospace_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..39db437d79581ee85076df07eb11a66e40e0f30a --- /dev/null +++ b/test/onnx/depthtospace_test.onnx @@ -0,0 +1,19 @@ +depthtospace_test:… +6 +xy" DepthToSpace* + blocksize * +mode"DCR depthtospace_testZ +x + + + + +b +y + + + + + + +B \ No newline at end of file diff --git a/test/onnx/dequantizelinear_axis_test.onnx b/test/onnx/dequantizelinear_axis_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..334a47a608c87651d47a2a39ca3645f7015cb9ca --- /dev/null +++ b/test/onnx/dequantizelinear_axis_test.onnx @@ -0,0 +1,26 @@ +dequantizelinear_axis_test:© +- +0 +1 +2out"DequantizeLinear* +axis dequantizelinear_axis_testZ +0 + + + + +Z +1 + + +Z +2 + + +b +out + + + + +B \ No newline at end of file diff --git a/test/onnx/dequantizelinear_neg_axis_test.onnx b/test/onnx/dequantizelinear_neg_axis_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3ba88f50fd1f11e3eb6669fac11329c9ed93ff91 --- /dev/null +++ b/test/onnx/dequantizelinear_neg_axis_test.onnx @@ -0,0 +1,26 @@ +dequantizelinear_neg_axis_test:¶ +6 +0 +1 +2out"DequantizeLinear* +axisþÿÿÿÿÿÿÿÿ dequantizelinear_neg_axis_testZ +0 + + + + +Z +1 + + +Z +2 + + +b +out + + + + +B \ No newline at end of file diff --git a/test/onnx/dequantizelinear_test.onnx b/test/onnx/dequantizelinear_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cf48719dbd7537f56c379bf2b5ee9d94dd6a38a6 --- /dev/null +++ b/test/onnx/dequantizelinear_test.onnx @@ -0,0 +1,16 @@ +dequantizelinear_test:k + +0 +1out"DequantizeLineardequantizelinear_testZ +0 + + +Z +1 + + +b +out + + +B \ No newline at end of file diff --git a/test/onnx/dequantizelinear_zero_point_test.onnx b/test/onnx/dequantizelinear_zero_point_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7ab24fcf19b0e2defe8e8eea7c97373c09fdd917 --- /dev/null +++ b/test/onnx/dequantizelinear_zero_point_test.onnx @@ -0,0 +1,21 @@ + dequantizelinear_zero_point_test:Š + +0 +1 +2out"DequantizeLinear dequantizelinear_zero_point_testZ +0 + + +Z +1 + + +Z +2 + + +b +out + + +B \ No newline at end of file diff --git a/test/onnx/embedding_bag_offset_test.onnx b/test/onnx/embedding_bag_offset_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c61ce12f74897622e23e6b3fdc7ab65d9e274526 Binary files /dev/null and b/test/onnx/embedding_bag_offset_test.onnx differ diff --git a/test/onnx/embedding_bag_test.onnx b/test/onnx/embedding_bag_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b88e9f6f7489a58ab19168a6006a812059a1bf1c Binary files /dev/null and b/test/onnx/embedding_bag_test.onnx differ diff --git a/test/onnx/equal_bool_test.onnx b/test/onnx/equal_bool_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..6cac1987ff0c6c341167a9994191ec926aee228f --- /dev/null +++ b/test/onnx/equal_bool_test.onnx @@ -0,0 +1,19 @@ +equal_bool_test:ƒ + +x1bx1"Cast* +to   + +bx1 +x2y"Equalequal_bool_testZ +x1 +  + +Z +x2 +   + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/equal_test.onnx b/test/onnx/equal_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..bebb942d775fbce9b68efc3a433f54a5e8c1caaf Binary files /dev/null and b/test/onnx/equal_test.onnx differ diff --git a/test/onnx/ext_path/conv.weight b/test/onnx/ext_path/conv.weight new file mode 100644 index 0000000000000000000000000000000000000000..36d6bbd10939df8b9422873e83e859133d24007f Binary files /dev/null and b/test/onnx/ext_path/conv.weight differ diff --git a/test/onnx/ext_path/external_data_test.onnx b/test/onnx/ext_path/external_data_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..640940edcbe17b2919c925b767923b163cee2948 Binary files /dev/null and b/test/onnx/ext_path/external_data_test.onnx differ diff --git a/test/onnx/external_data_test.onnx b/test/onnx/external_data_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..640940edcbe17b2919c925b767923b163cee2948 Binary files /dev/null and b/test/onnx/external_data_test.onnx differ diff --git a/test/onnx/eyelike_default_test.onnx b/test/onnx/eyelike_default_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b130fb6d331141697fa42e0e254de2a1eba82c5b --- /dev/null +++ b/test/onnx/eyelike_default_test.onnx @@ -0,0 +1,11 @@ +eyelike_default_test:U + +T1T2"EyeLikeeyelike_default_testZ +T1 +  + +b +T2 +  + +B \ No newline at end of file diff --git a/test/onnx/eyelike_double_test.onnx b/test/onnx/eyelike_double_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..158a1b40a9ccf8a641ae424faa62845469e7e409 --- /dev/null +++ b/test/onnx/eyelike_double_test.onnx @@ -0,0 +1,11 @@ +eyelike_double_test:T + +T1T2"EyeLikeeyelike_double_testZ +T1 +   + +b +T2 +   + +B \ No newline at end of file diff --git a/test/onnx/eyelike_half_test.onnx b/test/onnx/eyelike_half_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3fe9b83ae4a6ca76408aa60141d86de2602aaf9c --- /dev/null +++ b/test/onnx/eyelike_half_test.onnx @@ -0,0 +1,13 @@ +eyelike_half_test:R + +T1T2"EyeLikeeyelike_half_testZ +T1 +  + + +b +T2 +  + + +B \ No newline at end of file diff --git a/test/onnx/eyelike_k_outofbounds_neg_test.onnx b/test/onnx/eyelike_k_outofbounds_neg_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..984487903c6795fc46e876fe291a54c328776c22 --- /dev/null +++ b/test/onnx/eyelike_k_outofbounds_neg_test.onnx @@ -0,0 +1,12 @@ +eyelike_k_outofbounds_neg_test:r +$ +T1T2"EyeLike* +kþÿÿÿÿÿÿÿÿ eyelike_k_outofbounds_neg_testZ +T1 +  + +b +T2 +  + +B \ No newline at end of file diff --git a/test/onnx/eyelike_k_outofbounds_pos_test.onnx b/test/onnx/eyelike_k_outofbounds_pos_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4d887e1d88caa0ecc4aacb42d989ddc9a43f3cd3 --- /dev/null +++ b/test/onnx/eyelike_k_outofbounds_pos_test.onnx @@ -0,0 +1,12 @@ +eyelike_k_outofbounds_pos_test:i + +T1T2"EyeLike* +k eyelike_k_outofbounds_pos_testZ +T1 +  + +b +T2 +  + +B \ No newline at end of file diff --git a/test/onnx/eyelike_k_test.onnx b/test/onnx/eyelike_k_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..16d2c0964762d94f82e77d15820114bc625ea431 --- /dev/null +++ b/test/onnx/eyelike_k_test.onnx @@ -0,0 +1,12 @@ +eyelike_k_test:Y + +T1T2"EyeLike* +k eyelike_k_testZ +T1 +  + +b +T2 +  + +B \ No newline at end of file diff --git a/test/onnx/eyelike_not_rank2_test.onnx b/test/onnx/eyelike_not_rank2_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..512fb4858261e2f415a54806fd6dd7fa56eeb4fb --- /dev/null +++ b/test/onnx/eyelike_not_rank2_test.onnx @@ -0,0 +1,12 @@ +eyelike_not_rank2_test:[ + +T1T2"EyeLikeeyelike_not_rank2_testZ +T1 + + + +b +T2 +  + +B \ No newline at end of file diff --git a/test/onnx/eyelike_set_dtype_test.onnx b/test/onnx/eyelike_set_dtype_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b9974ec585ecab1123ea1be220d3832eb0e35d04 --- /dev/null +++ b/test/onnx/eyelike_set_dtype_test.onnx @@ -0,0 +1,12 @@ +eyelike_set_dtype_test:e + +T1T2"EyeLike* +dtype  eyelike_set_dtype_testZ +T1 +  + +b +T2 +   + +B \ No newline at end of file diff --git a/test/onnx/eyelike_verify_negk_test.onnx b/test/onnx/eyelike_verify_negk_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..51c67ab408674be8ab1811033f1db9c3c491c973 --- /dev/null +++ b/test/onnx/eyelike_verify_negk_test.onnx @@ -0,0 +1,12 @@ +eyelike_verify_negk_test:l +$ +T1T2"EyeLike* +kþÿÿÿÿÿÿÿÿ eyelike_verify_negk_testZ +T1 +  + +b +T2 +  + +B \ No newline at end of file diff --git a/test/onnx/eyelike_verify_test.onnx b/test/onnx/eyelike_verify_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8e12120a8f9c8cb0a3ede1961989c8553103901c --- /dev/null +++ b/test/onnx/eyelike_verify_test.onnx @@ -0,0 +1,12 @@ +eyelike_verify_test:^ + +T1T2"EyeLike* +k eyelike_verify_testZ +T1 +  + +b +T2 +  + +B \ No newline at end of file diff --git a/test/onnx/flatten_nonstd_test.onnx b/test/onnx/flatten_nonstd_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e729aed00bbff28c2ab78b906ebd6d7e1c032f8d Binary files /dev/null and b/test/onnx/flatten_nonstd_test.onnx differ diff --git a/test/onnx/gather_elements_axis0_test.onnx b/test/onnx/gather_elements_axis0_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e61d8d17795e713d4cb6284b826bb20b076b3410 Binary files /dev/null and b/test/onnx/gather_elements_axis0_test.onnx differ diff --git a/test/onnx/gather_elements_axis1_test.onnx b/test/onnx/gather_elements_axis1_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..de325d1a6807ba94d656f35c681361503560a80f --- /dev/null +++ b/test/onnx/gather_elements_axis1_test.onnx @@ -0,0 +1,17 @@ +gather_elements_axis1_test:• +/ +data +indicesy"GatherElements* +axis gather_elements_axis1_testZ +data +  + +Z +indices +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/gathernd_batch_dims_test.onnx b/test/onnx/gathernd_batch_dims_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1b488f13d062fa6908d2411270d13d9c2e8057bc --- /dev/null +++ b/test/onnx/gathernd_batch_dims_test.onnx @@ -0,0 +1,19 @@ +gathernd_batch_dims_test:— +/ +data +indicesy"GatherND* + +batch_dims gathernd_batch_dims_testZ +data + + + +Z +indices +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/gathernd_test.onnx b/test/onnx/gathernd_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8d6afc78bee4eb43585e1eefc03902e628eb75d8 --- /dev/null +++ b/test/onnx/gathernd_test.onnx @@ -0,0 +1,16 @@ + gathernd_test:q + +data +indicesy"GatherND gathernd_testZ +data +  + +Z +indices +  + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/gemm_ex_test.onnx b/test/onnx/gemm_ex_test.onnx index c280b5fbad20381504dcce529fe50ccbed9e1996..d3f128563f7258af64cd71b6c2bca082e30105c8 100644 Binary files a/test/onnx/gemm_ex_test.onnx and b/test/onnx/gemm_ex_test.onnx differ diff --git a/test/onnx/gemm_half_test.onnx b/test/onnx/gemm_half_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..54cb1d9af78a0cd2d9d9050a9ccecddf3e5ef2eb Binary files /dev/null and b/test/onnx/gemm_half_test.onnx differ diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py old mode 100644 new mode 100755 index 2c519da2eee4361a77c8ac9aa7660dc7c7095f6e..454a034ec3217f3b0f430c60e43b9976c476e5cf --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -1,8 +1,10 @@ +# This script generates onnx files for MIGraphX onnx operator tests. +# To generate an individual onnx file, you can use the following +# command: python -c "import gen_onnx; gen_onnx.{test_name}_test()" import numpy as np import onnx from onnx import helper -from onnx import numpy_helper -from onnx import AttributeProto, TensorProto, GraphProto +from onnx import TensorProto def onnx_test(op_test): @@ -38,6 +40,20 @@ def acos_test(): return ([node], [x], [y]) +@onnx_test +def acosh_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10]) + + node = onnx.helper.make_node( + 'Acosh', + inputs=['x'], + outputs=['y'], + ) + + return ([node], [x], [y]) + + @onnx_test def add_bcast_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) @@ -78,14 +94,13 @@ def add_fp16_test(): @onnx_test def add_scalar_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) - y = helper.make_tensor_value_info('1', TensorProto.FLOAT, []) - z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [2, 3, 4, 5]) + x = helper.make_tensor_value_info('0', TensorProto.UINT8, [2, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.UINT8, []) + z = helper.make_tensor_value_info('2', TensorProto.UINT8, [2, 3, 4, 5]) node = onnx.helper.make_node('Add', inputs=['0', '1'], outputs=['2']) - return ([node], [x, y], [z], - [helper.make_tensor('1', TensorProto.FLOAT, [], [1])]) + return ([node], [x, y], [z]) @onnx_test @@ -130,6 +145,20 @@ def asin_test(): return ([node], [x], [y]) +@onnx_test +def asinh_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10]) + + node = onnx.helper.make_node( + 'Asinh', + inputs=['x'], + outputs=['y'], + ) + + return ([node], [x], [y]) + + @onnx_test def atan_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) @@ -144,6 +173,160 @@ def atan_test(): return ([node], [x], [y]) +@onnx_test +def atanh_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10]) + + node = onnx.helper.make_node( + 'Atanh', + inputs=['x'], + outputs=['y'], + ) + + return ([node], [x], [y]) + + +@onnx_test +def averagepool_1d_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 5]) + out = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 3]) + + node = onnx.helper.make_node('AveragePool', + inputs=['0'], + outputs=['1'], + kernel_shape=[3]) + + return ([node], [x], [out]) + + +@onnx_test +def averagepool_3d_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 5, 5, 5]) + out = helper.make_tensor_value_info('1', TensorProto.FLOAT, + [1, 3, 3, 3, 3]) + + node = onnx.helper.make_node('AveragePool', + inputs=['0'], + outputs=['1'], + kernel_shape=[3, 3, 3]) + + return ([node], [x], [out]) + + +@onnx_test +def averagepool_notset_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 1, 1]) + + node = onnx.helper.make_node('AveragePool', + inputs=['x'], + outputs=['y'], + kernel_shape=[6, 6], + strides=[2, 2], + pads=[0, 0, 1, 1], + auto_pad='NOTSET') + + return ([node], [x], [y]) + + +@onnx_test +def averagepool_nt_cip_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 1, 1]) + + node = onnx.helper.make_node('AveragePool', + inputs=['x'], + outputs=['y'], + kernel_shape=[6, 6], + strides=[2, 2], + pads=[0, 0, 1, 1], + auto_pad='NOTSET', + count_include_pad=1) + + return ([node], [x], [y]) + + +@onnx_test +def averagepool_same_lower_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 5, 5]) + + node = onnx.helper.make_node('AveragePool', + inputs=['x'], + outputs=['y'], + kernel_shape=[2, 2], + auto_pad='SAME_LOWER') + + return ([node], [x], [y]) + + +@onnx_test +def averagepool_sl_cip_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 5, 5]) + + node = onnx.helper.make_node('AveragePool', + inputs=['x'], + outputs=['y'], + kernel_shape=[2, 2], + auto_pad='SAME_LOWER', + count_include_pad=1) + + return ([node], [x], [y]) + + +@onnx_test +def averagepool_same_upper_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 5, 5]) + + node = onnx.helper.make_node('AveragePool', + inputs=['x'], + outputs=['y'], + kernel_shape=[2, 2], + auto_pad='SAME_UPPER') + + return ([node], [x], [y]) + + +@onnx_test +def batchnorm_1d_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 5]) + scale = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) + bias = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3]) + mean = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3]) + var = helper.make_tensor_value_info('4', TensorProto.FLOAT, [3]) + out = helper.make_tensor_value_info('5', TensorProto.FLOAT, [1, 3, 5]) + + node = onnx.helper.make_node('BatchNormalization', + inputs=['0', '1', '2', '3', '4'], + outputs=['5'], + epsilon=1e-6, + momentum=0.9) + + return ([node], [x, scale, bias, mean, var], [out]) + + +@onnx_test +def batchnorm_3d_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 5, 5, 5]) + scale = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) + bias = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3]) + mean = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3]) + var = helper.make_tensor_value_info('4', TensorProto.FLOAT, [3]) + out = helper.make_tensor_value_info('5', TensorProto.FLOAT, + [1, 3, 5, 5, 5]) + + node = onnx.helper.make_node('BatchNormalization', + inputs=['0', '1', '2', '3', '4'], + outputs=['5'], + epsilon=1e-6, + momentum=0.9) + + return ([node], [x, scale, bias, mean, var], [out]) + + @onnx_test def cast_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [10]) @@ -168,6 +351,65 @@ def ceil_test(): return ([node], [x], [y]) +@onnx_test +def celu_alpha_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node('Celu', + inputs=['x'], + outputs=['y'], + alpha=0.8) + + return ([node], [x], [y]) + + +@onnx_test +def celu_default_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) + + node = onnx.helper.make_node('Celu', inputs=['x'], outputs=['y']) + + return ([node], [x], [y]) + + +@onnx_test +def celu_verify_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) + + node = onnx.helper.make_node('Celu', + inputs=['x'], + outputs=['y'], + alpha=0.5) + + return ([node], [x], [y]) + + +@onnx_test +def celu_wrong_type_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [2, 3]) + + node = onnx.helper.make_node('Celu', inputs=['x'], outputs=['y']) + + return ([node], [x], [y]) + + +@onnx_test +def celu_zero_alpha_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) + + node = onnx.helper.make_node('Celu', + inputs=['x'], + outputs=['y'], + alpha=0.0) + + return ([node], [x], [y]) + + @onnx_test def clip_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) @@ -182,6 +424,83 @@ def clip_test(): return ([node], [x], [y]) +@onnx_test +def clip_test_op11(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) + + min_val = helper.make_tensor('min', TensorProto.FLOAT, [], [0.0]) + max_val = helper.make_tensor('max', TensorProto.FLOAT, [], [6.0]) + + node = onnx.helper.make_node('Clip', + inputs=['0', 'min', 'max'], + outputs=['1']) + + return ([node], [x], [y], [min_val, max_val]) + + +@onnx_test +def clip_test_op11_max_only(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) + + max_val = helper.make_tensor('max', TensorProto.FLOAT, [], [0.0]) + + node = onnx.helper.make_node('Clip', + inputs=['0', '', 'max'], + outputs=['1']) + + return ([node], [x], [y], [max_val]) + + +@onnx_test +def clip_test_op11_min_only(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) + + min_val = helper.make_tensor('min', TensorProto.FLOAT, [], [0.0]) + + node = onnx.helper.make_node('Clip', inputs=['0', 'min'], outputs=['1']) + + return ([node], [x], [y], [min_val]) + + +@onnx_test +def clip_test_op11_no_args(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node('Clip', inputs=['0'], outputs=['1']) + + return ([node], [x], [y]) + + +@onnx_test +def clip_test_op11_no_args1(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node('Clip', inputs=['0', '', ''], outputs=['1']) + + return ([node], [x], [y]) + + +@onnx_test +def clip_test_args_type_mismatch(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3, 3]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 3]) + + min_val = helper.make_tensor('min', TensorProto.FLOAT, [1, 3], + [1.5, 2.5, 3.5]) + max_val = helper.make_tensor('max', TensorProto.INT64, [3, 1], [2, 3, 4]) + + node = onnx.helper.make_node('Clip', + inputs=['0', 'min', 'max'], + outputs=['1']) + + return ([node], [x], [y], [min_val, max_val]) + + @onnx_test def concat_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 4, 3]) @@ -238,7 +557,6 @@ def constant_fill_test(): @onnx_test def constant_fill_input_as_shape_test(): np_shape = np.array([2, 3]) - shape = helper.make_tensor_value_info('shape', TensorProto.INT32, [2]) value = helper.make_tensor_value_info('value', TensorProto.FLOAT, [2, 3]) ts_shape = helper.make_tensor(name='shape_tensor', @@ -289,7 +607,6 @@ def constant_scalar_test(): def const_of_shape_empty_input_test(): tensor_val = onnx.helper.make_tensor('value', onnx.TensorProto.INT64, [1], [10]) - shape_val = np.array([2, 3, 4]).astype(np.int64) empty_val = np.array([]).astype(np.int64) empty_ts = helper.make_tensor(name='empty_tensor', data_type=TensorProto.INT32, @@ -389,6 +706,43 @@ def const_of_shape_no_value_attr_test(): return ([shape_const, node], [], [y]) +@onnx_test +def conv_1d_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 3]) + out = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 3]) + + node = onnx.helper.make_node('Conv', inputs=['0', '1'], outputs=['2']) + + return ([node], [x, y], [out]) + + +@onnx_test +def conv_3d_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 5, 5, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 3, 3, 3]) + out = helper.make_tensor_value_info('2', TensorProto.FLOAT, + [1, 1, 3, 3, 3]) + + node = onnx.helper.make_node('Conv', inputs=['0', '1'], outputs=['2']) + + return ([node], [x, y], [out]) + + +@onnx_test +def conv_attr_fail_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 3]) + out = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 3]) + + node = onnx.helper.make_node('Conv', + inputs=['0', '1'], + strides=[1, 1], + outputs=['2']) + + return ([node], [x, y], [out]) + + @onnx_test def conv_autopad_fail_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 32, 32]) @@ -406,6 +760,22 @@ def conv_autopad_fail_test(): return ([node], [x, y], [out]) +@onnx_test +def conv_autopad_same_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 32, 32]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 3, 3]) + out = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 32, 32]) + + node = onnx.helper.make_node('Conv', + inputs=['0', '1'], + outputs=['2'], + dilations=[1, 1], + strides=[1, 1], + auto_pad='SAME') + + return ([node], [x, y], [out]) + + @onnx_test def conv_bias_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 32, 32]) @@ -528,6 +898,22 @@ def conv_relu_maxpool_x2_test(): return ([node1, node2, node3, node4, node5, node6], [x, y, z, m, n], [out]) +@onnx_test +def convinteger_bias_test(): + x = helper.make_tensor_value_info('0', TensorProto.INT8, [1, 3, 32, 32]) + y = helper.make_tensor_value_info('1', TensorProto.INT8, [1, 3, 5, 5]) + z = helper.make_tensor_value_info('2', TensorProto.INT32, [1]) + out = helper.make_tensor_value_info('3', TensorProto.INT32, [1, 2, 28, 28]) + + node = onnx.helper.make_node('ConvInteger', + inputs=['0', '1', '2'], + outputs=['3'], + dilations=[1, 1], + strides=[1, 1]) + + return ([node], [x, y, z], [out]) + + @onnx_test def cos_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) @@ -557,264 +943,263 @@ def cosh_test(): @onnx_test -def dropout_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 2, 2]) - y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 2, 2]) - - node = onnx.helper.make_node( - 'Dropout', - inputs=['0'], - outputs=['1'], - ) +def deconv_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 1, 3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 5, 5]) + + node = onnx.helper.make_node('ConvTranspose', + name='conv1', + inputs=['x', 'w'], + outputs=['y']) - return ([node], [x], [y]) + return ([node], [x, w], [y]) @onnx_test -def elu_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) - y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) +def deconv_bias_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 1, 3, 3]) + b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [1]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 5, 5]) + + node = onnx.helper.make_node('ConvTranspose', + name='conv1', + inputs=['x', 'w', 'b'], + outputs=['y']) - node = onnx.helper.make_node('Elu', - inputs=['0'], - outputs=['1'], - alpha=0.01) - - return ([node], [x], [y]) + return ([node], [x, w, b], [y]) @onnx_test -def erf_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10, 15]) +def deconv_input_pads_strides_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 2, 3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 7, 5]) - node = onnx.helper.make_node( - 'Erf', - inputs=['x'], - outputs=['y'], - ) + node = onnx.helper.make_node('ConvTranspose', + inputs=['x', 'w'], + outputs=['y'], + strides=[3, 2], + pads=[1, 1, 1, 1]) - return ([node], [x], [y]) + return ([node], [x, w], [y]) @onnx_test -def exp_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10]) +def deconv_input_pads_asymm_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 2, 3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 8, 6]) - node = onnx.helper.make_node( - 'Exp', - inputs=['x'], - outputs=['y'], - ) + node = onnx.helper.make_node('ConvTranspose', + inputs=['x', 'w'], + outputs=['y'], + strides=[3, 2], + pads=[0, 0, 1, 1]) - return ([node], [x], [y]) + return ([node], [x, w], [y]) @onnx_test -def expand_test(): - shape_val = np.array([2, 3, 4, 5]).astype(np.int64) - shape_ts = helper.make_tensor(name='shape_tensor', - data_type=TensorProto.INT32, - dims=shape_val.shape, - vals=shape_val.flatten().astype(int)) - shape_const = helper.make_node( - 'Constant', - inputs=[], - outputs=['shape'], - value=shape_ts, - ) - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 1, 1]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3, 4, 5]) +def deconv_input_pads_asymm_1d_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 6]) - node = onnx.helper.make_node('Expand', - inputs=['x', 'shape'], - outputs=['y']) + node = onnx.helper.make_node('ConvTranspose', + inputs=['x', 'w'], + outputs=['y'], + strides=[2], + pads=[0, 1], + dilations=[1]) - return ([shape_const, node], [x], [y]) + return ([node], [x, w], [y]) @onnx_test -def flatten_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) - y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [6, 20]) - y2 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [2, 60]) +def deconv_output_padding_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 2, 3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 10, 8]) - node = onnx.helper.make_node('Flatten', - inputs=['0'], - axis=2, - outputs=['2']) + node = onnx.helper.make_node('ConvTranspose', + inputs=['x', 'w'], + outputs=['y'], + strides=[3, 2], + output_padding=[1, 1]) - node2 = onnx.helper.make_node('Flatten', inputs=['0'], outputs=['3']) + return ([node], [x, w], [y]) - return ([node, node2], [x], [y, y2]) + +@onnx_test +def deconv_output_padding_3d_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3, 3]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 2, 3, 3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 10, 8, 8]) + + node = onnx.helper.make_node('ConvTranspose', + inputs=['x', 'w'], + outputs=['y'], + strides=[3, 2, 2], + output_padding=[1, 1, 1]) + + return ([node], [x, w], [y]) @onnx_test -def floor_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10]) +def deconv_output_shape_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 2, 3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 10, 8]) - node = onnx.helper.make_node( - 'Floor', - inputs=['x'], - outputs=['y'], - ) + node = onnx.helper.make_node('ConvTranspose', + inputs=['x', 'w'], + outputs=['y'], + strides=[3, 2], + output_shape=[10, 8]) - return ([node], [x], [y]) + return ([node], [x, w], [y]) @onnx_test -def gather_test(): - x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6]) - i = helper.make_tensor_value_info('indices', TensorProto.INT32, - [2, 3, 4, 5]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3, 4, 5]) +def deconv_output_shape_3d_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3, 3]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 2, 3, 3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 10, 8, 8]) - node = onnx.helper.make_node( - 'Gather', - inputs=['data', 'indices'], - outputs=['y'], - axis=1, - ) + node = onnx.helper.make_node('ConvTranspose', + inputs=['x', 'w'], + outputs=['y'], + strides=[3, 2, 2], + output_shape=[10, 8, 8]) - return ([node], [x, i], [y]) + return ([node], [x, w], [y]) @onnx_test -def gemm_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 7]) - y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [11, 5]) - z = helper.make_tensor_value_info('2', TensorProto.FLOAT, []) - a = helper.make_tensor_value_info('3', TensorProto.FLOAT, [7, 11]) +def deconv_stride_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 3, 3]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 2, 3, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 7, 3]) - node = onnx.helper.make_node('Gemm', - inputs=['0', '1', '2'], - outputs=['3'], - alpha=2.0, - beta=2.0, - transA=1, - transB=1) + node = onnx.helper.make_node('ConvTranspose', + inputs=['x', 'w'], + outputs=['y'], + strides=[3, 2]) - return ([node], [x, y, z], [a]) + return ([node], [x, w], [y]) @onnx_test -def gemm_ex_test(): - m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 1, 5, 6]) - m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 5, 7]) - m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 1, 6, 7]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 6, 7]) +def depthtospace_test(): - node = onnx.helper.make_node('Gemm', - inputs=['1', '2', '3'], + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 8, 5, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 10, 10]) + + node = onnx.helper.make_node('DepthToSpace', + inputs=['x'], outputs=['y'], - alpha=0.5, - beta=0.8, - transA=1) + blocksize=2, + mode='DCR') - return ([node], [m1, m2, m3], [y]) + return ([node], [x], [y]) @onnx_test -def gemm_ex_brcst_test(): - m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 1, 5, 6]) - m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 5, 7]) - m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 1, 6, 1]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 6, 7]) +def depthtospace_simple_test(): - node = onnx.helper.make_node('Gemm', - inputs=['1', '2', '3'], + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 8, 2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 4, 6]) + + node = onnx.helper.make_node('DepthToSpace', + inputs=['x'], outputs=['y'], - alpha=0.5, - beta=0.8, - transA=1) + blocksize=2, + mode='DCR') - return ([node], [m1, m2, m3], [y]) + return ([node], [x], [y]) @onnx_test -def globalavgpool_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16]) - y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 1, 1]) +def depthtospace_crd_test(): - node = onnx.helper.make_node( - 'GlobalAveragePool', - inputs=['0'], - outputs=['1'], - ) + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 8, 5, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 10, 10]) + + node = onnx.helper.make_node('DepthToSpace', + inputs=['x'], + outputs=['y'], + blocksize=2, + mode='CRD') return ([node], [x], [y]) @onnx_test -def globalmaxpool_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16]) - y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 1, 1]) +def spacetodepth_test(): - node = onnx.helper.make_node( - 'GlobalMaxPool', - inputs=['0'], - outputs=['1'], - ) + x = helper.make_tensor_value_info('x', TensorProto.float, [2, 2, 10, 10]) + y = helper.make_tensor_value_info('y', TensorProto.float, [2, 8, 5, 5]) + + node = onnx.helper.make_node('spacetodepth', + inputs=['x'], + outputs=['y'], + blocksize=2) return ([node], [x], [y]) @onnx_test -def group_conv_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 4, 16, 16]) - y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 1, 3, 3]) - z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 4, 14, 14]) +def spacetodepth_simple_test(): - node = onnx.helper.make_node( - 'Conv', - inputs=['0', '1'], - group=4, - outputs=['2'], - ) + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 2, 4, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 8, 2, 3]) - return ([node], [x, y], [z]) + node = onnx.helper.make_node('SpaceToDepth', + inputs=['x'], + outputs=['y'], + blocksize=2) + + return ([node], [x], [y]) @onnx_test -def imagescaler_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16]) - y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 16, 16]) +def spacetodepth_invalid_blocksize_test(): - node = onnx.helper.make_node('ImageScaler', - inputs=['0'], - outputs=['1'], - bias=[0.01, 0.02, 0.03], - scale=0.5) + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 2, 4, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 8, 2, 3]) + + node = onnx.helper.make_node('SpaceToDepth', + inputs=['x'], + outputs=['y'], + blocksize=0.3) return ([node], [x], [y]) @onnx_test -def implicit_add_bcast_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) - y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 4, 1]) - z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [2, 3, 4, 5]) +def spacetodepth_nondivisibility_test(): - node = onnx.helper.make_node( - 'Add', - inputs=['0', '1'], - outputs=['2'], - ) + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 2, 5, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 8, 2, 2]) - return ([node], [x, y], [z]) + node = onnx.helper.make_node('SpaceToDepth', + inputs=['x'], + outputs=['y'], + blocksize=2) + + return ([node], [x], [y]) @onnx_test -def implicit_pow_bcast_test(): - arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) - arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 4, 1]) - arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT, - [2, 3, 4, 5]) +def dequantizelinear_test(): + arg0 = helper.make_tensor_value_info('0', TensorProto.INT8, [5]) + arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1]) + arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT, [5]) node = onnx.helper.make_node( - 'Pow', + 'DequantizeLinear', inputs=['0', '1'], outputs=['out'], ) @@ -823,47 +1208,66 @@ def implicit_pow_bcast_test(): @onnx_test -def implicit_sub_bcast_test(): - arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) - arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 5]) - arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT, - [2, 3, 4, 5]) +def dequantizelinear_zero_point_test(): + arg0 = helper.make_tensor_value_info('0', TensorProto.INT8, [5]) + arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1]) + arg2 = helper.make_tensor_value_info('2', TensorProto.INT8, [1]) + arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT, [5]) node = onnx.helper.make_node( - 'Sub', - inputs=['0', '1'], + 'DequantizeLinear', + inputs=['0', '1', '2'], outputs=['out'], ) - return ([node], [arg0, arg1], [arg_out]) + return ([node], [arg0, arg1, arg2], [arg_out]) + + +def make_dequantizelinear_axis_graph(axis): + arg0 = helper.make_tensor_value_info('0', TensorProto.INT8, [1, 1, 5, 1]) + arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [5]) + arg2 = helper.make_tensor_value_info('2', TensorProto.INT8, [5]) + arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT, + [1, 1, 5, 1]) + + node = onnx.helper.make_node('DequantizeLinear', + inputs=['0', '1', '2'], + outputs=['out'], + axis=axis) + + return ([node], [arg0, arg1, arg2], [arg_out]) @onnx_test -def initializer_not_an_input(): - values = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) - w = helper.make_tensor(name='w', - data_type=TensorProto.FLOAT, - dims=values.shape, - vals=values.flatten().astype(np.float)) +def dequantizelinear_axis_test(): + return make_dequantizelinear_axis_graph(2) + + +@onnx_test +def dequantizelinear_neg_axis_test(): + return make_dequantizelinear_axis_graph(-2) - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5, 2]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5, 4]) + +@onnx_test +def dropout_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 2, 2]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 2, 2]) node = onnx.helper.make_node( - 'Gemm', - inputs=['x', 'w'], - outputs=['y'], + 'Dropout', + inputs=['0'], + outputs=['1'], ) - return ([node], [x], [y], [w]) + return ([node], [x], [y]) @onnx_test -def leaky_relu_test(): +def elu_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) - node = onnx.helper.make_node('LeakyRelu', + node = onnx.helper.make_node('Elu', inputs=['0'], outputs=['1'], alpha=0.01) @@ -872,300 +1276,2879 @@ def leaky_relu_test(): @onnx_test -def log_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10]) +def embedding_bag_test(): - node = onnx.helper.make_node( - 'Log', - inputs=['x'], - outputs=['y'], - ) + index_val = np.array([1, 0, 2]) + offset_val = np.array([0]) - return ([node], [x], [y]) + index_tensor = helper.make_tensor(name='index_val', + data_type=TensorProto.INT32, + dims=index_val.shape, + vals=index_val.astype(np.int32)) + index = onnx.helper.make_node('Constant', + inputs=[], + outputs=['index'], + value=index_tensor) -@onnx_test -def logsoftmax_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6]) + offset_tensor = helper.make_tensor(name='offset_val', + data_type=TensorProto.INT32, + dims=offset_val.reshape(()).shape, + vals=offset_val.astype(np.int32)) - node = onnx.helper.make_node('LogSoftmax', - inputs=['x'], - outputs=['y'], - axis=1) + offset = onnx.helper.make_node('Constant', + inputs=[], + outputs=['offset'], + value=offset_tensor) - return ([node], [x], [y]) + weight = helper.make_tensor_value_info('weight', TensorProto.FLOAT, [4, 2]) + y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [1, 2]) + y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [1, 2]) + y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [1, 2]) -@onnx_test -def lrn_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 28, 24, 24]) - y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 28, 24, 24]) + node1 = onnx.helper.make_node('ATen', + inputs=['weight', 'index', 'offset'], + outputs=['y1'], + mode=0, + operator='embedding_bag') - node = onnx.helper.make_node('LRN', - inputs=['0'], - size=5, - alpha=0.0001, - beta=0.75, - bias=1.0, - outputs=['1']) + node2 = onnx.helper.make_node('ATen', + inputs=['weight', 'index', 'offset'], + outputs=['y2'], + mode=1, + operator='embedding_bag') - return ([node], [x], [y]) + node3 = onnx.helper.make_node('ATen', + inputs=['weight', 'index', 'offset'], + outputs=['y3'], + mode=2, + operator='embedding_bag') + + return ([index, offset, node1, node2, node3], [weight], [y1, y2, y3]) @onnx_test -def matmul_bmbm_test(): - m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 6, 7]) - m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [5, 2, 1, 7, 8]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5, 2, 3, 6, 8]) +def embedding_bag_offset_test(): - node = onnx.helper.make_node( - 'MatMul', - inputs=['1', '2'], - outputs=['y'], - ) + index_val = np.array([1, 0]) + offset_val = np.array([0, 1]) - return ([node], [m1, m2], [y]) + index_tensor = helper.make_tensor(name='index_val', + data_type=TensorProto.INT32, + dims=index_val.shape, + vals=index_val.astype(np.int32)) + + index = onnx.helper.make_node('Constant', + inputs=[], + outputs=['index'], + value=index_tensor) + + offset_tensor = helper.make_tensor(name='offset_val', + data_type=TensorProto.INT32, + dims=offset_val.shape, + vals=offset_val.astype(np.int32)) + + offset = onnx.helper.make_node('Constant', + inputs=[], + outputs=['offset'], + value=offset_tensor) + + weight = helper.make_tensor_value_info('weight', TensorProto.FLOAT, [2, 3]) + + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) + + node = onnx.helper.make_node('ATen', + inputs=['weight', 'index', 'offset'], + outputs=['y'], + mode=0, + operator='embedding_bag') + + return ([index, offset, node], [weight], [y]) @onnx_test -def matmul_bmv_test(): - m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 6, 7]) - m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [7]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 6]) +def equal_test(): + ax1 = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + x1 = helper.make_tensor("x1", + data_type=TensorProto.FLOAT, + dims=(2, 3), + vals=ax1.astype(np.float32)) + + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) node = onnx.helper.make_node( - 'MatMul', - inputs=['1', '2'], + 'Equal', + inputs=['x1', 'x2'], outputs=['y'], ) - return ([node], [m1, m2], [y]) + return ([node], [x2], [y], [x1]) @onnx_test -def matmul_mv_test(): - m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [6, 7]) - m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [7]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [6]) +def equal_bool_test(): - node = onnx.helper.make_node( - 'MatMul', - inputs=['1', '2'], + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.BOOL, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) + + node1 = onnx.helper.make_node('Cast', inputs=['x1'], outputs=['bx1'], to=9) + + node2 = onnx.helper.make_node( + 'Equal', + inputs=['bx1', 'x2'], outputs=['y'], ) - return ([node], [m1, m2], [y]) + return ([node1, node2], [x1, x2], [y]) @onnx_test -def matmul_vbm_test(): - m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [7]) - m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [5, 7, 8]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5, 8]) +def erf_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10, 15]) node = onnx.helper.make_node( - 'MatMul', - inputs=['1', '2'], + 'Erf', + inputs=['x'], outputs=['y'], ) - return ([node], [m1, m2], [y]) + return ([node], [x], [y]) @onnx_test -def matmul_vm_test(): - m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [7]) - m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [7, 8]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [8]) +def exp_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10]) node = onnx.helper.make_node( - 'MatMul', - inputs=['1', '2'], + 'Exp', + inputs=['x'], outputs=['y'], ) - return ([node], [m1, m2], [y]) + return ([node], [x], [y]) @onnx_test -def matmul_vv_test(): - m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [7]) - m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [7]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1]) - - node = onnx.helper.make_node( - 'MatMul', - inputs=['1', '2'], - outputs=['y'], +def expand_test(): + shape_val = np.array([2, 3, 4, 5]).astype(np.int64) + shape_ts = helper.make_tensor(name='shape_tensor', + data_type=TensorProto.INT32, + dims=shape_val.shape, + vals=shape_val.flatten().astype(int)) + shape_const = helper.make_node( + 'Constant', + inputs=[], + outputs=['shape'], + value=shape_ts, ) + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 1, 1]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3, 4, 5]) - return ([node], [m1, m2], [y]) + node = onnx.helper.make_node('Expand', + inputs=['x', 'shape'], + outputs=['y']) + + return ([shape_const, node], [x], [y]) @onnx_test -def max_test(): - a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) - b = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) - c = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3]) - y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3]) +def eyelike_default_test(): + T1 = helper.make_tensor_value_info('T1', TensorProto.FLOAT, [3, 4]) + T2 = helper.make_tensor_value_info('T2', TensorProto.FLOAT, [3, 4]) node = onnx.helper.make_node( - 'Max', - inputs=['0', '1', '2'], - outputs=['3'], + 'EyeLike', + inputs=['T1'], + outputs=['T2'], ) - - return ([node], [a, b, c], [y]) + return ([node], [T1], [T2]) @onnx_test -def min_test(): - a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) - b = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) - c = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3]) - y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3]) +def eyelike_double_test(): + T1 = helper.make_tensor_value_info('T1', TensorProto.DOUBLE, [6, 15]) + T2 = helper.make_tensor_value_info('T2', TensorProto.DOUBLE, [6, 15]) node = onnx.helper.make_node( - 'Min', - inputs=['0', '1', '2'], - outputs=['3'], + 'EyeLike', + inputs=['T1'], + outputs=['T2'], ) - - return ([node], [a, b, c], [y]) + return ([node], [T1], [T2]) @onnx_test -def no_pad_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2]) - y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 2]) - - node = onnx.helper.make_node('Pad', - inputs=['0'], - pads=[0, 0, 0, 0], - outputs=['1']) +def eyelike_half_test(): + T1 = helper.make_tensor_value_info('T1', TensorProto.FLOAT16, [8, 8]) + T2 = helper.make_tensor_value_info('T2', TensorProto.FLOAT16, [8, 8]) - return ([node], [x], [y]) + node = onnx.helper.make_node( + 'EyeLike', + inputs=['T1'], + outputs=['T2'], + ) + return ([node], [T1], [T2]) @onnx_test -def pad_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2]) - y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 4]) +def eyelike_k_test(): + T1 = helper.make_tensor_value_info('T1', TensorProto.FLOAT, [3, 4]) + T2 = helper.make_tensor_value_info('T2', TensorProto.FLOAT, [3, 4]) + node = onnx.helper.make_node('EyeLike', inputs=['T1'], outputs=['T2'], k=1) + return ([node], [T1], [T2]) - node = onnx.helper.make_node('Pad', - inputs=['0'], - pads=[1, 1, 1, 1], - outputs=['1']) - return ([node], [x], [y]) +@onnx_test +def eyelike_k_outofbounds_neg_test(): + T1 = helper.make_tensor_value_info('T1', TensorProto.FLOAT, [2, 4]) + T2 = helper.make_tensor_value_info('T2', TensorProto.FLOAT, [2, 4]) + node = onnx.helper.make_node('EyeLike', + inputs=['T1'], + outputs=['T2'], + k=-2) + return ([node], [T1], [T2]) @onnx_test -def pow_test(): - arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) - arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 3, 4, 5]) - arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT, - [2, 3, 4, 5]) +def eyelike_k_outofbounds_pos_test(): + T1 = helper.make_tensor_value_info('T1', TensorProto.FLOAT, [3, 4]) + T2 = helper.make_tensor_value_info('T2', TensorProto.FLOAT, [3, 4]) + node = onnx.helper.make_node('EyeLike', inputs=['T1'], outputs=['T2'], k=4) + return ([node], [T1], [T2]) + +@onnx_test +def eyelike_not_rank2_test(): + T1 = helper.make_tensor_value_info('T1', TensorProto.FLOAT, [3, 4, 2]) + T2 = helper.make_tensor_value_info('T2', TensorProto.FLOAT, [3, 4]) node = onnx.helper.make_node( - 'Pow', - inputs=['0', '1'], - outputs=['out'], + 'EyeLike', + inputs=['T1'], + outputs=['T2'], ) - - return ([node], [arg0, arg1], [arg_out]) + return ([node], [T1], [T2]) @onnx_test -def reducemax_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 6]) - axes = [2] - - node = onnx.helper.make_node('ReduceMax', - inputs=['x'], - outputs=['y'], - axes=axes, - keepdims=0) - - return ([node], [x], [y]) +def eyelike_verify_test(): + T1 = helper.make_tensor_value_info('T1', TensorProto.FLOAT, [3, 4]) + T2 = helper.make_tensor_value_info('T2', TensorProto.FLOAT, [3, 4]) + node = onnx.helper.make_node('EyeLike', inputs=['T1'], outputs=['T2'], k=1) + return ([node], [T1], [T2]) @onnx_test -def reducemean_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4]) - axes = [2, 3] +def eyelike_verify_negk_test(): + T1 = helper.make_tensor_value_info('T1', TensorProto.FLOAT, [3, 4]) + T2 = helper.make_tensor_value_info('T2', TensorProto.FLOAT, [3, 4]) + node = onnx.helper.make_node('EyeLike', + inputs=['T1'], + outputs=['T2'], + k=-2) + return ([node], [T1], [T2]) - node = onnx.helper.make_node('ReduceMean', - inputs=['x'], - outputs=['y'], - axes=axes, - keepdims=0) - return ([node], [x], [y]) +@onnx_test +def eyelike_set_dtype_test(): + T1 = helper.make_tensor_value_info('T1', TensorProto.FLOAT, [3, 4]) + T2 = helper.make_tensor_value_info('T2', TensorProto.DOUBLE, [3, 4]) + node = onnx.helper.make_node('EyeLike', + inputs=['T1'], + outputs=['T2'], + dtype=TensorProto.DOUBLE) + return ([node], [T1], [T2]) @onnx_test -def reducemean_keepdims_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 6]) - axes = [2] +def flatten_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) + y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [6, 20]) + y2 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [2, 60]) - node = onnx.helper.make_node('ReduceMean', - inputs=['x'], - outputs=['y'], - axes=axes, - keepdims=1) + node = onnx.helper.make_node('Flatten', + inputs=['0'], + axis=2, + outputs=['2']) - return ([node], [x], [y]) + node2 = onnx.helper.make_node('Flatten', inputs=['0'], outputs=['3']) + + return ([node, node2], [x], [y, y2]) @onnx_test -def reducemin_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 1, 5, 1]) - axes = [1, 3] +def flatten_nonstd_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 5, 4]) + y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [6, 20]) + y2 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [2, 60]) - node = onnx.helper.make_node('ReduceMin', - inputs=['x'], - outputs=['y'], - axes=axes, - keepdims=1) + trans = helper.make_node( + 'Transpose', + inputs=['0'], + outputs=['tx'], + perm=[0, 1, 3, 2], + ) - return ([node], [x], [y]) + node = onnx.helper.make_node('Flatten', + inputs=['tx'], + axis=2, + outputs=['2']) + + node2 = onnx.helper.make_node('Flatten', inputs=['tx'], outputs=['3']) + + return ([trans, node, node2], [x], [y, y2]) @onnx_test -def reducesum_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 1]) - axes = [2] +def floor_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10]) - node = onnx.helper.make_node('ReduceSum', - inputs=['x'], - outputs=['y'], - axes=axes, - keepdims=1) + node = onnx.helper.make_node( + 'Floor', + inputs=['x'], + outputs=['y'], + ) return ([node], [x], [y]) @onnx_test -def reducesum_multiaxis_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 1]) - axes = [2, 3] +def gather_test(): + x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6]) + i = helper.make_tensor_value_info('indices', TensorProto.INT32, + [2, 3, 4, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3, 4, 5]) - node = onnx.helper.make_node('ReduceSum', - inputs=['x'], - outputs=['y'], - axes=axes, - keepdims=0) + node = onnx.helper.make_node( + 'Gather', + inputs=['data', 'indices'], + outputs=['y'], + axis=1, + ) - return ([node], [x], [y]) + return ([node], [x, i], [y]) + + +@onnx_test +def gather_elements_axis0_test(): + x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4]) + i = helper.make_tensor_value_info('indices', TensorProto.INT32, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) + + node = onnx.helper.make_node( + 'GatherElements', + inputs=['data', 'indices'], + outputs=['y'], + axis=0, + ) + + return ([node], [x, i], [y]) + + +@onnx_test +def gather_elements_axis1_test(): + x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4]) + i = helper.make_tensor_value_info('indices', TensorProto.INT32, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) + + node = onnx.helper.make_node( + 'GatherElements', + inputs=['data', 'indices'], + outputs=['y'], + axis=1, + ) + + return ([node], [x, i], [y]) + + +@onnx_test +def gathernd_test(): + x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2]) + i = helper.make_tensor_value_info('indices', TensorProto.INT64, [2, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2]) + + node = onnx.helper.make_node('GatherND', + inputs=['data', 'indices'], + outputs=['y']) + + return ([node], [x, i], [y]) + + +@onnx_test +def gathernd_batch_dims_test(): + x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2]) + i = helper.make_tensor_value_info('indices', TensorProto.INT64, [2, 1]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2]) + + node = onnx.helper.make_node( + 'GatherND', + inputs=['data', 'indices'], + outputs=['y'], + batch_dims=1, + ) + + return ([node], [x, i], [y]) + + +@onnx_test +def gemm_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 7]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [11, 5]) + z = helper.make_tensor_value_info('2', TensorProto.FLOAT, []) + a = helper.make_tensor_value_info('3', TensorProto.FLOAT, [7, 11]) + + node = onnx.helper.make_node('Gemm', + inputs=['0', '1', '2'], + outputs=['3'], + alpha=2.0, + beta=2.0, + transA=1, + transB=1) + + return ([node], [x, y, z], [a]) + + +@onnx_test +def gemm_ex_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 1, 8, 6]) + m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 8, 7]) + m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 1, 6, 7]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 6, 7]) + + node = onnx.helper.make_node('Gemm', + inputs=['1', '2', '3'], + outputs=['y'], + alpha=0.5, + beta=0.8, + transA=1) + + return ([node], [m1, m2, m3], [y]) + + +@onnx_test +def gemm_ex_brcst_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 1, 5, 6]) + m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 5, 7]) + m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 1, 6, 1]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 6, 7]) + + node = onnx.helper.make_node('Gemm', + inputs=['1', '2', '3'], + outputs=['y'], + alpha=0.5, + beta=0.8, + transA=1) + + return ([node], [m1, m2, m3], [y]) + + +@onnx_test +def gemm_half_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT16, [1, 1, 8, 6]) + m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT16, [1, 1, 8, 7]) + m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT16, [1, 1, 6, 1]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [1, 1, 6, 7]) + + node = onnx.helper.make_node('Gemm', + inputs=['1', '2', '3'], + outputs=['y'], + alpha=0.5, + beta=0.8, + transA=1) + + return ([node], [m1, m2, m3], [y]) + + +@onnx_test +def globalavgpool_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 1, 1]) + + node = onnx.helper.make_node( + 'GlobalAveragePool', + inputs=['0'], + outputs=['1'], + ) + + return ([node], [x], [y]) + + +@onnx_test +def globallppool_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 1, 1]) + + node = onnx.helper.make_node( + 'GlobalLpPool', + inputs=['0'], + outputs=['1'], + ) + + return ([node], [x], [y]) + + +@onnx_test +def globalmaxpool_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 1, 1]) + + node = onnx.helper.make_node( + 'GlobalMaxPool', + inputs=['0'], + outputs=['1'], + ) + + return ([node], [x], [y]) + + +@onnx_test +def greater_test(): + ax1 = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + x1 = helper.make_tensor("x1", + data_type=TensorProto.FLOAT, + dims=(2, 3), + vals=ax1.astype(np.float32)) + + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) + + node = onnx.helper.make_node( + 'Greater', + inputs=['x1', 'x2'], + outputs=['y'], + ) + + return ([node], [x2], [y], [x1]) + + +@onnx_test +def greater_bool_test(): + + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.BOOL, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) + + node1 = onnx.helper.make_node('Cast', inputs=['x1'], outputs=['bx1'], to=9) + + node2 = onnx.helper.make_node( + 'Greater', + inputs=['bx1', 'x2'], + outputs=['y'], + ) + + return ([node1, node2], [x1, x2], [y]) + + +@onnx_test +def greaterorequal_test(): + + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node( + 'GreaterOrEqual', + inputs=['x1', 'x2'], + outputs=['y'], + ) + + return ([node], [x1, x2], [y]) + + +@onnx_test +def group_conv_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 4, 16, 16]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 1, 3, 3]) + z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 4, 14, 14]) + + node = onnx.helper.make_node( + 'Conv', + inputs=['0', '1'], + group=4, + outputs=['2'], + ) + + return ([node], [x, y], [z]) + + +@onnx_test +def hardsigmoid_default_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 3, 4, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 3, 4, 5]) + + node = onnx.helper.make_node('HardSigmoid', inputs=['x'], outputs=['y']) + + return ([node], [x], [y]) + + +@onnx_test +def hardsigmoid_double_test(): + x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [1, 3, 4, 5]) + y = helper.make_tensor_value_info('y', TensorProto.DOUBLE, [1, 3, 4, 5]) + + node = onnx.helper.make_node('HardSigmoid', + inputs=['x'], + outputs=['y'], + alpha=0.3, + beta=0.7) + + return ([node], [x], [y]) + + +@onnx_test +def hardsigmoid_half_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [1, 3, 4, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [1, 3, 4, 5]) + + node = onnx.helper.make_node('HardSigmoid', inputs=['x'], outputs=['y']) + + return ([node], [x], [y]) + + +@onnx_test +def hardsigmoid_verify_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 5]) + + node = onnx.helper.make_node('HardSigmoid', inputs=['x'], outputs=['y']) + + return ([node], [x], [y]) + + +@onnx_test +def hardswish_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 5]) + + node = onnx.helper.make_node('HardSwish', inputs=['x'], outputs=['y']) + + return ([node], [x], [y]) + + +@onnx_test +def if_else_test(): + x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3]) + y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3]) + + then_out = onnx.helper.make_tensor_value_info('then_out', + onnx.TensorProto.FLOAT, + [2, 3]) + else_out = onnx.helper.make_tensor_value_info('else_out', + onnx.TensorProto.FLOAT, + [2, 3]) + + xt = np.ones((2, 3)).astype(np.float) + xt_tensor = helper.make_tensor(name='xt', + data_type=TensorProto.FLOAT, + dims=xt.shape, + vals=xt.flatten().astype(np.float32)) + + yt = np.random.randn(2, 3).astype(np.float) + yt_tensor = helper.make_tensor(name='yt', + data_type=TensorProto.FLOAT, + dims=yt.shape, + vals=yt.flatten().astype(np.float32)) + + then_add_node = onnx.helper.make_node('Add', + inputs=['x', 'xt'], + outputs=['then_out']) + + else_mul_node = onnx.helper.make_node('Mul', + inputs=['y', 'yt'], + outputs=['else_out']) + + then_body = onnx.helper.make_graph([then_add_node], 'then_body', [], + [then_out]) + + else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [], + [else_out]) + + cond = np.array([0]).astype(np.bool) + cond_tensor = helper.make_tensor(name="cond", + data_type=TensorProto.BOOL, + dims=cond.shape, + vals=cond.astype(bool)) + res = onnx.helper.make_tensor_value_info('res', TensorProto.FLOAT, []) + + node = onnx.helper.make_node('If', + inputs=['cond'], + outputs=['res'], + then_branch=then_body, + else_branch=else_body) + + return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor]) + + +@onnx_test +def if_literal_test(): + then_out = onnx.helper.make_tensor_value_info('then_out', + onnx.TensorProto.FLOAT, [5]) + else_out = onnx.helper.make_tensor_value_info('else_out', + onnx.TensorProto.FLOAT, [5]) + + x = np.array([1, 2, 3, 4, 5]).astype(np.float32) + y = np.array([5, 4, 3, 2, 1]).astype(np.float32) + z = np.array([]).astype(np.float32) + + then_const_node = onnx.helper.make_node( + 'Constant', + inputs=[], + outputs=['then_out'], + value=onnx.numpy_helper.from_array(x)) + + else_const_node = onnx.helper.make_node( + 'Constant', + inputs=[], + outputs=['else_out'], + value=onnx.numpy_helper.from_array(y)) + + empty_const_node = onnx.helper.make_node( + 'Constant', + inputs=[], + outputs=['empty_out'], + value=onnx.numpy_helper.from_array(z)) + + then_body = onnx.helper.make_graph([then_const_node, empty_const_node], + 'then_body', [], [then_out]) + + else_body = onnx.helper.make_graph([else_const_node, empty_const_node], + 'else_body', [], [else_out]) + + cond_input = onnx.helper.make_tensor_value_info('cond', + onnx.TensorProto.BOOL, []) + ret = onnx.helper.make_tensor_value_info('ret', TensorProto.FLOAT, []) + + node = onnx.helper.make_node('If', + inputs=['cond'], + outputs=['ret'], + then_branch=then_body, + else_branch=else_body) + + return ([node], [cond_input], [ret]) + + +@onnx_test +def if_param_excp_test(): + then_out = onnx.helper.make_tensor_value_info('then_out', + onnx.TensorProto.FLOAT, + [2, 3]) + else_out = onnx.helper.make_tensor_value_info('else_out', + onnx.TensorProto.FLOAT, + [2, 3]) + + x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3]) + y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 4]) + + yt = np.random.randn(2, 4).astype(np.float) + xt = np.random.randn(2, 3).astype(np.float) + + xt_tensor = helper.make_tensor(name='xt', + data_type=TensorProto.FLOAT, + dims=xt.shape, + vals=xt.flatten().astype(np.float32)) + + yt_tensor = helper.make_tensor(name='yt', + data_type=TensorProto.FLOAT, + dims=yt.shape, + vals=yt.flatten().astype(np.float32)) + + then_add_node = onnx.helper.make_node('Add', + inputs=['x', 'xt'], + outputs=['then_out']) + + else_mul_node = onnx.helper.make_node('Mul', + inputs=['y', 'yt'], + outputs=['else_out']) + + then_body = onnx.helper.make_graph([then_add_node], 'then_body', [], + [then_out], [xt_tensor]) + + else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [], + [else_out], [yt_tensor]) + + cond_input = onnx.helper.make_tensor_value_info('cond', + onnx.TensorProto.BOOL, []) + ret = onnx.helper.make_tensor_value_info('ret', TensorProto.FLOAT, []) + + node = onnx.helper.make_node('If', + inputs=['cond'], + outputs=['ret'], + then_branch=then_body, + else_branch=else_body) + + return ([node], [cond_input, x, y], [ret]) + + +@onnx_test +def if_param_excp1_test(): + then_out = onnx.helper.make_tensor_value_info('sub_out', + onnx.TensorProto.FLOAT, + [2, 3]) + + x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3]) + + xt = np.random.randn(2, 3).astype(np.float) + + xt_tensor = helper.make_tensor(name='xt', + data_type=TensorProto.FLOAT, + dims=xt.shape, + vals=xt.flatten().astype(np.float32)) + + then_add_node = onnx.helper.make_node('Add', + inputs=['x', 'xt'], + outputs=['sub_out']) + + sub_body = onnx.helper.make_graph([then_add_node], 'sub_body', [], + [then_out], [xt_tensor]) + + cond_input = onnx.helper.make_tensor_value_info('cond', + onnx.TensorProto.BOOL, [2]) + ret = onnx.helper.make_tensor_value_info('ret', TensorProto.FLOAT, []) + + node = onnx.helper.make_node('If', + inputs=['cond'], + outputs=['ret'], + then_branch=sub_body, + else_branch=sub_body) + + return ([node], [cond_input, x], [ret]) + + +@onnx_test +def if_param_test(): + then_out = onnx.helper.make_tensor_value_info('then_out', + onnx.TensorProto.FLOAT, + [2, 3]) + else_out = onnx.helper.make_tensor_value_info('else_out', + onnx.TensorProto.FLOAT, + [2, 3]) + + x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3]) + y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3]) + + yt = np.random.randn(2, 3).astype(np.float) + xt = np.random.randn(2, 3).astype(np.float) + + xt_tensor = helper.make_tensor(name='xt', + data_type=TensorProto.FLOAT, + dims=xt.shape, + vals=xt.flatten().astype(np.float32)) + + yt_tensor = helper.make_tensor(name='yt', + data_type=TensorProto.FLOAT, + dims=yt.shape, + vals=yt.flatten().astype(np.float32)) + + then_add_node = onnx.helper.make_node('Add', + inputs=['x', 'xt'], + outputs=['then_out']) + + else_mul_node = onnx.helper.make_node('Mul', + inputs=['y', 'yt'], + outputs=['else_out']) + + then_body = onnx.helper.make_graph([then_add_node], 'then_body', [], + [then_out], [xt_tensor]) + + else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [], + [else_out], [yt_tensor]) + + cond_input = onnx.helper.make_tensor_value_info('cond', + onnx.TensorProto.BOOL, []) + ret = onnx.helper.make_tensor_value_info('ret', TensorProto.FLOAT, []) + + node = onnx.helper.make_node('If', + inputs=['cond'], + outputs=['ret'], + then_branch=then_body, + else_branch=else_body) + + return ([node], [cond_input, x, y], [ret]) + + +@onnx_test +def if_pl_test(): + out_x = onnx.helper.make_tensor_value_info('out_x', onnx.TensorProto.FLOAT, + [2, 3]) + out_l_x = onnx.helper.make_tensor_value_info('out_l_x', + onnx.TensorProto.FLOAT, + [2, 3]) + out_y = onnx.helper.make_tensor_value_info('out_y', onnx.TensorProto.FLOAT, + [3, 3]) + out_l_y = onnx.helper.make_tensor_value_info('out_l_y', + onnx.TensorProto.FLOAT, + [3, 3]) + + x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3]) + y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [3, 3]) + + xt = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32) + yt = np.array([[8, 7, 6], [5, 4, 3], [2, 1, 0]]).astype(np.float32) + + xt_tensor = helper.make_tensor(name='xt', + data_type=TensorProto.FLOAT, + dims=xt.shape, + vals=xt.flatten().astype(np.float32)) + + yt_tensor = helper.make_tensor(name='yt', + data_type=TensorProto.FLOAT, + dims=yt.shape, + vals=yt.flatten().astype(np.float32)) + + then_add_node = onnx.helper.make_node('Add', + inputs=['x', 'xt'], + outputs=['out_x']) + + else_mul_node = onnx.helper.make_node('Mul', + inputs=['y', 'yt'], + outputs=['out_y']) + + then_const_node = onnx.helper.make_node( + 'Constant', + inputs=[], + outputs=['out_l_y'], + value=onnx.numpy_helper.from_array(yt)) + + else_const_node = onnx.helper.make_node( + 'Constant', + inputs=[], + outputs=['out_l_x'], + value=onnx.numpy_helper.from_array(xt)) + + then_body = onnx.helper.make_graph([then_add_node, then_const_node], + 'then_body', [], [out_x, out_l_y]) + + else_body = onnx.helper.make_graph([else_mul_node, else_const_node], + 'else_body', [], [out_l_x, out_y]) + + cond_input = onnx.helper.make_tensor_value_info('cond', + onnx.TensorProto.BOOL, []) + ret = onnx.helper.make_tensor_value_info('ret', TensorProto.FLOAT, []) + + node = onnx.helper.make_node('If', + inputs=['cond'], + outputs=['ret'], + then_branch=then_body, + else_branch=else_body) + + return ([node], [cond_input, x, y], [ret], [xt_tensor, yt_tensor]) + + +@onnx_test +def if_then_test(): + x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3]) + y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3]) + + then_out = onnx.helper.make_tensor_value_info('then_out', + onnx.TensorProto.FLOAT, + [2, 3]) + else_out = onnx.helper.make_tensor_value_info('else_out', + onnx.TensorProto.FLOAT, + [2, 3]) + + xt = np.ones((2, 3)).astype(np.float) + xt_tensor = helper.make_tensor(name='xt', + data_type=TensorProto.FLOAT, + dims=xt.shape, + vals=xt.flatten().astype(np.float32)) + + yt = np.random.randn(2, 3).astype(np.float) + yt_tensor = helper.make_tensor(name='yt', + data_type=TensorProto.FLOAT, + dims=yt.shape, + vals=yt.flatten().astype(np.float32)) + + then_add_node = onnx.helper.make_node('Add', + inputs=['x', 'xt'], + outputs=['then_out']) + + else_mul_node = onnx.helper.make_node('Mul', + inputs=['y', 'yt'], + outputs=['else_out']) + + then_body = onnx.helper.make_graph([then_add_node], 'then_body', [], + [then_out]) + + else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [], + [else_out]) + + cond = np.array([1]).astype(np.bool) + cond_tensor = helper.make_tensor(name="cond", + data_type=TensorProto.BOOL, + dims=cond.shape, + vals=cond.astype(bool)) + res = onnx.helper.make_tensor_value_info('res', TensorProto.FLOAT, []) + + node = onnx.helper.make_node('If', + inputs=['cond'], + outputs=['res'], + then_branch=then_body, + else_branch=else_body) + + return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor]) + + +@onnx_test +def if_tuple_test(): + x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [1, 4]) + y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [3, 4]) + cond_input = onnx.helper.make_tensor_value_info('cond', + onnx.TensorProto.BOOL, []) + + then_out0 = onnx.helper.make_tensor_value_info('then_out0', + onnx.TensorProto.FLOAT, + [1, 4]) + then_out1 = onnx.helper.make_tensor_value_info('then_out1', + onnx.TensorProto.FLOAT, + [3, 4]) + else_out0 = onnx.helper.make_tensor_value_info('else_out0', + onnx.TensorProto.FLOAT, + [1, 4]) + else_out1 = onnx.helper.make_tensor_value_info('else_out1', + onnx.TensorProto.FLOAT, + [3, 4]) + + one = np.ones([1]).astype(np.float) + one_tensor = helper.make_tensor(name='one', + data_type=TensorProto.FLOAT, + dims=one.shape, + vals=one.flatten().astype(np.float32)) + + two = np.array([2]).astype(np.float) + two_tensor = helper.make_tensor(name='two', + data_type=TensorProto.FLOAT, + dims=two.shape, + vals=two.flatten().astype(np.float32)) + + three = np.array([3]).astype(np.float) + three_tensor = helper.make_tensor(name='three', + data_type=TensorProto.FLOAT, + dims=three.shape, + vals=three.flatten().astype(np.float32)) + + then_add_node = onnx.helper.make_node('Add', + inputs=['x', 'one'], + outputs=['then_out0']) + then_mul_node = onnx.helper.make_node('Mul', + inputs=['y', 'two'], + outputs=['then_out1']) + + else_mul_node = onnx.helper.make_node('Mul', + inputs=['x', 'three'], + outputs=['else_out0']) + else_add_node = onnx.helper.make_node('Add', + inputs=['y', 'three'], + outputs=['else_out1']) + + then_body = onnx.helper.make_graph([then_add_node, then_mul_node], + 'then_body', [], [then_out0, then_out1]) + + else_body = onnx.helper.make_graph([else_mul_node, else_add_node], + 'else_body', [], [else_out0, else_out1]) + + res0 = onnx.helper.make_tensor_value_info('res0', TensorProto.FLOAT, []) + res1 = onnx.helper.make_tensor_value_info('res1', TensorProto.FLOAT, []) + + node = onnx.helper.make_node('If', + inputs=['cond'], + outputs=['res0', 'res1'], + then_branch=then_body, + else_branch=else_body) + + return ([node], [cond_input, x, + y], [res0, res1], [one_tensor, two_tensor, three_tensor]) + + +@onnx_test +def imagescaler_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 16, 16]) + + node = onnx.helper.make_node('ImageScaler', + inputs=['0'], + outputs=['1'], + bias=[0.01, 0.02, 0.03], + scale=0.5) + + return ([node], [x], [y]) + + +@onnx_test +def imagescaler_half_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT16, [1, 3, 16, 16]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT16, [1, 3, 16, 16]) + + node = onnx.helper.make_node('ImageScaler', + inputs=['0'], + outputs=['1'], + bias=[0.01, 0.02, 0.03], + scale=0.5) + + return ([node], [x], [y]) + + +@onnx_test +def implicit_add_bcast_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 4, 1]) + z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [2, 3, 4, 5]) + + node = onnx.helper.make_node( + 'Add', + inputs=['0', '1'], + outputs=['2'], + ) + + return ([node], [x, y], [z]) + + +@onnx_test +def implicit_pow_bcast_test(): + arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) + arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 4, 1]) + arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT, + [2, 3, 4, 5]) + + node = onnx.helper.make_node( + 'Pow', + inputs=['0', '1'], + outputs=['out'], + ) + + return ([node], [arg0, arg1], [arg_out]) + + +@onnx_test +def implicit_sub_bcast_test(): + arg0 = helper.make_tensor_value_info('0', TensorProto.UINT64, [2, 3, 4, 5]) + arg1 = helper.make_tensor_value_info('1', TensorProto.UINT64, [4, 5]) + arg_out = helper.make_tensor_value_info('out', TensorProto.UINT64, + [2, 3, 4, 5]) + + node = onnx.helper.make_node( + 'Sub', + inputs=['0', '1'], + outputs=['out'], + ) + + return ([node], [arg0, arg1], [arg_out]) + + +@onnx_test +def initializer_not_an_input(): + values = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) + w = helper.make_tensor(name='w', + data_type=TensorProto.FLOAT, + dims=values.shape, + vals=values.flatten().astype(np.float)) + + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5, 4]) + + node = onnx.helper.make_node( + 'Gemm', + inputs=['x', 'w'], + outputs=['y'], + ) + + return ([node], [x], [y], [w]) + + +@onnx_test +def instance_norm_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 3, 3]) + scale = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2]) + bias = helper.make_tensor_value_info('2', TensorProto.FLOAT, [2]) + y = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 2, 3, 3]) + + node = onnx.helper.make_node('InstanceNormalization', + inputs=['0', '1', '2'], + outputs=['3']) + + return ([node], [x, scale, bias], [y]) + + +@onnx_test +def instance_norm_val_test(): + x = np.array([[[[0, 1, 2], [3, 4, 5], [6, 7, 8]], + [[0, 1, 2], [3, 4, 5], [6, 7, 8]]]]) + scale = np.array([1, 2]) + bias = np.array([0, 1]) + + x_tensor = helper.make_tensor(name='x_tensor', + data_type=TensorProto.FLOAT, + dims=x.shape, + vals=x.flatten().astype(np.float)) + scale_tensor = helper.make_tensor(name='scale_tensor', + data_type=TensorProto.FLOAT, + dims=scale.shape, + vals=scale.flatten().astype(np.float)) + bias_tensor = helper.make_tensor(name='bias_tensor', + data_type=TensorProto.FLOAT, + dims=bias.shape, + vals=bias.flatten().astype(np.float)) + + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 3, 3]) + + node = onnx.helper.make_node( + 'InstanceNormalization', + inputs=['x_tensor', 'scale_tensor', 'bias_tensor'], + outputs=['y']) + + return ([node], [], [y], [x_tensor, scale_tensor, bias_tensor]) + + +@onnx_test +def instance_norm_val_3d_test(): + x = np.array([[[[[0, 1], [2, 3]], [[4, 5], [6, 7]]], + [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]]]) + scale = np.array([1, 2]) + bias = np.array([0, 1]) + + x_tensor = helper.make_tensor(name='x_tensor', + data_type=TensorProto.FLOAT, + dims=x.shape, + vals=x.flatten().astype(np.float)) + scale_tensor = helper.make_tensor(name='scale_tensor', + data_type=TensorProto.FLOAT, + dims=scale.shape, + vals=scale.flatten().astype(np.float)) + bias_tensor = helper.make_tensor(name='bias_tensor', + data_type=TensorProto.FLOAT, + dims=bias.shape, + vals=bias.flatten().astype(np.float)) + + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 2, 2, 2, 2]) + + node = onnx.helper.make_node( + 'InstanceNormalization', + inputs=['x_tensor', 'scale_tensor', 'bias_tensor'], + outputs=['y']) + + return ([node], [], [y], [x_tensor, scale_tensor, bias_tensor]) + + +@onnx_test +def isnan_float_test(): + t1 = helper.make_tensor_value_info('t1', TensorProto.FLOAT, [2, 3]) + t2 = helper.make_tensor_value_info('t2', TensorProto.FLOAT, [2, 3]) + + node = onnx.helper.make_node( + 'IsNaN', + inputs=['t1'], + outputs=['t2'], + ) + return ([node], [t1], [t2]) + + +@onnx_test +def isnan_half_test(): + t1 = helper.make_tensor_value_info('t1', TensorProto.FLOAT16, [2, 3]) + t2 = helper.make_tensor_value_info('t2', TensorProto.FLOAT16, [2, 3]) + + node = onnx.helper.make_node( + 'IsNaN', + inputs=['t1'], + outputs=['t2'], + ) + return ([node], [t1], [t2]) + + +@onnx_test +def layernorm_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 1, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 1, 5]) + scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT, [5]) + bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [5]) + axes = [2] + pow_2 = np.array([[[2, 2, 2, 2, 2]]]) + epsilon = np.array([1e-12]) + + pow_tensor = helper.make_tensor(name='pow', + data_type=TensorProto.FLOAT, + dims=pow_2.shape, + vals=pow_2.flatten().astype(np.float)) + + epsilon_tensor = helper.make_tensor(name='epsilon', + data_type=TensorProto.FLOAT, + dims=epsilon.shape, + vals=epsilon.flatten().astype( + np.float)) + + mean = onnx.helper.make_node('ReduceMean', + inputs=['0'], + outputs=['mean_out'], + axes=axes) + + sub_mean = onnx.helper.make_node('Sub', + inputs=['0', 'mean_out'], + outputs=['sub_out']) + + sub_pow = onnx.helper.make_node('Pow', + inputs=['sub_out', 'pow'], + outputs=['pow_out']) + + var = onnx.helper.make_node('ReduceMean', + inputs=['pow_out'], + outputs=['var_out'], + axes=axes) + + add = onnx.helper.make_node('Add', + inputs=['var_out', 'epsilon'], + outputs=['add_out']) + + sqrt = onnx.helper.make_node('Sqrt', + inputs=['add_out'], + outputs=['sqrt_out']) + + div = onnx.helper.make_node('Div', + inputs=['sub_out', 'sqrt_out'], + outputs=['div_out']) + + mul = onnx.helper.make_node('Mul', + inputs=['scale', 'div_out'], + outputs=['mul_out']) + + bias_add = onnx.helper.make_node('Add', + inputs=['mul_out', 'bias'], + outputs=['1']) + + return ([mean, sub_mean, sub_pow, var, add, sqrt, div, mul, + bias_add], [x, scale, bias], [y], [pow_tensor, epsilon_tensor]) + + +@onnx_test +def leaky_relu_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node('LeakyRelu', + inputs=['0'], + outputs=['1'], + alpha=0.01) + + return ([node], [x], [y]) + + +@onnx_test +def less_test(): + ax1 = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + x1 = helper.make_tensor("x1", + data_type=TensorProto.FLOAT, + dims=(2, 3), + vals=ax1.astype(np.float32)) + + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) + + node = onnx.helper.make_node( + 'Less', + inputs=['x1', 'x2'], + outputs=['y'], + ) + + return ([node], [x2], [y], [x1]) + + +@onnx_test +def less_bool_test(): + + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [2, 3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.BOOL, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) + + node1 = onnx.helper.make_node('Cast', inputs=['x1'], outputs=['bx1'], to=9) + + node2 = onnx.helper.make_node( + 'Less', + inputs=['bx1', 'x2'], + outputs=['y'], + ) + + return ([node1, node2], [x1, x2], [y]) + + +@onnx_test +def lessorequal_test(): + + x1 = helper.make_tensor_value_info('x1', TensorProto.FLOAT, [3]) + x2 = helper.make_tensor_value_info('x2', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node( + 'LessOrEqual', + inputs=['x1', 'x2'], + outputs=['y'], + ) + + return ([node], [x1, x2], [y]) + + +@onnx_test +def log_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10]) + + node = onnx.helper.make_node( + 'Log', + inputs=['x'], + outputs=['y'], + ) + + return ([node], [x], [y]) + + +@onnx_test +def logical_and_bcast_test(): + x = helper.make_tensor_value_info('0', TensorProto.BOOL, [2, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.BOOL, [4, 5]) + z = helper.make_tensor_value_info('2', TensorProto.BOOL, [2, 3, 4, 5]) + + node = onnx.helper.make_node('And', inputs=['0', '1'], outputs=['2']) + + return ([node], [x, y], [z]) + + +@onnx_test +def logical_or_test(): + x = helper.make_tensor_value_info('0', TensorProto.BOOL, [2, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.BOOL, [2, 3, 4, 5]) + z = helper.make_tensor_value_info('2', TensorProto.BOOL, [2, 3, 4, 5]) + + node = onnx.helper.make_node('Or', inputs=['0', '1'], outputs=['2']) + + return ([node], [x, y], [z]) + + +@onnx_test +def logical_xor_bcast_test(): + x = helper.make_tensor_value_info('0', TensorProto.BOOL, [2, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.BOOL, [4, 1]) + z = helper.make_tensor_value_info('2', TensorProto.BOOL, [2, 3, 4, 5]) + + node = onnx.helper.make_node('Xor', inputs=['0', '1'], outputs=['2']) + + return ([node], [x, y], [z]) + + +@onnx_test +def logsoftmax_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6]) + + node = onnx.helper.make_node('LogSoftmax', + inputs=['x'], + outputs=['y'], + axis=1) + + return ([node], [x], [y]) + + +@onnx_test +def logsoftmax_nonstd_input_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [6, 9]) + y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3, 4]) + + node0 = onnx.helper.make_node('Slice', + inputs=['0'], + axes=[0, 1], + starts=[1, 0], + ends=[4, 4], + outputs=['1']) + + node1 = onnx.helper.make_node('LogSoftmax', + inputs=['1'], + outputs=['2'], + axis=-1) + + return ([node0, node1], [x], [y]) + + +@onnx_test +def loop_default_test(): + body = helper.make_graph([ + helper.make_node("Add", ["a", "b_in"], ["my_local"]), + helper.make_node("Sub", ["a", "b_in"], ["a_sub_b_in"]), + helper.make_node("Greater", ["my_local", "a_sub_b_in"], + ["keep_going"]), + helper.make_node("Add", ["a_sub_b_in", "a_sub_b_in"], + ["user_defined_vals"]), + ], "body", [ + helper.make_tensor_value_info('iteration_num', TensorProto.INT64, []), + helper.make_tensor_value_info('keep_going_inp', TensorProto.BOOL, []), + helper.make_tensor_value_info('b_in', TensorProto.FLOAT, []) + ], [ + helper.make_tensor_value_info('keep_going', TensorProto.BOOL, []), + helper.make_tensor_value_info('a_sub_b_in', TensorProto.FLOAT, []), + helper.make_tensor_value_info('my_local', TensorProto.FLOAT, []), + helper.make_tensor_value_info('user_defined_vals', TensorProto.FLOAT, + []), + ]) + + node = helper.make_node( + "Loop", + inputs=["", "", "b"], + outputs=["b_loop", "my_local_loop", "user_defined_vals_loop"], + body=body) + + a = helper.make_tensor_value_info('a', TensorProto.FLOAT, []) + b = helper.make_tensor_value_info('b', TensorProto.FLOAT, []) + + b_loop = helper.make_tensor_value_info('b_loop', TensorProto.FLOAT, []) + uout = helper.make_tensor_value_info('user_defined_vals_loop', + TensorProto.FLOAT, [2, 1]) + + return ([node], [a, b], [b_loop, uout]) + + +@onnx_test +def loop_test(): + body = helper.make_graph([ + helper.make_node("Add", ["a", "b_in"], ["my_local"]), + helper.make_node("Sub", ["a", "b_in"], ["a_sub_b_in"]), + helper.make_node("Greater", ["my_local", "a_sub_b_in"], + ["keep_going"]), + helper.make_node("Add", ["a_sub_b_in", "a_sub_b_in"], + ["user_defined_vals"]), + ], "body", [ + helper.make_tensor_value_info('iteration_num', TensorProto.INT64, [1]), + helper.make_tensor_value_info('keep_going_inp', TensorProto.BOOL, [1]), + helper.make_tensor_value_info('b_in', TensorProto.FLOAT, [1]) + ], [ + helper.make_tensor_value_info('keep_going', TensorProto.BOOL, [1]), + helper.make_tensor_value_info('a_sub_b_in', TensorProto.FLOAT, [1]), + helper.make_tensor_value_info('my_local', TensorProto.FLOAT, [1]), + helper.make_tensor_value_info('user_defined_vals', TensorProto.FLOAT, + [1]), + ]) + + node = helper.make_node( + "Loop", + inputs=["max_trip_count", "keep_going_cond", "b"], + outputs=["b_loop", "my_local_loop", "user_defined_vals_loop"], + body=body) + + a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [1]) + b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [1]) + cond = helper.make_tensor_value_info('keep_going_cond', TensorProto.BOOL, + [1]) + iter = helper.make_tensor_value_info('max_trip_count', TensorProto.INT64, + [1]) + + b_loop = helper.make_tensor_value_info('b_loop', TensorProto.FLOAT, [1]) + uout = helper.make_tensor_value_info('user_defined_vals_loop', + TensorProto.FLOAT, [2, 1]) + + return ([node], [iter, cond, a, b], [b_loop, uout]) + + +@onnx_test +def lpnormalization_axis_error_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) + + node = onnx.helper.make_node('LpNormalization', + inputs=['x'], + outputs=['y'], + axis=2) + return ([node], [x], [y]) + + +@onnx_test +def lpnormalization_default_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4]) + + node = onnx.helper.make_node( + 'LpNormalization', + inputs=['x'], + outputs=['y'], + axis=0, + ) + return ([node], [x], [y]) + + +@onnx_test +def lpnormalization_l1_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4]) + + node = onnx.helper.make_node( + 'LpNormalization', + inputs=['x'], + outputs=['y'], + p=1, + ) + return ([node], [x], [y]) + + +@onnx_test +def lpnormalization_l2_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4]) + + node = onnx.helper.make_node('LpNormalization', + inputs=['x'], + outputs=['y'], + p=2) + return ([node], [x], [y]) + + +@onnx_test +def lpnormalization_p_error_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3]) + + node = onnx.helper.make_node('LpNormalization', + inputs=['x'], + outputs=['y'], + p=3) + return ([node], [x], [y]) + + +@onnx_test +def lppool_l1_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 3, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 3, 3]) + + node = onnx.helper.make_node('LpPool', + inputs=['x'], + outputs=['y'], + kernel_shape=[3], + p=1) + return ([node], [x], [y]) + + +@onnx_test +def lppool_l2_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 3, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 3, 3]) + + node = onnx.helper.make_node('LpPool', + inputs=['x'], + outputs=['y'], + kernel_shape=[3], + p=2) + return ([node], [x], [y]) + + +@onnx_test +def lrn_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 28, 24, 24]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 28, 24, 24]) + + node = onnx.helper.make_node('LRN', + inputs=['0'], + size=5, + alpha=0.0001, + beta=0.75, + bias=1.0, + outputs=['1']) + + return ([node], [x], [y]) + + +@onnx_test +def matmul_bmbm_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 6, 7]) + m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [5, 2, 1, 7, 8]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5, 2, 3, 6, 8]) + + node = onnx.helper.make_node( + 'MatMul', + inputs=['1', '2'], + outputs=['y'], + ) + + return ([node], [m1, m2], [y]) + + +@onnx_test +def matmul_bmv_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 6, 7]) + m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [7]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 6]) + + node = onnx.helper.make_node( + 'MatMul', + inputs=['1', '2'], + outputs=['y'], + ) + + return ([node], [m1, m2], [y]) + + +@onnx_test +def matmul_mv_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [6, 7]) + m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [7]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [6]) + + node = onnx.helper.make_node( + 'MatMul', + inputs=['1', '2'], + outputs=['y'], + ) + + return ([node], [m1, m2], [y]) + + +@onnx_test +def matmul_vbm_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [7]) + m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [5, 7, 8]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5, 8]) + + node = onnx.helper.make_node( + 'MatMul', + inputs=['1', '2'], + outputs=['y'], + ) + + return ([node], [m1, m2], [y]) + + +@onnx_test +def matmul_vm_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [7]) + m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [7, 8]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [8]) + + node = onnx.helper.make_node( + 'MatMul', + inputs=['1', '2'], + outputs=['y'], + ) + + return ([node], [m1, m2], [y]) + + +@onnx_test +def matmul_vv_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [7]) + m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [7]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1]) + + node = onnx.helper.make_node( + 'MatMul', + inputs=['1', '2'], + outputs=['y'], + ) + + return ([node], [m1, m2], [y]) + + +@onnx_test +def matmulinteger_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [3, 6, 16]) + m2 = helper.make_tensor_value_info('2', TensorProto.INT8, [3, 16, 8]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [3, 6, 8]) + + node = onnx.helper.make_node( + 'MatMulInteger', + inputs=['1', '2'], + outputs=['y'], + ) + + return ([node], [m1, m2], [y]) + + +@onnx_test +def max_test(): + a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) + b = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) + c = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node( + 'Max', + inputs=['0', '1', '2'], + outputs=['3'], + ) + + return ([node], [a, b, c], [y]) + + +@onnx_test +def maxpool_notset_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 1, 1]) + + node = onnx.helper.make_node('MaxPool', + inputs=['x'], + outputs=['y'], + kernel_shape=[6, 6], + strides=[2, 2], + pads=[0, 0, 1, 1], + auto_pad='NOTSET') + + return ([node], [x], [y]) + + +@onnx_test +def maxpool_same_upper_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 5, 5]) + + node = onnx.helper.make_node('MaxPool', + inputs=['x'], + outputs=['y'], + kernel_shape=[2, 2], + auto_pad='SAME_UPPER') + + return ([node], [x], [y]) + + +@onnx_test +def mean_broadcast_test(): + data_0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4]) + data_1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, + [1, 2, 3, 4]) + data_2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [4]) + data_3 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1]) + data_4 = helper.make_tensor_value_info('4', TensorProto.FLOAT, [2, 3, 1]) + + mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT, + [1, 2, 3, 4]) + + node = onnx.helper.make_node("Mean", + inputs=["0", "1", "2", "3", "4"], + outputs=["mean"]) + + return ([node], [data_0, data_1, data_2, data_3, data_4], [mean]) + + +@onnx_test +def mean_fp16_test(): + data_0 = helper.make_tensor_value_info('0', TensorProto.FLOAT16, [1, 2, 3]) + data_1 = helper.make_tensor_value_info('1', TensorProto.FLOAT16, [1, 2, 3]) + data_2 = helper.make_tensor_value_info('2', TensorProto.FLOAT16, [1, 2, 3]) + + mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT16, + [1, 2, 3]) + + node = onnx.helper.make_node("Mean", + inputs=["0", "1", "2"], + outputs=["mean"]) + + return ([node], [data_0, data_1, data_2], [mean]) + + +@onnx_test +def mean_invalid_broadcast_test(): + data_0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 3]) + data_1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 2, 3]) + data_2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 2, 4]) + + mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT, [1, 2, 3]) + + node = onnx.helper.make_node("Mean", + inputs=["0", "1", "2"], + outputs=["mean"]) + + return ([node], [data_0, data_1, data_2], [mean]) + + +@onnx_test +def mean_single_input_test(): + data_0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 3]) + mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT, [1, 2, 3]) + + node = onnx.helper.make_node("Mean", inputs=["0"], outputs=["mean"]) + + return ([node], [data_0], [mean]) + + +@onnx_test +def mean_test(): + data = [ + helper.make_tensor_value_info(str(i), TensorProto.DOUBLE, [2, 2, 2]) + for i in range(10) + ] + data_names = [str(i) for i in range(10)] + mean = helper.make_tensor_value_info('mean', TensorProto.DOUBLE, [2, 2, 2]) + + node = onnx.helper.make_node("Mean", inputs=data_names, outputs=["mean"]) + + return ([node], data, [mean]) + + +@onnx_test +def mean_integral_test(): + data = [ + helper.make_tensor_value_info(str(i), TensorProto.INT32, [2, 2, 2]) + for i in range(10) + ] + data_names = [str(i) for i in range(10)] + mean = helper.make_tensor_value_info('mean', TensorProto.INT32, [2, 2, 2]) + + node = onnx.helper.make_node("Mean", inputs=data_names, outputs=["mean"]) + + return ([node], data, [mean]) + + +@onnx_test +def min_test(): + a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) + b = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) + c = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node( + 'Min', + inputs=['0', '1', '2'], + outputs=['3'], + ) + + return ([node], [a, b, c], [y]) + + +@onnx_test +def multinomial_test(): + sample_size = 10 + seed = 0.0 + input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 10]) + output = helper.make_tensor_value_info("output", TensorProto.INT32, + [1, 10]) + + node = onnx.helper.make_node('Multinomial', + inputs=['input'], + sample_size=sample_size, + seed=seed, + outputs=['output']) + + return ([node], [input], [output]) + + +@onnx_test +def multinomial_generated_seed_test(): + sample_size = 10 + input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 10]) + output = helper.make_tensor_value_info("output", TensorProto.INT32, + [1, 10]) + + node = onnx.helper.make_node('Multinomial', + inputs=['input'], + sample_size=sample_size, + outputs=['output']) + + return ([node], [input], [output]) + + +@onnx_test +def multinomial_dtype_error_test(): + sample_size = 10 + dtype = 0 + input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 10]) + output = helper.make_tensor_value_info("output", TensorProto.INT64, + [1, 10]) + + node = onnx.helper.make_node('Multinomial', + inputs=['input'], + sample_size=sample_size, + dtype=dtype, + outputs=['output']) + + return ([node], [input], [output]) + + +@onnx_test +def multinomial_int64_test(): + sample_size = 10 + dtype = 7 + seed = 1.0 + input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 10]) + output = helper.make_tensor_value_info("output", TensorProto.INT64, + [1, 10]) + + node = onnx.helper.make_node('Multinomial', + inputs=['input'], + sample_size=sample_size, + dtype=dtype, + seed=seed, + outputs=['output']) + + return ([node], [input], [output]) + + +@onnx_test +def neg_test(): + x = helper.make_tensor_value_info('0', TensorProto.INT64, [2, 3]) + y = helper.make_tensor_value_info('1', TensorProto.INT64, [2, 3]) + + node = onnx.helper.make_node('Neg', inputs=['0'], outputs=['1']) + + return ([node], [x], [y]) + + +@onnx_test +def nms_test(): + b = helper.make_tensor_value_info('boxes', TensorProto.FLOAT, [1, 6, 4]) + s = helper.make_tensor_value_info('scores', TensorProto.FLOAT, [1, 1, 6]) + mo = helper.make_tensor_value_info('max_output_boxes_per_class', + TensorProto.INT64, [1]) + iou = helper.make_tensor_value_info('iou_threshold', TensorProto.FLOAT, + [1]) + st = helper.make_tensor_value_info('score_threshold', TensorProto.FLOAT, + [1]) + out = helper.make_tensor_value_info('selected_indices', TensorProto.INT64, + [6, 3]) + + node = onnx.helper.make_node('NonMaxSuppression', + inputs=[ + 'boxes', 'scores', + 'max_output_boxes_per_class', + 'iou_threshold', 'score_threshold' + ], + outputs=['selected_indices'], + center_point_box=1) + + return ([node], [b, s, mo, iou, st], [out]) + + +@onnx_test +def not_test(): + x = helper.make_tensor_value_info('0', TensorProto.INT32, [4]) + y = helper.make_tensor_value_info('1', TensorProto.INT32, [4]) + + node = onnx.helper.make_node('Not', inputs=['0'], outputs=['1']) + + return ([node], [x], [y]) + + +@onnx_test +def not_bool_test(): + x = helper.make_tensor_value_info('0', TensorProto.BOOL, [4]) + y = helper.make_tensor_value_info('1', TensorProto.BOOL, [4]) + + node = onnx.helper.make_node('Not', inputs=['0'], outputs=['1']) + + return ([node], [x], [y]) + + +@onnx_test +def no_pad_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 2]) + + node = onnx.helper.make_node('Pad', + inputs=['0'], + pads=[0, 0, 0, 0], + outputs=['1']) + + return ([node], [x], [y]) + + +@onnx_test +def nonzero_dynamic_test(): + x = helper.make_tensor_value_info('data', TensorProto.BOOL, [2, 2]) + y = helper.make_tensor_value_info('indices', TensorProto.INT64, [2, 3]) + + node = onnx.helper.make_node('NonZero', + inputs=['data'], + outputs=['indices']) + + return ([node], [x], [y]) + + +@onnx_test +def nonzero_test(): + data1 = np.array([[1., 0.], [1., 1.]]) + data = helper.make_tensor(name='data', + data_type=TensorProto.FLOAT, + dims=data1.shape, + vals=data1.flatten().astype(np.float)) + y = helper.make_tensor_value_info('indices', TensorProto.INT64, [2, 3]) + + node = onnx.helper.make_node('NonZero', + inputs=['data'], + outputs=['indices']) + + return ([node], [], [y], [data]) + + +@onnx_test +def nonzero_int_test(): + data1 = np.array([[1, 1, 0], [1, 0, 1]]) + data = helper.make_tensor(name='data', + data_type=TensorProto.INT16, + dims=data1.shape, + vals=data1.flatten().astype(np.int16)) + y = helper.make_tensor_value_info('indices', TensorProto.INT64, [2, 4]) + + node = onnx.helper.make_node('NonZero', + inputs=['data'], + outputs=['indices']) + + return ([node], [], [y], [data]) + + +@onnx_test +def onehot_test(): + axis_value = 0 + depth = np.array([3]) + indices = helper.make_tensor_value_info("indices", TensorProto.INT32, + [5, 2]) + values = helper.make_tensor_value_info("values", TensorProto.FLOAT16, [2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [3, 5, 2]) + + depth_tensor = helper.make_tensor(name="depth", + data_type=TensorProto.INT32, + dims=None, + vals=depth.astype(int)) + + node = onnx.helper.make_node('OneHot', + inputs=['indices', 'depth', 'values'], + outputs=['y'], + axis=axis_value) + + return ([node], [indices, values], [y], [depth_tensor]) + + +@onnx_test +def pad_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 4]) + + node = onnx.helper.make_node('Pad', + inputs=['0'], + pads=[1, 1, 1, 1], + outputs=['1']) + + return ([node], [x], [y]) + + +@onnx_test +def pad_3arg_test(): + values = np.array([1]) + val_tensor = helper.make_tensor(name='val', + data_type=TensorProto.FLOAT, + dims=values.reshape(()).shape, + vals=values.astype(float)) + arg_val = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_val'], + value=val_tensor) + + sizes = np.array([1, 1, 2, 2]) + pad_tensor = helper.make_tensor(name='pad_size', + data_type=TensorProto.INT32, + dims=sizes.shape, + vals=sizes.astype(int)) + arg_pad = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_pad'], + value=pad_tensor) + + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [5, 5]) + + node = onnx.helper.make_node('Pad', + inputs=['0', 'arg_pad', 'arg_val'], + outputs=['1']) + + return ([arg_val, arg_pad, node], [x], [y]) + + +@onnx_test +def pad_reflect_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 5]) + + sizes = np.array([0, 2, 0, 1]) + pad_tensor = helper.make_tensor(name='pad_size', + data_type=TensorProto.INT32, + dims=sizes.shape, + vals=sizes.astype(int)) + arg_pad = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_pad'], + value=pad_tensor) + + node = onnx.helper.make_node('Pad', + mode='reflect', + inputs=['0', 'arg_pad'], + outputs=['1']) + + return ([arg_pad, node], [x], [y]) + + +@onnx_test +def pad_reflect_multiaxis_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 5]) + + sizes = np.array([0, 2, 2, 0]) + pad_tensor = helper.make_tensor(name='pad_size', + data_type=TensorProto.INT32, + dims=sizes.shape, + vals=sizes.astype(int)) + arg_pad = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_pad'], + value=pad_tensor) + + node = onnx.helper.make_node('Pad', + mode='reflect', + inputs=['0', 'arg_pad'], + outputs=['1']) + + return ([arg_pad, node], [x], [y]) + + +@onnx_test +def pow_test(): + arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) + arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 3, 4, 5]) + arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT, + [2, 3, 4, 5]) + + node = onnx.helper.make_node( + 'Pow', + inputs=['0', '1'], + outputs=['out'], + ) + + return ([node], [arg0, arg1], [arg_out]) + + +@onnx_test +def pow_fp32_i64_test(): + arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) + arg1 = helper.make_tensor_value_info('1', TensorProto.INT64, [2, 3, 4, 5]) + arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT, + [2, 3, 4, 5]) + + node = onnx.helper.make_node( + 'Pow', + inputs=['0', '1'], + outputs=['out'], + ) + + return ([node], [arg0, arg1], [arg_out]) + + +@onnx_test +def pow_i64_fp32_test(): + arg0 = helper.make_tensor_value_info('0', TensorProto.INT64, [2, 3, 4, 5]) + arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 3, 4, 5]) + arg_out = helper.make_tensor_value_info('out', TensorProto.INT64, + [2, 3, 4, 5]) + + node = onnx.helper.make_node( + 'Pow', + inputs=['0', '1'], + outputs=['out'], + ) + + return ([node], [arg0, arg1], [arg_out]) + + +@onnx_test +def prefix_scan_sum_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 2]) + axis_val = np.array([0]) + axis_tensor = helper.make_tensor(name="axis", + data_type=TensorProto.INT32, + dims=axis_val.shape, + vals=axis_val.astype(int)) + node = onnx.helper.make_node('CumSum', + inputs=['x', 'axis'], + outputs=['y'], + exclusive=1, + reverse=1) + return ([node], [x], [y], [axis_tensor]) + + +@onnx_test +def prelu_brcst_test(): + arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) + arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 5]) + arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT, + [2, 3, 4, 5]) + + node = onnx.helper.make_node( + 'PRelu', + inputs=['0', '1'], + outputs=['out'], + ) + + return ([node], [arg0, arg1], [arg_out]) + + +@onnx_test +def quantizelinear_test(): + arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5]) + arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1]) + arg_out = helper.make_tensor_value_info('out', TensorProto.INT8, [5]) + + node = onnx.helper.make_node( + 'QuantizeLinear', + inputs=['0', '1'], + outputs=['out'], + ) + + return ([node], [arg0, arg1], [arg_out]) + + +@onnx_test +def quantizelinear_int32_test(): + arg0 = helper.make_tensor_value_info('0', TensorProto.INT32, [5]) + arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1]) + arg_out = helper.make_tensor_value_info('out', TensorProto.INT8, [5]) + + node = onnx.helper.make_node( + 'QuantizeLinear', + inputs=['0', '1'], + outputs=['out'], + ) + + return ([node], [arg0, arg1], [arg_out]) + + +@onnx_test +def quantizelinear_zero_point_test(): + arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5]) + arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1]) + arg2 = helper.make_tensor_value_info('2', TensorProto.INT8, [1]) + arg_out = helper.make_tensor_value_info('out', TensorProto.INT8, [5]) + + node = onnx.helper.make_node( + 'QuantizeLinear', + inputs=['0', '1', '2'], + outputs=['out'], + ) + + return ([node], [arg0, arg1, arg2], [arg_out]) + + +def make_quantizelinear_axis_graph(axis): + arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 1, 5, 1]) + arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [5]) + arg2 = helper.make_tensor_value_info('2', TensorProto.INT8, [5]) + arg_out = helper.make_tensor_value_info('out', TensorProto.INT8, + [1, 1, 5, 1]) + + node = onnx.helper.make_node('QuantizeLinear', + inputs=['0', '1', '2'], + outputs=['out'], + axis=axis) + + return ([node], [arg0, arg1, arg2], [arg_out]) + + +@onnx_test +def quantizelinear_axis_test(): + return make_quantizelinear_axis_graph(2) + + +@onnx_test +def quantizelinear_neg_axis_test(): + return make_quantizelinear_axis_graph(-2) + + +@onnx_test +def randomnormal_test(): + dtype = 11 + mean = 10.0 + scale = 1.5 + seed = 0.0 + shape = [2, 3, 4] + output = helper.make_tensor_value_info('output', TensorProto.DOUBLE, + [2, 3, 4]) + + node = onnx.helper.make_node('RandomNormal', + inputs=[], + outputs=['output'], + dtype=dtype, + mean=mean, + scale=scale, + seed=seed, + shape=shape) + + return ([node], [], [output]) + + +@onnx_test +def randomnormal_dtype_error_test(): + dtype = 6 + shape = [2, 3, 4] + output = helper.make_tensor_value_info('output', TensorProto.INT32, + [2, 3, 4]) + + node = onnx.helper.make_node('RandomNormal', + inputs=[], + outputs=['output'], + dtype=dtype, + shape=shape) + + return ([node], [], [output]) + + +@onnx_test +def randomnormal_generated_seed_test(): + sample_size = 10 + input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 10]) + output = helper.make_tensor_value_info("output", TensorProto.INT32, + [1, 10]) + + node = onnx.helper.make_node('RandomNormal', + inputs=['input'], + sample_size=sample_size, + outputs=['output']) + + return ([node], [input], [output]) + + +@onnx_test +def randomnormal_shape_error_test(): + dtype = 1 + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, + [2, 3, 4]) + + node = onnx.helper.make_node('RandomNormal', + inputs=[], + outputs=['output'], + dtype=dtype) + + return ([node], [], [output]) + + +@onnx_test +def randomnormallike_test(): + dtype = 10 + mean = 10.0 + scale = 1.5 + seed = 0.0 + input = helper.make_tensor_value_info('input', TensorProto.FLOAT16, + [2, 3, 4]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT16, + [2, 3, 4]) + + node = onnx.helper.make_node('RandomNormalLike', + inputs=['input'], + outputs=['output'], + dtype=dtype, + mean=mean, + scale=scale, + seed=seed) + + return ([node], [input], [output]) + + +@onnx_test +def randomnormallike_type_error_test(): + seed = 0 + input = helper.make_tensor_value_info('input', TensorProto.INT32, + [2, 3, 4]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, + [2, 3, 4]) + + node = onnx.helper.make_node('RandomNormalLike', + inputs=['input'], + outputs=['output'], + seed=seed) + + return ([node], [input], [output]) + + +@onnx_test +def randomuniform_test(): + dtype = 11 + high = 1.0 + low = 0.0 + seed = 0.0 + shape = [2, 3, 4] + output = helper.make_tensor_value_info('output', TensorProto.DOUBLE, + [2, 3, 4]) + + node = onnx.helper.make_node('RandomUniform', + inputs=[], + outputs=['output'], + dtype=dtype, + high=high, + low=low, + seed=seed, + shape=shape) + + return ([node], [], [output]) + + +@onnx_test +def randomuniform_dtype_error_test(): + dtype = 6 + shape = [2, 3, 4] + output = helper.make_tensor_value_info('output', TensorProto.INT32, + [2, 3, 4]) + + node = onnx.helper.make_node('RandomUniform', + inputs=[], + outputs=['output'], + dtype=dtype, + shape=shape) + + return ([node], [], [output]) + + +@onnx_test +def randomuniform_generated_seed_test(): + sample_size = 10 + input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 10]) + output = helper.make_tensor_value_info("output", TensorProto.INT32, + [1, 10]) + + node = onnx.helper.make_node('RandomUniform', + inputs=['input'], + sample_size=sample_size, + outputs=['output']) + + return ([node], [input], [output]) + + +@onnx_test +def randomuniform_shape_error_test(): + dtype = 1 + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, + [2, 3, 4]) + + node = onnx.helper.make_node('RandomUniform', + inputs=[], + outputs=['output'], + dtype=dtype) + + return ([node], [], [output]) + + +@onnx_test +def randomuniformlike_test(): + dtype = 10 + high = 10.0 + low = 1.0 + seed = 0.0 + input = helper.make_tensor_value_info('input', TensorProto.FLOAT16, + [2, 3, 4]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT16, + [2, 3, 4]) + + node = onnx.helper.make_node('RandomUniformLike', + inputs=['input'], + outputs=['output'], + dtype=dtype, + high=high, + low=low, + seed=seed) + + return ([node], [input], [output]) + + +@onnx_test +def randomuniformlike_type_error_test(): + seed = 0 + input = helper.make_tensor_value_info('input', TensorProto.INT32, + [2, 3, 4]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, + [2, 3, 4]) + + node = onnx.helper.make_node('RandomUniformLike', + inputs=['input'], + outputs=['output'], + seed=seed) + + return ([node], [input], [output]) + + +@onnx_test +def range_test(): + + start_val = np.array([10]) + limit_val = np.array([6]) + delta_val = np.array([-3]) + + start_tensor = helper.make_tensor(name='start_val', + data_type=TensorProto.INT64, + dims=start_val.reshape(()).shape, + vals=start_val.astype(np.int64)) + start = onnx.helper.make_node('Constant', + inputs=[], + outputs=['start'], + value=start_tensor) + + limit_tensor = helper.make_tensor(name='limit_val', + data_type=TensorProto.INT64, + dims=limit_val.reshape(()).shape, + vals=limit_val.astype(np.int64)) + limit = onnx.helper.make_node('Constant', + inputs=[], + outputs=['limit'], + value=limit_tensor) + + delta_tensor = helper.make_tensor(name='delta_val', + data_type=TensorProto.INT64, + dims=delta_val.reshape(()).shape, + vals=delta_val.astype(np.int64)) + delta = onnx.helper.make_node('Constant', + inputs=[], + outputs=['delta'], + value=delta_tensor) + + node = onnx.helper.make_node('Range', + inputs=['start', 'limit', 'delta'], + outputs=['1']) + + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) + + return ([start, limit, delta, node], [], [y]) + + +@onnx_test +def range_float_test(): + + start_val = np.array([2]) + limit_val = np.array([11]) + delta_val = np.array([2]) + + start_tensor = helper.make_tensor(name='start_val', + data_type=TensorProto.FLOAT, + dims=start_val.reshape(()).shape, + vals=start_val.astype(np.float)) + start = onnx.helper.make_node('Constant', + inputs=[], + outputs=['start'], + value=start_tensor) + + limit_tensor = helper.make_tensor(name='limit_val', + data_type=TensorProto.FLOAT, + dims=limit_val.reshape(()).shape, + vals=limit_val.astype(np.float)) + limit = onnx.helper.make_node('Constant', + inputs=[], + outputs=['limit'], + value=limit_tensor) + + delta_tensor = helper.make_tensor(name='delta_val', + data_type=TensorProto.FLOAT, + dims=delta_val.reshape(()).shape, + vals=delta_val.astype(np.float)) + delta = onnx.helper.make_node('Constant', + inputs=[], + outputs=['delta'], + value=delta_tensor) + + node = onnx.helper.make_node('Range', + inputs=['start', 'limit', 'delta'], + outputs=['1']) + + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) + + return ([start, limit, delta, node], [], [y]) + + +@onnx_test +def recip_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node( + 'Reciprocal', + inputs=['x'], + outputs=['y'], + ) + + return ([node], [x], [y]) + + +@onnx_test +def reducel1_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 6]) + axes = [-2] + + node = onnx.helper.make_node('ReduceL1', + inputs=['x'], + outputs=['y'], + axes=axes, + keepdims=0) + + return ([node], [x], [y]) + + +@onnx_test +def reducel2_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5]) + axes = [-1] + + node = onnx.helper.make_node('ReduceL2', + inputs=['x'], + outputs=['y'], + axes=axes, + keepdims=0) + + return ([node], [x], [y]) + + +@onnx_test +def reduce_log_sum_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 1, 5, 6]) + axes = [-3] + + node = onnx.helper.make_node('ReduceLogSum', + inputs=['x'], + outputs=['y'], + axes=axes, + keepdims=1) + + return ([node], [x], [y]) + + +@onnx_test +def reduce_log_sum_exp_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 5, 6]) + axes = [-4] + + node = onnx.helper.make_node('ReduceLogSumExp', + inputs=['x'], + outputs=['y'], + axes=axes, + keepdims=1) + + return ([node], [x], [y]) + + +@onnx_test +def reducemax_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 6]) + axes = [2] + + node = onnx.helper.make_node('ReduceMax', + inputs=['x'], + outputs=['y'], + axes=axes, + keepdims=0) + + return ([node], [x], [y]) + + +@onnx_test +def reducemean_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4]) + axes = [2, 3] + + node = onnx.helper.make_node('ReduceMean', + inputs=['x'], + outputs=['y'], + axes=axes, + keepdims=0) + + return ([node], [x], [y]) + + +@onnx_test +def reducemean_keepdims_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 6]) + axes = [2] + + node = onnx.helper.make_node('ReduceMean', + inputs=['x'], + outputs=['y'], + axes=axes, + keepdims=1) + + return ([node], [x], [y]) + + +@onnx_test +def reducemin_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 1, 5, 1]) + axes = [1, 3] + + node = onnx.helper.make_node('ReduceMin', + inputs=['x'], + outputs=['y'], + axes=axes, + keepdims=1) + + return ([node], [x], [y]) + + +@onnx_test +def reduceprod_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 6]) + axes = [2] + + node = onnx.helper.make_node('ReduceProd', + inputs=['x'], + outputs=['y'], + axes=axes, + keepdims=1) + + return ([node], [x], [y]) + + +@onnx_test +def reducesum_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 6]) + axes = [2] + + node = onnx.helper.make_node('ReduceSum', + inputs=['x'], + outputs=['y'], + axes=axes, + keepdims=0) + + return ([node], [x], [y]) + + +@onnx_test +def reducesum_empty_axes_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 6]) + axes = np.array([], dtype=np.int64) + axes_tensor = helper.make_tensor(name="axes", + data_type=TensorProto.INT64, + dims=axes.shape, + vals=axes.astype(np.int64)) + + node = onnx.helper.make_node('ReduceSum', + inputs=['x', 'axes'], + outputs=['y'], + keepdims=0, + noop_with_empty_axes=False) + + return ([node], [x], [y], [axes_tensor]) + + +@onnx_test +def reducesum_noop_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 6]) + axes = np.array([], dtype=np.int64) + axes_tensor = helper.make_tensor(name="axes", + data_type=TensorProto.INT64, + dims=axes.shape, + vals=axes.astype(np.int64)) + + node = onnx.helper.make_node('ReduceSum', + inputs=['x', 'axes'], + outputs=['y'], + keepdims=0, + noop_with_empty_axes=True) + + return ([node], [x], [y], [axes_tensor]) @onnx_test @@ -1174,171 +4157,1046 @@ def reducesum_keepdims_test(): y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 1]) axes = [2, 3] - node = onnx.helper.make_node('ReduceSum', + node = onnx.helper.make_node('ReduceSum', + inputs=['x'], + outputs=['y'], + axes=axes, + keepdims=1) + + return ([node], [x], [y]) + + +@onnx_test +def reducesum_multiaxis_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 1]) + axes = [2, 3] + + node = onnx.helper.make_node('ReduceSum', + inputs=['x'], + outputs=['y'], + axes=axes, + keepdims=0) + + return ([node], [x], [y]) + + +@onnx_test +def reducesum_square_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 6]) + axes = [-2] + + node = onnx.helper.make_node('ReduceSumSquare', + inputs=['x'], + outputs=['y'], + axes=axes, + keepdims=0) + + return ([node], [x], [y]) + + +@onnx_test +def reshape_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [4, 2, 3]) + x_shape = helper.make_tensor_value_info('1', TensorProto.INT64, [2]) + x_shape_list = [3, 8] + y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3, 8]) + y2 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3, 8]) + + node = onnx.helper.make_node('Reshape', inputs=['0', '1'], outputs=['2']) + + node2 = onnx.helper.make_node('Reshape', + inputs=['0'], + shape=x_shape_list, + outputs=['3']) + + return ([node, node2], [x, x_shape], [y, y2], + [helper.make_tensor('1', TensorProto.INT64, [2], [3, 8])]) + + +@onnx_test +def reshape_non_standard_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 3, 2]) + + trans = helper.make_node( + 'Transpose', + inputs=['x'], + outputs=['trans_x'], + perm=[0, 2, 1], + ) + + res = onnx.helper.make_node('Reshape', + inputs=['trans_x'], + outputs=['y'], + shape=[4, 3, 2]) + + return ([trans, res], [x], [y]) + + +@onnx_test +def resize_downsample_f_test(): + scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32) + scale_tensor = helper.make_tensor(name='scales', + data_type=TensorProto.FLOAT, + dims=scales.shape, + vals=scales.flatten().astype(np.float32)) + + X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 4]) + Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, []) + + node = onnx.helper.make_node( + 'Resize', + inputs=['X', '', 'scales'], + outputs=['Y'], + coordinate_transformation_mode='align_corners', + mode='nearest', + nearest_mode='floor') + + return ([node], [X], [Y], [scale_tensor]) + + +@onnx_test +def resize_downsample_c_test(): + scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32) + scale_tensor = helper.make_tensor(name='scales', + data_type=TensorProto.FLOAT, + dims=scales.shape, + vals=scales.flatten().astype(np.float32)) + + X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 4]) + Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1, 1, 2]) + + node = onnx.helper.make_node('Resize', + inputs=['X', '', 'scales'], + outputs=['Y'], + coordinate_transformation_mode='asymmetric', + mode='nearest', + nearest_mode='ceil') + + return ([node], [X], [Y], [scale_tensor]) + + +@onnx_test +def resize_downsample_linear_test(): + scales = np.array([1.0, 1.0, 0.6, 0.5], dtype=np.float32) + scale_tensor = helper.make_tensor(name='scales', + data_type=TensorProto.FLOAT, + dims=scales.shape, + vals=scales.flatten().astype(np.float32)) + + X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 4]) + Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, []) + + node = onnx.helper.make_node('Resize', + inputs=['X', '', 'scales'], + outputs=['Y'], + mode='linear') + + return ([node], [X], [Y], [scale_tensor]) + + +@onnx_test +def resize_nonstd_input_test(): + scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32) + scale_tensor = helper.make_tensor(name='scales', + data_type=TensorProto.FLOAT, + dims=scales.shape, + vals=scales.flatten().astype(np.float32)) + + X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 4, 2]) + Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1, 1, 2]) + + trn = onnx.helper.make_node('Transpose', + inputs=['X'], + outputs=['TX'], + perm=[0, 1, 3, 2]) + + node = onnx.helper.make_node('Resize', + inputs=['TX', '', 'scales'], + outputs=['Y'], + coordinate_transformation_mode='asymmetric', + mode='nearest', + nearest_mode='ceil') + + return ([trn, node], [X], [Y], [scale_tensor]) + + +@onnx_test +def resize_outsize_test(): + out_lens = np.array([1, 1, 4, 6], dtype=np.int64) + out_lens_tensor = helper.make_tensor(name='out_lens', + data_type=TensorProto.INT64, + dims=out_lens.shape, + vals=out_lens.flatten().astype( + np.int64)) + + X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 2]) + Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1, 4, 6]) + + node = onnx.helper.make_node( + 'Resize', + inputs=['X', '', '', 'out_lens'], + outputs=['Y'], + coordinate_transformation_mode='tf_half_pixel_for_nn', + mode='nearest', + nearest_mode='round_prefer_floor') + + return ([node], [X], [Y], [out_lens_tensor]) + + +@onnx_test +def resize_upsample_linear_ac_test(): + scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32) + scales_tensor = helper.make_tensor(name='scales', + data_type=TensorProto.FLOAT, + dims=scales.shape, + vals=scales.flatten().astype( + np.float32)) + X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 2]) + Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, []) + + node = onnx.helper.make_node( + 'Resize', + inputs=['X', '', 'scales'], + outputs=['Y'], + mode='linear', + coordinate_transformation_mode='align_corners') + + return ([node], [X], [Y], [scales_tensor]) + + +@onnx_test +def resize_upsample_linear_test(): + scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32) + scales_tensor = helper.make_tensor(name='scales', + data_type=TensorProto.FLOAT, + dims=scales.shape, + vals=scales.flatten().astype( + np.float32)) + X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 2]) + Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, []) + + node = onnx.helper.make_node('Resize', + inputs=['X', '', 'scales'], + outputs=['Y'], + mode='linear') + + return ([node], [X], [Y], [scales_tensor]) + + +@onnx_test +def resize_upsample_pf_test(): + scales = np.array([1.0, 1.0, 2.0, 3.0], dtype=np.float32) + scale_tensor = helper.make_tensor(name='scales', + data_type=TensorProto.FLOAT, + dims=scales.shape, + vals=scales.flatten().astype(np.float32)) + + X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 2]) + Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1, 4, 6]) + + node = onnx.helper.make_node('Resize', + inputs=['X', '', 'scales'], + outputs=['Y'], + mode='nearest') + + return ([node], [X], [Y], [scale_tensor]) + + +@onnx_test +def resize_upsample_pc_test(): + scales = np.array([1.0, 1.0, 2.0, 1.5], dtype=np.float32) + scale_tensor = helper.make_tensor(name='scales', + data_type=TensorProto.FLOAT, + dims=scales.shape, + vals=scales.flatten().astype(np.float32)) + + X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 4]) + Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1, 4, 6]) + + node = onnx.helper.make_node( + 'Resize', + inputs=['X', '', 'scales'], + outputs=['Y'], + coordinate_transformation_mode='pytorch_half_pixel', + mode='nearest', + exclude_outside=0, + nearest_mode='round_prefer_ceil') + + return ([node], [X], [Y], [scale_tensor]) + + +@onnx_test +def reversesequence_4D_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 2, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 2, 2]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x'], + outputs=['y'], + time_axis=0, + batch_axis=1, + sequence_lens=[2, 1], + ) + return ([node], [x], [y]) + + +@onnx_test +def reversesequence_batch_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4]) + seq_lens = np.array([1, 2, 3, 4]) + seq_lens_tensor = helper.make_tensor( + name="sequence_lens", + data_type=TensorProto.INT64, + dims=seq_lens.shape, + vals=seq_lens.astype(np.int64), + ) + arg_seq_lens = helper.make_node( + "Constant", + inputs=[], + outputs=['arg_seq_lens'], + value=seq_lens_tensor, + ) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x', 'arg_seq_lens'], + outputs=['y'], + time_axis=1, + batch_axis=0, + ) + return ([arg_seq_lens, node], [x], [y]) + + +@onnx_test +def reversesequence_batch_axis_err_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4, 2]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x'], + outputs=['y'], + time_axis=0, + batch_axis=2, + sequence_lens=[4, 3, 2, 1], + ) + return ([node], [x], [y]) + + +@onnx_test +def reversesequence_rank_err_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x'], + outputs=['y'], + sequence_lens=[4, 3, 2, 1], + ) + return ([node], [x], [y]) + + +@onnx_test +def reversesequence_sequence_lens_shape_err_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x'], + outputs=['y'], + sequence_lens=[4, 3, 2], + ) + return ([node], [x], [y]) + + +@onnx_test +def reversesequence_same_axis_err_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x'], + outputs=['y'], + time_axis=1, + batch_axis=1, + sequence_lens=[4, 3, 2, 1], + ) + return ([node], [x], [y]) + + +@onnx_test +def reversesequence_time_axis_err_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4, 2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4, 2, 3]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x'], + outputs=['y'], + time_axis=3, + batch_axis=0, + sequence_lens=[4, 3, 2, 1], + ) + return ([node], [x], [y]) + + +@onnx_test +def reversesequence_time_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x'], + outputs=['y'], + time_axis=0, + batch_axis=1, + sequence_lens=[4, 3, 2, 1], + ) + return ([node], [x], [y]) + + +@onnx_test +def roialign_default_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 4, 7, 8]) + roi = helper.make_tensor_value_info('rois', TensorProto.FLOAT, [8, 4]) + bi = helper.make_tensor_value_info('batch_ind', TensorProto.INT64, [8]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [8, 4, 1, 1]) + + node = onnx.helper.make_node('RoiAlign', + inputs=['x', 'rois', 'batch_ind'], + outputs=['y']) + + return ([node], [x, roi, bi], [y]) + + +@onnx_test +def roialign_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 5, 4, 7]) + roi = helper.make_tensor_value_info('rois', TensorProto.FLOAT, [8, 4]) + bi = helper.make_tensor_value_info('batch_ind', TensorProto.INT64, [8]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [8, 4, 5, 5]) + + node = onnx.helper.make_node( + 'RoiAlign', + inputs=['x', 'rois', 'batch_ind'], + outputs=['y'], + spatial_scale=2.0, + output_height=5, + output_width=5, + sampling_ratio=3, + mode="avg", + coordinate_transformation_mode="output_half_pixel") + + return ([node], [x, roi, bi], [y]) + + +@onnx_test +def scatter_add_test(): + x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6]) + i = helper.make_tensor_value_info('indices', TensorProto.INT32, + [2, 3, 4, 5]) + u = helper.make_tensor_value_info('update', TensorProto.FLOAT, + [2, 3, 4, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6]) + + node = onnx.helper.make_node( + 'ScatterElements', + reduction='add', + inputs=['data', 'indices', 'update'], + outputs=['y'], + axis=-2, + ) + + return ([node], [x, i, u], [y]) + + +@onnx_test +def scatter_mul_test(): + x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6]) + i = helper.make_tensor_value_info('indices', TensorProto.INT32, + [2, 3, 4, 5]) + u = helper.make_tensor_value_info('update', TensorProto.FLOAT, + [2, 3, 4, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6]) + + node = onnx.helper.make_node( + 'ScatterElements', + reduction='mul', + inputs=['data', 'indices', 'update'], + outputs=['y'], + axis=-2, + ) + + return ([node], [x, i, u], [y]) + + +@onnx_test +def scatter_none_test(): + x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6]) + i = helper.make_tensor_value_info('indices', TensorProto.INT32, + [2, 3, 4, 5]) + u = helper.make_tensor_value_info('update', TensorProto.FLOAT, + [2, 3, 4, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6]) + + node = onnx.helper.make_node( + 'ScatterElements', + reduction='none', + inputs=['data', 'indices', 'update'], + outputs=['y'], + axis=-2, + ) + + return ([node], [x, i, u], [y]) + + +@onnx_test +def scatternd_add_test(): + data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2]) + indices = helper.make_tensor_value_info('indices', TensorProto.INT64, + [2, 1, 2]) + updates = helper.make_tensor_value_info('updates', TensorProto.FLOAT, + [2, 1, 2]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, + [2, 2, 2]) + + node = onnx.helper.make_node('ScatterND', + inputs=['data', 'indices', 'updates'], + outputs=['output'], + reduction="add") + + return ([node], [data, indices, updates], [output]) + + +@onnx_test +def scatternd_mul_test(): + data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2]) + indices = helper.make_tensor_value_info('indices', TensorProto.INT64, + [2, 1, 2]) + updates = helper.make_tensor_value_info('updates', TensorProto.FLOAT, + [2, 1, 2]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, + [2, 2, 2]) + + node = onnx.helper.make_node('ScatterND', + inputs=['data', 'indices', 'updates'], + outputs=['output'], + reduction="mul") + + return ([node], [data, indices, updates], [output]) + + +@onnx_test +def scatternd_test(): + data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2]) + indices = helper.make_tensor_value_info('indices', TensorProto.INT64, + [2, 1, 2]) + updates = helper.make_tensor_value_info('updates', TensorProto.FLOAT, + [2, 1, 2]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, + [2, 2, 2]) + + node = onnx.helper.make_node('ScatterND', + inputs=['data', 'indices', 'updates'], + outputs=['output']) + + return ([node], [data, indices, updates], [output]) + + +@onnx_test +def selu_test(): + x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.DOUBLE, [2, 3]) + + node = onnx.helper.make_node('Selu', inputs=['x'], outputs=['y'], - axes=axes, - keepdims=1) + alpha=0.3, + gamma=0.5) return ([node], [x], [y]) @onnx_test -def reshape_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [4, 2, 3]) - x_shape = helper.make_tensor_value_info('1', TensorProto.INT64, [2]) - x_shape_list = [3, 8] - y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3, 8]) - y2 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3, 8]) +def shape_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + y = helper.make_tensor_value_info('y', TensorProto.INT64, [4]) - node = onnx.helper.make_node('Reshape', inputs=['0', '1'], outputs=['2']) + node = onnx.helper.make_node( + 'Shape', + inputs=['x'], + outputs=['y'], + ) - node2 = onnx.helper.make_node('Reshape', - inputs=['0'], - shape=x_shape_list, - outputs=['3']) + return ([node], [x], [y]) - return ([node, node2], [x, x_shape], [y, y2], - [helper.make_tensor('1', TensorProto.INT64, [2], [3, 8])]) + +@onnx_test +def shape_gather_test(): + values = np.array([1]) + # value = helper.make_tensor_value_info('value', TensorProto.INT32, [1]) + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [7, 3, 10]) + z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [1]) + + value_tensor = helper.make_tensor(name='const_tensor', + data_type=TensorProto.INT32, + dims=values.shape, + vals=values.flatten().astype(int)) + + node_const = onnx.helper.make_node( + 'Constant', + inputs=[], + outputs=['value'], + value=value_tensor, + ) + + node_shape = onnx.helper.make_node( + 'Shape', + inputs=['x'], + outputs=['y'], + ) + + node_gather = helper.make_node( + 'Gather', + inputs=['y', 'value'], + outputs=['z'], + axis=0, + ) + + return ([node_const, node_shape, node_gather], [x], [z]) @onnx_test -def reshape_non_standard_test(): +def sign_test(): + x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [10, 5]) + y = helper.make_tensor_value_info('y', TensorProto.DOUBLE, [10, 5]) + + node = onnx.helper.make_node( + 'Sign', + inputs=['x'], + outputs=['y'], + ) + + return ([node], [x], [y]) + + +@onnx_test +def sin_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10]) + + node = onnx.helper.make_node( + 'Sin', + inputs=['x'], + outputs=['y'], + ) + + return ([node], [x], [y]) + + +@onnx_test +def sinh_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10]) + + node = onnx.helper.make_node( + 'Sinh', + inputs=['x'], + outputs=['y'], + ) + + return ([node], [x], [y]) + + +@onnx_test +def size_float_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3, 4]) - trans_x = helper.make_tensor_value_info('trans_x', TensorProto.FLOAT, - [2, 4, 3]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 3, 2]) + y = helper.make_tensor_value_info('y', TensorProto.INT64, [1]) + node = onnx.helper.make_node( + 'Size', + inputs=['x'], + outputs=['y'], + ) + return ([node], [x], [y]) - trans = helper.make_node( - 'Transpose', + +@onnx_test +def size_half_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [3, 1]) + y = helper.make_tensor_value_info('y', TensorProto.INT64, [1]) + node = onnx.helper.make_node( + 'Size', inputs=['x'], - outputs=['trans_x'], - perm=[0, 2, 1], + outputs=['y'], ) + return ([node], [x], [y]) - res = onnx.helper.make_node('Reshape', - inputs=['trans_x'], - outputs=['y'], - shape=[4, 3, 2]) - return ([trans, res], [x], [y]) +@onnx_test +def size_int_test(): + x = helper.make_tensor_value_info('x', TensorProto.INT32, [8, 2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.INT64, [1]) + node = onnx.helper.make_node( + 'Size', + inputs=['x'], + outputs=['y'], + ) + return ([node], [x], [y]) + + +@onnx_test +def size_verify_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 5, 3]) + y = helper.make_tensor_value_info('y', TensorProto.INT64, [1]) + node = onnx.helper.make_node( + 'Size', + inputs=['x'], + outputs=['y'], + ) + return ([node], [x], [y]) + + +@onnx_test +def slice_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3, 2]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 2]) + + node = onnx.helper.make_node('Slice', + inputs=['0'], + axes=[0, 1], + starts=[1, 0], + ends=[2, 2], + outputs=['1']) + + return ([node], [x], [y]) + + +@onnx_test +def slice_3arg_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 5]) + start = np.array([0, 0]) + start_tensor = helper.make_tensor(name="start", + data_type=TensorProto.INT32, + dims=start.shape, + vals=start.astype(int)) + + arg_start = helper.make_node("Constant", + inputs=[], + outputs=['arg_start'], + value=start_tensor) + + end = np.array([2, 5]) + end_tensor = helper.make_tensor(name="end", + data_type=TensorProto.INT32, + dims=end.shape, + vals=end.astype(int)) + arg_end = helper.make_node("Constant", + inputs=[], + outputs=['arg_end'], + value=end_tensor) + + node = onnx.helper.make_node('Slice', + inputs=['0', 'arg_start', 'arg_end'], + outputs=['1']) + + return ([arg_start, arg_end, node], [x], [y]) + + +@onnx_test +def slice_5arg_test(): + step = np.array([1, 1]) + step_tensor = helper.make_tensor(name="step", + data_type=TensorProto.INT32, + dims=step.shape, + vals=step.astype(int)) + arg_step = helper.make_node("Constant", + inputs=[], + outputs=['arg_step'], + value=step_tensor) + + axis = np.array([-1, -2]) + axis_tensor = helper.make_tensor(name="axis", + data_type=TensorProto.INT32, + dims=axis.shape, + vals=axis.astype(int)) + arg_axis = helper.make_node("Constant", + inputs=[], + outputs=['arg_axis'], + value=axis_tensor) + + end = np.array([-1, -1]) + end_tensor = helper.make_tensor(name="end", + data_type=TensorProto.INT32, + dims=end.shape, + vals=end.astype(int)) + arg_end = helper.make_node("Constant", + inputs=[], + outputs=['arg_end'], + value=end_tensor) + + start = np.array([-5, -3]) + start_tensor = helper.make_tensor(name="start", + data_type=TensorProto.INT32, + dims=start.shape, + vals=start.astype(int)) + arg_start = helper.make_node("Constant", + inputs=[], + outputs=['arg_start'], + value=start_tensor) + + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 2]) + + node = onnx.helper.make_node( + 'Slice', + inputs=['0', 'arg_start', 'arg_end', 'arg_axis', 'arg_step'], + outputs=['1']) + + return ([arg_step, arg_axis, arg_end, arg_start, node], [x], [y]) + + +@onnx_test +def slice_5arg_reverse_test(): + step = np.array([-1, 1]) + step_tensor = helper.make_tensor(name="step", + data_type=TensorProto.INT32, + dims=step.shape, + vals=step.astype(int)) + arg_step = helper.make_node("Constant", + inputs=[], + outputs=['arg_step'], + value=step_tensor) + + axis = np.array([-1, -2]) + axis_tensor = helper.make_tensor(name="axis", + data_type=TensorProto.INT32, + dims=axis.shape, + vals=axis.astype(int)) + arg_axis = helper.make_node("Constant", + inputs=[], + outputs=['arg_axis'], + value=axis_tensor) + + end = np.array([-5, -1]) + end_tensor = helper.make_tensor(name="end", + data_type=TensorProto.INT32, + dims=end.shape, + vals=end.astype(int)) + arg_end = helper.make_node("Constant", + inputs=[], + outputs=['arg_end'], + value=end_tensor) + + start = np.array([-1, -3]) + start_tensor = helper.make_tensor(name="start", + data_type=TensorProto.INT32, + dims=start.shape, + vals=start.astype(int)) + arg_start = helper.make_node("Constant", + inputs=[], + outputs=['arg_start'], + value=start_tensor) + + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 2]) + + node = onnx.helper.make_node( + 'Slice', + inputs=['0', 'arg_start', 'arg_end', 'arg_axis', 'arg_step'], + outputs=['1']) + + return ([arg_step, arg_axis, arg_end, arg_start, node], [x], [y]) + + +@onnx_test +def slice_5arg_step_test(): + step = np.array([-2, 2]) + step_tensor = helper.make_tensor(name="step", + data_type=TensorProto.INT32, + dims=step.shape, + vals=step.astype(int)) + arg_step = helper.make_node("Constant", + inputs=[], + outputs=['arg_step'], + value=step_tensor) + + axis = np.array([-1, -2]) + axis_tensor = helper.make_tensor(name="axis", + data_type=TensorProto.INT32, + dims=axis.shape, + vals=axis.astype(int)) + arg_axis = helper.make_node("Constant", + inputs=[], + outputs=['arg_axis'], + value=axis_tensor) + + end = np.array([-5, -1]) + end_tensor = helper.make_tensor(name="end", + data_type=TensorProto.INT32, + dims=end.shape, + vals=end.astype(int)) + arg_end = helper.make_node("Constant", + inputs=[], + outputs=['arg_end'], + value=end_tensor) + + start = np.array([-1, -3]) + start_tensor = helper.make_tensor(name="start", + data_type=TensorProto.INT32, + dims=start.shape, + vals=start.astype(int)) + arg_start = helper.make_node("Constant", + inputs=[], + outputs=['arg_start'], + value=start_tensor) + + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 2]) + + node = onnx.helper.make_node( + 'Slice', + inputs=['0', 'arg_start', 'arg_end', 'arg_axis', 'arg_step'], + outputs=['1']) + + return ([arg_step, arg_axis, arg_end, arg_start, node], [x], [y]) + + +@onnx_test +def slice_max_end_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [10, 20]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [9, 17]) + + node = onnx.helper.make_node('Slice', + inputs=['0'], + axes=[0, 1], + starts=[1, 2], + ends=[3000000000, -1], + outputs=['1']) + + return ([node], [x], [y]) @onnx_test -def shape_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) - y = helper.make_tensor_value_info('y', TensorProto.INT64, [4]) +def softmax_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3]) - node = onnx.helper.make_node( - 'Shape', - inputs=['x'], - outputs=['y'], - ) + node = onnx.helper.make_node('Softmax', inputs=['0'], outputs=['1']) return ([node], [x], [y]) @onnx_test -def shape_gather_test(): - values = np.array([1]) - value = helper.make_tensor_value_info('value', TensorProto.INT32, [1]) - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [7, 3, 10]) - y = helper.make_tensor_value_info('y', TensorProto.INT64, [3]) - z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [1]) +def softmax_nonstd_input_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [6, 8]) + y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3, 4]) - value_tensor = helper.make_tensor(name='const_tensor', - data_type=TensorProto.INT32, - dims=values.shape, - vals=values.flatten().astype(int)) + node0 = onnx.helper.make_node('Slice', + inputs=['0'], + axes=[0, 1], + starts=[1, 0], + ends=[4, 4], + outputs=['1']) - node_const = onnx.helper.make_node( - 'Constant', - inputs=[], - outputs=['value'], - value=value_tensor, - ) + node1 = onnx.helper.make_node('Softmax', inputs=['1'], outputs=['2']) - node_shape = onnx.helper.make_node( - 'Shape', - inputs=['x'], - outputs=['y'], - ) + return ([node0, node1], [x], [y]) - node_gather = helper.make_node( - 'Gather', - inputs=['y', 'value'], - outputs=['z'], - axis=0, - ) - return ([node_const, node_shape, node_gather], [x], [z]) +@onnx_test +def softsign_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5]) + node = onnx.helper.make_node('Softsign', inputs=['x'], outputs=['y']) -@onnx_test -def sign_test(): - x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [10, 5]) - y = helper.make_tensor_value_info('y', TensorProto.DOUBLE, [10, 5]) + return ([node], [x], [y]) - node = onnx.helper.make_node( - 'Sign', - inputs=['x'], - outputs=['y'], - ) + +def softplus_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5]) + + node = onnx.helper.make_node('Softplus', inputs=['x'], outputs=['y']) return ([node], [x], [y]) @onnx_test -def sin_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10]) +def softsign_nd_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [3, 4, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [3, 4, 5]) - node = onnx.helper.make_node( - 'Sin', - inputs=['x'], - outputs=['y'], - ) + node = onnx.helper.make_node('Softsign', inputs=['x'], outputs=['y']) + + return ([node], [x], [y]) + + +def softplus_nd_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [3, 4, 5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [3, 4, 5]) + + node = onnx.helper.make_node('Softplus', inputs=['x'], outputs=['y']) return ([node], [x], [y]) @onnx_test -def sinh_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [10]) +def split_minus_axis_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15]) + y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [10, 5]) + y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [10, 5]) + y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [10, 5]) node = onnx.helper.make_node( - 'Sinh', + 'Split', inputs=['x'], - outputs=['y'], + outputs=['y1', 'y2', 'y3'], + axis=-1, ) - return ([node], [x], [y]) + return ([node], [x], [y1, y2, y3]) @onnx_test -def slice_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3, 2]) - y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 2]) +def split_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15]) + y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [10, 7]) + y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [10, 4]) + y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [10, 4]) - node = onnx.helper.make_node('Slice', - inputs=['0'], - axes=[0, 1], - starts=[1, 0], - ends=[2, 2], - outputs=['1']) + node = onnx.helper.make_node('Split', + inputs=['x'], + outputs=['y1', 'y2', 'y3'], + axis=1, + split=[7, 4, 4]) - return ([node], [x], [y]) + return ([node], [x], [y1, y2, y3]) @onnx_test -def softmax_test(): - x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3]) - y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3]) +def split_test_default(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15]) + y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [5, 15]) + y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [5, 15]) - node = onnx.helper.make_node('Softmax', inputs=['0'], outputs=['1']) + node = onnx.helper.make_node( + 'Split', + inputs=['x'], + outputs=['y1', 'y2'], + ) - return ([node], [x], [y]) + return ([node], [x], [y1, y2]) @onnx_test @@ -1355,12 +5213,45 @@ def sqrt_test(): return ([node], [x], [y]) +@onnx_test +def squeeze_axes_input_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 1, 5, 1]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 5]) + axes = np.array([1, 3], dtype=np.int64) + axes_tensor = helper.make_tensor(name="axes", + data_type=TensorProto.INT64, + dims=axes.shape, + vals=axes.astype(np.int64)) + + node = onnx.helper.make_node('Squeeze', + inputs=['x', 'axes'], + outputs=['y']) + + return ([node], [x], [y], [axes_tensor]) + + +@onnx_test +def squeeze_empty_axes_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 1, 5, 1]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 5]) + axes = np.array([], dtype=np.int64) + axes_tensor = helper.make_tensor(name="axes", + data_type=TensorProto.INT64, + dims=axes.shape, + vals=axes.astype(np.int64)) + + node = onnx.helper.make_node('Squeeze', + inputs=['x', 'axes'], + outputs=['y']) + + return ([node], [x], [y], [axes_tensor]) + + @onnx_test def squeeze_unsqueeze_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 1, 1, 2, 1]) - y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 2]) - z = helper.make_tensor_value_info('2', TensorProto.FLOAT, + y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 3, 1, 2, 1]) node = onnx.helper.make_node('Squeeze', @@ -1373,7 +5264,7 @@ def squeeze_unsqueeze_test(): axes=[0, 1, 3, 5], outputs=['2']) - return ([node, node2], [x], [z]) + return ([node, node2], [x], [y]) @onnx_test @@ -1404,7 +5295,7 @@ def sub_scalar_test(): values_tensor = helper.make_tensor(name='const', data_type=TensorProto.FLOAT, - dims=values.shape, + dims=values.reshape(()).shape, vals=values.flatten().astype(float)) arg_const = onnx.helper.make_node( @@ -1424,20 +5315,23 @@ def sub_scalar_test(): @onnx_test -def sum_test(): - a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) - b = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3]) - c = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3]) +def sum_int_test(): + a = helper.make_tensor_value_info('0', TensorProto.INT16, [3]) + b = helper.make_tensor_value_info('1', TensorProto.UINT16, [3]) + c = helper.make_tensor_value_info('2', TensorProto.UINT32, [3]) + y = helper.make_tensor_value_info('3', TensorProto.UINT32, [3]) - y = helper.make_tensor_value_info('3', TensorProto.FLOAT, [3]) + cnode1 = onnx.helper.make_node('Cast', inputs=['0'], outputs=['c0'], to=12) + + cnode2 = onnx.helper.make_node('Cast', inputs=['1'], outputs=['c1'], to=12) node = onnx.helper.make_node( 'Sum', - inputs=['0', '1', '2'], + inputs=['c0', 'c1', '2'], outputs=['3'], ) - return ([node], [a, b, c], [y]) + return ([cnode1, cnode2, node], [a, b, c], [y]) @onnx_test @@ -1456,6 +5350,100 @@ def sum_test(): return ([node], [a, b, c], [y]) +@onnx_test +def sum_type_test(): + valb = np.array([1, 0]) + t_bool = helper.make_tensor(name="bool", + data_type=TensorProto.BOOL, + dims=valb.shape, + vals=valb.astype(np.bool)) + + val = np.array([1, 1]) + t_int8 = helper.make_tensor(name="int8", + data_type=TensorProto.INT8, + dims=val.shape, + vals=val.astype(np.int8)) + + t_uint8 = helper.make_tensor(name="uint8", + data_type=TensorProto.UINT8, + dims=val.shape, + vals=val.astype(np.uint8)) + + t_uint16 = helper.make_tensor(name="uint16", + data_type=TensorProto.UINT16, + dims=val.shape, + vals=val.astype(np.uint16)) + + t_uint32 = helper.make_tensor(name="uint32", + data_type=TensorProto.UINT32, + dims=val.shape, + vals=val.astype(np.uint32)) + + t_uint64 = helper.make_tensor(name="uint64", + data_type=TensorProto.UINT64, + dims=val.shape, + vals=val.astype(np.uint64)) + + t_double = helper.make_tensor(name="double", + data_type=TensorProto.DOUBLE, + dims=val.shape, + vals=val.astype(np.float64)) + + valr = np.array([1.5, 2.0]) + t_raw = helper.make_tensor(name="raw", + data_type=TensorProto.DOUBLE, + dims=valr.shape, + vals=valr.tobytes(), + raw=True) + + n_bool = onnx.helper.make_node('Cast', + inputs=['bool'], + outputs=['o_bool'], + to=11) + + n_int8 = onnx.helper.make_node('Cast', + inputs=['int8'], + outputs=['o_int8'], + to=11) + + n_uint8 = onnx.helper.make_node('Cast', + inputs=['uint8'], + outputs=['o_uint8'], + to=11) + + n_uint16 = onnx.helper.make_node('Cast', + inputs=['uint16'], + outputs=['o_uint16'], + to=11) + + n_uint32 = onnx.helper.make_node('Cast', + inputs=['uint32'], + outputs=['o_uint32'], + to=11) + + n_uint64 = onnx.helper.make_node('Cast', + inputs=['uint64'], + outputs=['o_uint64'], + to=11) + + node = onnx.helper.make_node( + 'Sum', + inputs=[ + 'o_bool', 'o_int8', 'o_uint8', 'o_uint16', 'o_uint32', 'o_uint64', + 'double', 'raw' + ], + outputs=['out'], + ) + + y = helper.make_tensor_value_info('out', TensorProto.DOUBLE, [2]) + + return ([n_bool, n_int8, n_uint8, n_uint16, n_uint32, n_uint64, + node], [], [y], [ + t_bool, t_int8, t_uint8, t_uint16, t_uint32, t_uint64, + t_double, t_raw + ]) + + @onnx_test def tan_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) @@ -1484,6 +5472,154 @@ def tanh_test(): return ([node], [x], [y]) +@onnx_test +def thresholdedrelu_default_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 3]) + + node = onnx.helper.make_node('ThresholdedRelu', + inputs=['x'], + outputs=['y']) + + return ([node], [x], [y]) + + +@onnx_test +def thresholdedrelu_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 3]) + alpha = 3.0 + + node = onnx.helper.make_node('ThresholdedRelu', + inputs=['x'], + outputs=['y'], + alpha=alpha) + + return ([node], [x], [y]) + + +@onnx_test +def thresholdedrelu_int_test(): + x = helper.make_tensor_value_info('x', TensorProto.INT32, [2, 2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [2, 2, 3]) + alpha = 3.0 + + node = onnx.helper.make_node('ThresholdedRelu', + inputs=['x'], + outputs=['y'], + alpha=alpha) + + return ([node], [x], [y]) + + +@onnx_test +def tile_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2]) + y = helper.make_tensor_value_info('y', TensorProto.INT64, [2]) + z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [2, 4]) + + node = onnx.helper.make_node('Tile', inputs=['x', 'y'], outputs=['z']) + + return ([node], [x, y], [z], + [helper.make_tensor('y', TensorProto.INT64, [2], [1, 2])]) + + +@onnx_test +def tile_test_3x2(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2]) + y = helper.make_tensor_value_info('y', TensorProto.INT64, [2]) + z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [6, 4]) + + node = onnx.helper.make_node('Tile', inputs=['x', 'y'], outputs=['z']) + + return ([node], [x, y], [z], + [helper.make_tensor('y', TensorProto.INT64, [2], [3, 2])]) + + +@onnx_test +def topk_attrk_test(): + x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 5, 3, 2]) + val = helper.make_tensor_value_info('val', TensorProto.FLOAT, [2, 2, 3, 2]) + ind = helper.make_tensor_value_info('indices', TensorProto.INT64, + [2, 2, 3, 2]) + + node = onnx.helper.make_node('TopK', + inputs=['data'], + outputs=['val', 'indices'], + k=2) + return ([node], [x], [val, ind]) + + +@onnx_test +def topk_neg_axis_test(): + k = np.array([3]) + x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6]) + val = helper.make_tensor_value_info('val', TensorProto.FLOAT, [3, 3, 5, 6]) + ind = helper.make_tensor_value_info('indices', TensorProto.INT64, + [3, 3, 5, 6]) + + k_tensor = helper.make_tensor(name='k', + data_type=TensorProto.INT64, + dims=k.shape, + vals=k.astype(np.int64)) + + node = onnx.helper.make_node('TopK', + inputs=['data', 'k'], + outputs=['val', 'indices'], + axis=-2, + sorted=0) + return ([node], [x], [val, ind], [k_tensor]) + + +@onnx_test +def topk_test(): + k = np.array([4]) + x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 5, 3, 2]) + val = helper.make_tensor_value_info('val', TensorProto.FLOAT, [2, 4, 3, 2]) + ind = helper.make_tensor_value_info('indices', TensorProto.INT64, + [2, 4, 3, 2]) + + k_tensor = helper.make_tensor(name='k', + data_type=TensorProto.INT64, + dims=k.shape, + vals=k.astype(np.int64)) + + node = onnx.helper.make_node('TopK', + inputs=['data', 'k'], + outputs=['val', 'indices'], + largest=0, + axis=1) + return ([node], [x], [val, ind], [k_tensor]) + + +def transpose_default_perm_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 5, 2, 3]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 2, 5, 1]) + + node = onnx.helper.make_node( + 'Transpose', + inputs=['0'], + outputs=['1'], + ) + + return ([node], [x], [y]) + + +@onnx_test +def transpose_invalid_perm_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 4, 3]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 2, 2]) + + node = onnx.helper.make_node( + 'Transpose', + perm=[0, 2, 1], + inputs=['0'], + outputs=['1'], + ) + + return ([node], [x], [y]) + + @onnx_test def transpose_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 2, 3]) @@ -1529,11 +5665,23 @@ def transpose_gather_test(): return ([td, ti, node], [x, i], [y]) +@onnx_test +def undefined_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 3, 4, 5]) + + node = onnx.helper.make_node('Identity', inputs=[''], outputs=['1']) + + return ([node], [x], [y]) + + @onnx_test def unknown_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 4]) - z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [2, 3, 4, 5]) + + helper.make_tensor_value_info('2', TensorProto.FLOAT, [2, 3, 4, 5]) + a = helper.make_tensor_value_info('3', TensorProto.FLOAT, [2, 3, 4, 5]) node = onnx.helper.make_node('Unknown', inputs=['0', '1'], outputs=['2']) @@ -1541,3 +5689,97 @@ def unknown_test(): node2 = onnx.helper.make_node('Unknown', inputs=['2'], outputs=['3']) return ([node, node2], [x, y], [a]) + + +@onnx_test +def unknown_aten_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 4]) + + helper.make_tensor_value_info('2', TensorProto.FLOAT, [2, 3, 4, 5]) + + a = helper.make_tensor_value_info('3', TensorProto.FLOAT, [2, 3, 4, 5]) + + node = onnx.helper.make_node('ATen', + inputs=['0', '1'], + outputs=['2'], + operator='unknown') + + return ([node], [x, y], [a]) + + +@onnx_test +def upsample_linear_test(): + scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32) + scales_tensor = helper.make_tensor(name='scales', + data_type=TensorProto.FLOAT, + dims=scales.shape, + vals=scales.flatten().astype( + np.float32)) + X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 2]) + Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, []) + + node = onnx.helper.make_node('Upsample', + inputs=['X', '', 'scales'], + outputs=['Y'], + mode='linear') + + return ([node], [X], [Y], [scales_tensor]) + + +@onnx_test +def upsample_test(): + scales = np.array([1.0, 1.0, 2.0, 3.0], dtype=np.float32) + scale_tensor = helper.make_tensor(name='scales', + data_type=TensorProto.FLOAT, + dims=scales.shape, + vals=scales.flatten().astype(np.float32)) + + X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 2]) + Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1, 4, 6]) + + node = onnx.helper.make_node( + 'Upsample', + inputs=['X', 'scales'], + outputs=['Y'], + mode='nearest', + ) + + return ([node], [X], [Y], [scale_tensor]) + + +@onnx_test +def variable_batch_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, + [None, 3, 16, 16]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, + [None, 3, 16, 16]) + + node = onnx.helper.make_node('Identity', inputs=['0'], outputs=['1']) + + return ([node], [x], [y]) + + +@onnx_test +def variable_batch_leq_zero_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [0, 3, 16, 16]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [-1, 3, 16, 16]) + + z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [-1, 3, 16, 16]) + node = onnx.helper.make_node('Add', inputs=['0', '1'], outputs=['2']) + + return ([node], [x, y], [z]) + + +@onnx_test +def where_test(): + c = helper.make_tensor_value_info('c', TensorProto.BOOL, [2]) + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 1, 2, 2]) + + z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [2, 2, 2, 2]) + node = onnx.helper.make_node('Where', + inputs=['c', 'x', 'y'], + outputs=['z']) + + return ([node], [c, x, y], [z]) diff --git a/test/onnx/globallppool_test.onnx b/test/onnx/globallppool_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5dfb8761ede85c8960aa3c25b9199894dd2425f2 --- /dev/null +++ b/test/onnx/globallppool_test.onnx @@ -0,0 +1,15 @@ +globallppool_test:c + +01" GlobalLpPoolgloballppool_testZ +0 + + + + +b +1 + + + + +B \ No newline at end of file diff --git a/test/onnx/greater_bool_test.onnx b/test/onnx/greater_bool_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..24f0caa4ee136ca60ede37bc6c2bd19eca6cf4f8 --- /dev/null +++ b/test/onnx/greater_bool_test.onnx @@ -0,0 +1,19 @@ +greater_bool_test:‡ + +x1bx1"Cast* +to   + +bx1 +x2y"Greatergreater_bool_testZ +x1 +  + +Z +x2 +   + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/greater_test.onnx b/test/onnx/greater_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..41d8990c6445f935d6ea96ce3e7b56706f406866 Binary files /dev/null and b/test/onnx/greater_test.onnx differ diff --git a/test/onnx/greaterorequal_test.onnx b/test/onnx/greaterorequal_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8f83e069bb7343676ed3b820fcb178520e347f60 --- /dev/null +++ b/test/onnx/greaterorequal_test.onnx @@ -0,0 +1,16 @@ +greaterorequal_test:g + +x1 +x2y"GreaterOrEqualgreaterorequal_testZ +x1 + + +Z +x2 + + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/hardsigmoid_default_test.onnx b/test/onnx/hardsigmoid_default_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..893416d9ab90fe4ae73c4d14220ac479408487af --- /dev/null +++ b/test/onnx/hardsigmoid_default_test.onnx @@ -0,0 +1,15 @@ +hardsigmoid_default_test:i + +xy" HardSigmoidhardsigmoid_default_testZ +x + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/hardsigmoid_double_test.onnx b/test/onnx/hardsigmoid_double_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a972352ddce376aec7953a0e2aa51d558e4943a0 --- /dev/null +++ b/test/onnx/hardsigmoid_double_test.onnx @@ -0,0 +1,17 @@ +hardsigmoid_double_test:‰ +4 +xy" HardSigmoid* +alphaš™™> * +beta333? hardsigmoid_double_testZ +x +  + + + +b +y +  + + + +B \ No newline at end of file diff --git a/test/onnx/hardsigmoid_half_test.onnx b/test/onnx/hardsigmoid_half_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..71939c2a97ef71d76eb490faef5ef31c21f77937 --- /dev/null +++ b/test/onnx/hardsigmoid_half_test.onnx @@ -0,0 +1,17 @@ +hardsigmoid_half_test:f + +xy" HardSigmoidhardsigmoid_half_testZ +x + + + + + +b +y + + + + + +B \ No newline at end of file diff --git a/test/onnx/hardsigmoid_verify_test.onnx b/test/onnx/hardsigmoid_verify_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f2cc586055d9725131b5fd7bf05c669cd7d75fdd --- /dev/null +++ b/test/onnx/hardsigmoid_verify_test.onnx @@ -0,0 +1,11 @@ +hardsigmoid_verify_test:X + +xy" HardSigmoidhardsigmoid_verify_testZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/hardswish_test.onnx b/test/onnx/hardswish_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..532bc2de6d5fcfc1fbf9af37425a8dac6bd250c1 --- /dev/null +++ b/test/onnx/hardswish_test.onnx @@ -0,0 +1,11 @@ +hardswish_test:M + +xy" HardSwishhardswish_testZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/if_else_test.onnx b/test/onnx/if_else_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..555458ac07e05fa820e2d09d462ebf3ec568ebdf Binary files /dev/null and b/test/onnx/if_else_test.onnx differ diff --git a/test/onnx/if_literal_test.onnx b/test/onnx/if_literal_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c61cf489287bb7b2fabe60eb37f06d1d8374bbfd Binary files /dev/null and b/test/onnx/if_literal_test.onnx differ diff --git a/test/onnx/if_param_excp1_test.onnx b/test/onnx/if_param_excp1_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a5e8ed77196a5715b6eb60148f427b336e71c208 Binary files /dev/null and b/test/onnx/if_param_excp1_test.onnx differ diff --git a/test/onnx/if_param_excp_test.onnx b/test/onnx/if_param_excp_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d2f09df30548de1383ee04877fd1edd016c163b0 Binary files /dev/null and b/test/onnx/if_param_excp_test.onnx differ diff --git a/test/onnx/if_param_test.onnx b/test/onnx/if_param_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..492a4e354d912f054db9eec34f785dcce0b40131 Binary files /dev/null and b/test/onnx/if_param_test.onnx differ diff --git a/test/onnx/if_pl_test.onnx b/test/onnx/if_pl_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9bb7b626bc133ce4a0f69d3df3ec76206247b25c Binary files /dev/null and b/test/onnx/if_pl_test.onnx differ diff --git a/test/onnx/if_then_test.onnx b/test/onnx/if_then_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..26c08341e451c2ed2016e33a5a36dfd771a1a2f6 Binary files /dev/null and b/test/onnx/if_then_test.onnx differ diff --git a/test/onnx/if_tuple_test.onnx b/test/onnx/if_tuple_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f8d6973c7e8daddd7436c82c5ea732085636feb7 Binary files /dev/null and b/test/onnx/if_tuple_test.onnx differ diff --git a/test/onnx/imagescaler_half_test.onnx b/test/onnx/imagescaler_half_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8d47329e5a8c3df10791b5bdf45d78898b1be7ce Binary files /dev/null and b/test/onnx/imagescaler_half_test.onnx differ diff --git a/test/onnx/implicit_sub_bcast_test.onnx b/test/onnx/implicit_sub_bcast_test.onnx index 8e80534963dedff92d6ccfdf953e96f60a187049..519e23b8c97215bdf7496fbaf174f905ea3bd269 100644 --- a/test/onnx/implicit_sub_bcast_test.onnx +++ b/test/onnx/implicit_sub_bcast_test.onnx @@ -1,20 +1,20 @@ -add2:q +implicit_sub_bcast_test:|  0 -1out"Sub subtraction2Z +1out"Subimplicit_sub_bcast_testZ 0 - +     Z 1 -  +    b out - +     -B +B \ No newline at end of file diff --git a/test/onnx/instance_norm_test.onnx b/test/onnx/instance_norm_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..6dd7c57c5a4bbecbafb9e2f31a5ef6d8d46eef4b --- /dev/null +++ b/test/onnx/instance_norm_test.onnx @@ -0,0 +1,25 @@ +instance_norm_test:• +# +0 +1 +23"InstanceNormalizationinstance_norm_testZ +0 + + + + +Z +1 + + +Z +2 + + +b +3 + + + + +B diff --git a/test/onnx/instance_norm_val_3d_test.onnx b/test/onnx/instance_norm_val_3d_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c6dcb1045c4bd76cbb521171ba1489c02f0233fc Binary files /dev/null and b/test/onnx/instance_norm_val_3d_test.onnx differ diff --git a/test/onnx/instance_norm_val_test.onnx b/test/onnx/instance_norm_val_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..13566f1f366110aff97eb51359e447c6f5ec8766 Binary files /dev/null and b/test/onnx/instance_norm_val_test.onnx differ diff --git a/test/onnx/isnan_float_test.onnx b/test/onnx/isnan_float_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e451ad185f41b3e9a87dafd3d97b83acbea8e6cd --- /dev/null +++ b/test/onnx/isnan_float_test.onnx @@ -0,0 +1,11 @@ +isnan_float_test:O + +t1t2"IsNaNisnan_float_testZ +t1 +  + +b +t2 +  + +B \ No newline at end of file diff --git a/test/onnx/isnan_half_test.onnx b/test/onnx/isnan_half_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..2079294486f750cdd9a553f4d3b1cf40f01ee85b --- /dev/null +++ b/test/onnx/isnan_half_test.onnx @@ -0,0 +1,13 @@ +isnan_half_test:N + +t1t2"IsNaNisnan_half_testZ +t1 +  + + +b +t2 +  + + +B \ No newline at end of file diff --git a/test/onnx/less_bool_test.onnx b/test/onnx/less_bool_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..44cf895e89a882663c8f71116f4965c2fccab3be --- /dev/null +++ b/test/onnx/less_bool_test.onnx @@ -0,0 +1,19 @@ +less_bool_test: + +x1bx1"Cast* +to   + +bx1 +x2y"Lessless_bool_testZ +x1 +  + +Z +x2 +   + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/less_test.onnx b/test/onnx/less_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3cee969b90b424e424438e1a495fe48a2c86ecca Binary files /dev/null and b/test/onnx/less_test.onnx differ diff --git a/test/onnx/lessorequal_test.onnx b/test/onnx/lessorequal_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4c876d569131c30510588c0df6da74ad49c061d3 --- /dev/null +++ b/test/onnx/lessorequal_test.onnx @@ -0,0 +1,16 @@ +lessorequal_test:a + +x1 +x2y" LessOrEquallessorequal_testZ +x1 + + +Z +x2 + + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/logical_and_bcast_test.onnx b/test/onnx/logical_and_bcast_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f2b4cb421d4f10ebf6bef1a652cb14f734e3e82c --- /dev/null +++ b/test/onnx/logical_and_bcast_test.onnx @@ -0,0 +1,20 @@ +logical_and_bcast_test:w + +0 +12"Andlogical_and_bcast_testZ +0 +  + + + +Z +1 +   + +b +2 +  + + + +B \ No newline at end of file diff --git a/test/onnx/logical_or_test.onnx b/test/onnx/logical_or_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4a54339ecb609f1e87118ec49fefe70ad145895b --- /dev/null +++ b/test/onnx/logical_or_test.onnx @@ -0,0 +1,22 @@ +logical_or_test:w + +0 +12"Orlogical_or_testZ +0 +  + + + +Z +1 +  + + + +b +2 +  + + + +B \ No newline at end of file diff --git a/test/onnx/logical_xor_bcast_test.onnx b/test/onnx/logical_xor_bcast_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..dc8c9c8073cb77fa3aa492a5beeeb26d4ac31c13 --- /dev/null +++ b/test/onnx/logical_xor_bcast_test.onnx @@ -0,0 +1,20 @@ +logical_xor_bcast_test:w + +0 +12"Xorlogical_xor_bcast_testZ +0 +  + + + +Z +1 +   + +b +2 +  + + + +B \ No newline at end of file diff --git a/test/onnx/logsoftmax_nonstd_input_test.onnx b/test/onnx/logsoftmax_nonstd_input_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ef20f6a6fc442dcfd00f0a19f344f87487483b68 Binary files /dev/null and b/test/onnx/logsoftmax_nonstd_input_test.onnx differ diff --git a/test/onnx/loop_default_test.onnx b/test/onnx/loop_default_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..17dccad67d7d058380e6f7fa811c493c5c4a2738 Binary files /dev/null and b/test/onnx/loop_default_test.onnx differ diff --git a/test/onnx/loop_test.onnx b/test/onnx/loop_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7bf0bf0cedd7accda7dd2f70bc74f59982988674 --- /dev/null +++ b/test/onnx/loop_test.onnx @@ -0,0 +1,77 @@ + loop_test:ì +¿ +max_trip_count +keep_going_cond +bb_loop my_local_loopuser_defined_vals_loop"Loop*ã +body2× + +a +b_inmy_local"Add + +a +b_in +a_sub_b_in"Sub ++ +my_local + +a_sub_b_in +keep_going"Greater +0 + +a_sub_b_in + +a_sub_b_inuser_defined_vals"AddbodyZ + iteration_num + + +Z +keep_going_inp + +  +Z +b_in + + +b + +keep_going + +  +b + +a_sub_b_in + + +b +my_local + + +b +user_defined_vals + + +  loop_testZ +max_trip_count + + +Z +keep_going_cond + +  +Z +a + + +Z +b + + +b +b_loop + + +b( +user_defined_vals_loop +  + +B \ No newline at end of file diff --git a/test/onnx/lpnormalization_axis_error_test.onnx b/test/onnx/lpnormalization_axis_error_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f73d279ab1e511a33ddc2f4b2067576dd2e483de --- /dev/null +++ b/test/onnx/lpnormalization_axis_error_test.onnx @@ -0,0 +1,12 @@ +lpnormalization_axis_error_test:q +$ +xy"LpNormalization* +axis lpnormalization_axis_error_testZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/lpnormalization_default_test.onnx b/test/onnx/lpnormalization_default_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9f3accd444366f5ce47472792db3b5a3572a2150 Binary files /dev/null and b/test/onnx/lpnormalization_default_test.onnx differ diff --git a/test/onnx/lpnormalization_l1_test.onnx b/test/onnx/lpnormalization_l1_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..2af3ec34864e61aa36d4e8b483a8d36eff982a62 --- /dev/null +++ b/test/onnx/lpnormalization_l1_test.onnx @@ -0,0 +1,12 @@ +lpnormalization_l1_test:f +! +xy"LpNormalization* +p lpnormalization_l1_testZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/lpnormalization_l2_test.onnx b/test/onnx/lpnormalization_l2_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c607b672f162761ba7310d14ea189ea8a1ff7f58 --- /dev/null +++ b/test/onnx/lpnormalization_l2_test.onnx @@ -0,0 +1,12 @@ +lpnormalization_l2_test:f +! +xy"LpNormalization* +p lpnormalization_l2_testZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/lpnormalization_p_error_test.onnx b/test/onnx/lpnormalization_p_error_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..517621a73218daa291391af4b388660512716d46 --- /dev/null +++ b/test/onnx/lpnormalization_p_error_test.onnx @@ -0,0 +1,12 @@ +lpnormalization_p_error_test:k +! +xy"LpNormalization* +p lpnormalization_p_error_testZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/lppool_l1_test.onnx b/test/onnx/lppool_l1_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..676504e75d51d639b65121c9959bf177b059b369 --- /dev/null +++ b/test/onnx/lppool_l1_test.onnx @@ -0,0 +1,15 @@ +lppool_l1_test:q +- +xy"LpPool* + kernel_shape@ * +p lppool_l1_testZ +x + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/lppool_l2_test.onnx b/test/onnx/lppool_l2_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fa4435ffba37db0c7e78b8afbe15b1ee3a66bc50 --- /dev/null +++ b/test/onnx/lppool_l2_test.onnx @@ -0,0 +1,15 @@ +lppool_l2_test:q +- +xy"LpPool* + kernel_shape@ * +p lppool_l2_testZ +x + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/matmulinteger_test.onnx b/test/onnx/matmulinteger_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..369c9b077a482b495a1d1a0915db60bc16d98bcd --- /dev/null +++ b/test/onnx/matmulinteger_test.onnx @@ -0,0 +1,19 @@ +matmulinteger_test:y + +1 +2y" MatMulIntegermatmulinteger_testZ +1 + + + +Z +2 + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/maxpool_notset_test.onnx b/test/onnx/maxpool_notset_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..810c231f5fda656dbaad1bacce008d05505eb24c Binary files /dev/null and b/test/onnx/maxpool_notset_test.onnx differ diff --git a/test/onnx/maxpool_same_upper_test.onnx b/test/onnx/maxpool_same_upper_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e1fa39dde6495794f5d281662f4eec6d4d152d31 --- /dev/null +++ b/test/onnx/maxpool_same_upper_test.onnx @@ -0,0 +1,18 @@ +maxpool_same_upper_test:– +A +xy"MaxPool* +auto_pad" +SAME_UPPER * + kernel_shape@@ maxpool_same_upper_testZ +x + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/mean_broadcast_test.onnx b/test/onnx/mean_broadcast_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..556dd1dc902aa517f377caff07c8430f4c495d6a --- /dev/null +++ b/test/onnx/mean_broadcast_test.onnx @@ -0,0 +1,37 @@ +mean_broadcast_test:à + +0 +1 +2 +3 +4mean"Meanmean_broadcast_testZ +0 + + + +Z +1 + + + + +Z +2 + + +Z +3 + + +Z +4 + + + +b +mean + + + + +B \ No newline at end of file diff --git a/test/onnx/mean_fp16_test.onnx b/test/onnx/mean_fp16_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8ad779e6da59f499c8f945968f90ea38530fc783 --- /dev/null +++ b/test/onnx/mean_fp16_test.onnx @@ -0,0 +1,29 @@ +mean_fp16_test:Ž + +0 +1 +2mean"Meanmean_fp16_testZ +0 + + + + +Z +1 + + + + +Z +2 + + + + +b +mean + + + + +B \ No newline at end of file diff --git a/test/onnx/mean_integral_test.onnx b/test/onnx/mean_integral_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..39b19b326929dc74f723a6640f02578bc0b930af --- /dev/null +++ b/test/onnx/mean_integral_test.onnx @@ -0,0 +1,67 @@ +mean_integral_test:Ö +* +0 +1 +2 +3 +4 +5 +6 +7 +8 +9mean"Meanmean_integral_testZ +0 + + + +Z +1 + + + +Z +2 + + + +Z +3 + + + +Z +4 + + + +Z +5 + + + +Z +6 + + + +Z +7 + + + +Z +8 + + + +Z +9 + + + +b +mean + + + +B \ No newline at end of file diff --git a/test/onnx/mean_invalid_broadcast_test.onnx b/test/onnx/mean_invalid_broadcast_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..48f5f098bc6185b835c13595bd5ecaa05ba3496e --- /dev/null +++ b/test/onnx/mean_invalid_broadcast_test.onnx @@ -0,0 +1,25 @@ +mean_invalid_broadcast_test:› + +0 +1 +2mean"Meanmean_invalid_broadcast_testZ +0 + + + +Z +1 + + + +Z +2 + + + +b +mean + + + +B \ No newline at end of file diff --git a/test/onnx/mean_single_input_test.onnx b/test/onnx/mean_single_input_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..51b9dd1da5996ccfcfe4ab51ce98cc9435551e17 --- /dev/null +++ b/test/onnx/mean_single_input_test.onnx @@ -0,0 +1,13 @@ +mean_single_input_test:^ + +0mean"Meanmean_single_input_testZ +0 + + + +b +mean + + + +B \ No newline at end of file diff --git a/test/onnx/mean_test.onnx b/test/onnx/mean_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5626c452ea67fc9bba4e4f808373f94a6ac930a4 --- /dev/null +++ b/test/onnx/mean_test.onnx @@ -0,0 +1,67 @@ + mean_test:Í +* +0 +1 +2 +3 +4 +5 +6 +7 +8 +9mean"Mean mean_testZ +0 +  + + +Z +1 +  + + +Z +2 +  + + +Z +3 +  + + +Z +4 +  + + +Z +5 +  + + +Z +6 +  + + +Z +7 +  + + +Z +8 +  + + +Z +9 +  + + +b +mean +  + + +B \ No newline at end of file diff --git a/test/onnx/multinomial_dtype_error_test.onnx b/test/onnx/multinomial_dtype_error_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..26a991ae43f182f0194f958815984ed225649945 Binary files /dev/null and b/test/onnx/multinomial_dtype_error_test.onnx differ diff --git a/test/onnx/multinomial_generated_seed_test.onnx b/test/onnx/multinomial_generated_seed_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1f3af38ee2865691937776052225bf5f2bdd0d4b --- /dev/null +++ b/test/onnx/multinomial_generated_seed_test.onnx @@ -0,0 +1,15 @@ +multinomial_generated_seed_test:† +0 +inputoutput" Multinomial* + sample_size + multinomial_generated_seed_testZ +input +  + + +b +output +  + + +B \ No newline at end of file diff --git a/test/onnx/multinomial_int64_test.onnx b/test/onnx/multinomial_int64_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f4c4114a109670ba1b48b2cc38a79492952feed0 Binary files /dev/null and b/test/onnx/multinomial_int64_test.onnx differ diff --git a/test/onnx/multinomial_test.onnx b/test/onnx/multinomial_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e2a40ac7f0bc49cfa05c6ec5729aafd0687dd80c Binary files /dev/null and b/test/onnx/multinomial_test.onnx differ diff --git a/test/onnx/neg_test.onnx b/test/onnx/neg_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..16806cddfd42a8c917b329da47fb911425a3f47e --- /dev/null +++ b/test/onnx/neg_test.onnx @@ -0,0 +1,11 @@ +neg_test:A + +01"Negneg_testZ +0 +  + +b +1 +  + +B \ No newline at end of file diff --git a/test/onnx/nms_test.onnx b/test/onnx/nms_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..484a8b81a1fb1316a5a125fb9380aa794553e368 --- /dev/null +++ b/test/onnx/nms_test.onnx @@ -0,0 +1,34 @@ +nms_test:Û +‰ +boxes +scores +max_output_boxes_per_class + iou_threshold +score_thresholdselected_indices"NonMaxSuppression* +center_point_box nms_testZ +boxes + + + +Z +scores + + + +Z( +max_output_boxes_per_class + + +Z + iou_threshold + + +Z +score_threshold + + +b" +selected_indices +  + +B \ No newline at end of file diff --git a/test/onnx/nonzero_dynamic_test.onnx b/test/onnx/nonzero_dynamic_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cfe3b2ce50066f2f070e2054ab54042d9aae2059 --- /dev/null +++ b/test/onnx/nonzero_dynamic_test.onnx @@ -0,0 +1,11 @@ +nonzero_dynamic_test:c + +dataindices"NonZerononzero_dynamic_testZ +data +   + +b +indices +  + +B \ No newline at end of file diff --git a/test/onnx/nonzero_int_test.onnx b/test/onnx/nonzero_int_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..44b4e77eeff88a8a7ffcb548e9d94041c9c33001 Binary files /dev/null and b/test/onnx/nonzero_int_test.onnx differ diff --git a/test/onnx/nonzero_test.onnx b/test/onnx/nonzero_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..16631d7ecbab998e90cd522c80a9b520774f15f5 Binary files /dev/null and b/test/onnx/nonzero_test.onnx differ diff --git a/test/onnx/not_bool_test.onnx b/test/onnx/not_bool_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0ae7ad92b12bad67cdae8ced1f9d58daf518ad2c --- /dev/null +++ b/test/onnx/not_bool_test.onnx @@ -0,0 +1,11 @@ + not_bool_test:> + +01"Not not_bool_testZ +0 + +  +b +1 + +  +B \ No newline at end of file diff --git a/test/onnx/not_test.onnx b/test/onnx/not_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f993833c53fbe37fd6688be449abec89e6b869ac --- /dev/null +++ b/test/onnx/not_test.onnx @@ -0,0 +1,11 @@ +not_test:9 + +01"Notnot_testZ +0 + + +b +1 + + +B \ No newline at end of file diff --git a/test/onnx/onehot_test.onnx b/test/onnx/onehot_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa4e4f1ee994f6aed472d582f2afb8783aafe57c Binary files /dev/null and b/test/onnx/onehot_test.onnx differ diff --git a/test/onnx/onnx_lstm_cell.onnx b/test/onnx/onnx_lstm_cell.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5bdfda48f7d60b08c42d3f25fb19a097ca83a26e Binary files /dev/null and b/test/onnx/onnx_lstm_cell.onnx differ diff --git a/test/onnx/onnx_lstm_hs.onnx b/test/onnx/onnx_lstm_hs.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e6ca5ccbaf36f15c4f7aac6c3376b61f2f14a4e6 Binary files /dev/null and b/test/onnx/onnx_lstm_hs.onnx differ diff --git a/test/onnx/onnx_lstm_last.onnx b/test/onnx/onnx_lstm_last.onnx new file mode 100644 index 0000000000000000000000000000000000000000..05849ceb15b1b2b6047c85c22f6eb3441d670a3f Binary files /dev/null and b/test/onnx/onnx_lstm_last.onnx differ diff --git a/test/onnx/onnx_rnn_test.cpp b/test/onnx/onnx_rnn_test.cpp old mode 100644 new mode 100755 index c2016a6dd82aa99ab99dce877b5ead5d4c8f9cd1..343a6b164e90ee9af678a078c3938200b1f86b61 --- a/test/onnx/onnx_rnn_test.cpp +++ b/test/onnx/onnx_rnn_test.cpp @@ -4,9 +4,33 @@ #include #include #include +#include +#include +#include #include +#include + +#include + #include "test.hpp" +migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode = true) +{ + auto prog = migraphx::parse_onnx(name); + auto* mm = prog.get_main_module(); + if(eliminate_deadcode) + migraphx::run_passes(*mm, {migraphx::dead_code_elimination{}}); + + // remove the last identity instruction + auto last_ins = std::prev(mm->end()); + if(last_ins->name() == "@return") + { + mm->remove_instruction(last_ins); + } + + return prog; +} + TEST_CASE(rnn_test_bidirectional) { std::size_t sl = 5; // sequence len @@ -23,27 +47,32 @@ TEST_CASE(rnn_test_bidirectional) migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}}; migraphx::program p; - - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto ih = p.add_parameter("h0", ih_shape); - - auto out_hs = - p.add_instruction(migraphx::op::rnn{hs, - {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - seq_len, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_rnn_bi.onnx"); + auto* mm = p.get_main_module(); + + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + seq_len, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_rnn_bi.onnx"); EXPECT(p == prog); } @@ -66,26 +95,31 @@ TEST_CASE(rnn_test_one_direction) // forward { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto ih = p.add_parameter("h0", ih_shape); - - auto out_hs = - p.add_instruction(migraphx::op::rnn{hs, - {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r, - bias, - seq_len, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_rnn_forward.onnx"); + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{ + migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + seq_len, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_rnn_forward.onnx"); EXPECT(p == prog); } @@ -93,25 +127,30 @@ TEST_CASE(rnn_test_one_direction) // reverse { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto ih = p.add_parameter("h0", ih_shape); - auto out_hs = - p.add_instruction(migraphx::op::rnn{hs, - {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::reverse, - clip}, - seq, - w, - r, - bias, - seq_len, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_rnn_reverse.onnx"); + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto out_hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{ + migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r, + bias, + seq_len, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_rnn_reverse.onnx"); EXPECT(p == prog); } @@ -119,23 +158,28 @@ TEST_CASE(rnn_test_one_direction) // 3 argumments { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto out_hs = - p.add_instruction(migraphx::op::rnn{hs, - {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::reverse, - clip}, - seq, - w, - r, - und, - und, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx"); + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto out_hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{ + migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r, + und, + und, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_rnn_3args.onnx"); EXPECT(p == prog); } @@ -143,27 +187,32 @@ TEST_CASE(rnn_test_one_direction) // 5 argumments { migraphx::program p; - - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = - p.add_instruction(migraphx::op::rnn{hs, - {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r, - bias, - seq_len, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx"); + auto* mm = p.get_main_module(); + + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{ + migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + seq_len, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_rnn_5args.onnx"); EXPECT(p == prog); } @@ -181,33 +230,39 @@ TEST_CASE(gru_test) { nd = 1; migraphx::program p; + auto* mm = p.get_main_module(); auto seq = - p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); + mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = - p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); + mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = - p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); + mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = - p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); - auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); - - auto out_hs = - p.add_instruction(migraphx::op::gru{hs, - {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::forward, - clip, - 1}, - seq, - w, - r, - bias, - seq_len, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx"); + mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); + auto ih = + mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{ + migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + seq_len, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_gru_forward.onnx"); EXPECT(p == prog); } @@ -216,32 +271,38 @@ TEST_CASE(gru_test) { nd = 1; migraphx::program p; + auto* mm = p.get_main_module(); auto seq = - p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); + mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = - p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); + mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = - p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); + mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = - p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); - auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); - - auto out_hs = - p.add_instruction(migraphx::op::gru{hs, - {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::reverse, - clip}, - seq, - w, - r, - bias, - seq_len, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx"); + mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); + auto ih = + mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{ + migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r, + bias, + seq_len, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_gru_reverse.onnx"); EXPECT(p == prog); } @@ -250,35 +311,40 @@ TEST_CASE(gru_test) { nd = 2; migraphx::program p; + auto* mm = p.get_main_module(); auto seq = - p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); + mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = - p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); + mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = - p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); + mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = - p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); - auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); - - auto out_hs = - p.add_instruction(migraphx::op::gru{hs, - {migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::relu{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - seq_len, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx"); + mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); + auto ih = + mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("relu"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + seq_len, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_gru_bi.onnx"); EXPECT(p == prog); } @@ -297,27 +363,32 @@ TEST_CASE(gru_test_args) { nd = 1; migraphx::program p; + auto* mm = p.get_main_module(); auto seq = - p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); + mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = - p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); - auto und = p.add_instruction(migraphx::op::undefined{}); - auto out_hs = - p.add_instruction(migraphx::op::gru{hs, - {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r, - und, - und, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx"); + mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto out_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{ + migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + und, + und, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_gru_3arg.onnx"); EXPECT(p == prog); } @@ -326,30 +397,35 @@ TEST_CASE(gru_test_args) { nd = 1; migraphx::program p; + auto* mm = p.get_main_module(); auto seq = - p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); + mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = - p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); + mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = - p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = - p.add_instruction(migraphx::op::gru{hs, - {migraphx::op::relu{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip}, - seq, - w, - r, - bias, - und, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx"); + mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("relu"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_gru_4arg.onnx"); EXPECT(p == prog); } @@ -358,35 +434,39 @@ TEST_CASE(gru_test_args) { nd = 2; migraphx::program p; + auto* mm = p.get_main_module(); auto seq = - p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); + mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = - p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); + mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = - p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); + mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = - p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = - p.add_instruction(migraphx::op::gru{hs, - {migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::relu{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - seq_len, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx"); + mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("relu"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + seq_len, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_gru_5arg.onnx"); EXPECT(p == prog); } @@ -404,35 +484,40 @@ TEST_CASE(gru_test_actv_funcs) { nd = 2; migraphx::program p; + auto* mm = p.get_main_module(); auto seq = - p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); + mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = - p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); + mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = - p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); + mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = - p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); - auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); - - auto out_hs = - p.add_instruction(migraphx::op::gru{hs, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - seq_len, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx"); + mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); + auto ih = + mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + seq_len, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_gru_bi_0.onnx"); EXPECT(p == prog); } @@ -441,35 +526,41 @@ TEST_CASE(gru_test_actv_funcs) { nd = 2; migraphx::program p; + auto* mm = p.get_main_module(); auto seq = - p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); + mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = - p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); + mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = - p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); + mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = - p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); - auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); - - auto out_hs = - p.add_instruction(migraphx::op::gru{hs, - {migraphx::op::sigmoid{}, - migraphx::op::sigmoid{}, - migraphx::op::sigmoid{}, - migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - seq_len, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx"); + mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); + auto ih = + mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("sigmoid"), + migraphx::make_op("sigmoid"), + migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + seq_len, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_gru_bi_1.onnx"); EXPECT(p == prog); } @@ -478,35 +569,41 @@ TEST_CASE(gru_test_actv_funcs) { nd = 2; migraphx::program p; + auto* mm = p.get_main_module(); auto seq = - p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); + mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = - p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); + mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = - p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); + mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = - p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); - auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); - - auto out_hs = - p.add_instruction(migraphx::op::gru{hs, - {migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - seq_len, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx"); + mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); + auto ih = + mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + seq_len, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_gru_bi_2.onnx"); EXPECT(p == prog); } @@ -515,35 +612,40 @@ TEST_CASE(gru_test_actv_funcs) { nd = 2; migraphx::program p; + auto* mm = p.get_main_module(); auto seq = - p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); + mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = - p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); + mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = - p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); + mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = - p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); - auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); - - auto out_hs = - p.add_instruction(migraphx::op::gru{hs, - {migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - seq, - w, - r, - bias, - seq_len, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx"); + mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); + auto ih = + mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + seq_len, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_gru_bi_3.onnx"); EXPECT(p == prog); } @@ -552,32 +654,38 @@ TEST_CASE(gru_test_actv_funcs) { nd = 1; migraphx::program p; + auto* mm = p.get_main_module(); auto seq = - p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); + mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = - p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); + mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = - p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); + mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = - p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); - auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); - - auto out_hs = - p.add_instruction(migraphx::op::gru{hs, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - seq, - w, - r, - bias, - seq_len, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx"); + mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); + auto ih = + mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + seq_len, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_gru_forward_0.onnx"); EXPECT(p == prog); } @@ -586,32 +694,38 @@ TEST_CASE(gru_test_actv_funcs) { nd = 1; migraphx::program p; + auto* mm = p.get_main_module(); auto seq = - p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); + mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}}); auto w = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}}); auto r = - p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); + mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}}); auto bias = - p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); + mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}}); auto seq_len = - p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); - auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); - - auto out_hs = - p.add_instruction(migraphx::op::gru{hs, - {migraphx::op::relu{}, migraphx::op::relu{}}, - migraphx::op::rnn_direction::reverse, - clip}, - seq, - w, - r, - bias, - seq_len, - ih); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx"); + mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); + auto ih = + mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("relu"), + migraphx::make_op("relu")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r, + bias, + seq_len, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_gru_reverse_1.onnx"); EXPECT(p == prog); } @@ -635,22 +749,27 @@ TEST_CASE(lstm_forward) migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}}; { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto ih = p.add_parameter("h0", ih_shape); - auto ic = p.add_parameter("c0", ih_shape); - auto pph = p.add_parameter("pph", pph_shape); - - auto out_hs = p.add_instruction( - migraphx::op::lstm{ - hs, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - input_forget}, + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto ic = mm->add_parameter("c0", ih_shape); + auto pph = mm->add_parameter("pph", pph_shape); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", input_forget}}), seq, w, r, @@ -659,9 +778,8 @@ TEST_CASE(lstm_forward) ih, ic, pph); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_forward.onnx"); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_forward.onnx"); EXPECT(p == prog); } @@ -669,18 +787,23 @@ TEST_CASE(lstm_forward) // 3 args { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = p.add_instruction( - migraphx::op::lstm{ - hs, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - input_forget}, + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", input_forget}}), seq, w, r, @@ -689,9 +812,109 @@ TEST_CASE(lstm_forward) und, und, und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_f3args.onnx"); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_f3args.onnx"); + + EXPECT(p == prog); + } + + // 3 args, hs output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + und, + und, + und, + und, + und); + auto prog = optimize_onnx("onnx_lstm_hs.onnx"); + + EXPECT(p == prog); + } + + // 3 args, last output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + und, + und, + und, + und, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_last.onnx"); + + EXPECT(p == prog); + } + + // 3 args, cell output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + und, + und, + und, + und, + und); + mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_cell.onnx"); EXPECT(p == prog); } @@ -699,19 +922,24 @@ TEST_CASE(lstm_forward) // 4 args { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = p.add_instruction( - migraphx::op::lstm{ - hs, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - input_forget}, + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", input_forget}}), seq, w, r, @@ -720,9 +948,8 @@ TEST_CASE(lstm_forward) und, und, und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_f4args.onnx"); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_f4args.onnx"); EXPECT(p == prog); } @@ -730,20 +957,25 @@ TEST_CASE(lstm_forward) // 5 args { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = p.add_instruction( - migraphx::op::lstm{ - hs, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - input_forget}, + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", input_forget}}), seq, w, r, @@ -752,9 +984,9 @@ TEST_CASE(lstm_forward) und, und, und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_f5args.onnx"); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_f5args.onnx"); EXPECT(p == prog); } @@ -762,21 +994,26 @@ TEST_CASE(lstm_forward) // 6 args { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto ih = p.add_parameter("h0", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = p.add_instruction( - migraphx::op::lstm{ - hs, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - input_forget}, + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", input_forget}}), seq, w, r, @@ -785,9 +1022,9 @@ TEST_CASE(lstm_forward) ih, und, und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_f6args.onnx"); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_f6args.onnx"); EXPECT(p == prog); } @@ -795,22 +1032,27 @@ TEST_CASE(lstm_forward) // 7 args { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto ih = p.add_parameter("h0", ih_shape); - auto ic = p.add_parameter("c0", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = p.add_instruction( - migraphx::op::lstm{ - hs, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - input_forget}, + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto ic = mm->add_parameter("c0", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", input_forget}}), seq, w, r, @@ -819,9 +1061,9 @@ TEST_CASE(lstm_forward) ih, ic, und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_f7args.onnx"); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_f7args.onnx"); EXPECT(p == prog); } @@ -845,18 +1087,24 @@ TEST_CASE(lstm_forward_actv_func) // no activation function specified { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = p.add_instruction( - migraphx::op::lstm{ - hs, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip, - input_forget}, + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + // auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", input_forget}}), seq, w, r, @@ -865,9 +1113,8 @@ TEST_CASE(lstm_forward_actv_func) und, und, und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_f0af.onnx"); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_f0af.onnx"); EXPECT(p == prog); } @@ -875,19 +1122,25 @@ TEST_CASE(lstm_forward_actv_func) // 1 activation function specified { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = p.add_instruction( - migraphx::op::lstm{ - hs, - {migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::forward, - clip, - input_forget}, + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("sigmoid"), + migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", input_forget}}), seq, w, r, @@ -896,9 +1149,8 @@ TEST_CASE(lstm_forward_actv_func) und, und, und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_f1af.onnx"); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_f1af.onnx"); EXPECT(p == prog); } @@ -906,20 +1158,26 @@ TEST_CASE(lstm_forward_actv_func) // 2 activation function specified { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = p.add_instruction( - migraphx::op::lstm{ - hs, - {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::forward, - clip, - input_forget}, + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", input_forget}}), seq, w, r, @@ -928,9 +1186,9 @@ TEST_CASE(lstm_forward_actv_func) und, und, und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_f2af.onnx"); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_f2af.onnx"); EXPECT(p == prog); } @@ -954,22 +1212,27 @@ TEST_CASE(lstm_reverse) migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}}; { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto ih = p.add_parameter("h0", ih_shape); - auto ic = p.add_parameter("c0", ih_shape); - auto pph = p.add_parameter("pph", pph_shape); - - auto out_hs = p.add_instruction( - migraphx::op::lstm{ - hs, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip, - input_forget}, + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto ic = mm->add_parameter("c0", ih_shape); + auto pph = mm->add_parameter("pph", pph_shape); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", input_forget}}), seq, w, r, @@ -978,9 +1241,8 @@ TEST_CASE(lstm_reverse) ih, ic, pph); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_reverse.onnx"); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_reverse.onnx"); EXPECT(p == prog); } @@ -988,20 +1250,25 @@ TEST_CASE(lstm_reverse) // 5 args { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = p.add_instruction( - migraphx::op::lstm{ - hs, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip, - input_forget}, + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", input_forget}}), seq, w, r, @@ -1010,9 +1277,9 @@ TEST_CASE(lstm_reverse) und, und, und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_r5args.onnx"); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_r5args.onnx"); EXPECT(p == prog); } @@ -1020,18 +1287,23 @@ TEST_CASE(lstm_reverse) // no activation function specified { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = p.add_instruction( - migraphx::op::lstm{ - hs, - {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, - migraphx::op::rnn_direction::reverse, - clip, - input_forget}, + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", input_forget}}), seq, w, r, @@ -1040,9 +1312,8 @@ TEST_CASE(lstm_reverse) und, und, und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_r0af.onnx"); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_r0af.onnx"); EXPECT(p == prog); } @@ -1066,37 +1337,40 @@ TEST_CASE(lstm_bidirectional) migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}}; { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto ih = p.add_parameter("h0", ih_shape); - auto ic = p.add_parameter("c0", ih_shape); - auto pph = p.add_parameter("pph", pph_shape); - - auto out_hs = - p.add_instruction(migraphx::op::lstm{hs, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - input_forget}, - seq, - w, - r, - bias, - seq_len, - ih, - ic, - pph); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_bi.onnx"); + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto ic = mm->add_parameter("c0", ih_shape); + auto pph = mm->add_parameter("pph", pph_shape); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + bias, + seq_len, + ih, + ic, + pph); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_bi.onnx"); EXPECT(p == prog); } @@ -1104,33 +1378,36 @@ TEST_CASE(lstm_bidirectional) // 3 args { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = - p.add_instruction(migraphx::op::lstm{hs, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - input_forget}, - seq, - w, - r, - und, - und, - und, - und, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_bi3args.onnx"); + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + und, + und, + und, + und, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_bi3args.onnx"); EXPECT(p == prog); } @@ -1138,34 +1415,37 @@ TEST_CASE(lstm_bidirectional) // 4 args { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = - p.add_instruction(migraphx::op::lstm{hs, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - input_forget}, - seq, - w, - r, - bias, - und, - und, - und, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_bi4args.onnx"); + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + bias, + und, + und, + und, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_bi4args.onnx"); EXPECT(p == prog); } @@ -1173,35 +1453,38 @@ TEST_CASE(lstm_bidirectional) // 5 args { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = - p.add_instruction(migraphx::op::lstm{hs, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - input_forget}, - seq, - w, - r, - bias, - seq_len, - und, - und, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_bi5args.onnx"); + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + bias, + seq_len, + und, + und, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_bi5args.onnx"); EXPECT(p == prog); } @@ -1209,36 +1492,39 @@ TEST_CASE(lstm_bidirectional) // 6 args { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto ih = p.add_parameter("h0", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = - p.add_instruction(migraphx::op::lstm{hs, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - input_forget}, - seq, - w, - r, - bias, - seq_len, - ih, - und, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_bi6args.onnx"); + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + bias, + seq_len, + ih, + und, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_bi6args.onnx"); EXPECT(p == prog); } @@ -1246,37 +1532,40 @@ TEST_CASE(lstm_bidirectional) // 7 args { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto ih = p.add_parameter("h0", ih_shape); - auto ic = p.add_parameter("c0", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = - p.add_instruction(migraphx::op::lstm{hs, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - input_forget}, - seq, - w, - r, - bias, - seq_len, - ih, - ic, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_bi7args.onnx"); + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto ic = mm->add_parameter("c0", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + bias, + seq_len, + ih, + ic, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_bi7args.onnx"); EXPECT(p == prog); } @@ -1301,33 +1590,36 @@ TEST_CASE(lstm_bi_actv_funcs) // 0 activation function { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = - p.add_instruction(migraphx::op::lstm{hs, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - input_forget}, - seq, - w, - r, - und, - und, - und, - und, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_bi0af.onnx"); + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + und, + und, + und, + und, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_bi0af.onnx"); EXPECT(p == prog); } @@ -1335,34 +1627,38 @@ TEST_CASE(lstm_bi_actv_funcs) // 1 activation function { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = - p.add_instruction(migraphx::op::lstm{hs, - {migraphx::op::sigmoid{}, - migraphx::op::sigmoid{}, - migraphx::op::sigmoid{}, - migraphx::op::sigmoid{}, - migraphx::op::sigmoid{}, - migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - input_forget}, - seq, - w, - r, - bias, - und, - und, - und, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_bi1af.onnx"); + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("sigmoid"), + migraphx::make_op("sigmoid"), + migraphx::make_op("sigmoid"), + migraphx::make_op("sigmoid"), + migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + bias, + und, + und, + und, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_bi1af.onnx"); EXPECT(p == prog); } @@ -1370,35 +1666,38 @@ TEST_CASE(lstm_bi_actv_funcs) // 2 activation functions { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = - p.add_instruction(migraphx::op::lstm{hs, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - input_forget}, - seq, - w, - r, - bias, - seq_len, - und, - und, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_bi2af.onnx"); + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + bias, + seq_len, + und, + und, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_bi2af.onnx"); EXPECT(p == prog); } @@ -1406,36 +1705,39 @@ TEST_CASE(lstm_bi_actv_funcs) // 4 activation functions { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto ih = p.add_parameter("h0", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = - p.add_instruction(migraphx::op::lstm{hs, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - input_forget}, - seq, - w, - r, - bias, - seq_len, - ih, - und, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_bi4af.onnx"); + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + bias, + seq_len, + ih, + und, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_bi4af.onnx"); EXPECT(p == prog); } @@ -1443,37 +1745,41 @@ TEST_CASE(lstm_bi_actv_funcs) // 5 activation functions { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto bias = p.add_parameter("bias", bias_shape); - auto seq_len = p.add_parameter("seq_len", sl_shape); - auto ih = p.add_parameter("h0", ih_shape); - auto ic = p.add_parameter("c0", ih_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = - p.add_instruction(migraphx::op::lstm{hs, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::sigmoid{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - input_forget}, - seq, - w, - r, - bias, - seq_len, - ih, - ic, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_bi5af.onnx"); + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto ic = mm->add_parameter("c0", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + bias, + seq_len, + ih, + ic, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_bi5af.onnx"); EXPECT(p == prog); } @@ -1481,33 +1787,36 @@ TEST_CASE(lstm_bi_actv_funcs) // 6 activation functions { migraphx::program p; - auto seq = p.add_parameter("seq", seq_shape); - auto w = p.add_parameter("w", w_shape); - auto r = p.add_parameter("r", r_shape); - auto und = p.add_instruction(migraphx::op::undefined{}); - - auto out_hs = - p.add_instruction(migraphx::op::lstm{hs, - {migraphx::op::sigmoid{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::tanh{}, - migraphx::op::sigmoid{}, - migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip, - input_forget}, - seq, - w, - r, - und, - und, - und, - und, - und); - p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); - p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); - auto prog = migraphx::parse_onnx("onnx_lstm_bi6af.onnx"); + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + und, + und, + und, + und, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto prog = optimize_onnx("onnx_lstm_bi6af.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index 8060278d145347161ee1e9cbc73ca4479216ef53..dac235f06d21f4dce732a1448d4b7cadd8c9ea69 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -1,20 +1,105 @@ #include +#include #include +#include +#include +#include #include -#include #include #include #include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + #include "test.hpp" +migraphx::program optimize_onnx(const std::string& name, bool run_passes = false) +{ + migraphx::onnx_options options; + options.skip_unknown_operators = true; + auto prog = migraphx::parse_onnx(name, options); + auto* mm = prog.get_main_module(); + if(run_passes) + migraphx::run_passes(*mm, + {migraphx::rewrite_quantization{}, migraphx::dead_code_elimination{}}); + + // remove the last identity instruction + auto last_ins = std::prev(mm->end()); + if(last_ins->name() == "@return") + { + mm->remove_instruction(last_ins); + } + + return prog; +} + +void add_celu_instruction(migraphx::module* mm, const migraphx::shape& s, float alpha) +{ + auto x = mm->add_parameter("x", s); + const auto& input_lens = s.lens(); + const auto& input_type = s.type(); + auto zero_lit = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}})); + auto one_lit = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}})); + auto alpha_lit = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}})); + auto linear_part = mm->add_instruction(migraphx::make_op("max"), zero_lit, x); + auto divi = mm->add_instruction(migraphx::make_op("div"), x, alpha_lit); + auto expo = mm->add_instruction(migraphx::make_op("exp"), divi); + auto sub = mm->add_instruction(migraphx::make_op("sub"), expo, one_lit); + auto mul = mm->add_instruction(migraphx::make_op("mul"), alpha_lit, sub); + auto exp_part = mm->add_instruction(migraphx::make_op("min"), zero_lit, mul); + mm->add_instruction(migraphx::make_op("add"), linear_part, exp_part); +} + +static std::vector make_r_eyelike(size_t num_rows, size_t num_cols, size_t k) +{ + std::vector eyelike_mat(num_rows * num_cols, 0); + for(size_t i = 0; i < num_rows; ++i) + { + if(i + k < num_cols) + eyelike_mat[(num_cols + 1) * i + k] = 1.; + } + return eyelike_mat; +} + TEST_CASE(acos_test) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); - p.add_instruction(migraphx::op::acos{}, input); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); + mm->add_instruction(migraphx::make_op("acos"), input); + + auto prog = optimize_onnx("acos_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(acosh_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); + mm->add_instruction(migraphx::make_op("acosh"), input); - auto prog = migraphx::parse_onnx("acos_test.onnx"); + auto prog = optimize_onnx("acosh_test.onnx"); EXPECT(p == prog); } @@ -22,12 +107,14 @@ TEST_CASE(acos_test) TEST_CASE(add_bcast_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); - auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1); - p.add_instruction(migraphx::op::add{}, l0, l2); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); + auto l2 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", l0->get_shape().lens()}}), l1); + mm->add_instruction(migraphx::make_op("add"), l0, l2); - auto prog = migraphx::parse_onnx("add_bcast_test.onnx"); + auto prog = optimize_onnx("add_bcast_test.onnx"); EXPECT(p == prog); } @@ -35,12 +122,13 @@ TEST_CASE(add_bcast_test) TEST_CASE(add_fp16_test) { migraphx::program p; + auto* mm = p.get_main_module(); auto l0 = - p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {1.5}}); + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {1.5}}); auto l1 = - p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {2.5}}); - p.add_instruction(migraphx::op::add{}, l0, l1); - auto prog = migraphx::parse_onnx("add_fp16_test.onnx"); + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type, {1}}, {2.5}}); + mm->add_instruction(migraphx::make_op("add"), l0, l1); + auto prog = optimize_onnx("add_fp16_test.onnx"); EXPECT(p == prog); } @@ -48,11 +136,13 @@ TEST_CASE(add_fp16_test) TEST_CASE(add_scalar_test) { migraphx::program p; - auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}}); - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); - auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0); - auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); - p.add_instruction(migraphx::op::add{}, m0, m1); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::uint8_type, {2, 3, 4, 5}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::uint8_type}); + auto m1 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1); + auto r = mm->add_instruction(migraphx::make_op("add"), l0, m1); + mm->add_return({r}); auto prog = migraphx::parse_onnx("add_scalar_test.onnx"); EXPECT(p == prog); @@ -61,10 +151,11 @@ TEST_CASE(add_scalar_test) TEST_CASE(argmax_test) { migraphx::program p; - auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); - auto ins = p.add_instruction(migraphx::op::argmax{2}, l0); - p.add_instruction(migraphx::op::squeeze{{2}}, ins); - auto prog = migraphx::parse_onnx("argmax_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins); + auto prog = optimize_onnx("argmax_test.onnx"); EXPECT(p == prog); } @@ -72,10 +163,11 @@ TEST_CASE(argmax_test) TEST_CASE(argmin_test) { migraphx::program p; - auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); - auto ins = p.add_instruction(migraphx::op::argmin{3}, l0); - p.add_instruction(migraphx::op::squeeze{{3}}, ins); - auto prog = migraphx::parse_onnx("argmin_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto ins = mm->add_instruction(migraphx::make_op("argmin", {{"axis", 3}}), l0); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {3}}}), ins); + auto prog = optimize_onnx("argmin_test.onnx"); EXPECT(p == prog); } @@ -83,801 +175,3653 @@ TEST_CASE(argmin_test) TEST_CASE(asin_test) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); - p.add_instruction(migraphx::op::asin{}, input); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); + mm->add_instruction(migraphx::make_op("asin"), input); - auto prog = migraphx::parse_onnx("asin_test.onnx"); + auto prog = optimize_onnx("asin_test.onnx"); EXPECT(p == prog); } -TEST_CASE(atan_test) +TEST_CASE(asinh_test) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); - p.add_instruction(migraphx::op::atan{}, input); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); + mm->add_instruction(migraphx::make_op("asinh"), input); - auto prog = migraphx::parse_onnx("atan_test.onnx"); + auto prog = optimize_onnx("asinh_test.onnx"); EXPECT(p == prog); } -TEST_CASE(cast_test) +TEST_CASE(atan_test) { migraphx::program p; - auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {10}}); - p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, l); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); + mm->add_instruction(migraphx::make_op("atan"), input); + + auto prog = optimize_onnx("atan_test.onnx"); - auto prog = migraphx::parse_onnx("cast_test.onnx"); EXPECT(p == prog); } -TEST_CASE(ceil_test) +TEST_CASE(atanh_test) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); - p.add_instruction(migraphx::op::ceil{}, input); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); + mm->add_instruction(migraphx::make_op("atanh"), input); - auto prog = migraphx::parse_onnx("ceil_test.onnx"); + auto prog = optimize_onnx("atanh_test.onnx"); EXPECT(p == prog); } -TEST_CASE(clip_test) +TEST_CASE(averagepool_1d_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); - p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0); - auto prog = migraphx::parse_onnx("clip_test.onnx"); - + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5}}); + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {0, 0}}, + {"stride", {1}}, + {"lengths", {3}}}), + l0); + + auto prog = optimize_onnx("averagepool_1d_test.onnx"); EXPECT(p == prog); } -TEST_CASE(concat_test) +TEST_CASE(averagepool_3d_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4, 3}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7, 4, 3}}); - p.add_instruction(migraphx::op::concat{0}, l0, l1); - auto prog = migraphx::parse_onnx("concat_test.onnx"); - + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5, 5, 5}}); + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {0, 0, 0, 0, 0, 0}}, + {"stride", {1, 1, 1}}, + {"lengths", {3, 3, 3}}}), + l0); + + auto prog = optimize_onnx("averagepool_3d_test.onnx"); EXPECT(p == prog); } -TEST_CASE(constant_test) +TEST_CASE(averagepool_notset_test) { migraphx::program p; - p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0, 1, 2}}); - auto prog = migraphx::parse_onnx("constant_test.onnx"); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); + auto ins = mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {2, 2, 2, 2}}, + {"stride", {2, 2}}, + {"lengths", {6, 6}}}), + input); + auto ret = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {2, 2}}}), ins); + mm->add_return({ret}); + auto prog = migraphx::parse_onnx("averagepool_notset_test.onnx"); EXPECT(p == prog); } -TEST_CASE(constant_fill_test) +TEST_CASE(averagepool_nt_cip_test) { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - std::vector value(s.elements(), 1.0); - p.add_literal(migraphx::literal{s, value}); - auto prog = migraphx::parse_onnx("constant_fill_test.onnx"); - + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); + std::vector pads = {0, 0, 0, 0, 0, 0, 1, 1}; + auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input); + auto ret = mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {0, 0, 0, 0}}, + {"stride", {2, 2}}, + {"lengths", {6, 6}}}), + ins_pad); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("averagepool_nt_cip_test.onnx"); EXPECT(p == prog); } -TEST_CASE(constant_fill_input_as_shape_test) +TEST_CASE(averagepool_same_lower_test) { migraphx::program p; - auto l0 = p.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2}}, {2, 3}}); - std::vector dims(l0->get_shape().elements()); - migraphx::literal ls = l0->get_literal(); - ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); }); - migraphx::shape s{migraphx::shape::float_type, dims}; - std::vector value(s.elements(), 1.0); - p.add_literal(migraphx::literal{s, value}); - auto prog = migraphx::parse_onnx("constant_fill_input_as_shape_test.onnx"); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); + auto ins = mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {1, 1, 1, 1}}, + {"stride", {1, 1}}, + {"lengths", {2, 2}}}), + input); + auto ret = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {0, 0}}, {"ends", {5, 5}}}), ins); + mm->add_return({ret}); + auto prog = migraphx::parse_onnx("averagepool_same_lower_test.onnx"); EXPECT(p == prog); } -TEST_CASE(constant_scalar_test) +TEST_CASE(averagepool_sl_cip_test) { migraphx::program p; - p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1}}, {1}}); - auto prog = migraphx::parse_onnx("constant_scalar_test.onnx"); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); + std::vector pads = {0, 0, 1, 1, 0, 0, 0, 0}; + auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input); + auto ret = mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"lengths", {2, 2}}}), + ins_pad); + mm->add_return({ret}); + auto prog = migraphx::parse_onnx("averagepool_sl_cip_test.onnx"); EXPECT(p == prog); } -TEST_CASE(const_of_shape_empty_input_test) +TEST_CASE(averagepool_same_upper_test) { migraphx::program p; - p.add_literal(migraphx::literal()); - migraphx::shape s(migraphx::shape::int64_type, {1}, {0}); - std::vector vec(s.elements(), 10); - p.add_literal(migraphx::literal(s, vec)); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); + auto ins = mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {1, 1, 1, 1}}, + {"stride", {1, 1}}, + {"lengths", {2, 2}}}), + input); + auto ret = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {6, 6}}}), ins); + mm->add_return({ret}); + auto prog = migraphx::parse_onnx("averagepool_same_upper_test.onnx"); - auto prog = migraphx::parse_onnx("const_of_shape_empty_input_test.onnx"); EXPECT(p == prog); } -TEST_CASE(const_of_shape_float_test) +TEST_CASE(batchnorm_1d_test) { migraphx::program p; - migraphx::shape ss(migraphx::shape::int32_type, {3}); - p.add_literal(migraphx::literal(ss, {2, 3, 4})); - migraphx::shape s(migraphx::shape::float_type, {2, 3, 4}); - std::vector vec(s.elements(), 10.0f); - p.add_literal(migraphx::literal(s, vec)); - - auto prog = migraphx::parse_onnx("const_of_shape_float_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {3}}); + auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {3}}); + auto l3 = mm->add_parameter("3", {migraphx::shape::float_type, {3}}); + auto l4 = mm->add_parameter("4", {migraphx::shape::float_type, {3}}); + mm->add_instruction(migraphx::make_op("batch_norm_inference"), l0, l1, l2, l3, l4); + + auto prog = optimize_onnx("batchnorm_1d_test.onnx"); EXPECT(p == prog); } -TEST_CASE(const_of_shape_int64_test) +TEST_CASE(batchnorm_3d_test) { migraphx::program p; - migraphx::shape ss(migraphx::shape::int32_type, {3}); - p.add_literal(migraphx::literal(ss, {2, 3, 4})); - migraphx::shape s(migraphx::shape::int64_type, {2, 3, 4}); - std::vector vec(s.elements(), 10); - p.add_literal(migraphx::literal(s, vec)); - - auto prog = migraphx::parse_onnx("const_of_shape_int64_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5, 5, 5}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {3}}); + auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {3}}); + auto l3 = mm->add_parameter("3", {migraphx::shape::float_type, {3}}); + auto l4 = mm->add_parameter("4", {migraphx::shape::float_type, {3}}); + mm->add_instruction(migraphx::make_op("batch_norm_inference"), l0, l1, l2, l3, l4); + + auto prog = optimize_onnx("batchnorm_3d_test.onnx"); EXPECT(p == prog); } -TEST_CASE(const_of_shape_no_value_attr_test) +TEST_CASE(cast_test) { migraphx::program p; - migraphx::shape ss(migraphx::shape::int32_type, {3}); - p.add_literal(migraphx::literal(ss, {2, 3, 4})); - migraphx::shape s(migraphx::shape::float_type, {2, 3, 4}); - std::vector vec(s.elements(), 0.0f); - p.add_literal(migraphx::literal(s, vec)); - - auto prog = migraphx::parse_onnx("const_of_shape_no_value_attr_test.onnx"); + auto* mm = p.get_main_module(); + auto l = mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {10}}); + mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), + l); + + auto prog = optimize_onnx("cast_test.onnx"); EXPECT(p == prog); } -TEST_CASE(conv_autopad_fail_test) +TEST_CASE(ceil_test) { - EXPECT(test::throws([&] { migraphx::parse_onnx("conv_autopad_fail_test.onnx"); })); + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); + mm->add_instruction(migraphx::make_op("ceil"), input); + + auto prog = optimize_onnx("ceil_test.onnx"); + + EXPECT(p == prog); } -TEST_CASE(conv_bias_test) +TEST_CASE(celu_alpha_test) { migraphx::program p; - auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}}); - auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}}); - auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}}); - uint64_t axis = 1; - auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); - auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2); - p.add_instruction(migraphx::op::add{}, l3, l4); - - auto prog = migraphx::parse_onnx("conv_bias_test.onnx"); + auto* mm = p.get_main_module(); + std::vector input_lens = {3}; + auto input_type = migraphx::shape::float_type; + migraphx::shape s{input_type, input_lens}; + float alpha = 0.8; + add_celu_instruction(mm, s, alpha); + auto prog = optimize_onnx("celu_alpha_test.onnx"); EXPECT(p == prog); } -TEST_CASE(conv_bn_relu_maxpool_test) +TEST_CASE(celu_default_test) { migraphx::program p; - auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}}); - auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}}); - auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}}); + auto* mm = p.get_main_module(); + std::vector input_lens = {2, 3}; + auto input_type = migraphx::shape::float_type; + migraphx::shape s{input_type, input_lens}; + float alpha = 1.0; + add_celu_instruction(mm, s, alpha); + auto prog = optimize_onnx("celu_default_test.onnx"); + EXPECT(p == prog); +} - auto p3 = p.add_parameter("3", {migraphx::shape::float_type, {1}}); - auto p4 = p.add_parameter("4", {migraphx::shape::float_type, {1}}); - auto p5 = p.add_parameter("5", {migraphx::shape::float_type, {1}}); - auto p6 = p.add_parameter("6", {migraphx::shape::float_type, {1}}); - uint64_t axis = 1; - auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); - auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2); - auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4); - auto l6 = p.add_instruction(migraphx::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6); - auto l7 = p.add_instruction(migraphx::op::relu{}, l6); - p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7); +TEST_CASE(celu_wrong_type_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("celu_wrong_type_test.onnx"); })); +} - auto prog = migraphx::parse_onnx("conv_bn_relu_maxpool_test.onnx"); - EXPECT(p == prog); +TEST_CASE(celu_zero_alpha_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("celu_zero_alpha_test.onnx"); })); } -TEST_CASE(conv_relu_maxpool_test) +TEST_CASE(clip_test) { migraphx::program p; - auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}}); - auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}}); - auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}}); - uint64_t axis = 1; - auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); - auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2); - auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4); - auto l6 = p.add_instruction(migraphx::op::relu{}, l5); - p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); + auto min_val = mm->add_literal(0.0f); + auto max_val = mm->add_literal(6.0f); + min_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), min_val); + max_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), max_val); + mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); + auto prog = optimize_onnx("clip_test.onnx"); - auto prog = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); EXPECT(p == prog); } -TEST_CASE(conv_relu_maxpool_x2_test) +TEST_CASE(clip_test_op11_max_only) { migraphx::program p; - auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}}); - auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {5, 3, 5, 5}}); - auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {5}}); - uint64_t axis = 1; - auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); - auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2); - auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4); - auto l6 = p.add_instruction(migraphx::op::relu{}, l5); - auto l7 = p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); - - auto l8 = p.add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}}); - auto l9 = p.add_parameter("4", {migraphx::shape::float_type, {1}}); - auto l10 = p.add_instruction(migraphx::op::convolution{}, l7, l8); - auto l11 = p.add_instruction(migraphx::op::broadcast{axis, l10->get_shape().lens()}, l9); - auto l12 = p.add_instruction(migraphx::op::add{}, l10, l11); - auto l13 = p.add_instruction(migraphx::op::relu{}, l12); - p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13); + auto* mm = p.get_main_module(); + auto max_val = mm->add_literal(0.0f); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); + mm->add_instruction(migraphx::make_op("undefined")); + max_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), max_val); + auto r = mm->add_instruction(migraphx::make_op("min"), l0, max_val); + mm->add_return({r}); - auto prog = migraphx::parse_onnx("conv_relu_maxpool_x2_test.onnx"); + auto prog = migraphx::parse_onnx("clip_test_op11_max_only.onnx"); EXPECT(p == prog); } -TEST_CASE(cos_test) +TEST_CASE(clip_test_op11) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); - p.add_instruction(migraphx::op::cos{}, input); + auto* mm = p.get_main_module(); + auto min_val = mm->add_literal(0.0f); + auto max_val = mm->add_literal(6.0f); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); + min_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), min_val); + max_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), max_val); + mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); + auto prog = optimize_onnx("clip_test_op11.onnx"); - auto prog = migraphx::parse_onnx("cos_test.onnx"); EXPECT(p == prog); } -TEST_CASE(cosh_test) +TEST_CASE(clip_test_op11_min_only) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}}); - p.add_instruction(migraphx::op::cosh{}, input); - - auto prog = migraphx::parse_onnx("cosh_test.onnx"); + auto* mm = p.get_main_module(); + auto min_val = mm->add_literal(0.0f); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); + min_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), min_val); + mm->add_instruction(migraphx::make_op("max"), l0, min_val); + auto prog = optimize_onnx("clip_test_op11_min_only.onnx"); EXPECT(p == prog); } -TEST_CASE(dropout_test) +TEST_CASE(clip_test_op11_no_args) { migraphx::program p; - auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}}); - p.add_instruction(migraphx::op::identity{}, input); - - auto prog = migraphx::parse_onnx("dropout_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); + mm->add_instruction(migraphx::make_op("identity"), l0); + auto prog = optimize_onnx("clip_test_op11_no_args.onnx"); EXPECT(p == prog); } -TEST_CASE(elu_test) +TEST_CASE(clip_test_op11_no_args1) { migraphx::program p; - auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); - p.add_instruction(migraphx::op::elu{0.01}, input); + auto* mm = p.get_main_module(); - auto prog = migraphx::parse_onnx("elu_test.onnx"); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); + mm->add_instruction(migraphx::make_op("undefined")); + auto r = mm->add_instruction(migraphx::make_op("identity"), l0); + mm->add_return({r}); + auto prog = migraphx::parse_onnx("clip_test_op11_no_args1.onnx"); EXPECT(p == prog); } -TEST_CASE(erf_test) +TEST_CASE(clip_test_args_type_mismatch) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}}); - p.add_instruction(migraphx::op::erf{}, input); - - auto prog = migraphx::parse_onnx("erf_test.onnx"); + auto* mm = p.get_main_module(); + auto min_val = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1, 3}}, {1.5, 2.5, 3.5}}); + auto max_val = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {3, 1}}, {2, 3, 4}}); + + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 3}}); + min_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), min_val); + max_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), max_val); + max_val = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), max_val); + auto r = mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); + mm->add_return({r}); + auto prog = migraphx::parse_onnx("clip_test_args_type_mismatch.onnx"); EXPECT(p == prog); } -TEST_CASE(exp_test) +TEST_CASE(concat_test) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); - p.add_instruction(migraphx::op::exp{}, input); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4, 3}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7, 4, 3}}); + mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), l0, l1); + auto prog = optimize_onnx("concat_test.onnx"); - auto prog = migraphx::parse_onnx("exp_test.onnx"); EXPECT(p == prog); } -TEST_CASE(expand_test) +TEST_CASE(constant_test) { migraphx::program p; - migraphx::shape s(migraphx::shape::float_type, {3, 1, 1}); - auto param = p.add_parameter("x", s); - migraphx::shape ss(migraphx::shape::int32_type, {4}); - p.add_literal(migraphx::literal(ss, {2, 3, 4, 5})); - p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, param); + auto* mm = p.get_main_module(); + mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0, 1, 2}}); + auto prog = optimize_onnx("constant_test.onnx"); - auto prog = migraphx::parse_onnx("expand_test.onnx"); EXPECT(p == prog); } -TEST_CASE(flatten_test) +TEST_CASE(constant_fill_test) { + migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); - p.add_instruction(migraphx::op::flatten{2}, l0); - p.add_instruction(migraphx::op::flatten{1}, l0); - auto prog = migraphx::parse_onnx("flatten_test.onnx"); + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + std::vector value(s.elements(), 1.0); + mm->add_literal(migraphx::literal{s, value}); + auto prog = optimize_onnx("constant_fill_test.onnx"); EXPECT(p == prog); } -TEST_CASE(floor_test) +TEST_CASE(constant_fill_input_as_shape_test) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); - p.add_instruction(migraphx::op::floor{}, input); - - auto prog = migraphx::parse_onnx("floor_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {2}}, {2, 3}}); + std::vector dims(l0->get_shape().elements()); + migraphx::literal ls = l0->get_literal(); + ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); }); + migraphx::shape s{migraphx::shape::float_type, dims}; + std::vector value(s.elements(), 1.0); + mm->add_literal(migraphx::literal{s, value}); + auto prog = optimize_onnx("constant_fill_input_as_shape_test.onnx"); EXPECT(p == prog); } -TEST_CASE(gather_test) +TEST_CASE(constant_scalar_test) { migraphx::program p; - auto l0 = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); - auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}}); - int axis = 1; - p.add_instruction(migraphx::op::gather{axis}, l0, l1); - auto prog = migraphx::parse_onnx("gather_test.onnx"); + auto* mm = p.get_main_module(); + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {1}}, {1}}); + auto prog = optimize_onnx("constant_scalar_test.onnx"); EXPECT(p == prog); } -TEST_CASE(gemm_test) +TEST_CASE(const_of_shape_empty_input_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}}); - p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {}}); - auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0); - auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); - auto alpha = 2.f; - auto beta = 2.0f; - p.add_instruction(migraphx::op::dot{alpha, beta}, t0, t1); - auto prog = migraphx::parse_onnx("gemm_test.onnx"); + auto* mm = p.get_main_module(); + mm->add_literal(migraphx::literal()); + migraphx::shape s(migraphx::shape::int64_type, {1}, {0}); + std::vector vec(s.elements(), 10); + mm->add_literal(migraphx::literal(s, vec)); + auto prog = optimize_onnx("const_of_shape_empty_input_test.onnx"); EXPECT(p == prog); } -TEST_CASE(gemm_ex_test) +TEST_CASE(const_of_shape_float_test) { migraphx::program p; - auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}}); - auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}}); - auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}}); - auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0); - auto alpha = 0.5f; - auto beta = 0.8f; - p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, l2); - auto prog = migraphx::parse_onnx("gemm_ex_test.onnx"); + auto* mm = p.get_main_module(); + migraphx::shape ss(migraphx::shape::int32_type, {3}); + mm->add_literal(migraphx::literal(ss, {2, 3, 4})); + migraphx::shape s(migraphx::shape::float_type, {2, 3, 4}); + std::vector vec(s.elements(), 10.0f); + mm->add_literal(migraphx::literal(s, vec)); + auto prog = optimize_onnx("const_of_shape_float_test.onnx"); EXPECT(p == prog); } -TEST_CASE(gemm_ex_brcst_test) +TEST_CASE(const_of_shape_int64_test) { migraphx::program p; - auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}}); - auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}}); - auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 1}}); - auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0); - std::vector out_lens{1, 1, 6, 7}; - auto t2 = p.add_instruction(migraphx::op::multibroadcast{out_lens}, l2); - auto alpha = 0.5f; - auto beta = 0.8f; - p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, t2); - auto prog = migraphx::parse_onnx("gemm_ex_brcst_test.onnx"); + auto* mm = p.get_main_module(); + migraphx::shape ss(migraphx::shape::int32_type, {3}); + mm->add_literal(migraphx::literal(ss, {2, 3, 4})); + migraphx::shape s(migraphx::shape::int64_type, {2, 3, 4}); + std::vector vec(s.elements(), 10); + mm->add_literal(migraphx::literal(s, vec)); + auto prog = optimize_onnx("const_of_shape_int64_test.onnx"); EXPECT(p == prog); } -TEST_CASE(globalavgpool_test) +TEST_CASE(const_of_shape_no_value_attr_test) { migraphx::program p; - auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - auto op = migraphx::op::pooling{"average"}; - auto lens = input->get_shape().lens(); - op.lengths = {lens[2], lens[3]}; - p.add_instruction(op, input); - - auto prog = migraphx::parse_onnx("globalavgpool_test.onnx"); + auto* mm = p.get_main_module(); + migraphx::shape ss(migraphx::shape::int32_type, {3}); + mm->add_literal(migraphx::literal(ss, {2, 3, 4})); + migraphx::shape s(migraphx::shape::float_type, {2, 3, 4}); + std::vector vec(s.elements(), 0.0f); + mm->add_literal(migraphx::literal(s, vec)); + auto prog = optimize_onnx("const_of_shape_no_value_attr_test.onnx"); EXPECT(p == prog); } -TEST_CASE(globalmaxpool_test) +TEST_CASE(conv_autopad_fail_test) { - migraphx::program p; - auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - auto op = migraphx::op::pooling{"max"}; - auto lens = input->get_shape().lens(); - op.lengths = {lens[2], lens[3]}; - p.add_instruction(op, input); + EXPECT(test::throws([&] { optimize_onnx("conv_autopad_fail_test.onnx"); })); +} - auto prog = migraphx::parse_onnx("globalmaxpool_test.onnx"); +TEST_CASE(conv_1d_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 3}}); + mm->add_instruction( + migraphx::make_op("convolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}), + l0, + l1); + + auto prog = optimize_onnx("conv_1d_test.onnx"); + EXPECT(p == prog); +} +TEST_CASE(conv_3d_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5, 5, 5}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3, 3}}); + mm->add_instruction( + migraphx::make_op("convolution", + {{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}), + l0, + l1); + + auto prog = optimize_onnx("conv_3d_test.onnx"); EXPECT(p == prog); } -TEST_CASE(group_conv_test) +TEST_CASE(conv_attr_fail_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("conv_attr_fail_test.onnx"); })); +} + +TEST_CASE(conv_autopad_same_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 4, 16, 16}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 1, 3, 3}}); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3}}); migraphx::op::convolution op; - op.group = 4; - p.add_instruction(op, l0, l1); - auto prog = migraphx::parse_onnx("group_conv_test.onnx"); + op.padding = {1, 1, 1, 1}; + op.padding_mode = migraphx::op::padding_mode_t::same; + mm->add_instruction(op, l0, l1); + auto prog = optimize_onnx("conv_autopad_same_test.onnx"); EXPECT(p == prog); } -TEST_CASE(imagescaler_test) +TEST_CASE(conv_bias_test) { migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}}; - auto l0 = p.add_parameter("0", s); - auto scale_val = p.add_literal(0.5f); - auto bias_vals = p.add_literal( - migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); - auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val); - auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor); - auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals); - p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); - - auto prog = migraphx::parse_onnx("imagescaler_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}}); + auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {1}}); + uint64_t axis = 1; + auto l3 = mm->add_instruction(migraphx::make_op("convolution"), l0, l1); + auto l4 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2); + mm->add_instruction(migraphx::make_op("add"), l3, l4); + auto prog = optimize_onnx("conv_bias_test.onnx"); EXPECT(p == prog); } -TEST_CASE(implicit_add_bcast_test) +TEST_CASE(conv_bn_relu_maxpool_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}}); - auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0); - auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); - p.add_instruction(migraphx::op::add{}, l2, l3); - - auto prog = migraphx::parse_onnx("implicit_add_bcast_test.onnx"); - + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}}); + auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {1}}); + + auto p3 = mm->add_parameter("3", {migraphx::shape::float_type, {1}}); + auto p4 = mm->add_parameter("4", {migraphx::shape::float_type, {1}}); + auto p5 = mm->add_parameter("5", {migraphx::shape::float_type, {1}}); + auto p6 = mm->add_parameter("6", {migraphx::shape::float_type, {1}}); + uint64_t axis = 1; + auto l3 = + mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l0, l1); + auto l4 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2); + auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4); + auto l6 = mm->add_instruction( + migraphx::make_op("batch_norm_inference", {{"epsilon", 1.0e-5f}}), l5, p3, p4, p5, p6); + auto l7 = mm->add_instruction(migraphx::make_op("relu"), l6); + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::max}, + {"padding", {0, 0, 0, 0}}, + {"stride", {2, 2}}, + {"lengths", {2, 2}}}), + l7); + + auto prog = optimize_onnx("conv_bn_relu_maxpool_test.onnx"); EXPECT(p == prog); } -TEST_CASE(implicit_pow_bcast_test) +TEST_CASE(conv_relu_maxpool_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}}); - auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0); - auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); - p.add_instruction(migraphx::op::pow{}, l2, l3); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}}); + auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {1}}); + uint64_t axis = 1; + auto l3 = + mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l0, l1); + auto l4 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2); + auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4); + auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5); + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::max}, + {"padding", {0, 0, 0, 0}}, + {"stride", {2, 2}}, + {"lengths", {2, 2}}}), + l6); + + auto prog = optimize_onnx("conv_relu_maxpool_test.onnx"); + EXPECT(p == prog); +} - auto prog = migraphx::parse_onnx("implicit_pow_bcast_test.onnx"); +TEST_CASE(conv_relu_maxpool_x2_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5, 3, 5, 5}}); + auto l2 = mm->add_parameter("2", {migraphx::shape::float_type, {5}}); + uint64_t axis = 1; + auto l3 = + mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l0, l1); + auto l4 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2); + auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4); + auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5); + auto l7 = mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::max}, + {"padding", {0, 0, 0, 0}}, + {"stride", {2, 2}}, + {"lengths", {2, 2}}}), + l6); + + auto l8 = mm->add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}}); + auto l9 = mm->add_parameter("4", {migraphx::shape::float_type, {1}}); + auto l10 = + mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l7, l8); + auto l11 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l10->get_shape().lens()}}), + l9); + auto l12 = mm->add_instruction(migraphx::make_op("add"), l10, l11); + auto l13 = mm->add_instruction(migraphx::make_op("relu"), l12); + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::max}, + {"padding", {0, 0, 0, 0}}, + {"stride", {2, 2}}, + {"lengths", {2, 2}}}), + l13); + + auto prog = optimize_onnx("conv_relu_maxpool_x2_test.onnx"); EXPECT(p == prog); } -TEST_CASE(implicit_sub_bcast_test) +TEST_CASE(convinteger_bias_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}}); - auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0); - auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); - p.add_instruction(migraphx::op::sub{}, l2, l3); - - auto prog = migraphx::parse_onnx("implicit_sub_bcast_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, {1, 3, 32, 32}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::int8_type, {1, 3, 5, 5}}); + auto l2 = mm->add_parameter("2", {migraphx::shape::int32_type, {1}}); + uint64_t axis = 1; + auto l3 = mm->add_instruction(migraphx::make_op("quant_convolution"), l0, l1); + auto l4 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2); + mm->add_instruction(migraphx::make_op("add"), l3, l4); + auto prog = optimize_onnx("convinteger_bias_test.onnx"); EXPECT(p == prog); } -TEST_CASE(initializer_not_an_input) +TEST_CASE(cos_test) { migraphx::program p; - std::vector w = {1, 2, 3, 4, 5, 6, 7, 8}; - auto l1 = p.add_literal(migraphx::literal({migraphx::shape::float_type, {2, 4}}, w)); - auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}}); - p.add_instruction(migraphx::op::dot{}, l0, l1); - - auto prog = migraphx::parse_onnx("initializer_not_an_input.onnx"); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); + mm->add_instruction(migraphx::make_op("cos"), input); + auto prog = optimize_onnx("cos_test.onnx"); EXPECT(p == prog); } -TEST_CASE(leaky_relu_test) +TEST_CASE(cosh_test) { migraphx::program p; - float alpha = 0.01f; - auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {3}}); - p.add_instruction(migraphx::op::leaky_relu{alpha}, l0); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}}); + mm->add_instruction(migraphx::make_op("cosh"), input); - auto prog = migraphx::parse_onnx("leaky_relu_test.onnx"); + auto prog = optimize_onnx("cosh_test.onnx"); EXPECT(p == prog); } -TEST_CASE(log_test) +TEST_CASE(deconv_test) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); - p.add_instruction(migraphx::op::log{}, input); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}}); + auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 1, 3, 3}}); + mm->add_instruction(migraphx::make_op("deconvolution"), l0, l1); - auto prog = migraphx::parse_onnx("log_test.onnx"); + auto prog = optimize_onnx("deconv_test.onnx"); EXPECT(p == prog); } -TEST_CASE(logsoftmax_test) +TEST_CASE(deconv_bias_test) { migraphx::program p; - auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); - int axis = 1; - p.add_instruction(migraphx::op::logsoftmax{axis}, l0); - auto prog = migraphx::parse_onnx("logsoftmax_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}}); + auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 1, 3, 3}}); + auto l2 = mm->add_parameter("b", {migraphx::shape::float_type, {1}}); + uint64_t axis = 1; + auto l3 = mm->add_instruction(migraphx::make_op("deconvolution"), l0, l1); + auto l4 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2); + mm->add_instruction(migraphx::make_op("add"), l3, l4); + auto prog = optimize_onnx("deconv_bias_test.onnx"); EXPECT(p == prog); } -TEST_CASE(lrn_test) +TEST_CASE(deconv_input_pads_strides_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 28, 24, 24}}); - migraphx::op::lrn op; - op.size = 5; - op.alpha = 0.0001; - op.beta = 0.75; - op.bias = 1.0; - p.add_instruction(op, l0); - auto prog = migraphx::parse_onnx("lrn_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}}); + auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}}); + mm->add_instruction( + migraphx::make_op("deconvolution", {{"padding", {1, 1}}, {"stride", {3, 2}}}), l0, l1); + auto prog = optimize_onnx("deconv_input_pads_strides_test.onnx"); EXPECT(p == prog); } -TEST_CASE(matmul_bmbm_test) +TEST_CASE(deconv_input_pads_asymm_test) { migraphx::program p; - auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}}); - auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 2, 1, 7, 8}}); - auto bl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 6, 7}}, l0); - auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 7, 8}}, l1); - p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bl0, bl1); - - auto prog = migraphx::parse_onnx("matmul_bmbm_test.onnx"); - + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}}); + auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}}); + auto l2 = mm->add_instruction( + migraphx::make_op("deconvolution", {{"padding", {0, 0}}, {"stride", {3, 2}}}), l0, l1); + mm->add_instruction( + migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {0, 0}}, {"ends", {8, 6}}}), l2); + + auto prog = optimize_onnx("deconv_input_pads_asymm_test.onnx"); EXPECT(p == prog); } -TEST_CASE(matmul_bmv_test) +TEST_CASE(deconv_input_pads_asymm_1d_test) { migraphx::program p; - auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}}); - auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}}); - auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1); - auto bsl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 7, 1}}, sl1); - auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, bsl1); - p.add_instruction(migraphx::op::squeeze{{2}}, res); - - auto prog = migraphx::parse_onnx("matmul_bmv_test.onnx"); - + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3}}); + auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3}}); + auto l2 = mm->add_instruction( + migraphx::make_op("deconvolution", + {{"padding", {0, 0}}, {"stride", {2}}, {"dilation", {1}}}), + l0, + l1); + mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {6}}}), + l2); + + auto prog = optimize_onnx("deconv_input_pads_asymm_1d_test.onnx"); EXPECT(p == prog); } -TEST_CASE(matmul_mv_test) +TEST_CASE(deconv_output_padding_test) { migraphx::program p; - auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {6, 7}}); - auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}}); - auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1); - auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, sl1); - p.add_instruction(migraphx::op::squeeze{{1}}, res); - - auto prog = migraphx::parse_onnx("matmul_mv_test.onnx"); - + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}}); + auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}}); + auto l2 = mm->add_instruction( + migraphx::make_op("deconvolution", {{"padding", {0, 0}}, {"stride", {3, 2}}}), l0, l1); + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 1, 1}}}), l2); + + auto prog = optimize_onnx("deconv_output_padding_test.onnx"); EXPECT(p == prog); } -TEST_CASE(matmul_vbm_test) +TEST_CASE(deconv_output_padding_3d_test) { migraphx::program p; - auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}}); - auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}}); - auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0); - auto bsl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 1, 7}}, sl0); - std::cout << "ONNX_TEST" << std::endl; - auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bsl0, l1); - std::cout << "After Dot" << std::endl; - p.add_instruction(migraphx::op::squeeze{{1}}, res); - - auto prog = migraphx::parse_onnx("matmul_vbm_test.onnx"); - + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3, 3}}); + auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}}); + auto l2 = mm->add_instruction( + migraphx::make_op("deconvolution", + {{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}), + l0, + l1); + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}}), l2); + + auto prog = optimize_onnx("deconv_output_padding_3d_test.onnx"); EXPECT(p == prog); } -TEST_CASE(matmul_vm_test) +TEST_CASE(deconv_output_shape_test) { migraphx::program p; - auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}}); - auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7, 8}}); - auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0); - auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, l1); - p.add_instruction(migraphx::op::squeeze{{0}}, res); - - auto prog = migraphx::parse_onnx("matmul_vm_test.onnx"); - + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3}}); + auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3}}); + auto l2 = mm->add_instruction( + migraphx::make_op("deconvolution", {{"padding", {0, 0}}, {"stride", {3, 2}}}), l0, l1); + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 1, 1}}}), l2); + + auto prog = optimize_onnx("deconv_output_shape_test.onnx"); EXPECT(p == prog); } -TEST_CASE(matmul_vv_test) +TEST_CASE(deconv_output_shape_3d_test) { migraphx::program p; - auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}}); - auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}}); - auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0); - auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1); - auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, sl1); - auto sr0 = p.add_instruction(migraphx::op::squeeze{{0}}, res); - p.add_instruction(migraphx::op::squeeze{{0}}, sr0); - - auto prog = migraphx::parse_onnx("matmul_vv_test.onnx"); - + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 1, 3, 3, 3}}); + auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}}); + auto l2 = mm->add_instruction( + migraphx::make_op("deconvolution", + {{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}), + l0, + l1); + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}}), l2); + + auto prog = optimize_onnx("deconv_output_shape_3d_test.onnx"); EXPECT(p == prog); } -TEST_CASE(max_test) +TEST_CASE(depthtospace_test) { migraphx::program p; - auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); - auto input1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}}); - auto input2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}}); - auto l0 = p.add_instruction(migraphx::op::max{}, input0, input1); - p.add_instruction(migraphx::op::max{}, l0, input2); - - migraphx::parse_onnx("max_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}}); + auto tmp1 = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), l0); + auto tmp2 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), tmp1); + auto tmp3 = mm->add_instruction(migraphx::make_op("contiguous"), tmp2); + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), tmp3); + auto prog = optimize_onnx("depthtospace_test.onnx"); + EXPECT(p == prog); } -TEST_CASE(min_test) +TEST_CASE(depthtospace_crd_test) { migraphx::program p; - auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); - auto input1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}}); - auto input2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}}); - auto l0 = p.add_instruction(migraphx::op::min{}, input0, input1); - p.add_instruction(migraphx::op::min{}, l0, input2); - - migraphx::parse_onnx("min_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}}); + auto tmp1 = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), l0); + auto tmp2 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 4, 2, 5, 3}}}), tmp1); + auto tmp3 = mm->add_instruction(migraphx::make_op("contiguous"), tmp2); + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), tmp3); + auto prog = optimize_onnx("depthtospace_crd_test.onnx"); + EXPECT(p == prog); } -TEST_CASE(no_pad_test) +TEST_CASE(depthtospace_simple_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}}); - p.add_instruction(migraphx::op::identity{}, l0); - auto prog = migraphx::parse_onnx("no_pad_test.onnx"); - + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 2, 3}}); + auto tmp1 = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 2, 2, 2, 3}}}), l0); + auto tmp2 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), tmp1); + auto tmp3 = mm->add_instruction(migraphx::make_op("contiguous"), tmp2); + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 4, 6}}}), tmp3); + auto prog = optimize_onnx("depthtospace_simple_test.onnx"); EXPECT(p == prog); } -TEST_CASE(pad_test) +TEST_CASE(spacetodepth_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}}); - p.add_instruction(migraphx::op::pad{{1, 1, 1, 1}}, l0); - auto prog = migraphx::parse_onnx("pad_test.onnx"); - + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2, 10, 10}}); + auto tmp1 = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 5, 2, 5, 2}}}), l0); + auto tmp2 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 5, 1, 2, 4}}}), tmp1); + auto tmp3 = mm->add_instruction(migraphx::make_op("contiguous"), tmp2); + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 8, 5, 5}}}), tmp3); + auto prog = optimize_onnx("spacetodepth_test.onnx"); EXPECT(p == prog); } -TEST_CASE(pow_test) +TEST_CASE(spacetodepth_simple_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); - p.add_instruction(migraphx::op::pow{}, l0, l1); - - auto prog = migraphx::parse_onnx("pow_test.onnx"); - + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 2, 4, 6}}); + auto tmp1 = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 2, 2, 3, 2}}}), l0); + auto tmp2 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 5, 1, 2, 4}}}), tmp1); + auto tmp3 = mm->add_instruction(migraphx::make_op("contiguous"), tmp2); + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 8, 2, 3}}}), tmp3); + auto prog = optimize_onnx("spacetodepth_simple_test.onnx"); EXPECT(p == prog); } -TEST_CASE(reducemax_test) +TEST_CASE(spacetodepth_invalid_blocksize) { - migraphx::program p; - auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); - p.add_instruction(migraphx::op::reduce_max{{2}}, l0); - auto prog = migraphx::parse_onnx("reducemax_test.onnx"); + EXPECT(test::throws([&] { migraphx::parse_onnx("spacetodepth_invalid_blocksize_test.onnx"); })); +} - EXPECT(p == prog); +TEST_CASE(spacetodepth_nondivisibility_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("spacetodepth_nondivisibility_test.onnx"); })); } -TEST_CASE(reducemean_test) +TEST_CASE(dequantizelinear_test) { migraphx::program p; - auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); - auto l1 = p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0); - p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1); - auto prog = migraphx::parse_onnx("reducemean_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, {5}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}}); + auto l1_mbcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1); + auto dequant = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), + l0); + mm->add_instruction(migraphx::make_op("mul"), dequant, l1_mbcast); + + auto prog = optimize_onnx("dequantizelinear_test.onnx", true); + EXPECT(p.sort() == prog.sort()); +} - EXPECT(p == prog); +TEST_CASE(dequantizelinear_zero_point_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, {5}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}}); + auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}}); + auto l1_mbcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1); + auto l2_mbcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l2); + l2_mbcast = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), + l2_mbcast); + l0 = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), + l0); + + auto sub = mm->add_instruction(migraphx::make_op("sub"), l0, l2_mbcast); + mm->add_instruction(migraphx::make_op("mul"), sub, l1_mbcast); + + auto prog = optimize_onnx("dequantizelinear_zero_point_test.onnx", true); + EXPECT(p.sort() == prog.sort()); } -TEST_CASE(reducemean_keepdims_test) +migraphx::program make_dequantizelinear_axis_prog() { migraphx::program p; - auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); - p.add_instruction(migraphx::op::reduce_mean{{2}}, l0); - auto prog = migraphx::parse_onnx("reducemean_keepdims_test.onnx"); + std::vector input_lens{1, 1, 5, 1}; + int axis = 2; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, input_lens}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5}}); + auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {5}}); + auto l1_bcast = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l1); + auto l2_bcast = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l2); + l2_bcast = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), + l2_bcast); + l0 = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), + l0); + auto sub = mm->add_instruction(migraphx::make_op("sub"), l0, l2_bcast); + + mm->add_instruction(migraphx::make_op("mul"), sub, l1_bcast); + return p; +} - EXPECT(p == prog); +TEST_CASE(dequantizelinear_axis_test) +{ + migraphx::program p = make_dequantizelinear_axis_prog(); + + auto prog = optimize_onnx("dequantizelinear_axis_test.onnx", true); + EXPECT(p.sort() == prog.sort()); } -TEST_CASE(reducemin_test) +TEST_CASE(dequantizelinear_neg_axis_test) { - migraphx::program p; - auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); - auto l1 = p.add_instruction(migraphx::op::reduce_min{{2, 3}}, l0); - p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1); - auto prog = migraphx::parse_onnx("reducemin_test.onnx"); + migraphx::program p = make_dequantizelinear_axis_prog(); - EXPECT(p == prog); + auto prog = optimize_onnx("dequantizelinear_neg_axis_test.onnx", true); + EXPECT(p.sort() == prog.sort()); } -TEST_CASE(reducesum_test) +TEST_CASE(dropout_test) { migraphx::program p; - auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); - auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2}}, l0); - p.add_instruction(migraphx::op::squeeze{{2}}, l1); - auto prog = migraphx::parse_onnx("reducesum_test.onnx"); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}}); + auto out = mm->add_instruction(migraphx::make_op("identity"), input); + migraphx::shape s{migraphx::shape::bool_type, {1, 3, 2, 2}}; + std::vector vec(s.elements(), 1); + mm->add_literal(migraphx::literal(s, vec)); + mm->add_return({out}); + auto prog = migraphx::parse_onnx("dropout_test.onnx"); EXPECT(p == prog); } -TEST_CASE(reducesum_multiaxis_test) +TEST_CASE(elu_test) { migraphx::program p; - auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); - auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0); - p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1); - auto prog = migraphx::parse_onnx("reducesum_multiaxis_test.onnx"); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); + mm->add_instruction(migraphx::make_op("elu", {{"alpha", 0.01}}), input); + + auto prog = optimize_onnx("elu_test.onnx"); EXPECT(p == prog); } -TEST_CASE(reducesum_keepdims_test) +TEST_CASE(embedding_bag_test) { migraphx::program p; - auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); - p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0); - auto prog = migraphx::parse_onnx("reducesum_keepdims_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("weight", migraphx::shape{migraphx::shape::float_type, {4, 2}}); + migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {3}}, {1, 0, 2}}; + auto l1 = mm->add_literal(l); + mm->add_literal(0); + auto l4 = mm->add_instruction(migraphx::make_op("gather"), l0, l1); + auto r1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), l4); + auto l5 = mm->add_instruction(migraphx::make_op("gather"), l0, l1); + auto r2 = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0}}}), l5); + auto l6 = mm->add_instruction(migraphx::make_op("gather"), l0, l1); + auto r3 = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), l6); + mm->add_return({r1, r2, r3}); + + auto prog = migraphx::parse_onnx("embedding_bag_test.onnx"); EXPECT(p == prog); } -TEST_CASE(reshape_test) +TEST_CASE(embedding_bag_offset_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("embedding_bag_offset_test.onnx"); })); +} + +TEST_CASE(equal_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + auto input1 = mm->add_literal(migraphx::literal(s, data)); + auto input2 = mm->add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 3}}); + auto eq = mm->add_instruction(migraphx::make_op("equal"), input1, input2); + auto ret = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + eq); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("equal_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(equal_bool_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sf{migraphx::shape::float_type, {2, 3}}; + migraphx::shape sb{migraphx::shape::bool_type, {2, 3}}; + + auto input1 = mm->add_parameter("x1", sf); + auto input2 = mm->add_parameter("x2", sb); + auto cin1 = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + input1); + auto ret = mm->add_instruction(migraphx::make_op("equal"), cin1, input2); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("equal_bool_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(erf_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}}); + mm->add_instruction(migraphx::make_op("erf"), input); + + auto prog = optimize_onnx("erf_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(exp_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); + mm->add_instruction(migraphx::make_op("exp"), input); + + auto prog = optimize_onnx("exp_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(expand_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s(migraphx::shape::float_type, {3, 1, 1}); + auto param = mm->add_parameter("x", s); + migraphx::shape ss(migraphx::shape::int32_type, {4}); + mm->add_literal(migraphx::literal(ss, {2, 3, 4, 5})); + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), param); + + auto prog = optimize_onnx("expand_test.onnx"); + EXPECT(p == prog); +} + +migraphx::program create_external_data_prog() +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s(migraphx::shape::float_type, {1, 1, 224, 224}); + migraphx::shape s2(migraphx::shape::float_type, {10, 1, 11, 11}); + std::vector weight_data(1210, 1); + std::vector bias_data(10, 1); + auto bias = mm->add_literal(migraphx::literal({migraphx::shape::float_type, {10}}, bias_data)); + auto weights = mm->add_literal(migraphx::literal(s2, weight_data)); + auto param = mm->add_parameter("input", s); + auto conv = mm->add_instruction( + migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), param, weights); + auto bias_bcast = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 10, 214, 214}}}), bias); + mm->add_instruction(migraphx::make_op("add"), conv, bias_bcast); + return p; +} + +TEST_CASE(external_data_test) +{ + migraphx::program p = create_external_data_prog(); + + auto prog = optimize_onnx("external_data_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(external_data_diff_path_test) +{ + migraphx::program p = create_external_data_prog(); + + auto prog = optimize_onnx("ext_path/external_data_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(eyelike_default_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector input_lens{3, 4}; + const size_t k = 0; + auto num_rows = input_lens.front(); + auto num_cols = input_lens.back(); + auto input_type = migraphx::shape::float_type; + auto output_type = migraphx::shape::float_type; + migraphx::shape s{input_type, input_lens}; + mm->add_parameter("T1", s); + + auto eyelike_mat = make_r_eyelike(num_rows, num_cols, k); + mm->add_literal(migraphx::literal{migraphx::shape{output_type, input_lens}, eyelike_mat}); + + auto prog = optimize_onnx("eyelike_default_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(eyelike_double_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector input_lens{6, 15}; + const size_t k = 0; + auto num_rows = input_lens.front(); + auto num_cols = input_lens.back(); + auto input_type = migraphx::shape::double_type; + auto output_type = migraphx::shape::double_type; + migraphx::shape s{input_type, input_lens}; + mm->add_parameter("T1", s); + + auto eyelike_mat = make_r_eyelike(num_rows, num_cols, k); + mm->add_literal(migraphx::literal{migraphx::shape{output_type, input_lens}, eyelike_mat}); + + auto prog = optimize_onnx("eyelike_double_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(eyelike_half_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector input_lens{8, 8}; + const size_t k = 0; + auto num_rows = input_lens.front(); + auto num_cols = input_lens.back(); + auto input_type = migraphx::shape::half_type; + auto output_type = migraphx::shape::half_type; + migraphx::shape s{input_type, input_lens}; + mm->add_parameter("T1", s); + + auto eyelike_mat = make_r_eyelike(num_rows, num_cols, k); + mm->add_literal(migraphx::literal{migraphx::shape{output_type, input_lens}, eyelike_mat}); + + auto prog = optimize_onnx("eyelike_half_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(eyelike_k_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector input_lens{3, 4}; + const size_t k = 1; + auto num_rows = input_lens.front(); + auto num_cols = input_lens.back(); + auto input_type = migraphx::shape::float_type; + auto output_type = migraphx::shape::float_type; + migraphx::shape s{input_type, input_lens}; + mm->add_parameter("T1", s); + + auto eyelike_mat = make_r_eyelike(num_rows, num_cols, k); + mm->add_literal(migraphx::literal{migraphx::shape{output_type, input_lens}, eyelike_mat}); + + auto prog = optimize_onnx("eyelike_k_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(eyelike_k_outofbounds_neg_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("eyelike_k_outofbounds_neg_test.onnx"); })); +} + +TEST_CASE(eyelike_k_outofbounds_pos_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("eyelike_k_outofbounds_pos_test.onnx"); })); +} + +TEST_CASE(eyelike_not_rank2_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("eyelike_not_rank2_test.onnx"); })); +} + +TEST_CASE(eyelike_set_dtype_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector input_lens{3, 4}; + const size_t k = 0; + auto num_rows = input_lens.front(); + auto num_cols = input_lens.back(); + auto input_type = migraphx::shape::float_type; + auto output_type = migraphx::shape::double_type; + migraphx::shape s{input_type, input_lens}; + mm->add_parameter("T1", s); + + auto eyelike_mat = make_r_eyelike(num_rows, num_cols, k); + mm->add_literal(migraphx::literal{migraphx::shape{output_type, input_lens}, eyelike_mat}); + + auto prog = optimize_onnx("eyelike_set_dtype_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(flatten_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); + mm->add_instruction(migraphx::make_op("flatten", {{"axis", 2}}), l0); + mm->add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), l0); + auto prog = optimize_onnx("flatten_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(flatten_nonstd_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 5, 4}}); + auto l1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l0); + auto l2 = mm->add_instruction(migraphx::make_op("contiguous"), l1); + mm->add_instruction(migraphx::make_op("flatten", {{"axis", 2}}), l2); + auto l3 = mm->add_instruction(migraphx::make_op("contiguous"), l1); + mm->add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), l3); + auto prog = optimize_onnx("flatten_nonstd_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(floor_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); + mm->add_instruction(migraphx::make_op("floor"), input); + + auto prog = optimize_onnx("floor_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(gather_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}}); + int axis = 1; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l0, l1); + auto prog = optimize_onnx("gather_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(gather_elements_axis0_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto data = mm->add_parameter("data", {migraphx::shape::float_type, {3, 4}}); + auto indices = mm->add_parameter("indices", {migraphx::shape::int32_type, {2, 3}}); + std::vector ind_indices{0, 1, 2, 4, 5, 6}; + std::vector ind_axis_indices{0, 0, 0, 1, 1, 1}; + migraphx::shape ind_s{migraphx::shape::int32_type, {2, 3}}; + auto l_data_indices = + mm->add_literal(migraphx::literal{ind_s, ind_indices.begin(), ind_indices.end()}); + auto l_ind_axis_indices = + mm->add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()}); + auto l_stride = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {4}}); + + auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto lbst_stride = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride); + auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), indices, l_ind_axis_indices); + auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride); + auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta); + auto ret = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ind); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("gather_elements_axis0_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(gather_elements_axis1_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto data = mm->add_parameter("data", {migraphx::shape::float_type, {3, 4}}); + auto indices = mm->add_parameter("indices", {migraphx::shape::int32_type, {2, 3}}); + std::vector ind_indices{0, 1, 2, 4, 5, 6}; + std::vector ind_axis_indices{0, 1, 2, 0, 1, 2}; + migraphx::shape ind_s{migraphx::shape::int32_type, {2, 3}}; + auto l_data_indices = + mm->add_literal(migraphx::literal{ind_s, ind_indices.begin(), ind_indices.end()}); + auto l_ind_axis_indices = + mm->add_literal(migraphx::literal{ind_s, ind_axis_indices.begin(), ind_axis_indices.end()}); + auto l_stride = mm->add_literal(migraphx::literal{{migraphx::shape::int32_type, {1}}, {1}}); + + auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto lbst_stride = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride); + auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), indices, l_ind_axis_indices); + auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride); + auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta); + auto ret = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp_data, ind); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("gather_elements_axis1_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(gathernd_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 2}}); + mm->add_instruction(migraphx::make_op("gathernd"), l0, l1); + auto prog = optimize_onnx("gathernd_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(gathernd_batch_dims_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); + auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1}}); + int batch_dims = 1; + mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), l0, l1); + auto prog = optimize_onnx("gathernd_batch_dims_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(gemm_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}}); + auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type}); + auto alpha = 2.f; + auto beta = 2.0f; + auto a_l = mm->add_literal(alpha); + auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); + t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a); + auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); + auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f); + auto b_l = mm->add_literal(beta); + auto l2_b = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2); + auto b_b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l); + auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b); + mm->add_instruction(migraphx::make_op("add"), dot, l2_bb); + + auto prog = optimize_onnx("gemm_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(gemm_ex_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 6}}); + auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 7}}); + auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}}); + auto alpha = 0.5f; + auto beta = 0.8f; + auto a_l = mm->add_literal(alpha); + auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); + t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a); + auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); + auto b_l = mm->add_literal(beta); + auto b_b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l); + auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b); + mm->add_instruction(migraphx::make_op("add"), dot, l2_b); + + auto prog = optimize_onnx("gemm_ex_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(gemm_ex_brcst_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}}); + auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}}); + auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 1}}); + std::vector out_lens{1, 1, 6, 7}; + auto alpha = 0.5f; + auto beta = 0.8f; + auto a_l = mm->add_literal(alpha); + auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); + t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a); + auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); + auto b_l = mm->add_literal(beta); + auto l2_b = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), l2); + auto b_b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l); + auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b); + mm->add_instruction(migraphx::make_op("add"), dot, l2_bb); + + auto prog = optimize_onnx("gemm_ex_brcst_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(gemm_half_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::half_type, {1, 1, 8, 6}}); + auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::half_type, {1, 1, 8, 7}}); + auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::half_type, {1, 1, 6, 1}}); + auto alpha = 0.5f; + auto beta = 0.8f; + auto a_l = mm->add_literal(alpha); + auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); + t_a = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a); + t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a); + std::vector lens = {1, 1, 6, 7}; + auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); + l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2); + l2 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2); + auto b_l = mm->add_literal(beta); + auto b_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), b_l); + auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b); + l2_b = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l2_b); + mm->add_instruction(migraphx::make_op("add"), dot, l2_b); + + auto prog = optimize_onnx("gemm_half_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(globalavgpool_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average}; + auto lens = input->get_shape().lens(); + op.lengths = {lens[2], lens[3]}; + op.padding = {0, 0, 0, 0}; + mm->add_instruction(op, input); + + auto prog = optimize_onnx("globalavgpool_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(globallppool_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::lpnorm}; + auto lens = input->get_shape().lens(); + op.lengths = {lens[2], lens[3]}; + op.padding = {0, 0, 0, 0}; + mm->add_instruction(op, input); + + auto prog = optimize_onnx("globallppool_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(globalmaxpool_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max}; + auto lens = input->get_shape().lens(); + op.lengths = {lens[2], lens[3]}; + op.padding = {0, 0, 0, 0}; + mm->add_instruction(op, input); + + auto prog = optimize_onnx("globalmaxpool_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(greater_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + auto input1 = mm->add_literal(migraphx::literal(s, data)); + auto input2 = mm->add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 3}}); + auto gr = mm->add_instruction(migraphx::make_op("greater"), input1, input2); + auto ret = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + gr); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("greater_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(greater_bool_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sf{migraphx::shape::float_type, {2, 3}}; + migraphx::shape sb{migraphx::shape::bool_type, {2, 3}}; + + auto input1 = mm->add_parameter("x1", sf); + auto input2 = mm->add_parameter("x2", sb); + auto cin1 = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + input1); + auto ret = mm->add_instruction(migraphx::make_op("greater"), cin1, input2); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("greater_bool_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(greaterorequal_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto input1 = mm->add_parameter("x1", migraphx::shape{migraphx::shape::float_type, {3}}); + auto input2 = mm->add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {3}}); + auto temp = mm->add_instruction(migraphx::make_op("less"), input1, input2); + auto bt = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), temp); + auto ge = mm->add_instruction(migraphx::make_op("not"), bt); + + mm->add_return({ge}); + + auto prog = migraphx::parse_onnx("greaterorequal_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(group_conv_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 4, 16, 16}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 1, 3, 3}}); + migraphx::op::convolution op; + op.group = 4; + mm->add_instruction(op, l0, l1); + auto prog = optimize_onnx("group_conv_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(hardsigmoid_default_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector input_lens{1, 3, 4, 5}; + auto input_type = migraphx::shape::float_type; + migraphx::shape s{input_type, input_lens}; + auto x = mm->add_parameter("x", s); + + float alpha = 0.2; + float beta = 0.5; + + auto mb_alpha = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}})); + auto mb_beta = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {beta}})); + auto mb_zero = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {0}})); + auto mb_one = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {1}})); + + auto mul = mm->add_instruction(migraphx::make_op("mul"), mb_alpha, x); + auto add = mm->add_instruction(migraphx::make_op("add"), mb_beta, mul); + mm->add_instruction(migraphx::make_op("clip"), add, mb_zero, mb_one); + + auto prog = optimize_onnx("hardsigmoid_default_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(hardsigmoid_double_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector input_lens{1, 3, 4, 5}; + auto input_type = migraphx::shape::double_type; + migraphx::shape s{input_type, input_lens}; + auto x = mm->add_parameter("x", s); + + float alpha = 0.3; + float beta = 0.7; + + auto mb_alpha = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}})); + auto mb_beta = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {beta}})); + auto mb_zero = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {0}})); + auto mb_one = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {1}})); + + auto mul = mm->add_instruction(migraphx::make_op("mul"), mb_alpha, x); + auto add = mm->add_instruction(migraphx::make_op("add"), mb_beta, mul); + mm->add_instruction(migraphx::make_op("clip"), add, mb_zero, mb_one); + + auto prog = optimize_onnx("hardsigmoid_double_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(hardsigmoid_half_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector input_lens{1, 3, 4, 5}; + auto input_type = migraphx::shape::half_type; + migraphx::shape s{input_type, input_lens}; + auto x = mm->add_parameter("x", s); + + float alpha = 0.2; + float beta = 0.5; + + auto mb_alpha = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}})); + auto mb_beta = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {beta}})); + auto mb_zero = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {0}})); + auto mb_one = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {1}})); + + auto mul = mm->add_instruction(migraphx::make_op("mul"), mb_alpha, x); + auto add = mm->add_instruction(migraphx::make_op("add"), mb_beta, mul); + mm->add_instruction(migraphx::make_op("clip"), add, mb_zero, mb_one); + + auto prog = optimize_onnx("hardsigmoid_half_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(hardswish_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector input_lens{2, 5}; + auto input_type = migraphx::shape::float_type; + migraphx::shape s{input_type, input_lens}; + auto x = mm->add_parameter("x", s); + + float alpha = 1.0 / 6.0; + float beta = 0.5; + + auto mb_alpha = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}})); + auto mb_beta = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {beta}})); + auto mb_zero = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {0}})); + auto mb_one = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {1}})); + + auto mul = mm->add_instruction(migraphx::make_op("mul"), mb_alpha, x); + auto add = mm->add_instruction(migraphx::make_op("add"), mb_beta, mul); + auto hardsigmoid = mm->add_instruction(migraphx::make_op("clip"), add, mb_zero, mb_one); + mm->add_instruction(migraphx::make_op("mul"), x, hardsigmoid); + + auto prog = optimize_onnx("hardswish_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(if_else_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sc{migraphx::shape::bool_type, {1}}; + auto cond = mm->add_literal(migraphx::literal(sc, {0})); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + std::vector ones(s.elements(), 1.0f); + auto l1 = mm->add_literal(s, ones); + std::vector rand = {-0.583375, 0.633757, 0.0668345, -0.479422, -0.604634, 0.0388589}; + auto l2 = mm->add_literal(s, rand); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + + auto* then_mod = p.create_module("If_5_if"); + auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1); + then_mod->add_return({rt}); + + auto* else_mod = p.create_module("If_5_else"); + auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2); + else_mod->add_return({re}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + mm->add_return({r}); + + std::ifstream ifs("if_else_test.onnx", std::ios::binary); + ifs.seekg(0, std::ios::end); + auto length = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + std::vector onnx_buffer(length); + ifs.read(onnx_buffer.data(), length); + ifs.close(); + + auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {}); + EXPECT(p == prog); +} + +TEST_CASE(if_literal_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + auto cond = mm->add_parameter("cond", cond_s); + + migraphx::shape s{migraphx::shape::float_type, {5}}; + + auto* then_mod = p.create_module("If_1_if"); + std::vector data1 = {1, 2, 3, 4, 5}; + auto l1 = then_mod->add_literal(migraphx::literal(s, data1)); + then_mod->add_literal({}); + then_mod->add_return({l1}); + + auto* else_mod = p.create_module("If_1_else"); + std::vector data2 = {5, 4, 3, 2, 1}; + auto l2 = else_mod->add_literal(migraphx::literal(s, data2)); + else_mod->add_literal({}); + else_mod->add_return({l2}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("if_literal_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(if_param_excp_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("if_param_excp_test.onnx"); })); +} + +TEST_CASE(if_param_excp1_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("if_param_excp1_test.onnx"); })); +} + +TEST_CASE(if_param_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + auto cond = mm->add_parameter("cond", cond_s); + migraphx::shape ds{migraphx::shape::float_type, {2, 3}}; + auto x = mm->add_parameter("x", ds); + auto y = mm->add_parameter("y", ds); + + auto* then_mod = p.create_module("If_3_if"); + std::vector data1 = {0.384804, -1.77948, -0.453775, 0.477438, -1.06333, -1.12893}; + auto l1 = then_mod->add_literal(migraphx::literal(ds, data1)); + auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x, l1); + then_mod->add_return({a1}); + + auto* else_mod = p.create_module("If_3_else"); + std::vector data2 = {-0.258047, 0.360394, 0.536804, -0.577762, 1.0217, 1.02442}; + auto l2 = else_mod->add_literal(migraphx::literal(ds, data2)); + auto a2 = else_mod->add_instruction(migraphx::make_op("mul"), y, l2); + else_mod->add_return({a2}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("if_param_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(if_pl_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + migraphx::shape xs{migraphx::shape::float_type, {2, 3}}; + migraphx::shape ys{migraphx::shape::float_type, {3, 3}}; + std::vector datax = {1, 2, 3, 4, 5, 6}; + std::vector datay = {8, 7, 6, 5, 4, 3, 2, 1, 0}; + + auto lx = mm->add_literal(migraphx::literal(xs, datax)); + auto ly = mm->add_literal(migraphx::literal(ys, datay)); + auto cond = mm->add_parameter("cond", cond_s); + auto x = mm->add_parameter("x", xs); + auto y = mm->add_parameter("y", ys); + + auto* then_mod = p.create_module("If_5_if"); + auto l1 = then_mod->add_literal(migraphx::literal(ys, datay)); + auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x, lx); + then_mod->add_return({a1, l1}); + + auto* else_mod = p.create_module("If_5_else"); + auto l2 = else_mod->add_literal(migraphx::literal(xs, datax)); + auto a2 = else_mod->add_instruction(migraphx::make_op("mul"), y, ly); + else_mod->add_return({l2, a2}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("if_pl_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(if_then_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sc{migraphx::shape::bool_type, {1}}; + auto cond = mm->add_literal(migraphx::literal(sc, {1})); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + std::vector ones(s.elements(), 1.0f); + auto l1 = mm->add_literal(s, ones); + std::vector rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946}; + auto l2 = mm->add_literal(s, rand); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + + auto* then_mod = p.create_module("If_5_if"); + auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1); + then_mod->add_return({rt}); + + auto* else_mod = p.create_module("If_5_else"); + auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2); + else_mod->add_return({re}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("if_then_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(if_tuple_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sd{migraphx::shape::float_type, {1}}; + auto l1 = mm->add_literal(migraphx::literal(sd, {1})); + auto l2 = mm->add_literal(migraphx::literal(sd, {2})); + auto l3 = mm->add_literal(migraphx::literal(sd, {3})); + migraphx::shape sx{migraphx::shape::float_type, {1, 4}}; + migraphx::shape sy{migraphx::shape::float_type, {3, 4}}; + migraphx::shape sc{migraphx::shape::bool_type}; + auto cond = mm->add_parameter("cond", sc); + auto x = mm->add_parameter("x", sx); + auto y = mm->add_parameter("y", sy); + + auto* then_mod = p.create_module("If_6_if"); + auto m1 = + then_mod->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1); + auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1); + auto m2 = + then_mod->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2); + auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2); + then_mod->add_return({add0, mul0}); + + auto* else_mod = p.create_module("If_6_else"); + auto me1 = + else_mod->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3); + auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1); + auto me2 = + else_mod->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3); + auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2); + else_mod->add_return({mul1, add1}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); + mm->add_return({r0, r1}); + + auto prog = migraphx::parse_onnx("if_tuple_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(isnan_float_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + auto t1 = mm->add_parameter("t1", s); + auto ret = mm->add_instruction(migraphx::make_op("isnan"), t1); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("isnan_float_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(isnan_half_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::half_type, {2, 3}}; + auto t1 = mm->add_parameter("t1", s); + auto ret = mm->add_instruction(migraphx::make_op("isnan"), t1); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("isnan_half_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(imagescaler_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}}; + auto l0 = mm->add_parameter("0", s); + auto scale_val = mm->add_literal(0.5f); + auto bias_vals = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); + auto scaled_tensor = mm->add_instruction( + migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), scale_val); + auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), l0, scaled_tensor); + auto bias_bcast = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), bias_vals); + mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast); + + auto prog = optimize_onnx("imagescaler_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(imagescaler_half_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::half_type, {1, 3, 16, 16}}; + auto l0 = mm->add_parameter("0", s); + auto scale_val = + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5f}}); + auto bias_vals = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::half_type, {3}}, {0.01, 0.02, 0.03}}); + auto scaled_tensor = mm->add_instruction( + migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), scale_val); + auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), l0, scaled_tensor); + auto bias_bcast = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), bias_vals); + mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast); + + auto prog = optimize_onnx("imagescaler_half_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(implicit_add_bcast_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}}); + auto l3 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1); + mm->add_instruction(migraphx::make_op("add"), l0, l3); + + auto prog = optimize_onnx("implicit_add_bcast_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(implicit_add_bcast_user_input_shape_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5, 1}}); + auto l3 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4, 5, 6}}}), l1); + auto r = mm->add_instruction(migraphx::make_op("add"), l0, l3); + mm->add_return({r}); + + migraphx::onnx_options options; + options.map_input_dims["0"] = {3, 4, 5, 6}; + options.map_input_dims["1"] = {4, 5, 1}; + auto prog = migraphx::parse_onnx("implicit_add_bcast_test.onnx", options); + + EXPECT(p == prog); +} + +TEST_CASE(implicit_pow_bcast_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}}); + auto l3 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1); + mm->add_instruction(migraphx::make_op("pow"), l0, l3); + + auto prog = optimize_onnx("implicit_pow_bcast_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(implicit_sub_bcast_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::uint64_type, {2, 3, 4, 5}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::uint64_type, {4, 5}}); + auto l3 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1); + mm->add_instruction(migraphx::make_op("sub"), l0, l3); + + auto prog = optimize_onnx("implicit_sub_bcast_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(initializer_not_an_input) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector w = {1, 2, 3, 4, 5, 6, 7, 8}; + auto l1 = mm->add_literal(migraphx::literal({migraphx::shape::float_type, {2, 4}}, w)); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}}); + migraphx::add_apply_alpha_beta(*mm, {l0, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); + auto prog = optimize_onnx("initializer_not_an_input.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(instance_norm_test) +{ + std::vector dims{1, 2, 3, 3}; + migraphx::shape s1{migraphx::shape::float_type, dims}; + migraphx::shape s2{migraphx::shape::float_type, {2}}; + + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("0", s1); + auto scale = mm->add_parameter("1", s2); + auto bias = mm->add_parameter("2", s2); + + auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x); + auto mean_bcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean); + auto l0 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast); + auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l0); + auto l1 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast); + auto epsilon_literal = mm->add_literal(1e-5f); + auto epsilon_bcast = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal); + auto variance_bcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), variance); + auto l2 = mm->add_instruction(migraphx::make_op("add"), variance_bcast, epsilon_bcast); + auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2); + auto l4 = mm->add_instruction(migraphx::make_op("mul"), l1, l3); + auto scale_bcast = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale); + auto bias_bcast = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias); + auto l5 = mm->add_instruction(migraphx::make_op("mul"), l4, scale_bcast); + mm->add_instruction(migraphx::make_op("add"), l5, bias_bcast); + + auto prog = optimize_onnx("instance_norm_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(leaky_relu_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + float alpha = 0.01f; + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {3}}); + mm->add_instruction(migraphx::make_op("leaky_relu", {{"alpha", alpha}}), l0); + + auto prog = optimize_onnx("leaky_relu_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(less_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + auto input1 = mm->add_literal(migraphx::literal(s, data)); + auto input2 = mm->add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 3}}); + auto le = mm->add_instruction(migraphx::make_op("less"), input1, input2); + auto ret = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + le); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("less_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(less_bool_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sf{migraphx::shape::float_type, {2, 3}}; + migraphx::shape sb{migraphx::shape::bool_type, {2, 3}}; + + auto input1 = mm->add_parameter("x1", sf); + auto input2 = mm->add_parameter("x2", sb); + auto cin1 = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + input1); + auto ret = mm->add_instruction(migraphx::make_op("less"), cin1, input2); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("less_bool_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(lessorequal_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto input1 = mm->add_parameter("x1", migraphx::shape{migraphx::shape::float_type, {3}}); + auto input2 = mm->add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {3}}); + auto temp = mm->add_instruction(migraphx::make_op("greater"), input1, input2); + auto bt = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), temp); + auto le = mm->add_instruction(migraphx::make_op("not"), bt); + + mm->add_return({le}); + + auto prog = migraphx::parse_onnx("lessorequal_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(log_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); + mm->add_instruction(migraphx::make_op("log"), input); + + auto prog = optimize_onnx("log_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(logical_and_bcast_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::bool_type, {4, 5}}); + auto l2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", l0->get_shape().lens()}}), l1); + auto ret = mm->add_instruction(migraphx::make_op("logical_and"), l0, l2); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("logical_and_bcast_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(logical_or_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}}); + auto ret = mm->add_instruction(migraphx::make_op("logical_or"), l0, l1); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("logical_or_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(logical_xor_bcast_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::bool_type, {4, 1}}); + auto l2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", l0->get_shape().lens()}}), l1); + auto ret = mm->add_instruction(migraphx::make_op("logical_xor"), l0, l2); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("logical_xor_bcast_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(logsoftmax_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + int axis = 1; + mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", axis}}), l0); + auto prog = optimize_onnx("logsoftmax_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(logsoftmax_nonstd_input_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {6, 9}}); + auto l1 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 0}}, {"ends", {4, 4}}}), l0); + auto l2 = mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", -1}}), l1); + mm->add_return({l2}); + + auto prog = migraphx::parse_onnx("logsoftmax_nonstd_input_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(loop_default_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape su{migraphx::shape::float_type}; + auto a = mm->add_parameter("a", su); + auto b = mm->add_parameter("b", su); + migraphx::shape si{migraphx::shape::int64_type}; + auto max_iter = mm->add_literal(migraphx::literal(si, {10})); + migraphx::shape sc{migraphx::shape::bool_type}; + auto icond = mm->add_literal(migraphx::literal(sc, {1})); + mm->add_instruction(migraphx::make_op("undefined")); + + auto* body = p.create_module("Loop_3_loop"); + body->add_parameter("iteration_num", {migraphx::shape::int64_type}); + body->add_parameter("keep_going_inp", {migraphx::shape::bool_type}); + auto var = body->add_parameter("b_in", su); + + auto ad = body->add_instruction(migraphx::make_op("add"), a, var); + auto sb = body->add_instruction(migraphx::make_op("sub"), a, var); + auto gt = body->add_instruction(migraphx::make_op("greater"), ad, sb); + auto cv = body->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), gt); + auto ad1 = body->add_instruction(migraphx::make_op("add"), sb, sb); + body->add_return({cv, sb, ad, ad1}); + + auto lp = mm->add_instruction( + migraphx::make_op("loop", {{"max_iterations", 10}}), {max_iter, icond, b}, {body}); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), lp); + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), lp); + auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), lp); + mm->add_return({r0, r2}); + + auto prog = migraphx::parse_onnx("loop_default_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(loop_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape si{migraphx::shape::int64_type, {1}}; + auto max_iter = mm->add_parameter("max_trip_count", si); + migraphx::shape sc{migraphx::shape::bool_type, {1}}; + auto icond = mm->add_parameter("keep_going_cond", sc); + migraphx::shape su{migraphx::shape::float_type, {1}}; + auto a = mm->add_parameter("a", su); + auto b = mm->add_parameter("b", su); + + auto* body = p.create_module("Loop_4_loop"); + body->add_parameter("iteration_num", si); + body->add_parameter("keep_going_inp", sc); + auto var = body->add_parameter("b_in", su); + + auto ad = body->add_instruction(migraphx::make_op("add"), a, var); + auto sb = body->add_instruction(migraphx::make_op("sub"), a, var); + auto gt = body->add_instruction(migraphx::make_op("greater"), ad, sb); + auto cv = body->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), gt); + auto ad1 = body->add_instruction(migraphx::make_op("add"), sb, sb); + body->add_return({cv, sb, ad, ad1}); + + auto lp = mm->add_instruction( + migraphx::make_op("loop", {{"max_iterations", 10}}), {max_iter, icond, b}, {body}); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), lp); + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), lp); + auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), lp); + mm->add_return({r0, r2}); + + auto prog = migraphx::parse_onnx("loop_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(lpnormalization_default_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector input_lens{3, 4}; + auto input_type = migraphx::shape::float_type; + migraphx::shape s{input_type, input_lens}; + auto x = mm->add_parameter("x", s); + + std::ptrdiff_t axis = 0; + auto p_val = mm->add_instruction(migraphx::make_op("mul"), x, x); + auto norms = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {axis}}}), p_val); + norms = mm->add_instruction(migraphx::make_op("sqrt"), norms); + norms = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), norms); + auto zero_mb = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}})); + auto one_mb = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}})); + auto is_zero = mm->add_instruction(migraphx::make_op("equal"), norms, zero_mb); + auto norms_zeros_to_one = + mm->add_instruction(migraphx::make_op("where"), is_zero, one_mb, norms); + mm->add_instruction(migraphx::make_op("div"), x, norms_zeros_to_one); + + auto prog = optimize_onnx("lpnormalization_default_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(lpnormalization_axis_error_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("lpnormalization_axis_error_test.onnx"); })); +} + +TEST_CASE(lpnormalization_p_error_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("lpnormalization_p_error_test.onnx"); })); +} + +TEST_CASE(lppool_l1_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 3, 5}}); + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::lpnorm}, + {"padding", {0, 0}}, + {"stride", {1}}, + {"lengths", {3}}, + {"lp_order", 1}}), + l0); + auto prog = optimize_onnx("lppool_l1_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(lppool_l2_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", {migraphx::shape::float_type, {1, 3, 5}}); + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::lpnorm}, + {"padding", {0, 0}}, + {"stride", {1}}, + {"lengths", {3}}, + {"lp_order", 2}}), + l0); + auto prog = optimize_onnx("lppool_l2_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(lrn_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 28, 24, 24}}); + migraphx::op::lrn op; + op.size = 5; + op.alpha = 0.0001; + op.beta = 0.75; + op.bias = 1.0; + mm->add_instruction(op, l0); + auto prog = optimize_onnx("lrn_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(matmul_bmbm_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}}); + auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 2, 1, 7, 8}}); + auto bl0 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 6, 7}}}), l0); + auto bl1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 7, 8}}}), l1); + migraphx::add_apply_alpha_beta(*mm, {bl0, bl1}, migraphx::make_op("dot"), 1.0f, 0.0f); + auto prog = optimize_onnx("matmul_bmbm_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(matmul_bmv_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}}); + auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}}); + auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1); + auto bsl1 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 7, 1}}}), sl1); + auto res = + migraphx::add_apply_alpha_beta(*mm, {l0, bsl1}, migraphx::make_op("dot"), 1.0f, 0.0f); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), res); + + auto prog = optimize_onnx("matmul_bmv_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(matmul_mv_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {6, 7}}); + auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}}); + auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1); + auto res = migraphx::add_apply_alpha_beta(*mm, {l0, sl1}, migraphx::make_op("dot"), 1.0f, 0.0f); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res); + + auto prog = optimize_onnx("matmul_mv_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(matmul_vbm_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}}); + auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}}); + auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0); + auto bsl0 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 1, 7}}}), sl0); + auto res = + migraphx::add_apply_alpha_beta(*mm, {bsl0, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res); + + auto prog = optimize_onnx("matmul_vbm_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(matmul_vm_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}}); + auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7, 8}}); + auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0); + auto res = migraphx::add_apply_alpha_beta(*mm, {sl0, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res); + + auto prog = optimize_onnx("matmul_vm_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(matmul_vv_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}}); + auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}}); + auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0); + auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1); + auto res = + migraphx::add_apply_alpha_beta(*mm, {sl0, sl1}, migraphx::make_op("dot"), 1.0f, 0.0f); + auto sr0 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sr0); + + auto prog = optimize_onnx("matmul_vv_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(matmulinteger_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {3, 6, 16}}); + auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::int8_type, {3, 16, 8}}); + mm->add_instruction(migraphx::make_op("quant_dot"), l0, l1); + + auto prog = optimize_onnx("matmulinteger_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(max_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); + auto input1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}}); + auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}}); + auto l0 = mm->add_instruction(migraphx::make_op("max"), input0, input1); + mm->add_instruction(migraphx::make_op("max"), l0, input2); + + optimize_onnx("max_test.onnx"); +} + +TEST_CASE(maxpool_notset_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::max}, + {"padding", {0, 0, 1, 1}}, + {"stride", {2, 2}}, + {"lengths", {6, 6}}}), + input); + + auto prog = optimize_onnx("maxpool_notset_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(maxpool_same_upper_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::max}, + {"padding", {0, 0, 1, 1}}, + {"stride", {1, 1}}, + {"lengths", {2, 2}}}), + input); + + auto prog = optimize_onnx("maxpool_same_upper_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(mean_invalid_broadcast_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("mean_invalid_broadcast_test.onnx"); })); +} + +TEST_CASE(mean_single_input_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto data0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}); + mm->add_return({data0}); + + auto prog = migraphx::parse_onnx("mean_single_input_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(mean_test) +{ + const std::size_t num_data = 3; + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::half_type, {1, 2, 3}}; + auto data0 = mm->add_parameter("0", s); + auto data1 = mm->add_parameter("1", s); + auto data2 = mm->add_parameter("2", s); + auto div_lit = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {num_data}}); + auto divisor = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), div_lit); + auto mean = mm->add_instruction(migraphx::make_op("div"), data0, divisor); + divisor = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), div_lit); + data1 = mm->add_instruction(migraphx::make_op("div"), data1, divisor); + mean = mm->add_instruction(migraphx::make_op("add"), mean, data1); + divisor = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), div_lit); + data2 = mm->add_instruction(migraphx::make_op("div"), data2, divisor); + mean = mm->add_instruction(migraphx::make_op("add"), mean, data2); + + auto prog = optimize_onnx("mean_fp16_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(mean_integral_test) +{ + const std::size_t num_data = 10; + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int32_type, {2, 2, 2}}; + + auto mean = mm->add_parameter("0", s); + for(std::size_t i = 1; i < num_data; ++i) + { + auto data = mm->add_parameter(std::to_string(i), s); + mean = mm->add_instruction(migraphx::make_op("add"), mean, data); + } + + auto div_lit = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {num_data}}); + auto divisor = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), div_lit); + mean = mm->add_instruction(migraphx::make_op("div"), mean, divisor); + + auto prog = optimize_onnx("mean_integral_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(min_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); + auto input1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}}); + auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}}); + auto l0 = mm->add_instruction(migraphx::make_op("min"), input0, input1); + mm->add_instruction(migraphx::make_op("min"), l0, input2); + + optimize_onnx("min_test.onnx"); +} + +TEST_CASE(multinomial_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + size_t sample_size = 10; + float seed = 0.0f; + + auto input = mm->add_parameter("input", migraphx::shape{migraphx::shape::float_type, {1, 10}}); + auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input); + auto mb_maxes = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 10}}}), maxes); + auto cdf = mm->add_instruction(migraphx::make_op("sub"), input, mb_maxes); + cdf = mm->add_instruction(migraphx::make_op("exp"), cdf); + cdf = mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf); + + std::mt19937 gen(seed); + std::uniform_real_distribution<> dis(0.0, 1.0); + std::vector rand_samples(sample_size); + std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); + migraphx::shape rs{migraphx::shape::float_type, {1, sample_size}}; + auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples}); + + mm->add_instruction(migraphx::make_op("multinomial"), cdf, rs_lit); + + auto prog = optimize_onnx("multinomial_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(multinomial_dtype_error_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("multinomial_dtype_error_test.onnx"); })); +} + +TEST_CASE(multinomial_generated_seed_test) +{ + auto p1 = optimize_onnx("multinomial_generated_seed_test.onnx"); + auto p2 = optimize_onnx("multinomial_generated_seed_test.onnx"); + + EXPECT(p1 != p2); +} + +TEST_CASE(multinomial_int64_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + size_t sample_size = 10; + float seed = 1.0f; + migraphx::shape::type_t dtype = migraphx::shape::type_t::int64_type; + + auto input = mm->add_parameter("input", migraphx::shape{migraphx::shape::float_type, {1, 10}}); + auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input); + auto mb_maxes = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 10}}}), maxes); + auto cdf = mm->add_instruction(migraphx::make_op("sub"), input, mb_maxes); + cdf = mm->add_instruction(migraphx::make_op("exp"), cdf); + cdf = mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf); + + std::mt19937 gen(seed); + std::uniform_real_distribution<> dis(0.0, 1.0); + std::vector rand_samples(sample_size); + std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); + migraphx::shape rs{migraphx::shape::float_type, {1, sample_size}}; + auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples}); + + mm->add_instruction(migraphx::make_op("multinomial", {{"dtype", dtype}}), cdf, rs_lit); + + auto prog = optimize_onnx("multinomial_int64_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(no_pad_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + mm->add_instruction(migraphx::make_op("identity"), l0); + auto prog = optimize_onnx("no_pad_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(neg_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int64_type, {2, 3}}; + auto input = mm->add_parameter("0", s); + auto ret = mm->add_instruction(migraphx::make_op("neg"), input); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("neg_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(nms_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sb{migraphx::shape::float_type, {1, 6, 4}}; + auto b = mm->add_parameter("boxes", sb); + + migraphx::shape ss{migraphx::shape::float_type, {1, 1, 6}}; + auto s = mm->add_parameter("scores", ss); + + migraphx::shape smo{migraphx::shape::int64_type, {1}}; + auto mo = mm->add_parameter("max_output_boxes_per_class", smo); + + migraphx::shape siou{migraphx::shape::float_type, {1}}; + auto iou = mm->add_parameter("iou_threshold", siou); + + migraphx::shape sst{migraphx::shape::float_type, {1}}; + auto st = mm->add_parameter("score_threshold", sst); + + auto ret = mm->add_instruction( + migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}), b, s, mo, iou, st); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("nms_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(nonzero_dynamic_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::bool_type, {2, 2}}; + auto data = mm->add_parameter("data", s); + auto r = mm->add_instruction(migraphx::make_op("nonzero"), data); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("nonzero_dynamic_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(nonzero_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 2}}; + std::vector data = {1, 0, 1, 1}; + mm->add_literal(migraphx::literal(s, data)); + + migraphx::shape si{migraphx::shape::int64_type, {2, 3}}; + std::vector indices = {0, 1, 1, 0, 0, 1}; + auto r = mm->add_literal(migraphx::literal(si, indices)); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("nonzero_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(nonzero_int_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int16_type, {2, 3}}; + std::vector data = {1, 1, 0, 1, 0, 1}; + mm->add_literal(migraphx::literal(s, data.begin(), data.end())); + + migraphx::shape si{migraphx::shape::int64_type, {2, 4}}; + std::vector indices = {0, 0, 1, 1, 0, 1, 0, 2}; + auto r = mm->add_literal(migraphx::literal(si, indices)); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("nonzero_int_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(not_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::int32_type, {4}}); + auto ret = mm->add_instruction(migraphx::make_op("not"), l0); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("not_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(not_bool_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {4}}); + auto ret = mm->add_instruction(migraphx::make_op("not"), l0); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("not_bool_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(onehot_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s_ind{migraphx::shape::int32_type, {5, 2}}; + migraphx::shape s_val{migraphx::shape::half_type, {2}}; + mm->add_literal(3); + auto l_ind = mm->add_parameter("indices", s_ind); + auto l_val = mm->add_parameter("values", s_val); + migraphx::shape s_dep{migraphx::shape::half_type, {3, 3}}; + std::vector data_dep{1, 0, 0, 0, 1, 0, 0, 0, 1}; + auto l_dep = mm->add_literal(migraphx::literal(s_dep, data_dep)); + auto gather_out = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), l_dep, l_ind); + auto tr_out = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), + gather_out); + auto off_val = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), l_val); + auto on_val = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), l_val); + auto diff = mm->add_instruction(migraphx::make_op("sub"), on_val, off_val); + auto mb_off_val = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 2}}}), off_val); + auto mb_diff = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 2}}}), diff); + auto mul = mm->add_instruction(migraphx::make_op("mul"), tr_out, mb_diff); + auto r = mm->add_instruction(migraphx::make_op("add"), mul, mb_off_val); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("onehot_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(pad_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + mm->add_instruction(migraphx::make_op("pad", {{"pads", {1, 1, 1, 1}}}), l0); + auto prog = optimize_onnx("pad_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(pad_3arg_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {1.0f}}); + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 1, 2, 2}}); + auto r = mm->add_instruction( + migraphx::make_op("pad", {{"pads", {1, 1, 2, 2}}, {"value", 1.0f}}), l0); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("pad_3arg_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(pad_reflect_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {0, 2, 0, 1}}); + auto l1 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 1}}, {"ends", {2, 2}}}), l0); + auto l2 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {2, 1}}}), l0); + auto l3 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {2, 1}}}), l0); + auto r = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l1, l0, l3); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("pad_reflect_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(pad_reflect_multiaxis_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3}}); + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {0, 2, 2, 0}}); + auto l1 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 1}}, {"ends", {2, 2}}}), l0); + auto l2 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 2}}, {"ends", {2, 3}}}), l0); + auto l3 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l1, l0); + auto l4 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {1, 5}}}), l3); + auto l5 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 0}}, {"ends", {2, 5}}}), l3); + auto r = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), l3, l4, l5); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("pad_reflect_multiaxis_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(pow_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); + mm->add_instruction(migraphx::make_op("pow"), l0, l1); + + auto prog = optimize_onnx("pow_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(pow_fp32_i64_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 5}}); + auto l1f = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l1); + auto ret = mm->add_instruction(migraphx::make_op("pow"), l0, l1f); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("pow_fp32_i64_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(pow_i64_fp32_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 5}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); + auto l0f = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l0); + auto fr = mm->add_instruction(migraphx::make_op("pow"), l0f, l1); + auto ir = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), fr); + mm->add_return({ir}); + + auto prog = migraphx::parse_onnx("pow_i64_fp32_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(prefix_scan_sum) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {1}, {1}}, {0}}); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); + auto ret = mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", true}, {"reverse", true}}), + l0); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("prefix_scan_sum_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(prelu_brcst_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}}); + auto bl1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", l0->get_shape().lens()}}), l1); + auto ret = mm->add_instruction(migraphx::make_op("prelu"), l0, bl1); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("prelu_brcst_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(quantizelinear_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {5}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}}); + auto l1_mbcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1); + auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); + auto round = mm->add_instruction(migraphx::make_op("round"), div); + auto s = round->get_shape(); + std::vector min_data(s.elements(), 0); + std::vector max_data(s.elements(), 255); + auto min_arg = mm->add_literal(s, min_data); + auto max_arg = mm->add_literal(s, max_data); + auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_arg, max_arg); + mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}), + clip); + + auto prog = optimize_onnx("quantizelinear_test.onnx", true); + EXPECT(p.sort() == prog.sort()); +} + +TEST_CASE(quantizelinear_int32_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::int32_type, {5}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}}); + auto l1_mbcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1); + l0 = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), + l0); + auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); + auto round = mm->add_instruction(migraphx::make_op("round"), div); + auto s = round->get_shape(); + std::vector min_data(s.elements(), 0); + std::vector max_data(s.elements(), 255); + auto min_arg = mm->add_literal(s, min_data); + auto max_arg = mm->add_literal(s, max_data); + auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_arg, max_arg); + mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}), + clip); + + auto prog = optimize_onnx("quantizelinear_int32_test.onnx", true); + EXPECT(p.sort() == prog.sort()); +} + +TEST_CASE(quantizelinear_zero_point_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {5}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}}); + auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}}); + auto l1_mbcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1); + auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); + auto round = mm->add_instruction(migraphx::make_op("round"), div); + auto l2_mbcast = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l2); + l2_mbcast = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), + l2_mbcast); + auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_mbcast); + auto s = round->get_shape(); + std::vector min_data(s.elements(), -128); + std::vector max_data(s.elements(), 127); + auto min_arg = mm->add_literal(s, min_data); + auto max_arg = mm->add_literal(s, max_data); + auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_arg, max_arg); + mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}), + clip); + + auto prog = optimize_onnx("quantizelinear_zero_point_test.onnx", true); + EXPECT(p.sort() == prog.sort()); +} + +migraphx::program make_quantizelinear_axis_prog() +{ + migraphx::program p; + std::vector input_lens{1, 1, 5, 1}; + int axis = 2; + auto* mm = p.get_main_module(); + + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, input_lens}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5}}); + auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {5}}); + auto l1_bcast = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l1); + + auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_bcast); + auto round = mm->add_instruction(migraphx::make_op("round"), div); + auto l2_bcast = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l2); + l2_bcast = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), + l2_bcast); + auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_bcast); + auto s = round->get_shape(); + std::vector min_data(s.elements(), -128); + std::vector max_data(s.elements(), 127); + auto min_arg = mm->add_literal(s, min_data); + auto max_arg = mm->add_literal(s, max_data); + auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_arg, max_arg); + mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}), + clip); + return p; +} + +TEST_CASE(quantizelinear_axis_test) +{ + migraphx::program p = make_quantizelinear_axis_prog(); + + auto prog = optimize_onnx("quantizelinear_axis_test.onnx", true); + EXPECT(p.sort() == prog.sort()); +} + +TEST_CASE(quantizelinear_neg_axis_test) +{ + migraphx::program p = make_quantizelinear_axis_prog(); + + auto prog = optimize_onnx("quantizelinear_neg_axis_test.onnx", true); + EXPECT(p.sort() == prog.sort()); +} + +TEST_CASE(randomnormal_test) +{ + float mean = 10.0; + float scale = 1.5; + float seed = 0.0; + std::vector shape_attr{2, 3, 4}; + + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::double_type, shape_attr}; + std::vector rand_vals(s.elements()); + std::mt19937 gen(seed); + std::normal_distribution<> d(mean, scale); + std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); }); + + mm->add_literal(migraphx::literal{s, rand_vals}); + + auto prog = optimize_onnx("randomnormal_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(randomnormal_dtype_error_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("randomnormal_dtype_error_test.onnx"); })); +} + +TEST_CASE(randomnormal_generated_seed_test) +{ + auto p1 = optimize_onnx("randomnormal_generated_seed_test.onnx"); + auto p2 = optimize_onnx("randomnormal_generated_seed_test.onnx"); + + EXPECT(p1 != p2); +} + +TEST_CASE(randomnormal_shape_error_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("randomnormal_shape_error_test.onnx"); })); +} + +TEST_CASE(randomnormallike_test) +{ + float mean = 10.0; + float scale = 1.5; + float seed = 0.0; + std::vector shape_attr{2, 3, 4}; + + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::half_type, shape_attr}; + std::vector rand_vals(s.elements()); + std::mt19937 gen(seed); + std::normal_distribution<> d(mean, scale); + std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); }); + + mm->add_parameter("input", s); + mm->add_literal(migraphx::literal{s, rand_vals}); + + auto prog = optimize_onnx("randomnormallike_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(randomnormallike_type_error_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("randomnormallike_type_error_test.onnx"); })); +} + +TEST_CASE(randomuniform_test) +{ + float high = 1.0; + float low = 0.0; + float seed = 0.0; + std::vector shape_attr{2, 3, 4}; + + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::double_type, shape_attr}; + std::vector rand_vals(s.elements()); + std::mt19937 gen(seed); + std::uniform_real_distribution<> d(low, high); + std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); }); + + mm->add_literal(migraphx::literal{s, rand_vals}); + + auto prog = optimize_onnx("randomuniform_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(randomuniform_dtype_error_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("randomuniform_dtype_error_test.onnx"); })); +} + +TEST_CASE(randomuniform_generated_seed_test) +{ + auto p1 = optimize_onnx("randomuniform_generated_seed_test.onnx"); + auto p2 = optimize_onnx("randomuniform_generated_seed_test.onnx"); + + EXPECT(p1 != p2); +} + +TEST_CASE(randomuniform_shape_error_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("randomuniform_shape_error_test.onnx"); })); +} + +TEST_CASE(randomuniformlike_test) +{ + float high = 10.0; + float low = 1.0; + float seed = 0.0; + std::vector shape_attr{2, 3, 4}; + + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::half_type, shape_attr}; + std::vector rand_vals(s.elements()); + std::mt19937 gen(seed); + std::uniform_real_distribution<> d(low, high); + std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); }); + + mm->add_parameter("input", s); + mm->add_literal(migraphx::literal{s, rand_vals}); + + auto prog = optimize_onnx("randomuniformlike_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(randomuniformlike_type_error_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("randomuniformlike_type_error_test.onnx"); })); +} + +TEST_CASE(range_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + mm->add_literal(int64_t{10}); + mm->add_literal(int64_t{6}); + mm->add_literal(int64_t{-3}); + mm->add_literal(migraphx::literal{{migraphx::shape::int64_type, {2}}, {10, 7}}); + + auto prog = optimize_onnx("range_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(range_float_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + mm->add_literal(float{2}); + mm->add_literal(float{11}); + mm->add_literal(float{2}); + mm->add_literal(migraphx::literal{{migraphx::shape::float_type, {5}}, {2, 4, 6, 8, 10}}); + + auto prog = optimize_onnx("range_float_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(recip_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3}}); + mm->add_instruction(migraphx::make_op("recip"), input); + + auto prog = optimize_onnx("recip_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reducel1_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto abs_l0 = mm->add_instruction(migraphx::make_op("abs"), l0); + auto sum_l0 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {-2}}}), abs_l0); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {-2}}}), sum_l0); + auto prog = optimize_onnx("reducel1_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reducel2_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto square_l0 = mm->add_instruction(migraphx::make_op("mul"), l0, l0); + auto sum_l0 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {-1}}}), square_l0); + auto squ_l0 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {-1}}}), sum_l0); + mm->add_instruction(migraphx::make_op("sqrt"), squ_l0); + auto prog = optimize_onnx("reducel2_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reduce_log_sum_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto sum_l0 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {-3}}}), l0); + mm->add_instruction(migraphx::make_op("log"), sum_l0); + auto prog = optimize_onnx("reduce_log_sum_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reduce_log_sum_exp_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto exp_l0 = mm->add_instruction(migraphx::make_op("exp"), l0); + auto sum_l0 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {-4}}}), exp_l0); + mm->add_instruction(migraphx::make_op("log"), sum_l0); + auto prog = optimize_onnx("reduce_log_sum_exp_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reducemax_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), l0); + auto prog = optimize_onnx("reducemax_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reducemean_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto l1 = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l0); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), l1); + auto prog = optimize_onnx("reducemean_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reducemean_keepdims_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), l0); + auto prog = optimize_onnx("reducemean_keepdims_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reducemin_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto l1 = mm->add_instruction(migraphx::make_op("reduce_min", {{"axes", {2, 3}}}), l0); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), l1); + auto prog = optimize_onnx("reducemin_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reduceprod_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + mm->add_instruction(migraphx::make_op("reduce_prod", {{"axes", {2}}}), l0); + auto prog = optimize_onnx("reduceprod_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reducesum_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto l1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), l0); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), l1); + auto prog = optimize_onnx("reducesum_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reducesum_empty_axes_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + mm->add_literal({}); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto l1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1, 2, 3}}}), x); + auto r = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1, 2, 3}}}), l1); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("reducesum_empty_axes_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reducesum_noop_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + mm->add_literal({}); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + mm->add_return({x}); + auto prog = migraphx::parse_onnx("reducesum_noop_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reducesum_multiaxis_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto l1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2, 3}}}), l0); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), l1); + auto prog = optimize_onnx("reducesum_multiaxis_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reducesum_keepdims_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2, 3}}}), l0); + auto prog = optimize_onnx("reducesum_keepdims_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reducesum_square_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto squ_l0 = mm->add_instruction(migraphx::make_op("mul"), l0, l0); + auto sum_l0 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {-2}}}), squ_l0); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {-2}}}), sum_l0); + auto prog = optimize_onnx("reducesum_square_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reshape_test) { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::op::reshape op; std::vector reshape_dims{3, 8}; - p.add_literal( + mm->add_literal( migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims}); - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}}); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}}); op.dims = reshape_dims; - p.add_instruction(op, l0); - p.add_instruction(op, l0); - auto prog = migraphx::parse_onnx("reshape_test.onnx"); + mm->add_instruction(op, l0); + mm->add_instruction(op, l0); + auto prog = optimize_onnx("reshape_test.onnx"); EXPECT(p == prog); } @@ -885,36 +3829,758 @@ TEST_CASE(reshape_test) TEST_CASE(reshape_non_standard_test) { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::op::reshape op; - std::vector reshape_dims{4, 3, 2}; migraphx::shape s{migraphx::shape::float_type, {2, 3, 4}}; - auto x = p.add_parameter("x", s); - auto tran_x = p.add_instruction(migraphx::op::transpose{{0, 2, 1}}, x); - auto cont_x = p.add_instruction(migraphx::op::contiguous{}, tran_x); - p.add_instruction(migraphx::op::reshape{{4, 3, 2}}, cont_x); - auto prog = migraphx::parse_onnx("reshape_non_standard_test.onnx"); + auto x = mm->add_parameter("x", s); + auto tran_x = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), x); + auto cont_x = mm->add_instruction(migraphx::make_op("contiguous"), tran_x); + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, 3, 2}}}), cont_x); + auto prog = optimize_onnx("reshape_non_standard_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(resize_downsample_c_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector ds = {1.0f, 1.0f, 0.6f, 0.6f}; + migraphx::shape ss{migraphx::shape::float_type, {4}}; + mm->add_literal(migraphx::literal{ss, ds}); + + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 4}}; + auto inx = mm->add_parameter("X", sx); + + mm->add_instruction(migraphx::make_op("undefined")); + + migraphx::shape si{migraphx::shape::int32_type, {1, 1, 1, 2}}; + std::vector ind = {0, 2}; + auto li = mm->add_literal(migraphx::literal(si, ind)); + + auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), inx); + auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("resize_downsample_c_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(resize_downsample_f_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector ds = {1.0f, 1.0f, 0.6f, 0.6f}; + migraphx::shape ss{migraphx::shape::float_type, {4}}; + mm->add_literal(migraphx::literal{ss, ds}); + + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 4}}; + auto inx = mm->add_parameter("X", sx); + + mm->add_instruction(migraphx::make_op("undefined")); + + migraphx::shape si{migraphx::shape::int32_type, {1, 1, 1, 2}}; + std::vector ind = {0, 3}; + auto li = mm->add_literal(migraphx::literal(si, ind)); + + auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), inx); + auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("resize_downsample_f_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(resize_downsample_linear_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape ss{migraphx::shape::float_type, {4}}; + std::vector ds = {1, 1, 0.6, 0.5}; + mm->add_literal(migraphx::literal(ss, ds)); + + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 4}}; + auto x = mm->add_parameter("X", sx); + migraphx::shape s_ind{migraphx::shape::int32_type, {16, 1, 1, 2}}; + std::vector d_ind = {0, 2, 0, 2, 0, 2, 0, 2, 4, 6, 4, 6, 4, 6, 4, 6, + 1, 3, 1, 3, 1, 3, 1, 3, 5, 7, 5, 7, 5, 7, 5, 7}; + auto l_ind = mm->add_literal(migraphx::literal(s_ind, d_ind)); + + migraphx::shape s8{migraphx::shape::float_type, {8, 1, 1, 2}}; + std::vector d8(16, 0.5f); + auto l8 = mm->add_literal(migraphx::literal(s8, d8)); + + migraphx::shape s4{migraphx::shape::float_type, {4, 1, 1, 2}}; + std::vector d4(8, 1.0f / 3.0f); + auto l4 = mm->add_literal(migraphx::literal(s4, d4)); + + migraphx::shape s2{migraphx::shape::float_type, {2, 1, 1, 2}}; + std::vector d2(4, 0); + auto l2 = mm->add_literal(migraphx::literal(s2, d2)); + + migraphx::shape s1{migraphx::shape::float_type, {1, 1, 1, 2}}; + std::vector d1(2, 0.0f); + auto l1 = mm->add_literal(migraphx::literal(s1, d1)); + + mm->add_instruction(migraphx::make_op("undefined")); + auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), x); + auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind); + auto slc80 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data); + auto slc81 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {16}}}), data); + auto diff8 = mm->add_instruction(migraphx::make_op("sub"), slc81, slc80); + auto mul8 = mm->add_instruction(migraphx::make_op("mul"), diff8, l8); + auto add8 = mm->add_instruction(migraphx::make_op("add"), mul8, slc80); + auto slc40 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4}}}), add8); + auto slc41 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), add8); + auto diff4 = mm->add_instruction(migraphx::make_op("sub"), slc41, slc40); + auto mul4 = mm->add_instruction(migraphx::make_op("mul"), diff4, l4); + auto add4 = mm->add_instruction(migraphx::make_op("add"), mul4, slc40); + auto slc20 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), add4); + auto slc21 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), add4); + auto diff2 = mm->add_instruction(migraphx::make_op("sub"), slc21, slc20); + auto mul2 = mm->add_instruction(migraphx::make_op("mul"), diff2, l2); + auto add2 = mm->add_instruction(migraphx::make_op("add"), mul2, slc20); + auto slc10 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), add2); + auto slc11 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), add2); + auto diff1 = mm->add_instruction(migraphx::make_op("sub"), slc11, slc10); + auto mul1 = mm->add_instruction(migraphx::make_op("mul"), diff1, l1); + auto add1 = mm->add_instruction(migraphx::make_op("add"), mul1, slc10); + mm->add_return({add1}); + + auto prog = migraphx::parse_onnx("resize_downsample_linear_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(resize_outsize_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector out_len = {1, 1, 4, 6}; + migraphx::shape so{migraphx::shape::int64_type, {4}}; + mm->add_literal(migraphx::literal(so, out_len)); + + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; + auto inx = mm->add_parameter("X", sx); + + mm->add_instruction(migraphx::make_op("undefined")); + + migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}}; + std::vector ind = {0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3}; + auto li = mm->add_literal(migraphx::literal(si, ind)); + + auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx); + auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("resize_outsize_test.onnx"); EXPECT(p == prog); } -TEST_CASE(round_test) +TEST_CASE(resize_nonstd_input_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector ds = {1.0f, 1.0f, 0.6f, 0.6f}; + migraphx::shape ss{migraphx::shape::float_type, {4}}; + mm->add_literal(migraphx::literal{ss, ds}); + + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 4, 2}}; + auto inx = mm->add_parameter("X", sx); + + migraphx::shape si{migraphx::shape::int32_type, {1, 1, 1, 2}}; + std::vector ind = {0, 4}; + auto li = mm->add_literal(migraphx::literal(si, ind)); + + auto tx = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), inx); + mm->add_instruction(migraphx::make_op("undefined")); + auto tx_cont = mm->add_instruction(migraphx::make_op("contiguous"), tx); + + auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), tx_cont); + auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("resize_nonstd_input_test.onnx"); + + EXPECT(p == prog); +} + +static auto create_upsample_linear_prog() +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape ss{migraphx::shape::float_type, {4}}; + std::vector ds = {1, 1, 2, 2}; + mm->add_literal(migraphx::literal(ss, ds)); + + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; + auto x = mm->add_parameter("X", sx); + migraphx::shape s_ind{migraphx::shape::int32_type, {16, 1, 4, 4}}; + std::vector d_ind = { + 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, + 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, + 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, + 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, + 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, + 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, + 3, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, + 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, + 2, 3, 3, 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3}; + auto l_ind = mm->add_literal(migraphx::literal(s_ind, d_ind)); + + migraphx::shape s8{migraphx::shape::float_type, {8, 1, 4, 4}}; + std::vector d8 = { + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0}; + auto l8 = mm->add_literal(migraphx::literal(s8, d8)); + + migraphx::shape s4{migraphx::shape::float_type, {4, 1, 4, 4}}; + std::vector d4 = { + 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, + 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0, + 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, + 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0, + 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, + 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0, + 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, + 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0}; + auto l4 = mm->add_literal(migraphx::literal(s4, d4)); + + migraphx::shape s2{migraphx::shape::float_type, {2, 1, 4, 4}}; + std::vector d2(32, 0); + auto l2 = mm->add_literal(migraphx::literal(s2, d2)); + + migraphx::shape s1{migraphx::shape::float_type, {1, 1, 4, 4}}; + std::vector d1(16, 0.0f); + auto l1 = mm->add_literal(migraphx::literal(s1, d1)); + + mm->add_instruction(migraphx::make_op("undefined")); + auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x); + auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind); + auto slc80 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data); + auto slc81 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {16}}}), data); + auto diff8 = mm->add_instruction(migraphx::make_op("sub"), slc81, slc80); + auto mul8 = mm->add_instruction(migraphx::make_op("mul"), diff8, l8); + auto add8 = mm->add_instruction(migraphx::make_op("add"), mul8, slc80); + auto slc40 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4}}}), add8); + auto slc41 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), add8); + auto diff4 = mm->add_instruction(migraphx::make_op("sub"), slc41, slc40); + auto mul4 = mm->add_instruction(migraphx::make_op("mul"), diff4, l4); + auto add4 = mm->add_instruction(migraphx::make_op("add"), mul4, slc40); + auto slc20 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), add4); + auto slc21 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), add4); + auto diff2 = mm->add_instruction(migraphx::make_op("sub"), slc21, slc20); + auto mul2 = mm->add_instruction(migraphx::make_op("mul"), diff2, l2); + auto add2 = mm->add_instruction(migraphx::make_op("add"), mul2, slc20); + auto slc10 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), add2); + auto slc11 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), add2); + auto diff1 = mm->add_instruction(migraphx::make_op("sub"), slc11, slc10); + auto mul1 = mm->add_instruction(migraphx::make_op("mul"), diff1, l1); + auto add1 = mm->add_instruction(migraphx::make_op("add"), mul1, slc10); + mm->add_return({add1}); + + return p; +} + +TEST_CASE(resize_upsample_linear_ac_test) +{ + auto p = create_upsample_linear_prog(); + auto prog = migraphx::parse_onnx("resize_upsample_linear_ac_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(resize_upsample_linear_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape ss{migraphx::shape::float_type, {4}}; + std::vector ds = {1, 1, 2, 2}; + mm->add_literal(migraphx::literal(ss, ds)); + + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; + auto x = mm->add_parameter("X", sx); + migraphx::shape s_ind{migraphx::shape::int32_type, {16, 1, 4, 4}}; + std::vector d_ind = { + 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, + 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, + 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, + 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, + 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, + 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, + 3, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, + 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, + 2, 3, 3, 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3}; + auto l_ind = mm->add_literal(migraphx::literal(s_ind, d_ind)); + + migraphx::shape s8{migraphx::shape::float_type, {8, 1, 4, 4}}; + std::vector d8 = { + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, + 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0}; + auto l8 = mm->add_literal(migraphx::literal(s8, d8)); + + migraphx::shape s4{migraphx::shape::float_type, {4, 1, 4, 4}}; + std::vector d4 = { + 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, + 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0, + 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, + 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0, + 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, + 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0, + 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, + 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0}; + auto l4 = mm->add_literal(migraphx::literal(s4, d4)); + + migraphx::shape s2{migraphx::shape::float_type, {2, 1, 4, 4}}; + std::vector d2(32, 0); + auto l2 = mm->add_literal(migraphx::literal(s2, d2)); + + migraphx::shape s1{migraphx::shape::float_type, {1, 1, 4, 4}}; + std::vector d1(16, 0.0f); + auto l1 = mm->add_literal(migraphx::literal(s1, d1)); + + mm->add_instruction(migraphx::make_op("undefined")); + auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x); + auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind); + auto slc80 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data); + auto slc81 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {16}}}), data); + auto diff8 = mm->add_instruction(migraphx::make_op("sub"), slc81, slc80); + auto mul8 = mm->add_instruction(migraphx::make_op("mul"), diff8, l8); + auto add8 = mm->add_instruction(migraphx::make_op("add"), mul8, slc80); + auto slc40 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4}}}), add8); + auto slc41 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), add8); + auto diff4 = mm->add_instruction(migraphx::make_op("sub"), slc41, slc40); + auto mul4 = mm->add_instruction(migraphx::make_op("mul"), diff4, l4); + auto add4 = mm->add_instruction(migraphx::make_op("add"), mul4, slc40); + auto slc20 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), add4); + auto slc21 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), add4); + auto diff2 = mm->add_instruction(migraphx::make_op("sub"), slc21, slc20); + auto mul2 = mm->add_instruction(migraphx::make_op("mul"), diff2, l2); + auto add2 = mm->add_instruction(migraphx::make_op("add"), mul2, slc20); + auto slc10 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), add2); + auto slc11 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), add2); + auto diff1 = mm->add_instruction(migraphx::make_op("sub"), slc11, slc10); + auto mul1 = mm->add_instruction(migraphx::make_op("mul"), diff1, l1); + auto add1 = mm->add_instruction(migraphx::make_op("add"), mul1, slc10); + mm->add_return({add1}); + + auto prog = migraphx::parse_onnx("resize_upsample_linear_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(resize_upsample_pc_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector ds = {1.0f, 1.0f, 2.0f, 1.5f}; + migraphx::shape ss{migraphx::shape::float_type, {4}}; + mm->add_literal(migraphx::literal{ss, ds}); + + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 4}}; + auto inx = mm->add_parameter("X", sx); + + mm->add_instruction(migraphx::make_op("undefined")); + + migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}}; + std::vector ind = {0, 1, 1, 2, 3, 3, 0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 7, 7, 4, 5, 5, 6, 7, 7}; + auto li = mm->add_literal(migraphx::literal(si, ind)); + + auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), inx); + auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("resize_upsample_pc_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(resize_upsample_pf_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector ds = {1.0f, 1.0f, 2.0f, 3.0f}; + migraphx::shape ss{migraphx::shape::float_type, {4}}; + mm->add_literal(migraphx::literal{ss, ds}); + + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; + auto inx = mm->add_parameter("X", sx); + + mm->add_instruction(migraphx::make_op("undefined")); + + migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}}; + std::vector ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}; + auto li = mm->add_literal(migraphx::literal(si, ind)); + + auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx); + auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("resize_upsample_pf_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(reversesequence_batch_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + int batch_axis = 0; + int time_axis = 1; + + migraphx::shape sx{migraphx::shape::float_type, {4, 4}}; + auto input = mm->add_parameter("x", sx); + + std::vector sequence_lens = {1, 2, 3, 4}; + mm->add_literal({{migraphx::shape::int64_type, {4}}, sequence_lens}); + + int batch_size = sx.lens()[batch_axis]; + int time_size = sx.lens()[time_axis]; + + auto add_slice = + [&mm, &input, batch_axis, time_axis](int b_start, int b_end, int t_start, int t_end) { + return mm->add_instruction(migraphx::make_op("slice", + {{"axes", {batch_axis, time_axis}}, + {"starts", {b_start, t_start}}, + {"ends", {b_end, t_end}}}), + input); + }; + auto ret = add_slice(0, 1, 0, time_size); + for(int b = 1; b < batch_size; ++b) + { + auto s0 = add_slice(b, b + 1, 0, sequence_lens[b]); + s0 = mm->add_instruction(migraphx::make_op("reverse", {{"axes", {time_axis}}}), s0); + if(sequence_lens[b] < time_size) + { + auto s1 = add_slice(b, b + 1, sequence_lens[b], time_size); + s0 = mm->add_instruction(migraphx::make_op("concat", {{"axis", time_axis}}), s0, s1); + } + ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0); + } + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("reversesequence_batch_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(reversesequence_batch_axis_err_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_batch_axis_err_test.onnx"); })); +} + +TEST_CASE(reversesequence_rank_err_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_rank_err_test.onnx"); })); +} + +TEST_CASE(reversesequence_sequence_lens_shape_err_test) +{ + EXPECT(test::throws( + [&] { migraphx::parse_onnx("reversesequence_sequence_lens_shape_err_test.onnx"); })); +} + +TEST_CASE(reversesequence_same_axis_err_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_same_axis_err_test.onnx"); })); +} + +TEST_CASE(reversesequence_time_axis_err_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_time_axis_err_test.onnx"); })); +} + +TEST_CASE(reversesequence_time_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + int batch_axis = 1; + int time_axis = 0; + + migraphx::shape sx{migraphx::shape::float_type, {4, 4}}; + auto input = mm->add_parameter("x", sx); + + int batch_size = sx.lens()[batch_axis]; + int time_size = sx.lens()[time_axis]; + std::vector sequence_lens = {4, 3, 2, 1}; + + auto add_slice = + [&mm, &input, batch_axis, time_axis](int b_start, int b_end, int t_start, int t_end) { + return mm->add_instruction(migraphx::make_op("slice", + {{"axes", {batch_axis, time_axis}}, + {"starts", {b_start, t_start}}, + {"ends", {b_end, t_end}}}), + input); + }; + + migraphx::instruction_ref ret; + for(int b = 0; b < batch_size - 1; ++b) + { + auto s0 = add_slice(b, b + 1, 0, sequence_lens[b]); + s0 = mm->add_instruction(migraphx::make_op("reverse", {{"axes", {time_axis}}}), s0); + if(sequence_lens[b] < time_size) + { + auto s1 = add_slice(b, b + 1, sequence_lens[b], time_size); + s0 = mm->add_instruction(migraphx::make_op("concat", {{"axis", time_axis}}), s0, s1); + } + if(b == 0) + { + ret = s0; + } + else + { + ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0); + } + } + auto s0 = add_slice(batch_size - 1, batch_size, 0, time_size); + ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("reversesequence_time_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(roialign_default_test) +{ + migraphx::shape sx{migraphx::shape::float_type, {10, 4, 7, 8}}; + migraphx::shape srois{migraphx::shape::float_type, {8, 4}}; + migraphx::shape sbi{migraphx::shape::int64_type, {8}}; + + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", sx); + auto rois = mm->add_parameter("rois", srois); + auto bi = mm->add_parameter("batch_ind", sbi); + + auto r = mm->add_instruction(migraphx::make_op("roialign"), x, rois, bi); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("roialign_default_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(roialign_test) +{ + migraphx::shape sx{migraphx::shape::float_type, {10, 5, 4, 7}}; + migraphx::shape srois{migraphx::shape::float_type, {8, 4}}; + migraphx::shape sbi{migraphx::shape::int64_type, {8}}; + + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", sx); + auto rois = mm->add_parameter("rois", srois); + auto bi = mm->add_parameter("batch_ind", sbi); + + auto r = mm->add_instruction( + migraphx::make_op("roialign", + {{"coordinate_transformation_mode", "output_half_pixel"}, + {"spatial_scale", 2.0f}, + {"output_height", 5}, + {"output_width", 5}, + {"sampling_ratio", 3}}), + x, + rois, + bi); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("roialign_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(round_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}}); + mm->add_instruction(migraphx::make_op("round"), input); + + auto prog = optimize_onnx("round_test.onnx"); + EXPECT(p == prog); +} + +// the ScatterElements op has 3 reduction modes, which map to separate reference ops +migraphx::program create_scatter_program(const std::string& scatter_mode, int axis) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto l1 = + mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3, 4, 5}}); + auto l2 = + mm->add_parameter("update", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); + auto r = mm->add_instruction(migraphx::make_op(scatter_mode, {{"axis", axis}}), l0, l1, l2); + mm->add_return({r}); + return p; +} + +TEST_CASE(scatter_add_test) +{ + migraphx::program p = create_scatter_program("scatter_add", -2); + auto prog = migraphx::parse_onnx("scatter_add_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(scatter_mul_test) +{ + migraphx::program p = create_scatter_program("scatter_mul", -2); + auto prog = migraphx::parse_onnx("scatter_mul_test.onnx"); + + EXPECT(p == prog); +} +TEST_CASE(scatter_none_test) +{ + migraphx::program p = create_scatter_program("scatter_none", -2); + auto prog = migraphx::parse_onnx("scatter_none_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(scatternd_test) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = + mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); + auto l1 = + mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}}); + auto l2 = + mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}}); + auto r = mm->add_instruction(migraphx::make_op("scatternd_none"), l0, l1, l2); + mm->add_return({r}); + auto prog = migraphx::parse_onnx("scatternd_test.onnx"); + + EXPECT(p == prog); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = + mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); + auto l1 = + mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}}); + auto l2 = + mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}}); + auto r = mm->add_instruction(migraphx::make_op("scatternd_add"), l0, l1, l2); + mm->add_return({r}); + auto prog = migraphx::parse_onnx("scatternd_add_test.onnx"); + + EXPECT(p == prog); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = + mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); + auto l1 = + mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}}); + auto l2 = + mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}}); + auto r = mm->add_instruction(migraphx::make_op("scatternd_mul"), l0, l1, l2); + mm->add_return({r}); + auto prog = migraphx::parse_onnx("scatternd_mul_test.onnx"); + + EXPECT(p == prog); + } +} + +TEST_CASE(selu_test) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}}); - p.add_instruction(migraphx::op::round{}, input); + auto* mm = p.get_main_module(); + std::vector lens = {2, 3}; + migraphx::shape s{migraphx::shape::double_type, lens}; + auto x = mm->add_parameter("x", s); + + migraphx::shape ls{migraphx::shape::double_type, {1}}; + auto la = mm->add_literal({ls, {0.3}}); + auto lg = mm->add_literal({ls, {0.25}}); + auto mbla = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), la); + auto mblg = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lg); + + auto sign_x = mm->add_instruction(migraphx::make_op("sign"), x); + auto exp_x = mm->add_instruction(migraphx::make_op("exp"), x); + + auto mlax = mm->add_instruction(migraphx::make_op("mul"), mbla, exp_x); + auto smlax = mm->add_instruction(migraphx::make_op("sub"), mlax, mbla); + + auto item1 = mm->add_instruction(migraphx::make_op("add"), smlax, x); + auto item2 = mm->add_instruction(migraphx::make_op("sub"), smlax, x); + + auto sitem2 = mm->add_instruction(migraphx::make_op("mul"), sign_x, item2); + auto item12 = mm->add_instruction(migraphx::make_op("sub"), item1, sitem2); + auto r = mm->add_instruction(migraphx::make_op("mul"), item12, mblg); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("selu_test.onnx"); - auto prog = migraphx::parse_onnx("round_test.onnx"); EXPECT(p == prog); } TEST_CASE(shape_test) { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}}; - auto l0 = p.add_parameter("x", s); + auto l0 = mm->add_parameter("x", s); migraphx::shape s_shape{migraphx::shape::int64_type, {4}}; - p.add_literal(s_shape, l0->get_shape().lens()); - auto prog = migraphx::parse_onnx("shape_test.onnx"); + mm->add_literal(s_shape, l0->get_shape().lens()); + auto prog = optimize_onnx("shape_test.onnx"); EXPECT(p == prog); } @@ -922,14 +4588,15 @@ TEST_CASE(shape_test) TEST_CASE(shape_gather_test) { migraphx::program p; - auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {7, 3, 10}}); - auto l1 = - p.add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens()); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {7, 3, 10}}); migraphx::shape const_shape{migraphx::shape::int32_type, {1}}; - auto l2 = p.add_literal(migraphx::literal{const_shape, {1}}); + auto l2 = mm->add_literal(migraphx::literal{const_shape, {1}}); + auto l1 = + mm->add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens()); int axis = 0; - p.add_instruction(migraphx::op::gather{axis}, l1, l2); - auto prog = migraphx::parse_onnx("shape_gather_test.onnx"); + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l1, l2); + auto prog = optimize_onnx("shape_gather_test.onnx"); EXPECT(p == prog); } @@ -937,40 +4604,173 @@ TEST_CASE(shape_gather_test) TEST_CASE(sign_test) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}}); - p.add_instruction(migraphx::op::sign{}, input); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}}); + mm->add_instruction(migraphx::make_op("sign"), input); - auto prog = migraphx::parse_onnx("sign_test.onnx"); + auto prog = optimize_onnx("sign_test.onnx"); EXPECT(p == prog); } TEST_CASE(sin_test) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); - p.add_instruction(migraphx::op::sin{}, input); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); + mm->add_instruction(migraphx::make_op("sin"), input); - auto prog = migraphx::parse_onnx("sin_test.onnx"); + auto prog = optimize_onnx("sin_test.onnx"); EXPECT(p == prog); } TEST_CASE(sinh_test) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); - p.add_instruction(migraphx::op::sinh{}, input); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); + mm->add_instruction(migraphx::make_op("sinh"), input); + + auto prog = optimize_onnx("sinh_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(size_float_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + mm->add_parameter("x", s); + mm->add_literal(migraphx::literal{migraphx::shape::int64_type, {s.elements()}}); + + auto prog = optimize_onnx("size_float_test.onnx"); + EXPECT(p == prog); +} - auto prog = migraphx::parse_onnx("sinh_test.onnx"); +TEST_CASE(size_half_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::half_type, {3, 1}}; + mm->add_parameter("x", s); + mm->add_literal(migraphx::literal{migraphx::shape::int64_type, {s.elements()}}); + auto prog = optimize_onnx("size_half_test.onnx"); + EXPECT(p == prog); +} +TEST_CASE(size_int_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::int32_type, {8, 2, 3}}; + mm->add_parameter("x", s); + mm->add_literal(migraphx::literal{migraphx::shape::int64_type, {s.elements()}}); + auto prog = optimize_onnx("size_int_test.onnx"); EXPECT(p == prog); } TEST_CASE(slice_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 2}}); - p.add_instruction(migraphx::op::slice{{0, 1}, {1, 0}, {2, 2}}, l0); - auto prog = migraphx::parse_onnx("slice_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 2}}); + mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 0}}, {"ends", {2, 2}}}), l0); + auto prog = optimize_onnx("slice_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(slice_3arg_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}}); + mm->add_literal({{migraphx::shape::int32_type, {2}}, {0, 0}}); + mm->add_literal({{migraphx::shape::int32_type, {2}}, {2, 5}}); + auto ret = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {2, 5}}}), l0); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("slice_3arg_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(slice_5arg_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}}); + mm->add_literal({{migraphx::shape::int32_type, {2}}, {1, 1}}); + mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -2}}); + mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -1}}); + mm->add_literal({{migraphx::shape::int32_type, {2}}, {-5, -3}}); + auto ret = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {-1, -2}}, {"starts", {-5, -3}}, {"ends", {-1, -1}}}), + l0); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("slice_5arg_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(slice_5arg_reverse_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}}); + mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, 1}}); + mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -2}}); + mm->add_literal({{migraphx::shape::int32_type, {2}}, {-5, -1}}); + mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -3}}); + auto slice_out = mm->add_instruction( + migraphx::make_op("slice", + {{"axes", {-1, -2}}, {"starts", {-4, -3}}, {"ends", {2147483647, -1}}}), + l0); + auto ret = mm->add_instruction(migraphx::make_op("reverse", {{"axes", {-1}}}), slice_out); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("slice_5arg_reverse_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(slice_5arg_step_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}}); + mm->add_literal({{migraphx::shape::int32_type, {2}}, {-2, 2}}); + mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -2}}); + mm->add_literal({{migraphx::shape::int32_type, {2}}, {-5, -1}}); + mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -3}}); + auto slice_out = mm->add_instruction( + migraphx::make_op("slice", + {{"axes", {-1, -2}}, {"starts", {-4, -3}}, {"ends", {2147483647, -1}}}), + l0); + auto reverse_out = + mm->add_instruction(migraphx::make_op("reverse", {{"axes", {-1}}}), slice_out); + auto step_out = mm->add_instruction( + migraphx::make_op("step", {{"axes", {-1, -2}}, {"steps", {2, 2}}}), reverse_out); + mm->add_return({step_out}); + + auto prog = migraphx::parse_onnx("slice_5arg_step_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(slice_max_end_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {10, 20}}); + mm->add_instruction( + migraphx::make_op("slice", + {{"axes", {0, 1}}, {"starts", {1, 2}}, {"ends", {3000000000, -1}}}), + l0); + auto prog = optimize_onnx("slice_max_end_test.onnx"); EXPECT(p == prog); } @@ -978,33 +4778,209 @@ TEST_CASE(slice_test) TEST_CASE(softmax_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}}); - p.add_instruction(migraphx::op::softmax{1}, l0); - auto prog = migraphx::parse_onnx("softmax_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}}); + mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), l0); + auto prog = optimize_onnx("softmax_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(softmax_nonstd_input_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {6, 8}}); + auto l1 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 0}}, {"ends", {4, 4}}}), l0); + auto l2 = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), l1); + mm->add_return({l2}); + + auto prog = migraphx::parse_onnx("softmax_nonstd_input_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(softplus_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector input_lens{5}; + auto input_type = migraphx::shape::float_type; + + auto x = mm->add_parameter("x", migraphx::shape{input_type, input_lens}); + auto mb_ones = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {1}})); + auto exp = mm->add_instruction(migraphx::make_op("exp"), x); + auto add = mm->add_instruction(migraphx::make_op("add"), exp, mb_ones); + mm->add_instruction(migraphx::make_op("log"), add); + + auto prog = optimize_onnx("softplus_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(softplus_nd_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector input_lens{3, 4, 5}; + auto input_type = migraphx::shape::half_type; + + auto x = mm->add_parameter("x", migraphx::shape{input_type, input_lens}); + auto mb_ones = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {1}})); + auto exp = mm->add_instruction(migraphx::make_op("exp"), x); + auto add = mm->add_instruction(migraphx::make_op("add"), exp, mb_ones); + mm->add_instruction(migraphx::make_op("log"), add); + + auto prog = optimize_onnx("softplus_nd_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(softsign_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector input_lens{5}; + auto input_type = migraphx::shape::float_type; + + auto x = mm->add_parameter("x", migraphx::shape{input_type, input_lens}); + auto mb_ones = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {1}})); + auto abs = mm->add_instruction(migraphx::make_op("abs"), x); + auto add = mm->add_instruction(migraphx::make_op("add"), abs, mb_ones); + mm->add_instruction(migraphx::make_op("div"), x, add); + + auto prog = optimize_onnx("softsign_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(softsign_nd_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector input_lens{3, 4, 5}; + auto input_type = migraphx::shape::half_type; + + auto x = mm->add_parameter("x", migraphx::shape{input_type, input_lens}); + auto mb_ones = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {1}})); + auto abs = mm->add_instruction(migraphx::make_op("abs"), x); + auto add = mm->add_instruction(migraphx::make_op("add"), abs, mb_ones); + mm->add_instruction(migraphx::make_op("div"), x, add); + + auto prog = optimize_onnx("softsign_nd_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(split_minus_axis_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}}); + auto r1 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {-1}}, {"starts", {0}}, {"ends", {5}}}), input); + auto r2 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {-1}}, {"starts", {5}}, {"ends", {10}}}), input); + auto r3 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {-1}}, {"starts", {10}}, {"ends", {15}}}), input); + mm->add_return({r1, r2, r3}); + + auto prog = migraphx::parse_onnx("split_minus_axis_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(split_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}}); + auto r1 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {7}}}), input); + auto r2 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {7}}, {"ends", {11}}}), input); + auto r3 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {11}}, {"ends", {15}}}), input); + mm->add_return({r1, r2, r3}); + + auto prog = migraphx::parse_onnx("split_test.onnx"); + EXPECT(p == prog); +} +TEST_CASE(split_test_default) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}}); + auto r1 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {5}}}), input); + auto r2 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {5}}, {"ends", {10}}}), input); + mm->add_return({r1, r2}); + + auto prog = migraphx::parse_onnx("split_test_default.onnx"); EXPECT(p == prog); } TEST_CASE(sqrt_test) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}}); - p.add_instruction(migraphx::op::sqrt{}, input); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}}); + mm->add_instruction(migraphx::make_op("sqrt"), input); - auto prog = migraphx::parse_onnx("sqrt_test.onnx"); + auto prog = optimize_onnx("sqrt_test.onnx"); EXPECT(p == prog); } TEST_CASE(squeeze_unsqueeze_test) { migraphx::program p; + auto* mm = p.get_main_module(); std::vector squeeze_axes{0, 2, 3, 5}; std::vector unsqueeze_axes{0, 1, 3, 5}; auto l0 = - p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}}); - auto l1 = p.add_instruction(migraphx::op::squeeze{squeeze_axes}, l0); - p.add_instruction(migraphx::op::unsqueeze{unsqueeze_axes}, l1); - auto prog = migraphx::parse_onnx("squeeze_unsqueeze_test.onnx"); + mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}}); + auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), l0); + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), l1); + auto prog = optimize_onnx("squeeze_unsqueeze_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(squeeze_axes_input_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + mm->add_literal(migraphx::literal({migraphx::shape::int64_type, {2}}, {1, 3})); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}}); + auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 3}}}), l0); + mm->add_return({l1}); + + auto prog = migraphx::parse_onnx("squeeze_axes_input_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(squeeze_empty_axes_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + mm->add_literal({}); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}}); + auto l1 = mm->add_instruction(migraphx::make_op("squeeze"), l0); + mm->add_return({l1}); + + auto prog = migraphx::parse_onnx("squeeze_empty_axes_test.onnx"); EXPECT(p == prog); } @@ -1012,12 +4988,14 @@ TEST_CASE(squeeze_unsqueeze_test) TEST_CASE(sub_bcast_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); - auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1); - p.add_instruction(migraphx::op::sub{}, l0, l2); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); + auto l2 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", l0->get_shape().lens()}}), l1); + mm->add_instruction(migraphx::make_op("sub"), l0, l2); - auto prog = migraphx::parse_onnx("sub_bcast_test.onnx"); + auto prog = optimize_onnx("sub_bcast_test.onnx"); EXPECT(p == prog); } @@ -1025,59 +5003,293 @@ TEST_CASE(sub_bcast_test) TEST_CASE(sub_scalar_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); - auto l1 = - p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}}); - auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0); - auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); - p.add_instruction(migraphx::op::sub{}, m0, m1); - auto prog = migraphx::parse_onnx("sub_scalar_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); + auto l1 = mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}}); + auto m1 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1); + mm->add_instruction(migraphx::make_op("sub"), l0, m1); + auto prog = optimize_onnx("sub_scalar_test.onnx"); + + EXPECT(p == prog); +} +TEST_CASE(sum_int_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::int16_type, {3}}); + auto input1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::uint16_type, {3}}); + auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::uint32_type, {3}}); + auto cin0 = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::uint32_type)}}), + input0); + auto cin1 = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::uint32_type)}}), + input1); + auto l0 = mm->add_instruction(migraphx::make_op("add"), cin0, cin1); + mm->add_instruction(migraphx::make_op("add"), l0, input2); + + auto prog = optimize_onnx("sum_int_test.onnx"); EXPECT(p == prog); } TEST_CASE(sum_test) { migraphx::program p; - auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); - auto input1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}}); - auto input2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}}); - auto l0 = p.add_instruction(migraphx::op::add{}, input0, input1); - p.add_instruction(migraphx::op::add{}, l0, input2); + auto* mm = p.get_main_module(); + auto input0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); + auto input1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}}); + auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}}); + auto l0 = mm->add_instruction(migraphx::make_op("add"), input0, input1); + mm->add_instruction(migraphx::make_op("add"), l0, input2); + + auto prog = optimize_onnx("sum_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(sum_type_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l_bool = mm->add_literal({migraphx::shape{migraphx::shape::bool_type, {2}}, {1, 0}}); + auto l_int8 = mm->add_literal({migraphx::shape{migraphx::shape::int8_type, {2}}, {1, 1}}); + auto l_uint8 = mm->add_literal({migraphx::shape{migraphx::shape::uint8_type, {2}}, {1, 1}}); + auto l_uint16 = mm->add_literal({migraphx::shape{migraphx::shape::uint16_type, {2}}, {1, 1}}); + auto l_uint32 = mm->add_literal({migraphx::shape{migraphx::shape::uint32_type, {2}}, {1, 1}}); + auto l_uint64 = mm->add_literal({migraphx::shape{migraphx::shape::uint64_type, {2}}, {1, 1}}); + auto l_double = mm->add_literal({migraphx::shape{migraphx::shape::double_type, {2}}, {1, 1}}); + auto l_raw = mm->add_literal({migraphx::shape{migraphx::shape::double_type, {2}}, {1.5, 2.0}}); + auto o_bool = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::double_type)}}), + l_bool); + auto o_int8 = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::double_type)}}), + l_int8); + auto o_uint8 = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::double_type)}}), + l_uint8); + auto o_uint16 = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::double_type)}}), + l_uint16); + auto o_uint32 = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::double_type)}}), + l_uint32); + auto o_uint64 = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::double_type)}}), + l_uint64); + auto s0 = mm->add_instruction(migraphx::make_op("add"), o_bool, o_int8); + auto s1 = mm->add_instruction(migraphx::make_op("add"), s0, o_uint8); + auto s2 = mm->add_instruction(migraphx::make_op("add"), s1, o_uint16); + auto s3 = mm->add_instruction(migraphx::make_op("add"), s2, o_uint32); + auto s4 = mm->add_instruction(migraphx::make_op("add"), s3, o_uint64); + auto s5 = mm->add_instruction(migraphx::make_op("add"), s4, l_double); + auto s6 = mm->add_instruction(migraphx::make_op("add"), s5, l_raw); + mm->add_return({s6}); + + auto prog = migraphx::parse_onnx("sum_type_test.onnx"); - auto prog = migraphx::parse_onnx("sum_test.onnx"); EXPECT(p == prog); } TEST_CASE(tan_test) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); - p.add_instruction(migraphx::op::tan{}, input); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); + mm->add_instruction(migraphx::make_op("tan"), input); - auto prog = migraphx::parse_onnx("tan_test.onnx"); + auto prog = optimize_onnx("tan_test.onnx"); EXPECT(p == prog); } TEST_CASE(tanh_test) { migraphx::program p; - auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}}); - p.add_instruction(migraphx::op::tanh{}, input); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}}); + mm->add_instruction(migraphx::make_op("tanh"), input); + + auto prog = optimize_onnx("tanh_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(thresholdedrelu_default_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 3}}); + auto lz = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {0}}); + auto la = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {1.0f}}); + auto mbz = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lz); + auto mba = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), la); + auto condition = mm->add_instruction(migraphx::make_op("greater"), x, mba); + mm->add_instruction(migraphx::make_op("where"), condition, x, mbz); + + auto prog = optimize_onnx("thresholdedrelu_default_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(thresholdedrelu_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 3}}); + auto lz = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {0}}); + auto la = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {3.0f}}); + auto mbz = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lz); + auto mba = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), la); + auto condition = mm->add_instruction(migraphx::make_op("greater"), x, mba); + mm->add_instruction(migraphx::make_op("where"), condition, x, mbz); + + auto prog = optimize_onnx("thresholdedrelu_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(thresholdedrelu_int_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2, 3}}); + auto lz = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {0}}); + auto la = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {3}}); + auto mbz = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lz); + auto mba = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), la); + auto condition = mm->add_instruction(migraphx::make_op("greater"), x, mba); + mm->add_instruction(migraphx::make_op("where"), condition, x, mbz); + + auto prog = optimize_onnx("thresholdedrelu_int_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(tile_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, {1, 2}}); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), input, input); + + auto prog = optimize_onnx("tile_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(tile_test_3x2) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, {3, 2}}); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto l0 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), input, input); + auto l1 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), l0, input); + mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l1); + + auto prog = optimize_onnx("tile_test_3x2.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(transpose_default_perm_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 3}}); + std::vector perm{3, 2, 1, 0}; + auto r = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input); + mm->add_return({r}); - auto prog = migraphx::parse_onnx("tanh_test.onnx"); + auto prog = migraphx::parse_onnx("transpose_default_perm_test.onnx"); EXPECT(p == prog); } +TEST_CASE(transpose_invalid_perm_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("transpose_invalid_perm_test.onnx"); })); +} + TEST_CASE(transpose_test) { migraphx::program p; - auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); std::vector perm{0, 3, 1, 2}; - p.add_instruction(migraphx::op::transpose{perm}, input); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input); + + auto prog = optimize_onnx("transpose_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(topk_attrk_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 5, 3, 2}}; + auto data = mm->add_parameter("data", s); + auto out = mm->add_instruction(migraphx::make_op("topk", {{"k", 2}, {"axis", -1}}), data); + auto val = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out); + auto ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out); + mm->add_return({val, ind}); + + auto prog = migraphx::parse_onnx("topk_attrk_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(topk_neg_axis_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sk{migraphx::shape::int64_type, {1}}; + mm->add_literal(migraphx::literal(sk, {3})); + migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}}; + auto data = mm->add_parameter("data", s); + auto out = mm->add_instruction( + migraphx::make_op("topk", {{"k", 3}, {"axis", -2}, {"largest", 1}}), data); + auto val = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out); + auto ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out); + mm->add_return({val, ind}); + + auto prog = migraphx::parse_onnx("topk_neg_axis_test.onnx"); + + EXPECT(p == prog); +} - auto prog = migraphx::parse_onnx("transpose_test.onnx"); +TEST_CASE(topk_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sk{migraphx::shape::int64_type, {1}}; + mm->add_literal(migraphx::literal(sk, {4})); + migraphx::shape s{migraphx::shape::float_type, {2, 5, 3, 2}}; + auto data = mm->add_parameter("data", s); + auto out = mm->add_instruction( + migraphx::make_op("topk", {{"k", 4}, {"axis", 1}, {"largest", 0}}), data); + auto val = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out); + auto ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out); + mm->add_return({val, ind}); + + auto prog = migraphx::parse_onnx("topk_test.onnx"); EXPECT(p == prog); } @@ -1085,25 +5297,44 @@ TEST_CASE(transpose_test) TEST_CASE(transpose_gather_test) { migraphx::program p; - auto make_contiguous = [&p](migraphx::instruction_ref ins) { + auto* mm = p.get_main_module(); + auto make_contiguous = [&mm](migraphx::instruction_ref ins) { if(ins->get_shape().standard()) { return ins; } - return p.add_instruction(migraphx::op::contiguous{}, ins); + return mm->add_instruction(migraphx::make_op("contiguous"), ins); }; - auto data = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 5, 4, 6}}); + auto data = + mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 5, 4, 6}}); auto ind = - p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 4, 3, 5}}); - auto tr_data = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, data); - auto tr_ind = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, ind); - int axis = 1; - p.add_instruction( - migraphx::op::gather{axis}, make_contiguous(tr_data), make_contiguous(tr_ind)); + mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 4, 3, 5}}); + auto tr_data = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), data); + auto tr_ind = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), ind); + int axis = 1; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), + make_contiguous(tr_data), + make_contiguous(tr_ind)); + + auto prog = optimize_onnx("transpose_gather_test.onnx"); + + EXPECT(p.sort() == prog.sort()); +} - auto prog = migraphx::parse_onnx("transpose_gather_test.onnx"); +TEST_CASE(undefined_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); + auto l1 = mm->add_instruction(migraphx::make_op("undefined")); + auto l2 = mm->add_instruction(migraphx::make_op("identity"), l1); + mm->add_return({l2}); + + auto prog = migraphx::parse_onnx("undefined_test.onnx"); EXPECT(p == prog); } @@ -1111,11 +5342,121 @@ TEST_CASE(transpose_gather_test) TEST_CASE(unknown_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); - auto l2 = p.add_instruction(migraphx::op::unknown{"Unknown"}, l0, l1); - p.add_instruction(migraphx::op::unknown{"Unknown"}, l2); - auto prog = migraphx::parse_onnx("unknown_test.onnx"); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); + auto l2 = mm->add_instruction(migraphx::op::unknown{"Unknown"}, l0, l1); + mm->add_instruction(migraphx::op::unknown{"Unknown"}, l2); + auto prog = optimize_onnx("unknown_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(unknown_aten_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("unknown_aten_test.onnx"); })); +} + +TEST_CASE(unknown_test_throw) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("unknown_test.onnx"); })); +} + +TEST_CASE(upsample_linear_test) +{ + auto p = create_upsample_linear_prog(); + auto prog = migraphx::parse_onnx("upsample_linear_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(upsample_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape ss{migraphx::shape::float_type, {4}}; + mm->add_literal(migraphx::literal(ss, {1.0f, 1.0f, 2.0f, 3.0f})); + + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; + auto ix = mm->add_parameter("X", sx); + + migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}}; + std::vector ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}; + + auto li = mm->add_literal(migraphx::literal(si, ind)); + auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), ix); + auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, li); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("upsample_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(unknown_test_throw_print_error) +{ + migraphx::onnx_options options; + options.print_program_on_error = true; + EXPECT(test::throws([&] { migraphx::parse_onnx("unknown_test.onnx", options); })); +} + +TEST_CASE(variable_batch_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_instruction(migraphx::make_op("identity"), l0); + auto prog = optimize_onnx("variable_batch_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(variable_batch_user_input_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 16, 16}}); + auto r = mm->add_instruction(migraphx::make_op("identity"), l0); + mm->add_return({r}); + + migraphx::onnx_options options; + options.default_dim_value = 2; + + auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options); + + EXPECT(p == prog); +} + +TEST_CASE(variable_batch_leq_zero_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_instruction(migraphx::make_op("add"), l0, l1); + auto prog = optimize_onnx("variable_batch_leq_zero_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(where_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto lc = mm->add_parameter("c", migraphx::shape{migraphx::shape::bool_type, {2}}); + auto lx = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); + auto ly = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 1, 2, 2}}); + + auto lccm = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), lc); + auto lxm = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), lx); + auto lym = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), ly); + + auto r = mm->add_instruction(migraphx::make_op("where"), lccm, lxm, lym); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("where_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/pad_3arg_test.onnx b/test/onnx/pad_3arg_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e5db8cb1299e06f3860e791b8197b8964c16bb57 Binary files /dev/null and b/test/onnx/pad_3arg_test.onnx differ diff --git a/test/onnx/pad_reflect_multiaxis_test.onnx b/test/onnx/pad_reflect_multiaxis_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d425a108017917413fe7d08c0828f3bb0b47f7fa Binary files /dev/null and b/test/onnx/pad_reflect_multiaxis_test.onnx differ diff --git a/test/onnx/pad_reflect_test.onnx b/test/onnx/pad_reflect_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c25f00f0cf2c59a46c74d44f1afd0ccaaa9f4284 Binary files /dev/null and b/test/onnx/pad_reflect_test.onnx differ diff --git a/test/onnx/pow_fp32_i64_test.onnx b/test/onnx/pow_fp32_i64_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e273bb6187935b8bf472db1044c8dc4623283821 --- /dev/null +++ b/test/onnx/pow_fp32_i64_test.onnx @@ -0,0 +1,22 @@ +pow_fp32_i64_test:~ + +0 +1out"Powpow_fp32_i64_testZ +0 + + + + +Z +1 + + + + +b +out + + + + +B \ No newline at end of file diff --git a/test/onnx/pow_i64_fp32_test.onnx b/test/onnx/pow_i64_fp32_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..927141df7f33c5b1fb7f339a630635f00312d06d --- /dev/null +++ b/test/onnx/pow_i64_fp32_test.onnx @@ -0,0 +1,22 @@ +pow_i64_fp32_test:~ + +0 +1out"Powpow_i64_fp32_testZ +0 + + + + +Z +1 + + + + +b +out + + + + +B \ No newline at end of file diff --git a/test/onnx/prefix_scan_sum_test.onnx b/test/onnx/prefix_scan_sum_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c1aaa2d8f2ea4709b9861383b6b2ceb5f21c7f41 Binary files /dev/null and b/test/onnx/prefix_scan_sum_test.onnx differ diff --git a/test/onnx/prelu_brcst_test.onnx b/test/onnx/prelu_brcst_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8534c38ce470bd6416e957d16d06f46ddce20539 --- /dev/null +++ b/test/onnx/prelu_brcst_test.onnx @@ -0,0 +1,20 @@ +prelu_brcst_test:w + +0 +1out"PReluprelu_brcst_testZ +0 + + + + +Z +1 +  + +b +out + + + + +B \ No newline at end of file diff --git a/test/onnx/quantizelinear_axis_test.onnx b/test/onnx/quantizelinear_axis_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..445a6eca287b94b2f7db7483c2af98150dbb3fee --- /dev/null +++ b/test/onnx/quantizelinear_axis_test.onnx @@ -0,0 +1,26 @@ +quantizelinear_axis_test:¥ ++ +0 +1 +2out"QuantizeLinear* +axis quantizelinear_axis_testZ +0 + + + + +Z +1 + + +Z +2 + + +b +out + + + + +B \ No newline at end of file diff --git a/test/onnx/quantizelinear_int32_test.onnx b/test/onnx/quantizelinear_int32_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4db80d9d1d2f9d7ee15f31dab74b4a0c02a97d66 --- /dev/null +++ b/test/onnx/quantizelinear_int32_test.onnx @@ -0,0 +1,16 @@ +quantizelinear_int32_test:m + +0 +1out"QuantizeLinearquantizelinear_int32_testZ +0 + + +Z +1 + + +b +out + + +B \ No newline at end of file diff --git a/test/onnx/quantizelinear_neg_axis_test.onnx b/test/onnx/quantizelinear_neg_axis_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..396bce8a65d07cc4c137b7ee5b05fb4fcd4ce8eb --- /dev/null +++ b/test/onnx/quantizelinear_neg_axis_test.onnx @@ -0,0 +1,26 @@ +quantizelinear_neg_axis_test:² +4 +0 +1 +2out"QuantizeLinear* +axisþÿÿÿÿÿÿÿÿ quantizelinear_neg_axis_testZ +0 + + + + +Z +1 + + +Z +2 + + +b +out + + + + +B \ No newline at end of file diff --git a/test/onnx/quantizelinear_test.onnx b/test/onnx/quantizelinear_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ce4f3b427e5a384c36478650de8c2e431bd638d5 --- /dev/null +++ b/test/onnx/quantizelinear_test.onnx @@ -0,0 +1,16 @@ +quantizelinear_test:g + +0 +1out"QuantizeLinearquantizelinear_testZ +0 + + +Z +1 + + +b +out + + +B \ No newline at end of file diff --git a/test/onnx/quantizelinear_zero_point_test.onnx b/test/onnx/quantizelinear_zero_point_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..2aa5eb22f4e52aa6071caac6c570da8aeaf08332 --- /dev/null +++ b/test/onnx/quantizelinear_zero_point_test.onnx @@ -0,0 +1,21 @@ +quantizelinear_zero_point_test:† + +0 +1 +2out"QuantizeLinearquantizelinear_zero_point_testZ +0 + + +Z +1 + + +Z +2 + + +b +out + + +B \ No newline at end of file diff --git a/test/onnx/randomnormal_dtype_error_test.onnx b/test/onnx/randomnormal_dtype_error_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..283b80a917c9db3df77fc5df9df6bddcf0c58fd0 --- /dev/null +++ b/test/onnx/randomnormal_dtype_error_test.onnx @@ -0,0 +1,9 @@ +randomnormal_dtype_error_test:u +6output" RandomNormal* +dtype * +shape@@@ randomnormal_dtype_error_testb +output + + + +B \ No newline at end of file diff --git a/test/onnx/randomnormal_generated_seed_test.onnx b/test/onnx/randomnormal_generated_seed_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..42d1a036735c657cc15328bf96bc70e144378f53 --- /dev/null +++ b/test/onnx/randomnormal_generated_seed_test.onnx @@ -0,0 +1,15 @@ + randomnormal_generated_seed_test:ˆ +1 +inputoutput" RandomNormal* + sample_size +  randomnormal_generated_seed_testZ +input +  + + +b +output +  + + +B \ No newline at end of file diff --git a/test/onnx/randomnormal_shape_error_test.onnx b/test/onnx/randomnormal_shape_error_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..252c0542329c164c63bf6f5422373eaf6cd900ca --- /dev/null +++ b/test/onnx/randomnormal_shape_error_test.onnx @@ -0,0 +1,8 @@ +randomnormal_shape_error_test:c +$output" RandomNormal* +dtype randomnormal_shape_error_testb +output + + + +B \ No newline at end of file diff --git a/test/onnx/randomnormal_test.onnx b/test/onnx/randomnormal_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5df7409746c55fbf0f9c92917274472e1534f1ad Binary files /dev/null and b/test/onnx/randomnormal_test.onnx differ diff --git a/test/onnx/randomnormallike_dtype_fallback_test.onnx b/test/onnx/randomnormallike_dtype_fallback_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5f3ea49b80139615d7c331e22e752c54391d8e8d Binary files /dev/null and b/test/onnx/randomnormallike_dtype_fallback_test.onnx differ diff --git a/test/onnx/randomnormallike_test.onnx b/test/onnx/randomnormallike_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..864dcff1a20ca93f58da0b834739103036a9fe4c Binary files /dev/null and b/test/onnx/randomnormallike_test.onnx differ diff --git a/test/onnx/randomnormallike_type_error_test.onnx b/test/onnx/randomnormallike_type_error_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7d489d30621f56f5bab686082b02b7cab3934cd3 Binary files /dev/null and b/test/onnx/randomnormallike_type_error_test.onnx differ diff --git a/test/onnx/randomuniform_dtype_error_test.onnx b/test/onnx/randomuniform_dtype_error_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e59d368c61d47a390f60d4810b290fc944bc7acd --- /dev/null +++ b/test/onnx/randomuniform_dtype_error_test.onnx @@ -0,0 +1,9 @@ +randomuniform_dtype_error_test:w +7output" RandomUniform* +dtype * +shape@@@ randomuniform_dtype_error_testb +output + + + +B \ No newline at end of file diff --git a/test/onnx/randomuniform_generated_seed_test.onnx b/test/onnx/randomuniform_generated_seed_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..24d964fe481fa6892c2656071a5f4c7512bee711 --- /dev/null +++ b/test/onnx/randomuniform_generated_seed_test.onnx @@ -0,0 +1,15 @@ +!randomuniform_generated_seed_test:Š +2 +inputoutput" RandomUniform* + sample_size + !randomuniform_generated_seed_testZ +input +  + + +b +output +  + + +B \ No newline at end of file diff --git a/test/onnx/randomuniform_shape_error_test.onnx b/test/onnx/randomuniform_shape_error_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..053c4afaf771666edd45c02827b739f038368c9a --- /dev/null +++ b/test/onnx/randomuniform_shape_error_test.onnx @@ -0,0 +1,8 @@ +randomuniform_shape_error_test:e +%output" RandomUniform* +dtype randomuniform_shape_error_testb +output + + + +B \ No newline at end of file diff --git a/test/onnx/randomuniform_test.onnx b/test/onnx/randomuniform_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fa641e227a339335c704c97d31d8e75cf47cc938 Binary files /dev/null and b/test/onnx/randomuniform_test.onnx differ diff --git a/test/onnx/randomuniformlike_dtype_fallback_test.onnx b/test/onnx/randomuniformlike_dtype_fallback_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..36393e2f2bbb4f6f8f9d6fc7795f75e2a1851b2b Binary files /dev/null and b/test/onnx/randomuniformlike_dtype_fallback_test.onnx differ diff --git a/test/onnx/randomuniformlike_test.onnx b/test/onnx/randomuniformlike_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5fd746027f28b87a20d90c46e4506d9e110ead97 Binary files /dev/null and b/test/onnx/randomuniformlike_test.onnx differ diff --git a/test/onnx/randomuniformlike_type_error_test.onnx b/test/onnx/randomuniformlike_type_error_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d3cc2c8bc8fc2efcc13fd1d1447c211e6bbd5fed Binary files /dev/null and b/test/onnx/randomuniformlike_type_error_test.onnx differ diff --git a/test/onnx/range_float_test.onnx b/test/onnx/range_float_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa2eadc5779cca68065c2fae57a2d90d7fd9c0f9 Binary files /dev/null and b/test/onnx/range_float_test.onnx differ diff --git a/test/onnx/range_test.onnx b/test/onnx/range_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e14647e5dbbc53873373c3642574a173d30ff0e1 --- /dev/null +++ b/test/onnx/range_test.onnx @@ -0,0 +1,19 @@ + +range_test:Ú +/start"Constant* +value*: +B start_val  +/limit"Constant* +value*:B limit_val  +8delta"Constant*% +value*: +ýÿÿÿÿÿÿÿÿB delta_val  + +start +limit +delta1"Range +range_testb +1 + + +B \ No newline at end of file diff --git a/test/onnx/recip_test.onnx b/test/onnx/recip_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..65f01dde0c83c91445c572c2a5f71c47ac28ff7c --- /dev/null +++ b/test/onnx/recip_test.onnx @@ -0,0 +1,14 @@ + +recip_test:B + +xy" +Reciprocal +recip_testZ +x + + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/reduce_log_sum_exp_test.onnx b/test/onnx/reduce_log_sum_exp_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7c6e411b1b996cf0bc85c1909a6041a47d6b25c3 --- /dev/null +++ b/test/onnx/reduce_log_sum_exp_test.onnx @@ -0,0 +1,16 @@ +reduce_log_sum_exp_test: +> +xy"ReduceLogSumExp* +axes@üÿÿÿÿÿÿÿÿ * +keepdims reduce_log_sum_exp_testZ +x + + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/reduce_log_sum_test.onnx b/test/onnx/reduce_log_sum_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..73459d9dc22a979bb79f0a2fc0fe385ed11b5233 --- /dev/null +++ b/test/onnx/reduce_log_sum_test.onnx @@ -0,0 +1,17 @@ +reduce_log_sum_test:Œ +; +xy" ReduceLogSum* +axes@ýÿÿÿÿÿÿÿÿ * +keepdims reduce_log_sum_testZ +x + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/reducel1_test.onnx b/test/onnx/reducel1_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b72af05cfdffc9e15748dbc2e3f0525f16ac3176 Binary files /dev/null and b/test/onnx/reducel1_test.onnx differ diff --git a/test/onnx/reducel2_test.onnx b/test/onnx/reducel2_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3d596f5a6bb967d9c9e137e226af18766a31edca Binary files /dev/null and b/test/onnx/reducel2_test.onnx differ diff --git a/test/onnx/reduceprod_test.onnx b/test/onnx/reduceprod_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e51f5fbd1da62964e62fa8ecbe6229241944899b --- /dev/null +++ b/test/onnx/reduceprod_test.onnx @@ -0,0 +1,18 @@ +reduceprod_test:} +0 +xy" +ReduceProd* +axes@ * +keepdims reduceprod_testZ +x + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/reducesum_empty_axes_test.onnx b/test/onnx/reducesum_empty_axes_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..74c8a118f348fd3871c43f4b54ca75562235094f Binary files /dev/null and b/test/onnx/reducesum_empty_axes_test.onnx differ diff --git a/test/onnx/reducesum_noop_test.onnx b/test/onnx/reducesum_noop_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a0f1f812b1cd3c372e37bfee895edb576f8a7040 Binary files /dev/null and b/test/onnx/reducesum_noop_test.onnx differ diff --git a/test/onnx/reducesum_square_test.onnx b/test/onnx/reducesum_square_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa921999a341607354ab10c0090deca3e86c7498 Binary files /dev/null and b/test/onnx/reducesum_square_test.onnx differ diff --git a/test/onnx/reducesum_test.onnx b/test/onnx/reducesum_test.onnx index e1ad0e823fb47d914c110f846d7f7003fed2534b..73e6c53d0756e9181cb025b3ff1f565980468650 100644 Binary files a/test/onnx/reducesum_test.onnx and b/test/onnx/reducesum_test.onnx differ diff --git a/test/onnx/reshape_non_standard_test.onnx b/test/onnx/reshape_non_standard_test.onnx index 606286cdaacf5199372d0d1a5ae135342274c8af..d35feb5bbaf2931d077ad3a0e03a1144b28eb7c5 100644 Binary files a/test/onnx/reshape_non_standard_test.onnx and b/test/onnx/reshape_non_standard_test.onnx differ diff --git a/test/onnx/resize_downsample_c_test.onnx b/test/onnx/resize_downsample_c_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d659ca25e9f6cbf7fd2c2e8e5d5c16121249200c Binary files /dev/null and b/test/onnx/resize_downsample_c_test.onnx differ diff --git a/test/onnx/resize_downsample_f_test.onnx b/test/onnx/resize_downsample_f_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5cf4c227b02a3e6b00c89bbdf6b3730c8807eaeb Binary files /dev/null and b/test/onnx/resize_downsample_f_test.onnx differ diff --git a/test/onnx/resize_downsample_linear_test.onnx b/test/onnx/resize_downsample_linear_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e566118bab880b12eef138eae2f88f7adf928867 Binary files /dev/null and b/test/onnx/resize_downsample_linear_test.onnx differ diff --git a/test/onnx/resize_nonstd_input_test.onnx b/test/onnx/resize_nonstd_input_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..821d3ebef5c02725127224130959f04562086423 Binary files /dev/null and b/test/onnx/resize_nonstd_input_test.onnx differ diff --git a/test/onnx/resize_outsize_test.onnx b/test/onnx/resize_outsize_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f39d8b4cf99e9973956a1ba46cd3501dbe1b6f5b Binary files /dev/null and b/test/onnx/resize_outsize_test.onnx differ diff --git a/test/onnx/resize_upsample_linear_ac_test.onnx b/test/onnx/resize_upsample_linear_ac_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..27ab414fc76a3d4f094986681ced90d89ec4aa33 Binary files /dev/null and b/test/onnx/resize_upsample_linear_ac_test.onnx differ diff --git a/test/onnx/resize_upsample_linear_test.onnx b/test/onnx/resize_upsample_linear_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..80e5ebc0d99a8dc2665c86b42a5c6f63b4c8c6ea Binary files /dev/null and b/test/onnx/resize_upsample_linear_test.onnx differ diff --git a/test/onnx/resize_upsample_pc_test.onnx b/test/onnx/resize_upsample_pc_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0415d30ad4ede5277ec83631ebb817b9e81a5659 Binary files /dev/null and b/test/onnx/resize_upsample_pc_test.onnx differ diff --git a/test/onnx/resize_upsample_pf_test.onnx b/test/onnx/resize_upsample_pf_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0611aedc8fef8573d620992c7f72bdfc4cf03aa9 Binary files /dev/null and b/test/onnx/resize_upsample_pf_test.onnx differ diff --git a/test/onnx/reversesequence_4D_test.onnx b/test/onnx/reversesequence_4D_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7fd4e02f35ed2f2403344366337f95885a142b79 Binary files /dev/null and b/test/onnx/reversesequence_4D_test.onnx differ diff --git a/test/onnx/reversesequence_batch_axis_err_test.onnx b/test/onnx/reversesequence_batch_axis_err_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4c80e8e3679842f71b7c7249bd0aef5de3d38c80 Binary files /dev/null and b/test/onnx/reversesequence_batch_axis_err_test.onnx differ diff --git a/test/onnx/reversesequence_batch_test.onnx b/test/onnx/reversesequence_batch_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..eaf5f42e5e5382f932bd5a3ea09653d80c63c5a8 Binary files /dev/null and b/test/onnx/reversesequence_batch_test.onnx differ diff --git a/test/onnx/reversesequence_rank_err_test.onnx b/test/onnx/reversesequence_rank_err_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0aec563af40f1eec0c64e89dc8f18b660f18016e --- /dev/null +++ b/test/onnx/reversesequence_rank_err_test.onnx @@ -0,0 +1,12 @@ +reversesequence_rank_err_test:v +3 +xy"ReverseSequence* + sequence_lens@@@@ reversesequence_rank_err_testZ +x + + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/reversesequence_same_axis_err_test.onnx b/test/onnx/reversesequence_same_axis_err_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..10cd74336b0b8453961cb1c15e6a7520b5b72d4f --- /dev/null +++ b/test/onnx/reversesequence_same_axis_err_test.onnx @@ -0,0 +1,15 @@ +"reversesequence_same_axis_err_test:¨ +X +xy"ReverseSequence* + +batch_axis * + sequence_lens@@@@ * + time_axis "reversesequence_same_axis_err_testZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/reversesequence_sequence_lens_shape_err_test.onnx b/test/onnx/reversesequence_sequence_lens_shape_err_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3f08cf6a406e8be548d0f15ac0a82bb29d70a40b --- /dev/null +++ b/test/onnx/reversesequence_sequence_lens_shape_err_test.onnx @@ -0,0 +1,12 @@ +,reversesequence_sequence_lens_shape_err_test:‹ +1 +xy"ReverseSequence* + sequence_lens@@@ ,reversesequence_sequence_lens_shape_err_testZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/reversesequence_time_axis_err_test.onnx b/test/onnx/reversesequence_time_axis_err_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5768922d8283580043cc8b69b98a01757220f3a2 Binary files /dev/null and b/test/onnx/reversesequence_time_axis_err_test.onnx differ diff --git a/test/onnx/reversesequence_time_test.onnx b/test/onnx/reversesequence_time_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..bd904da891d9896d01e70d0e8741083d90d76f6f Binary files /dev/null and b/test/onnx/reversesequence_time_test.onnx differ diff --git a/test/onnx/roialign_default_test.onnx b/test/onnx/roialign_default_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4421e17be6015bd1c6f6af7b64eee39f43785865 --- /dev/null +++ b/test/onnx/roialign_default_test.onnx @@ -0,0 +1,26 @@ +roialign_default_test:¥ +! +x +rois + batch_indy"RoiAlignroialign_default_testZ +x + + + + + +Z +rois +  + +Z + batch_ind + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/roialign_test.onnx b/test/onnx/roialign_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f39485530c4758b3fadd6c7d5fe5ad180cc75a73 Binary files /dev/null and b/test/onnx/roialign_test.onnx differ diff --git a/test/onnx/scatter_add_test.onnx b/test/onnx/scatter_add_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0faeaec3b46cbbbe46c6f39d267c32d86c25a4db --- /dev/null +++ b/test/onnx/scatter_add_test.onnx @@ -0,0 +1,31 @@ +scatter_add_test:ì +V +data +indices +updatey"ScatterElements* +axisþÿÿÿÿÿÿÿÿ * + reduction"add scatter_add_testZ +data + + + + +Z! +indices + + + + +Z +update + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/scatter_mul_test.onnx b/test/onnx/scatter_mul_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0022932b2db1c5878b875108811dfdc59f0c92f9 --- /dev/null +++ b/test/onnx/scatter_mul_test.onnx @@ -0,0 +1,31 @@ +scatter_mul_test:ì +V +data +indices +updatey"ScatterElements* +axisþÿÿÿÿÿÿÿÿ * + reduction"mul scatter_mul_testZ +data + + + + +Z! +indices + + + + +Z +update + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/scatter_none_test.onnx b/test/onnx/scatter_none_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0278a3b2383b507266237b748e11cb73ad6b216d --- /dev/null +++ b/test/onnx/scatter_none_test.onnx @@ -0,0 +1,31 @@ +scatter_none_test:î +W +data +indices +updatey"ScatterElements* +axisþÿÿÿÿÿÿÿÿ * + reduction"none scatter_none_testZ +data + + + + +Z! +indices + + + + +Z +update + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/scatternd_add_test.onnx b/test/onnx/scatternd_add_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..17d79a8407da8a8a0823182d2f7d0352df21aabf --- /dev/null +++ b/test/onnx/scatternd_add_test.onnx @@ -0,0 +1,26 @@ +scatternd_add_test:Î +@ +data +indices +updatesoutput" ScatterND* + reduction"add scatternd_add_testZ +data + + + +Z +indices + + + +Z +updates + + + +b +output + + + +B \ No newline at end of file diff --git a/test/onnx/scatternd_mul_test.onnx b/test/onnx/scatternd_mul_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a5a101003469f8b7599939cf56478a01851ff628 --- /dev/null +++ b/test/onnx/scatternd_mul_test.onnx @@ -0,0 +1,26 @@ +scatternd_mul_test:Î +@ +data +indices +updatesoutput" ScatterND* + reduction"mul scatternd_mul_testZ +data + + + +Z +indices + + + +Z +updates + + + +b +output + + + +B \ No newline at end of file diff --git a/test/onnx/scatternd_test.onnx b/test/onnx/scatternd_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..496602d9cd2f727dc677eb87619a845b65fe2638 --- /dev/null +++ b/test/onnx/scatternd_test.onnx @@ -0,0 +1,25 @@ +scatternd_test:µ ++ +data +indices +updatesoutput" ScatterNDscatternd_testZ +data + + + +Z +indices + + + +Z +updates + + + +b +output + + + +B \ No newline at end of file diff --git a/test/onnx/selu_test.onnx b/test/onnx/selu_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..25c705cd2edec8fb0ec5a9b0b33c1b7ab3384bdf Binary files /dev/null and b/test/onnx/selu_test.onnx differ diff --git a/test/onnx/shape_gather_test.onnx b/test/onnx/shape_gather_test.onnx index ad48baaec2648c1f7e0e27d86f1b2bbdac4a6de7..22c35d5155d1fc16450108bae801f4d49b7cec93 100644 Binary files a/test/onnx/shape_gather_test.onnx and b/test/onnx/shape_gather_test.onnx differ diff --git a/test/onnx/size_float_test.onnx b/test/onnx/size_float_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e4f8387875b133b914bdafaeae9663a8ad17dde4 --- /dev/null +++ b/test/onnx/size_float_test.onnx @@ -0,0 +1,12 @@ +size_float_test:I + +xy"Sizesize_float_testZ +x + + + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/size_half_test.onnx b/test/onnx/size_half_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c853f0f388bf3c7461456da984c33559263ca0c1 --- /dev/null +++ b/test/onnx/size_half_test.onnx @@ -0,0 +1,12 @@ +size_half_test:D + +xy"Sizesize_half_testZ +x +  + + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/size_int_test.onnx b/test/onnx/size_int_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a7b892b28231a2d6ff3be289c3e3fdeaec1a871b --- /dev/null +++ b/test/onnx/size_int_test.onnx @@ -0,0 +1,12 @@ + size_int_test:G + +xy"Size size_int_testZ +x + + + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/size_verify_test.onnx b/test/onnx/size_verify_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..44ffe3a3dc72a915895373e6e261f95372f7104b --- /dev/null +++ b/test/onnx/size_verify_test.onnx @@ -0,0 +1,12 @@ +size_verify_test:J + +xy"Sizesize_verify_testZ +x + + + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/slice_3arg_test.onnx b/test/onnx/slice_3arg_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d829ad6483dcc90cfff261350d9deebace997f9b Binary files /dev/null and b/test/onnx/slice_3arg_test.onnx differ diff --git a/test/onnx/slice_5arg_reverse_test.onnx b/test/onnx/slice_5arg_reverse_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..01e8ab0d8489716ee3761223813ed81523059fb3 --- /dev/null +++ b/test/onnx/slice_5arg_reverse_test.onnx @@ -0,0 +1,23 @@ +slice_5arg_reverse_test: +9arg_step"Constant*# +value** ÿÿÿÿÿÿÿÿÿBstep  +Barg_axis"Constant*, +value* *ÿÿÿÿÿÿÿÿÿþÿÿÿÿÿÿÿÿBaxis  +@arg_end"Constant*+ +value**ûÿÿÿÿÿÿÿÿÿÿÿÿÿÿÿÿÿBend  +D arg_start"Constant*- +value*!*ÿÿÿÿÿÿÿÿÿýÿÿÿÿÿÿÿÿBstart  +5 +0 + arg_start +arg_end +arg_axis +arg_step1"Sliceslice_5arg_reverse_testZ +0 +  + +b +1 +  + +B \ No newline at end of file diff --git a/test/onnx/slice_5arg_step_test.onnx b/test/onnx/slice_5arg_step_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7aff08abae18d184edf1ff2a77f5490957f24438 --- /dev/null +++ b/test/onnx/slice_5arg_step_test.onnx @@ -0,0 +1,23 @@ +slice_5arg_step_test:þ +9arg_step"Constant*# +value** þÿÿÿÿÿÿÿÿBstep  +Barg_axis"Constant*, +value* *ÿÿÿÿÿÿÿÿÿþÿÿÿÿÿÿÿÿBaxis  +@arg_end"Constant*+ +value**ûÿÿÿÿÿÿÿÿÿÿÿÿÿÿÿÿÿBend  +D arg_start"Constant*- +value*!*ÿÿÿÿÿÿÿÿÿýÿÿÿÿÿÿÿÿBstart  +5 +0 + arg_start +arg_end +arg_axis +arg_step1"Sliceslice_5arg_step_testZ +0 +  + +b +1 +  + +B \ No newline at end of file diff --git a/test/onnx/slice_5arg_test.onnx b/test/onnx/slice_5arg_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f727cac5d0ec12304dbe5397843e96f52451d841 --- /dev/null +++ b/test/onnx/slice_5arg_test.onnx @@ -0,0 +1,23 @@ +slice_5arg_test:ð +0arg_step"Constant* +value**Bstep  +Barg_axis"Constant*, +value* *ÿÿÿÿÿÿÿÿÿþÿÿÿÿÿÿÿÿBaxis  +@arg_end"Constant*+ +value**ÿÿÿÿÿÿÿÿÿÿÿÿÿÿÿÿÿÿBend  +D arg_start"Constant*- +value*!*ûÿÿÿÿÿÿÿÿýÿÿÿÿÿÿÿÿBstart  +5 +0 + arg_start +arg_end +arg_axis +arg_step1"Sliceslice_5arg_testZ +0 +  + +b +1 +  + +B \ No newline at end of file diff --git a/test/onnx/slice_max_end_test.onnx b/test/onnx/slice_max_end_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..95fbe0081983c96a93adbacfc96e9b4b6e9d7702 Binary files /dev/null and b/test/onnx/slice_max_end_test.onnx differ diff --git a/test/onnx/softmax_nonstd_input_test.onnx b/test/onnx/softmax_nonstd_input_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e1a671367563c6fc11b67cc54076d88a87535d9b Binary files /dev/null and b/test/onnx/softmax_nonstd_input_test.onnx differ diff --git a/test/onnx/softplus_nd_test.onnx b/test/onnx/softplus_nd_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f4575d1528e82241143b5c1f6867fd1b9a8af6e9 --- /dev/null +++ b/test/onnx/softplus_nd_test.onnx @@ -0,0 +1,15 @@ +softplus_nd_test:V + +xy"Softplussoftplus_nd_testZ +x + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/softplus_test.onnx b/test/onnx/softplus_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..395da94e819676597c07b0b68e3f949e410ddf00 --- /dev/null +++ b/test/onnx/softplus_test.onnx @@ -0,0 +1,11 @@ + softplus_test:C + +xy"Softplus softplus_testZ +x + + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/softsign_nd_test.onnx b/test/onnx/softsign_nd_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f86c76432019329236806ff844f55ccac5f85e32 --- /dev/null +++ b/test/onnx/softsign_nd_test.onnx @@ -0,0 +1,15 @@ +softsign_nd_test:V + +xy"Softsignsoftsign_nd_testZ +x + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/softsign_test.onnx b/test/onnx/softsign_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b9a5fbbcb4d8d133f82a380879183d0574bf8831 --- /dev/null +++ b/test/onnx/softsign_test.onnx @@ -0,0 +1,11 @@ + softsign_test:C + +xy"Softsign softsign_testZ +x + + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/spacetodepth_invalid_blocksize_test.onnx b/test/onnx/spacetodepth_invalid_blocksize_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..10eaa3d0dd931b509fcb16b52c51f0fefe49ac45 --- /dev/null +++ b/test/onnx/spacetodepth_invalid_blocksize_test.onnx @@ -0,0 +1,16 @@ +#spacetodepth_invalid_blocksize_test:Š +) +xy" SpaceToDepth* + blocksizeš™™> #spacetodepth_invalid_blocksize_testZ +x + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/spacetodepth_nondivisibility_test.onnx b/test/onnx/spacetodepth_nondivisibility_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1af5aac7499c545cc84488fe2e9f1ba015cfab58 --- /dev/null +++ b/test/onnx/spacetodepth_nondivisibility_test.onnx @@ -0,0 +1,16 @@ +!spacetodepth_nondivisibility_test:… +& +xy" SpaceToDepth* + blocksize !spacetodepth_nondivisibility_testZ +x + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/spacetodepth_simple_test.onnx b/test/onnx/spacetodepth_simple_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fb858d2de69caeb8d3bfad7bfe77009bc731aa03 --- /dev/null +++ b/test/onnx/spacetodepth_simple_test.onnx @@ -0,0 +1,16 @@ +spacetodepth_simple_test:| +& +xy" SpaceToDepth* + blocksize spacetodepth_simple_testZ +x + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/spacetodepth_test.onnx b/test/onnx/spacetodepth_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..614c1a87a5017402f1ac769ea2efdb1cdb5ef2a9 --- /dev/null +++ b/test/onnx/spacetodepth_test.onnx @@ -0,0 +1,18 @@ +spacetodepth_test:u +& +xy" SpaceToDepth* + blocksize spacetodepth_testZ +x + + + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/split_minus_axis_test.onnx b/test/onnx/split_minus_axis_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c04a8d43f23a67d0801aecff0e179f4e23d5e6ac --- /dev/null +++ b/test/onnx/split_minus_axis_test.onnx @@ -0,0 +1,24 @@ +split_minus_axis_test:œ +, +xy1y2y3"Split* +axisÿÿÿÿÿÿÿÿÿ split_minus_axis_testZ +x +  + + +b +y1 +  + + +b +y2 +  + + +b +y3 +  + + +B \ No newline at end of file diff --git a/test/onnx/split_test.onnx b/test/onnx/split_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0d157e16df01fbe9271653e97b8180b6300bb8ab --- /dev/null +++ b/test/onnx/split_test.onnx @@ -0,0 +1,27 @@ + +split_test:š +5 +xy1y2y3"Split* +axis * +split@@@  +split_testZ +x +  + + +b +y1 +  + + +b +y2 +  + + +b +y3 +  + + +B \ No newline at end of file diff --git a/test/onnx/split_test_default.onnx b/test/onnx/split_test_default.onnx new file mode 100644 index 0000000000000000000000000000000000000000..530682a86ea4fe1400bd44cbc3735a8f1b0209f1 --- /dev/null +++ b/test/onnx/split_test_default.onnx @@ -0,0 +1,16 @@ +split_test_default:i + +xy1y2"Splitsplit_test_defaultZ +x +  + + +b +y1 +  + +b +y2 +  + +B \ No newline at end of file diff --git a/test/onnx/squeeze_axes_input_test.onnx b/test/onnx/squeeze_axes_input_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9dec528e22b4f1136dec269042d9c2267b88630e --- /dev/null +++ b/test/onnx/squeeze_axes_input_test.onnx @@ -0,0 +1,14 @@ +squeeze_axes_input_test:r + +x +axesy"Squeezesqueeze_axes_input_test*:BaxesZ +x + + + + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/squeeze_empty_axes_test.onnx b/test/onnx/squeeze_empty_axes_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fc657f1fbe793996365d6727bff4561283dd881c Binary files /dev/null and b/test/onnx/squeeze_empty_axes_test.onnx differ diff --git a/test/onnx/squeeze_unsqueeze_test.onnx b/test/onnx/squeeze_unsqueeze_test.onnx index 2841a122248e1d10bb1d8a4826a42f4e3ec153f6..3f170b6fbd13120ba566fa811fe24886097c403c 100644 Binary files a/test/onnx/squeeze_unsqueeze_test.onnx and b/test/onnx/squeeze_unsqueeze_test.onnx differ diff --git a/test/onnx/sub_scalar_test.onnx b/test/onnx/sub_scalar_test.onnx index 19dee1f86f6d5f31ea936003e89b603488da7106..0c20f4a627e1db6378c383351d9ea6afab807531 100644 Binary files a/test/onnx/sub_scalar_test.onnx and b/test/onnx/sub_scalar_test.onnx differ diff --git a/test/onnx/sum_int_test.onnx b/test/onnx/sum_int_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ca0049a28b3c61948097a627bf33be11cc9490d0 --- /dev/null +++ b/test/onnx/sum_int_test.onnx @@ -0,0 +1,27 @@ + sum_int_test:› + +0c0"Cast* +to   + +1c1"Cast* +to   + +c0 +c1 +23"Sum sum_int_testZ +0 + + +Z +1 + + +Z +2 + +  +b +3 + +  +B \ No newline at end of file diff --git a/test/onnx/sum_type_test.onnx b/test/onnx/sum_type_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..2384b4543031246138f9962ab80cea7f6e47ebae Binary files /dev/null and b/test/onnx/sum_type_test.onnx differ diff --git a/test/onnx/thresholdedrelu_default_test.onnx b/test/onnx/thresholdedrelu_default_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ae6e7fe309c7172d1c5cfdd9674b04abbde8bbe8 --- /dev/null +++ b/test/onnx/thresholdedrelu_default_test.onnx @@ -0,0 +1,13 @@ +thresholdedrelu_default_test:i + +xy"ThresholdedReluthresholdedrelu_default_testZ +x + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/thresholdedrelu_int_test.onnx b/test/onnx/thresholdedrelu_int_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d03b0cf07a74b2a672813868ebd0694845c38a9e Binary files /dev/null and b/test/onnx/thresholdedrelu_int_test.onnx differ diff --git a/test/onnx/thresholdedrelu_test.onnx b/test/onnx/thresholdedrelu_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e951fdc6975cd53cc8431893e41b6e909be11783 Binary files /dev/null and b/test/onnx/thresholdedrelu_test.onnx differ diff --git a/test/onnx/tile_test.onnx b/test/onnx/tile_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ce97d992fa816ede2cc6bc2feac33851cde6d778 --- /dev/null +++ b/test/onnx/tile_test.onnx @@ -0,0 +1,16 @@ + tile_test:d + +x +yz"Tile tile_test* :ByZ +x +  + +Z +y + + +b +z +  + +B diff --git a/test/onnx/tile_test_3x2.onnx b/test/onnx/tile_test_3x2.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ffb9690172949e122c1911d31cf718c6cb8c8149 --- /dev/null +++ b/test/onnx/tile_test_3x2.onnx @@ -0,0 +1,16 @@ + tile_test_3x2:h + +x +yz"Tile tile_test_3x2* :ByZ +x +  + +Z +y + + +b +z +  + +B diff --git a/test/onnx/topk_attrk_test.onnx b/test/onnx/topk_attrk_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..77d47ae20468b33e294b846142c6260a6d32cf57 --- /dev/null +++ b/test/onnx/topk_attrk_test.onnx @@ -0,0 +1,22 @@ +topk_attrk_test:™ +$ +datavalindices"TopK* +k topk_attrk_testZ +data + + + + +b +val + + + + +b! +indices + + + + +B \ No newline at end of file diff --git a/test/onnx/topk_neg_axis_test.onnx b/test/onnx/topk_neg_axis_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa281280332afc2968a1df747ee8a17f07ed3460 Binary files /dev/null and b/test/onnx/topk_neg_axis_test.onnx differ diff --git a/test/onnx/topk_test.onnx b/test/onnx/topk_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4e9bdf1938af7e1a8fd1e33cbbcbdbf244a4083b Binary files /dev/null and b/test/onnx/topk_test.onnx differ diff --git a/test/onnx/transpose_default_perm_test.onnx b/test/onnx/transpose_default_perm_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4a91c536e8e5abb838e7639ebb6fbf6106dec4eb --- /dev/null +++ b/test/onnx/transpose_default_perm_test.onnx @@ -0,0 +1,15 @@ +transpose_default_perm_test:j + +01" Transposetranspose_default_perm_testZ +0 + + + + +b +1 + + + + +B \ No newline at end of file diff --git a/test/onnx/transpose_invalid_perm_test.onnx b/test/onnx/transpose_invalid_perm_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..de059e936a7eab785fb08251cf5bf6acd9d6686a Binary files /dev/null and b/test/onnx/transpose_invalid_perm_test.onnx differ diff --git a/test/onnx/undefined_test.onnx b/test/onnx/undefined_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f8c38e2470080d283c83cdb4a9b34210c938a5b5 Binary files /dev/null and b/test/onnx/undefined_test.onnx differ diff --git a/test/onnx/unknown_aten_test.onnx b/test/onnx/unknown_aten_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..319beb98fc12c852279797729fc5dbc44d0fea32 --- /dev/null +++ b/test/onnx/unknown_aten_test.onnx @@ -0,0 +1,21 @@ +unknown_aten_test:‹ +' +0 +12"ATen* +operator"unknown unknown_aten_testZ +0 + + + + +Z +1 +  + +b +3 + + + + +B \ No newline at end of file diff --git a/test/onnx/upsample_linear_test.onnx b/test/onnx/upsample_linear_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c6ec8f52b4ae2ac7722294a459a91ad83e57fa41 Binary files /dev/null and b/test/onnx/upsample_linear_test.onnx differ diff --git a/test/onnx/upsample_test.onnx b/test/onnx/upsample_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..251eb09cd7cb746e75f7af3791ff4399b25b10e3 Binary files /dev/null and b/test/onnx/upsample_test.onnx differ diff --git a/test/onnx/variable_batch_leq_zero_test.onnx b/test/onnx/variable_batch_leq_zero_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..31abac331a08e85d631a840dfb95af74ed33ca48 Binary files /dev/null and b/test/onnx/variable_batch_leq_zero_test.onnx differ diff --git a/test/onnx/variable_batch_test.onnx b/test/onnx/variable_batch_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..db987a118032989ffcc4b31429cf68eeae896e48 Binary files /dev/null and b/test/onnx/variable_batch_test.onnx differ diff --git a/test/onnx/verify_onnx.cpp b/test/onnx/verify_onnx.cpp new file mode 100644 index 0000000000000000000000000000000000000000..20731476d54edc1d8fd74c79fd4f8d2a54a8114c --- /dev/null +++ b/test/onnx/verify_onnx.cpp @@ -0,0 +1,1008 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "test.hpp" + +TEST_CASE(averagepool_notset_test) +{ + auto p = migraphx::parse_onnx("averagepool_notset_test.onnx"); + p.compile(migraphx::ref::target{}); + std::vector data_x = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + migraphx::shape s_x{migraphx::shape::float_type, {1, 1, 5, 5}}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s_x, data_x.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {12}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(averagepool_nt_cip_test) +{ + auto p = migraphx::parse_onnx("averagepool_nt_cip_test.onnx"); + p.compile(migraphx::ref::target{}); + std::vector data_x = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + migraphx::shape s_x{migraphx::shape::float_type, {1, 1, 5, 5}}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s_x, data_x.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {8.33333}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(celu_verify_test) +{ + migraphx::program p = migraphx::parse_onnx("celu_verify_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + std::vector data = {-5.5, 2.0, 100., 7.0, 0., -1.}; + + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector correct(6); + float alpha = 0.5; + std::transform(data.begin(), data.end(), correct.begin(), [&](auto x) { + return std::max(0.0f, x) + std::min(0.0f, alpha * std::expm1(x / alpha)); + }); + EXPECT(migraphx::verify_range(result_vector, correct)); +} + +TEST_CASE(clip_args_type_mismatch) +{ + auto p = migraphx::parse_onnx("clip_test_args_type_mismatch.onnx"); + p.compile(migraphx::ref::target{}); + migraphx::shape s_0{migraphx::shape::float_type, {3, 3}}; + migraphx::parameter_map pp; + std::vector data_0 = {0.9, 1.2, 1.7, 1.9, 2.2, 2.7, 2.9, 3.2, 3.7}; + pp["0"] = migraphx::argument(s_0, data_0.data()); + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1.5, 2, 2, 1.9, 2.5, 3, 2.9, 3.2, 3.7}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(depthtospace_simple_test) +{ + auto p = migraphx::parse_onnx("depthtospace_simple_test.onnx"); + p.compile(migraphx::ref::target{}); + std::vector data_in(48); + std::iota(std::begin(data_in), std::end(data_in), 0); + migraphx::shape s_x{migraphx::shape::float_type, {1, 8, 2, 3}}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s_x, data_in.data()); + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0, 12, 1, 13, 2, 14, 24, 36, 25, 37, 26, 38, 3, 15, 4, 16, + 5, 17, 27, 39, 28, 40, 29, 41, 6, 18, 7, 19, 8, 20, 30, 42, + 31, 43, 32, 44, 9, 21, 10, 22, 11, 23, 33, 45, 34, 46, 35, 47}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(spacetodepth_simple_test) +{ + auto p = migraphx::parse_onnx("spacetodepth_simple_test.onnx"); + p.compile(migraphx::ref::target{}); + std::vector data_in(48); + std::iota(std::begin(data_in), std::end(data_in), 0); + migraphx::shape s_x{migraphx::shape::float_type, {1, 2, 4, 6}}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s_x, data_in.data()); + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0, 2, 4, 12, 14, 16, 24, 26, 28, 36, 38, 40, 1, 3, 5, 13, + 15, 17, 25, 27, 29, 37, 39, 41, 6, 8, 10, 18, 20, 22, 30, 32, + 34, 42, 44, 46, 7, 9, 11, 19, 21, 23, 31, 33, 35, 43, 45, 47}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(spacetodepth_depthtospace_test) +{ + // space to depth + auto p1 = migraphx::parse_onnx("spacetodepth_simple_test.onnx"); + p1.compile(migraphx::ref::target{}); + std::vector data_in(48); + std::iota(std::begin(data_in), std::end(data_in), 0); + migraphx::shape s_x_1{migraphx::shape::float_type, {1, 2, 4, 6}}; + migraphx::parameter_map pp1; + pp1["x"] = migraphx::argument(s_x_1, data_in.data()); + auto result1 = p1.eval(pp1).back(); + // depth to space + auto p2 = migraphx::parse_onnx("depthtospace_simple_test.onnx"); + p2.compile(migraphx::ref::target{}); + migraphx::parameter_map pp2; + pp2["x"] = result1; + auto result2 = p2.eval(pp2).back(); + std::vector result_vector2; + result2.visit([&](auto output) { result_vector2.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(result_vector2, data_in)); +} + +TEST_CASE(eyelike_verify_test) +{ + migraphx::program p = migraphx::parse_onnx("eyelike_verify_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape s{migraphx::shape::float_type, {3, 4}}; + std::vector data{12, 0}; + migraphx::parameter_map pp; + pp["T1"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector eyelike_mat = {0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.}; + EXPECT(migraphx::verify_range(result_vector, eyelike_mat)); +} + +TEST_CASE(eyelike_verify_negk_test) +{ + migraphx::program p = migraphx::parse_onnx("eyelike_verify_negk_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape s{migraphx::shape::float_type, {3, 4}}; + std::vector data{12, 0}; + migraphx::parameter_map pp; + pp["T1"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector eyelike_mat = {0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.}; + EXPECT(migraphx::verify_range(result_vector, eyelike_mat)); +} + +TEST_CASE(gather_elements) +{ + migraphx::program p = migraphx::parse_onnx("gather_elements_axis0_test.onnx"); + p.compile(migraphx::ref::target{}); + migraphx::shape s_data{migraphx::shape::float_type, {3, 4}}; + std::vector data = { + 0.25, 0.75, 0.9375, 0.4375, 0.6875, 0.5625, -0.875, 0.1875, -0.125, 0.5, -0.9375, -0.0625}; + + migraphx::shape s_ind{migraphx::shape::int32_type, {2, 3}}; + std::vector ind = {2, 1, 2, 0, 1, 0}; + + migraphx::parameter_map pp; + pp["data"] = migraphx::argument(s_data, data.data()); + pp["indices"] = migraphx::argument(s_ind, ind.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {-0.125, 0.5625, -0.9375, 0.25, 0.5625, 0.9375}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(greaterorequal_test) +{ + migraphx::program p = migraphx::parse_onnx("greaterorequal_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape s{migraphx::shape::float_type, {3}}; + std::vector data1 = {0.25, 0.75, 0.9375}; + std::vector data2 = {0.25, 0.74, 0.9411}; + + migraphx::parameter_map pp; + pp["x1"] = migraphx::argument(s, data1.data()); + pp["x2"] = migraphx::argument(s, data2.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1.0, 1.0, 0.0}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(hardsigmoid_verify_test) +{ + migraphx::program p = migraphx::parse_onnx("hardsigmoid_verify_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape s{migraphx::shape::float_type, {2, 5}}; + std::vector data = {-10.0, -2.5, -1.0, -0.5, 0, 1.0, 2.0, 2.5, 2.6, 100.0}; + + float alpha = 0.2; + float beta = 0.5; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold(10); + std::transform(data.begin(), data.end(), gold.begin(), [&](auto x) { + return std::max(0.0f, std::min(x * alpha + beta, 1.0f)); + }); + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(if_else_test) +{ + migraphx::program p = migraphx::parse_onnx("if_else_test.onnx"); + p.compile(migraphx::ref::target{}); + migraphx::shape s_data{migraphx::shape::float_type, {2, 3}}; + std::vector data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625}; + + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s_data, data.data()); + pp["y"] = migraphx::argument(s_data, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + -0.0364609435, 0.475317657, -0.00417715637, -0.0599277429, 0.0755792186, -0.0218581557}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(if_literal_test) +{ + auto run_prog = [](bool cond) { + migraphx::program p = migraphx::parse_onnx("if_literal_test.onnx"); + p.compile(migraphx::ref::target{}); + migraphx::shape s_data{migraphx::shape::bool_type}; + std::vector data = {static_cast(cond)}; + + migraphx::parameter_map pp; + pp["cond"] = migraphx::argument(s_data, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + return result_vector; + }; + + // then branch + { + auto result_vector = run_prog(true); + std::vector gold = {1, 2, 3, 4, 5}; + EXPECT(migraphx::verify_range(result_vector, gold)); + } + + // else branch + { + auto result_vector = run_prog(false); + std::vector gold = {5, 4, 3, 2, 1}; + EXPECT(migraphx::verify_range(result_vector, gold)); + } +} + +TEST_CASE(if_pl_test) +{ + auto run_prog = [](bool cond) { + migraphx::program p = migraphx::parse_onnx("if_pl_test.onnx"); + p.compile(migraphx::ref::target{}); + migraphx::shape xs{migraphx::shape::float_type, {2, 3}}; + migraphx::shape ys{migraphx::shape::float_type, {3, 3}}; + migraphx::shape cond_s{migraphx::shape::bool_type}; + + std::vector x_data(xs.elements(), 1.0f); + std::vector y_data(ys.elements(), 2.0f); + std::vector cond_data{static_cast(cond)}; + + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(xs, x_data.data()); + pp["y"] = migraphx::argument(ys, y_data.data()); + pp["cond"] = migraphx::argument(cond_s, cond_data.data()); + + auto result = p.eval(pp).back(); + std::vector ret; + result.visit([&](auto output) { ret.assign(output.begin(), output.end()); }); + + return ret; + }; + + // then branch + { + auto result_vector = run_prog(true); + std::vector gold = {2, 3, 4, 5, 6, 7}; + EXPECT(migraphx::verify_range(result_vector, gold)); + } + + // else branch + { + auto result_vector = run_prog(false); + std::vector gold = {1, 2, 3, 4, 5, 6}; + EXPECT(migraphx::verify_range(result_vector, gold)); + } +} + +TEST_CASE(if_tuple_test) +{ + auto run_prog = [](bool cond) { + migraphx::program p = migraphx::parse_onnx("if_tuple_test.onnx"); + p.compile(migraphx::ref::target{}); + migraphx::shape xs{migraphx::shape::float_type, {1, 4}}; + migraphx::shape ys{migraphx::shape::float_type, {3, 4}}; + migraphx::shape cond_s{migraphx::shape::bool_type}; + + std::vector x_data(xs.elements(), 1.0f); + std::vector y_data(ys.elements(), 2.0f); + std::vector cond_data{static_cast(cond)}; + + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(xs, x_data.data()); + pp["y"] = migraphx::argument(ys, y_data.data()); + pp["cond"] = migraphx::argument(cond_s, cond_data.data()); + + auto results = p.eval(pp); + std::vector> rets; + for(const auto& arg : results) + { + std::vector vec; + arg.visit([&](auto output) { vec.assign(output.begin(), output.end()); }); + rets.push_back(vec); + } + + return rets; + }; + + // then branch + { + auto results = run_prog(true); + std::vector gold0(4, 2.0f); + std::vector gold1(12, 4.0f); + EXPECT(migraphx::verify_range(results.at(0), gold0)); + EXPECT(migraphx::verify_range(results.at(1), gold1)); + } + + // else branch + { + auto results = run_prog(false); + std::vector gold0(4, 3.0f); + std::vector gold1(12, 5.0f); + EXPECT(migraphx::verify_range(results.at(0), gold0)); + EXPECT(migraphx::verify_range(results.at(1), gold1)); + } +} + +TEST_CASE(instance_norm_test) +{ + migraphx::program p = migraphx::parse_onnx("instance_norm_val_test.onnx"); + + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vector(9); + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {-1.54919, + -1.16189, + -0.774596, + -0.387298, + 0, + 0.387298, + 0.774596, + 1.16189, + 1.54919, + -2.09838, + -1.32379, + -0.549192, + 0.225404, + 1, + 1.7746, + 2.54919, + 3.32379, + 4.09838}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(instance_norm_3d_test) +{ + migraphx::program p = migraphx::parse_onnx("instance_norm_val_3d_test.onnx"); + + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vector(16); + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {-1.52752, + -1.09109, + -0.654653, + -0.218218, + 0.218218, + 0.654653, + 1.09109, + 1.52752, + -2.05505, + -1.18218, + -0.309306, + 0.563565, + 1.43644, + 2.30931, + 3.18218, + 4.05505}; + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(lessorequal_test) +{ + migraphx::program p = migraphx::parse_onnx("lessorequal_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape s{migraphx::shape::float_type, {3}}; + std::vector data1 = {0.25, 0.75, 0.9375}; + std::vector data2 = {0.25, 0.74, 0.9411}; + + migraphx::parameter_map pp; + pp["x1"] = migraphx::argument(s, data1.data()); + pp["x2"] = migraphx::argument(s, data2.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1, 0, 1}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(lpnormalization_1norm) +{ + migraphx::program p = migraphx::parse_onnx("lpnormalization_l1_test.onnx"); + p.compile(migraphx::ref::target{}); + migraphx::shape s{migraphx::shape::float_type, {3, 4}}; + std::vector data{0.f, 2.f, -2.f, 1.f, 1.f, -5.f, 3.f, -1.f, -4.f, 3.f, 0.f, 0.f}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold{0.f, + 2.f / 5.f, + -2.f / 5.f, + 1.f / 5.f, + 1.f / 10.f, + -5.f / 10.f, + 3.f / 10.f, + -1.f / 10.f, + -4.f / 7.f, + 3.f / 7.f, + 0.f, + 0.f}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(lpnormalization_2norm) +{ + migraphx::program p = migraphx::parse_onnx("lpnormalization_l2_test.onnx"); + p.compile(migraphx::ref::target{}); + migraphx::shape s{migraphx::shape::float_type, {3, 4}}; + std::vector data{0.f, 2.f, -2.f, 1.f, 1.f, -5.f, 3.f, -1.f, -4.f, 3.f, 0.f, 0.f}; + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector correct{0.f, + 2.f / 3.f, + -2.f / 3.f, + 1.f / 3.f, + 1.f / 6.f, + -5.f / 6.f, + 3.f / 6.f, + -1.f / 6.f, + -4.f / 5.f, + 3.f / 5.f, + 0.f, + 0.f}; + EXPECT(migraphx::verify_range(result_vector, correct)); +} + +TEST_CASE(mean_broadcast_test) +{ + migraphx::program p = migraphx::parse_onnx("mean_broadcast_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape s0{migraphx::shape::float_type, {1, 3, 4}}; + std::vector data0(12, 1); + migraphx::shape s1{migraphx::shape::float_type, {1, 2, 3, 4}}; + std::vector data1(24, 2); + migraphx::shape s2{migraphx::shape::float_type, {4}}; + std::vector data2(4, 3); + migraphx::shape s3{migraphx::shape::float_type, {1}}; + std::vector data3(1, 4); + migraphx::shape s4{migraphx::shape::float_type, {2, 3, 1}}; + std::vector data4(6, 5); + + migraphx::parameter_map pp; + pp["0"] = migraphx::argument(s0, data0.data()); + pp["1"] = migraphx::argument(s1, data1.data()); + pp["2"] = migraphx::argument(s2, data2.data()); + pp["3"] = migraphx::argument(s3, data3.data()); + pp["4"] = migraphx::argument(s4, data4.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold(24, 3); + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(mean_test) +{ + migraphx::program p = migraphx::parse_onnx("mean_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape s{migraphx::shape::double_type, {2, 2, 2}}; + const int num_elms = 8; + const int num_data = 10; + const std::vector scalars{1.0, 2.0, -2.5, 3.3, 10.7, -1.0, 100.0, 7.9, 0.01, -56.8}; + std::vector> data; + std::transform(scalars.begin(), scalars.end(), std::back_inserter(data), [&](const auto& i) { + return std::vector(num_elms, i); + }); + + migraphx::parameter_map pp; + for(std::size_t i = 0; i < num_data; ++i) + pp[std::to_string(i)] = migraphx::argument(s, data[i].data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + const auto mean = std::accumulate(scalars.begin(), scalars.end(), 0.0) / num_data; + std::vector gold(num_elms, mean); + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(mean_integral_test) +{ + migraphx::program p = migraphx::parse_onnx("mean_integral_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape s{migraphx::shape::int32_type, {2, 2, 2}}; + const int num_elms = 8; + const int num_data = 10; + const std::vector scalars{1, 5, 14, 2, 6, 21, 101, 0, -4, -11}; + std::vector> data; + std::transform(scalars.begin(), scalars.end(), std::back_inserter(data), [&](const auto i) { + return std::vector(num_elms, i); + }); + + migraphx::parameter_map pp; + for(std::size_t i = 0; i < num_data; ++i) + pp[std::to_string(i)] = migraphx::argument(s, data[i].data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + const auto mean = std::accumulate(scalars.begin(), scalars.end(), 0) / num_data; + std::vector gold(num_elms, mean); + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(nonzero_test) +{ + migraphx::program p = migraphx::parse_onnx("nonzero_dynamic_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape s{migraphx::shape::bool_type, {2, 2}}; + std::vector data = {1, 1, 1, 0}; + + migraphx::parameter_map pp; + pp["data"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0, 0, 1, 0, 0, 1, 0, 0}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(resize_downsample_f_test) +{ + migraphx::program p = migraphx::parse_onnx("resize_downsample_f_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 4}}; + std::vector dx(sx.elements()); + std::iota(dx.begin(), dx.end(), 0.0f); + + migraphx::parameter_map pp; + pp["X"] = migraphx::argument(sx, dx.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0.0f, 3.0f}; + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(resize_upsample_linear_ac_test) +{ + migraphx::program p = migraphx::parse_onnx("resize_upsample_linear_ac_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; + std::vector dx = {1.0f, 2.0f, 3.0f, 4.0f}; + + migraphx::parameter_map pp; + pp["X"] = migraphx::argument(sx, dx.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1, + 4.0f / 3, + 5.0f / 3, + 2, + 5.0f / 3, + 2, + 7.0f / 3, + 8.0f / 3, + 7.0f / 3, + 8.0f / 3, + 3, + 10.0f / 3, + 3, + 10.0f / 3, + 11.0f / 3, + 4}; + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(resize_upsample_linear_test) +{ + migraphx::program p = migraphx::parse_onnx("resize_upsample_linear_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; + std::vector dx = {1.0f, 2.0f, 3.0f, 4.0f}; + + migraphx::parameter_map pp; + pp["X"] = migraphx::argument(sx, dx.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 1, 1.25, 1.75, 2, 1.5, 1.75, 2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3, 3.25, 3.75, 4}; + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(resize_upsample_pf_test) +{ + migraphx::program p = migraphx::parse_onnx("resize_upsample_pf_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; + std::vector dx = {1.0f, 2.0f, 3.0f, 4.0f}; + + migraphx::parameter_map pp; + pp["X"] = migraphx::argument(sx, dx.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, + 3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4}; + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(reversesequence_4D_verify_test) +{ + migraphx::program p = migraphx::parse_onnx("reversesequence_4D_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape xs{migraphx::shape::float_type, {2, 2, 2, 2}}; + std::vector x_data = { + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0}; + migraphx::parameter_map param_map; + param_map["x"] = migraphx::argument(xs, x_data.data()); + + auto result = p.eval(param_map).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 8.0, 9.0, 10.0, 11.0, 4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0, 12.0, 13.0, 14.0, 15.0}; + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(reversesequence_batch_verify_test) +{ + migraphx::program p = migraphx::parse_onnx("reversesequence_batch_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape xs{migraphx::shape::float_type, {4, 4}}; + std::vector x_data = { + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0}; + migraphx::parameter_map param_map; + param_map["x"] = migraphx::argument(xs, x_data.data()); + + auto result = p.eval(param_map).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 0.0, 1.0, 2.0, 3.0, 5.0, 4.0, 6.0, 7.0, 10.0, 9.0, 8.0, 11.0, 15.0, 14.0, 13.0, 12.0}; + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(reversesequence_time_verify_test) +{ + migraphx::program p = migraphx::parse_onnx("reversesequence_time_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape xs{migraphx::shape::float_type, {4, 4}}; + std::vector x_data = { + 0.0, 4.0, 8.0, 12.0, 1.0, 5.0, 9.0, 13.0, 2.0, 6.0, 10.0, 14.0, 3.0, 7.0, 11.0, 15.0}; + migraphx::parameter_map param_map; + param_map["x"] = migraphx::argument(xs, x_data.data()); + + auto result = p.eval(param_map).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 3.0, 6.0, 9.0, 12.0, 2.0, 5.0, 8.0, 13.0, 1.0, 4.0, 10.0, 14.0, 0.0, 7.0, 11.0, 15.0}; + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(selu_test) +{ + migraphx::program p = migraphx::parse_onnx("selu_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape xs{migraphx::shape::double_type, {2, 3}}; + std::vector x_data = {1.1, 2.1, 0.0, -1.3, -5.3, 12.0}; + + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(xs, x_data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {0.55, 1.05, 0, -0.10912, -0.149251, 6}; + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(size_verify_test) +{ + migraphx::program p = migraphx::parse_onnx("size_verify_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape s{migraphx::shape::float_type, {2, 5, 3}}; + std::vector data(30, 1.); + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + auto size_result = result.at(); + EXPECT(size_result == int64_t{30}); +} + +TEST_CASE(slice_test) +{ + migraphx::program p = migraphx::parse_onnx("slice_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape sh_data{migraphx::shape::float_type, {3, 2}}; + std::vector data = {0, 1, 2, 3, 4, 5}; + + migraphx::parameter_map pp; + pp["0"] = migraphx::argument(sh_data, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {2, 3}; + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(slice_5arg_test) +{ + migraphx::program p = migraphx::parse_onnx("slice_5arg_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape sh_data{migraphx::shape::float_type, {5, 5}}; // start + std::vector data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + + migraphx::parameter_map pp; + pp["0"] = migraphx::argument(sh_data, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {10, 11, 12, 13, 15, 16, 17, 18}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(slice_reverse_test) +{ + migraphx::program p = migraphx::parse_onnx("slice_5arg_reverse_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape sh_data{migraphx::shape::float_type, {5, 5}}; // start + std::vector data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + + migraphx::parameter_map pp; + pp["0"] = migraphx::argument(sh_data, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {14, 13, 12, 11, 19, 18, 17, 16}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(slice_step_test) +{ + migraphx::program p = migraphx::parse_onnx("slice_5arg_step_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape sh_data{migraphx::shape::float_type, {5, 5}}; // start + std::vector data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + + migraphx::parameter_map pp; + pp["0"] = migraphx::argument(sh_data, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {14, 12}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(softplus_test) +{ + migraphx::program p = migraphx::parse_onnx("softplus_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape s{migraphx::shape::float_type, {5}}; + std::vector data = {0, 1, 2, 3, 4}; + + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold(5); + std::transform( + data.begin(), data.end(), gold.begin(), [](auto x) { return std::log1p(std::exp(x)); }); + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(softsign_test) +{ + migraphx::program p = migraphx::parse_onnx("softsign_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape s{migraphx::shape::float_type, {5}}; + std::vector data = {0, 1, 2, 3, 4}; + + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold(5); + std::transform( + data.begin(), data.end(), gold.begin(), [](auto x) { return x / (1.0 + std::abs(x)); }); + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(upsample_test) +{ + migraphx::program p = migraphx::parse_onnx("upsample_test.onnx"); + + std::vector x_data = {1, 2, 3, 4}; + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; + + migraphx::parameter_map pp; + pp["X"] = migraphx::argument(sx, x_data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, + 3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(where_test) +{ + migraphx::program p = migraphx::parse_onnx("where_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape c_shape{migraphx::shape::bool_type, {2}}; + std::vector c_data = {1, 0}; + + migraphx::shape x_shape{migraphx::shape::float_type, {2, 2, 2}}; + std::vector x_data(8, 1.0f); + + migraphx::shape y_shape{migraphx::shape::float_type, {2, 1, 2, 2}}; + std::vector y_data(8, 2.0f); + + migraphx::parameter_map pp; + pp["c"] = migraphx::argument(c_shape, c_data.data()); + pp["x"] = migraphx::argument(x_shape, x_data.data()); + pp["y"] = migraphx::argument(y_shape, y_data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/onnx/where_test.onnx b/test/onnx/where_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f04b36c4cc8ddfb6ae1685a5181877a2db2e2fda --- /dev/null +++ b/test/onnx/where_test.onnx @@ -0,0 +1,28 @@ + +where_test:… + +c +x +yz"Where +where_testZ +c + +  +Z +x + + + +Z +y + + + + +b +z + + + + +B \ No newline at end of file diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index b79c96cbb57a529c0c403f333f87a4f4e60b0f35..8eb25cf88dc1500e28753734aee482fe8e4504ed 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -3,21 +3,26 @@ #include #include #include +#include + +#include + #include "test.hpp" template void expect_shape(const migraphx::shape& expected, const migraphx::operation& op, Ts... xs) { migraphx::program p; + auto* mm = p.get_main_module(); std::vector shapes{xs...}; std::vector args(shapes.size()); std::transform( - shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); }); - p.add_instruction(op, args); - if(p.get_shape() != expected) + shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return mm->add_outline(s); }); + mm->add_instruction(op, args); + if(p.get_output_shapes().back() != expected) { - std::cout << "FAILED: Incorrect shape for " << op.name() << ": "; - std::cout << expected << " != " << p.get_shape() << std::endl; + std::cout << "FAILED: Incorrect shape for " << op << ": "; + std::cout << expected << " != " << p.get_output_shapes().back() << std::endl; for(auto&& s : shapes) std::cout << " " << s << std::endl; } @@ -27,11 +32,12 @@ template void throws_shape(const migraphx::operation& op, Ts... xs) { migraphx::program p; + auto* mm = p.get_main_module(); std::vector shapes{xs...}; std::vector args(shapes.size()); std::transform( - shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); }); - bool thrown = test::throws([&] { p.add_instruction(op, args); }); + shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return mm->add_outline(s); }); + bool thrown = test::throws([&] { mm->add_instruction(op, args); }); if(not thrown) { std::cout << "FAILED: No error found for " << op.name() << ": "; @@ -57,236 +63,161 @@ TEST_CASE(batch_norm_inference_shape) const size_t channels = 3; migraphx::shape s{migraphx::shape::float_type, {4, channels, 3, 3}}; migraphx::shape vars{migraphx::shape::float_type, {channels}}; - expect_shape(s, migraphx::op::batch_norm_inference{}, s, vars, vars, vars, vars); - throws_shape(migraphx::op::batch_norm_inference{}, s); - throws_shape(migraphx::op::batch_norm_inference{}, s, vars, vars, vars, vars, vars); + expect_shape(s, migraphx::make_op("batch_norm_inference"), s, vars, vars, vars, vars); + throws_shape(migraphx::make_op("batch_norm_inference"), s); + throws_shape(migraphx::make_op("batch_norm_inference"), s, vars, vars, vars, vars, vars); } -TEST_CASE(convolution_shape) +TEST_CASE(broadcast) { - migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}}; - migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}}; - migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}}; - expect_shape(output, migraphx::op::convolution{}, input, weights); - throws_shape(migraphx::op::convolution{}, input); + { + std::vector lens{1, 1}; + migraphx::shape input{migraphx::shape::float_type, {1}, {0}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}}, + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", lens}}), + input); + } - migraphx::shape input2{migraphx::shape::float_type, {3, 3}}; - migraphx::shape weights2{migraphx::shape::float_type, {3, 3}}; - throws_shape(migraphx::op::convolution{}, input2, weights2); - throws_shape(migraphx::op::convolution{}, input2, weights); -} + { + std::vector lens{1, 1}; + migraphx::shape input{migraphx::shape::float_type, {2}}; + throws_shape(migraphx::op::broadcast{1, lens}, input); + } -TEST_CASE(quant_convolution_shape) -{ - migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}}; - migraphx::shape input{migraphx::shape::int8_type, {4, 3, 3, 3}}; - migraphx::shape weights{migraphx::shape::int8_type, {4, 3, 3, 3}}; - expect_shape(output, migraphx::op::quant_convolution{}, input, weights); - throws_shape(migraphx::op::quant_convolution{}, input); + { + std::vector lens{2, 2}; + migraphx::shape input{migraphx::shape::float_type, {1, 2}}; + throws_shape(migraphx::op::broadcast{1, lens}, input); + } - migraphx::shape input2{migraphx::shape::int32_type, {3, 3}}; - migraphx::shape weights2{migraphx::shape::float_type, {3, 3}}; - throws_shape(migraphx::op::quant_convolution{}, input2, weights2); - throws_shape(migraphx::op::quant_convolution{}, input2, weights); + { + std::vector lens{3, 2, 4, 3}; + migraphx::shape input{migraphx::shape::float_type, {4, 3}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}}, + migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", lens}}), + input); + } - migraphx::shape input3{migraphx::shape::int32_type, {4, 3, 3, 3}}; - migraphx::shape weight3{migraphx::shape::float_type, {4, 3, 3, 3}}; - throws_shape(migraphx::op::quant_convolution{}, input3, weights); - throws_shape(migraphx::op::quant_convolution{}, input, weight3); - throws_shape(migraphx::op::quant_convolution{}, input3, weight3); + { + std::vector lens{3, 2, 4, 3}; + migraphx::shape input{migraphx::shape::float_type, {4, 4}}; + throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", lens}}), input); + } } -TEST_CASE(transpose_shape) +TEST_CASE(convolution_shape) { - migraphx::shape input{migraphx::shape::float_type, {2, 2}}; - migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}}; - expect_shape(input, migraphx::op::transpose{{0, 1}}, input); - expect_shape(output, migraphx::op::transpose{{1, 0}}, input); - throws_shape(migraphx::op::transpose{{1, 2}}, input); + migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}}; + migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}}; + migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}}; + expect_shape(output, migraphx::make_op("convolution"), input, weights); + throws_shape(migraphx::make_op("convolution"), input); + throws_shape( + migraphx::make_op("convolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}), + input); + + migraphx::shape input2{migraphx::shape::float_type, {3, 3}}; + migraphx::shape weights2{migraphx::shape::float_type, {3, 3}}; + throws_shape(migraphx::make_op("convolution"), input2, weights2); + throws_shape(migraphx::make_op("convolution"), input2, weights); + + migraphx::shape output_1d{migraphx::shape::float_type, {4, 4, 1}}; + migraphx::shape input_1d{migraphx::shape::float_type, {4, 3, 3}}; + migraphx::shape weights_1d{migraphx::shape::float_type, {4, 3, 3}}; + expect_shape( + output_1d, + migraphx::make_op("convolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}), + input_1d, + weights_1d); + + migraphx::shape output_3d{migraphx::shape::float_type, {4, 4, 1, 1, 1}}; + migraphx::shape input_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}}; + migraphx::shape weights_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}}; + expect_shape( + output_3d, + migraphx::make_op("convolution", + {{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}), + input_3d, + weights_3d); + + throws_shape(migraphx::make_op("convolution"), input_3d, weights_3d); } TEST_CASE(contiguous_shape) { migraphx::shape output{migraphx::shape::float_type, {2, 2}}; migraphx::shape input{migraphx::shape::float_type, {2, 2}, {1, 2}}; - expect_shape(output, migraphx::op::contiguous{}, input); - throws_shape(migraphx::op::contiguous{}, input, input); + expect_shape(output, migraphx::make_op("contiguous"), input); + throws_shape(migraphx::make_op("contiguous"), input, input); migraphx::shape single{migraphx::shape::float_type, {2}}; - expect_shape(single, migraphx::op::contiguous{}, single); + expect_shape(single, migraphx::make_op("contiguous"), single); } -TEST_CASE(reshape_shape) +TEST_CASE(contiguous_shape_scalar) { - migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}}; - for(auto&& new_shape : - std::vector>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}}) - { - std::vector lens(new_shape.size()); - std::copy(new_shape.begin(), new_shape.end(), lens.begin()); - migraphx::shape output{migraphx::shape::float_type, lens}; - expect_shape(output, migraphx::op::reshape{new_shape}, input); - } - - for(auto&& new_shape : - std::vector>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0, 0}, {3, 2, 0}}) - { - throws_shape(migraphx::op::reshape{new_shape}, input); - } - - std::vector, migraphx::shape>> minus1_tests{ - {{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}}, - {{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}}, - {{2, -1, 0}, {migraphx::shape::float_type, {2, 12, 1}}}, - {{0, 0, -1}, {migraphx::shape::float_type, {24, 1, 1}}}, - {{2, 0, -1}, {migraphx::shape::float_type, {2, 1, 12}}}, - {{-1, 2, 3}, {migraphx::shape::float_type, {4, 2, 3}}}, - {{-1, 0, 3}, {migraphx::shape::float_type, {8, 1, 3}}}, - {{-1, 0, 0}, {migraphx::shape::float_type, {24, 1, 1}}}, - {{-1, 3, 0}, {migraphx::shape::float_type, {8, 3, 1}}}}; + migraphx::shape output{migraphx::shape::float_type}; + migraphx::shape input{migraphx::shape::float_type}; + expect_shape(output, migraphx::make_op("contiguous"), input); +} - for(auto& it : minus1_tests) - { - expect_shape(it.second, migraphx::op::reshape{it.first}, input); - } +TEST_CASE(deconvolution_shape) +{ + migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}}; + migraphx::shape output{migraphx::shape::float_type, {4, 3, 3, 3}}; + migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}}; + expect_shape(output, migraphx::make_op("deconvolution"), input, weights); + throws_shape(migraphx::make_op("deconvolution"), input); + throws_shape( + migraphx::make_op("deconvolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}), + input); + + migraphx::shape input_1d{migraphx::shape::float_type, {4, 4, 1}}; + migraphx::shape output_1d{migraphx::shape::float_type, {4, 3, 3}}; + migraphx::shape weights_1d{migraphx::shape::float_type, {4, 3, 3}}; + expect_shape( + output_1d, + migraphx::make_op("deconvolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}), + input_1d, + weights_1d); + + migraphx::shape input_3d{migraphx::shape::float_type, {4, 4, 1, 1, 1}}; + migraphx::shape output_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}}; + migraphx::shape weights_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}}; + expect_shape( + output_3d, + migraphx::make_op("deconvolution", + {{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}), + input_3d, + weights_3d); } TEST_CASE(flatten_shape) { migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}}; expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2 * 4 * 6 * 8}}, - migraphx::op::flatten{0}, + migraphx::make_op("flatten", {{"axis", 0}}), + input); + expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2 * 4 * 6 * 8}}, + migraphx::make_op("flatten", {{"axis", -4}}), input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 4 * 6 * 8}}, - migraphx::op::flatten{1}, + migraphx::make_op("flatten", {{"axis", 1}}), + input); + expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 4 * 6 * 8}}, + migraphx::make_op("flatten", {{"axis", -3}}), input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4, 6 * 8}}, - migraphx::op::flatten{2}, + migraphx::make_op("flatten", {{"axis", 2}}), input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4 * 6, 8}}, - migraphx::op::flatten{3}, + migraphx::make_op("flatten", {{"axis", 3}}), input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4 * 6 * 8, 1}}, - migraphx::op::flatten{4}, - input); - throws_shape(migraphx::op::flatten{5}, input); -} - -TEST_CASE(slice_shape) -{ - migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}}; - expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, - migraphx::op::slice{{2}, {1}, {3}}, - input); - expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, - migraphx::op::slice{{0, 1, 2}, {0, 0, 1}, {2, 2, 3}}, - input); - expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}}, - migraphx::op::slice{{2}, {2}, {10}}, + migraphx::make_op("flatten", {{"axis", 4}}), input); -} - -TEST_CASE(multibroadcast) -{ - { - std::vector lens{4, 2, 5, 3}; - migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}}, - migraphx::op::multibroadcast{lens}, - input); - } - { - std::vector lens{4, 2, 5, 3}; - migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}}, - migraphx::op::multibroadcast{lens}, - input); - } - { - std::vector lens{4, 2, 5, 3}; - migraphx::shape input{migraphx::shape::float_type, {5, 1}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}}, - migraphx::op::multibroadcast{lens}, - input); - } - { - std::vector lens{4, 2, 5, 3}; - migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}}, - migraphx::op::multibroadcast{lens}, - input); - } - { - std::vector lens{4, 2, 5, 3}; - migraphx::shape input{migraphx::shape::float_type, {3}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}}, - migraphx::op::multibroadcast{lens}, - input); - } - { - std::vector lens{4, 4, 1, 3}; - migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}}, - migraphx::op::multibroadcast{lens}, - input); - } - { - std::vector lens{4, 1, 1, 3}; - migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}}, - migraphx::op::multibroadcast{lens}, - input); - } - { - std::vector lens{4, 1, 3}; - migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}}; - throws_shape(migraphx::op::multibroadcast{lens}, input); - } - { - std::vector lens{4, 1, 3}; - migraphx::shape input{migraphx::shape::float_type, {}}; - throws_shape(migraphx::op::multibroadcast{lens}, input); - } - { - std::vector lens{2, 3, 4, 5}; - migraphx::shape input{migraphx::shape::float_type, {3, 4}}; - throws_shape(migraphx::op::multibroadcast{lens}, input); - } - { - std::vector lens{2, 3, 4, 5}; - migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}}; - throws_shape(migraphx::op::multibroadcast{lens}, input); - } -} - -TEST_CASE(broadcast) -{ - { - std::vector lens{1, 1}; - migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}}, - migraphx::op::broadcast{0, lens}, - input); - } - { - std::vector lens{1, 1}; - migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}}; - throws_shape(migraphx::op::broadcast{1, lens}, input); - } - - { - std::vector lens{3, 2, 4, 3}; - migraphx::shape input{migraphx::shape::float_type, {4, 3}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}}, - migraphx::op::broadcast{2, lens}, - input); - } - - { - std::vector lens{3, 2, 4, 3}; - migraphx::shape input{migraphx::shape::float_type, {4, 4}}; - throws_shape(migraphx::op::broadcast{2, lens}, input); - } + throws_shape(migraphx::make_op("flatten", {{"axis", 5}}), input); + throws_shape(migraphx::make_op("flatten", {{"axis", -5}}), input); } TEST_CASE(gather) @@ -296,7 +227,7 @@ TEST_CASE(gather) migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; int axis = 1; expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 4, 5}}, - migraphx::op::gather{axis}, + migraphx::make_op("gather", {{"axis", axis}}), input, indices); } @@ -306,7 +237,7 @@ TEST_CASE(gather) migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; int axis = -4; expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 3, 4, 5}}, - migraphx::op::gather{axis}, + migraphx::make_op("gather", {{"axis", axis}}), input, indices); } @@ -316,7 +247,7 @@ TEST_CASE(gather) migraphx::shape indices{migraphx::shape::int32_type, {1}}; int axis = -4; expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, - migraphx::op::gather{axis}, + migraphx::make_op("gather", {{"axis", axis}}), input, indices); } @@ -326,7 +257,7 @@ TEST_CASE(gather) migraphx::shape indices{migraphx::shape::int32_type}; int axis = -4; expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}}, - migraphx::op::gather{axis}, + migraphx::make_op("gather", {{"axis", axis}}), input, indices); } @@ -336,7 +267,7 @@ TEST_CASE(gather) migraphx::shape indices{migraphx::shape::int32_type}; int axis = 3; expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}, - migraphx::op::gather{axis}, + migraphx::make_op("gather", {{"axis", axis}}), input, indices); } @@ -346,7 +277,7 @@ TEST_CASE(gather) migraphx::shape indices{migraphx::shape::int32_type}; int axis = 0; expect_shape(migraphx::shape{migraphx::shape::float_type}, - migraphx::op::gather{axis}, + migraphx::make_op("gather", {{"axis", axis}}), input, indices); } @@ -356,7 +287,7 @@ TEST_CASE(gather) migraphx::shape indices{migraphx::shape::int32_type, {1}}; int axis = 0; expect_shape(migraphx::shape{migraphx::shape::float_type, {1}}, - migraphx::op::gather{axis}, + migraphx::make_op("gather", {{"axis", axis}}), input, indices); } @@ -365,720 +296,863 @@ TEST_CASE(gather) migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; int axis = 4; - throws_shape(migraphx::op::gather{axis}, input, indices); + throws_shape(migraphx::make_op("gather", {{"axis", axis}}), input, indices); } { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; int axis = -5; - throws_shape(migraphx::op::gather{axis}, input, indices); + throws_shape(migraphx::make_op("gather", {{"axis", axis}}), input, indices); } } -template -void test_softmax_variations() +// 3 input arguments +TEST_CASE(gemm) { { - migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input); + migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; + migraphx::shape s_m2{migraphx::shape::float_type, {10, 8}}; + throws_shape(migraphx::make_op("dot"), s_m1, s_m2); } { - migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input); + migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}}; + migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; + throws_shape(migraphx::make_op("dot"), s_m1, s_m2); } { - migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input); + migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; + migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}}, + migraphx::make_op("dot"), + s_m1, + s_m2); } { - migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input); + migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}}; + migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}}, + migraphx::make_op("dot"), + s_m1, + s_m2); } { - migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - int axis = 4; - throws_shape(T{axis}, input); + migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}}; + migraphx::shape s_m2{migraphx::shape::float_type, {2, 5, 8}}; + throws_shape(migraphx::make_op("dot"), s_m1, s_m2); } } -TEST_CASE(softmax) { test_softmax_variations(); } - -TEST_CASE(logsoftmax) { test_softmax_variations(); } +TEST_CASE(get_tuple_elem_test) +{ + migraphx::shape s0{migraphx::shape::bool_type, {1, 1}}; + migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::int32_type, {5, 6}}; + migraphx::shape s_tuple({s0, s1, s2}); + + expect_shape(s0, migraphx::make_op("get_tuple_elem", {{"index", 0}}), s_tuple); + expect_shape(s1, migraphx::make_op("get_tuple_elem", {{"index", 1}}), s_tuple); + expect_shape(s2, migraphx::make_op("get_tuple_elem", {{"index", 2}}), s_tuple); + throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 3}}), s_tuple); + throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 0}}), s0); + throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 1}}), s1); + throws_shape(migraphx::make_op("get_tuple_elem", {{"index", 0}}), s2); +} -TEST_CASE(test_argmax) +TEST_CASE(gru) { { - migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, - migraphx::op::argmax{0}, - input); - } + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; - { - migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, - migraphx::op::argmax{1}, - input); - } + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - { - migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, - migraphx::op::argmax{2}, - input); + expect_shape( + migraphx::shape{migraphx::shape::float_type, + {seq_len, num_dirct, batch_size, hidden_size}}, + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + in_shape, + w_shape, + r_shape, + b_shape, + ih_shape); } { - migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, - migraphx::op::argmax{3}, - input); - } + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; - { - migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - throws_shape(migraphx::op::argmax{4}, input); - } -} + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; -TEST_CASE(test_argmin) -{ - { - migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, - migraphx::op::argmin{0}, - input); + expect_shape( + migraphx::shape{migraphx::shape::float_type, + {seq_len, num_dirct, batch_size, hidden_size}}, + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + in_shape, + w_shape, + r_shape, + b_shape, + ih_shape); } { - migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, - migraphx::op::argmin{1}, - input); + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + expect_shape( + migraphx::shape{migraphx::shape::float_type, + {seq_len, num_dirct, batch_size, hidden_size}}, + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + in_shape, + w_shape, + r_shape, + b_shape, + ih_shape); } { - migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, - migraphx::op::argmin{2}, - input); + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + throws_shape( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size + 1}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + in_shape, + w_shape, + r_shape, + b_shape, + ih_shape); } { - migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, - migraphx::op::argmin{3}, - input); + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + throws_shape( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + in_shape, + w_shape, + r_shape, + b_shape, + ih_shape); } { - migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - throws_shape(migraphx::op::argmin{4}, input); + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + throws_shape( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + in_shape, + w_shape, + r_shape, + b_shape, + ih_shape); } } +TEST_CASE(inconsistent_attr_shape) +{ + migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}}; + migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}}; + throws_shape(migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}), + input, + weights); + throws_shape(migraphx::make_op("deconvolution", + {{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}), + input, + weights); + throws_shape(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::max}, + {"padding", {1}}, + {"stride", {0}}, + {"lengths", {1, 1}}}), + input); +} + template -void test_reduce_ops() +void test_softmax_variations() { { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input); + expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input); } { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape( - migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input); - } - { - migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input); + expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input); } + { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input); + expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input); } + { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input); + expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input); } + { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - throws_shape(T{{4}}, input); + int axis = 4; + throws_shape(T{axis}, input); } } +TEST_CASE(logsoftmax) { test_softmax_variations(); } -TEST_CASE(reduce_sum) { test_reduce_ops(); } -TEST_CASE(reduce_mean) { test_reduce_ops(); } - -// 2 inputs arguments -TEST_CASE(matmul) +TEST_CASE(lstm) { { - migraphx::shape s_m1{migraphx::shape::float_type, {5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {5}}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2); - } + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; - { - migraphx::shape s_m1{migraphx::shape::float_type, {5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {5, 2}}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2); - } + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; - { - migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {5}}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2); + expect_shape( + migraphx::shape{migraphx::shape::float_type, + {seq_len, num_dirct, batch_size, hidden_size}}, + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + in_shape, + w_shape, + r_shape); } { - migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {5, 4}}; + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + expect_shape( - migraphx::shape{migraphx::shape::float_type, {1, 4}}, migraphx::op::dot{}, s_m1, s_m2); + migraphx::shape{migraphx::shape::float_type, + {seq_len, num_dirct, batch_size, hidden_size}}, + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + in_shape, + w_shape, + r_shape, + b_shape, + ih_shape); } { - migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {4, 4}}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2); - } + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + float clip = 0.0f; - { - migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2); - } + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - { - migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 1, 4}}, - migraphx::op::dot{}, - s_m1, - s_m2); + expect_shape( + migraphx::shape{migraphx::shape::float_type, + {seq_len, num_dirct, batch_size, hidden_size}}, + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + in_shape, + w_shape, + r_shape, + b_shape, + ih_shape); } { - migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}}, - migraphx::op::dot{}, - s_m1, - s_m2); - } + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; - { - migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; - expect_shape( - migraphx::shape{migraphx::shape::float_type, {4, 8}}, migraphx::op::dot{}, s_m1, s_m2); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + throws_shape( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size + 1}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + in_shape, + w_shape, + r_shape, + b_shape, + ih_shape); } { - migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}}; - migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}}; - expect_shape( - migraphx::shape{migraphx::shape::float_type, {1, 1}}, migraphx::op::dot{}, s_m1, s_m2); + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + throws_shape( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + in_shape, + w_shape, + r_shape, + b_shape, + ih_shape); } { - migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 7}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 7}}, - migraphx::op::dot{}, - s_m1, - s_m2); + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + throws_shape( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + in_shape, + w_shape, + r_shape, + b_shape, + ih_shape); } +} +// 2 inputs arguments +TEST_CASE(matmul) +{ { - migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2); + migraphx::shape s_m1{migraphx::shape::float_type, {5}}; + migraphx::shape s_m2{migraphx::shape::float_type, {5}}; + throws_shape(migraphx::make_op("dot"), s_m1, s_m2); } { - migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2); + migraphx::shape s_m1{migraphx::shape::float_type, {5}}; + migraphx::shape s_m2{migraphx::shape::float_type, {5, 2}}; + throws_shape(migraphx::make_op("dot"), s_m1, s_m2); } -} -// 3 input arguments -TEST_CASE(gemm) -{ { - migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; - migraphx::shape s_m3{migraphx::shape::float_type, {1}}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3); + migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}}; + migraphx::shape s_m2{migraphx::shape::float_type, {5}}; + throws_shape(migraphx::make_op("dot"), s_m1, s_m2); } { - migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; - migraphx::shape s_m3{migraphx::shape::float_type, {1, 1}}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3); + migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}}; + migraphx::shape s_m2{migraphx::shape::float_type, {5, 4}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4}}, + migraphx::make_op("dot"), + s_m1, + s_m2); } { - migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; - migraphx::shape s_m3{migraphx::shape::float_type, {8}}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3); + migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}}; + migraphx::shape s_m2{migraphx::shape::float_type, {4, 4}}; + throws_shape(migraphx::make_op("dot"), s_m1, s_m2); } { - migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; - migraphx::shape s_m3{migraphx::shape::float_type, {4, 1}}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3); + migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}}; + migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}}; + throws_shape(migraphx::make_op("dot"), s_m1, s_m2); } { - migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}}; - migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; - migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3); + migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}}; + migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 1, 4}}, + migraphx::make_op("dot"), + s_m1, + s_m2); } { - migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; - migraphx::shape s_m3{migraphx::shape::float_type, {4}}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3); + migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}}; + migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}}, + migraphx::make_op("dot"), + s_m1, + s_m2); } { migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; - migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}}; expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}}, - migraphx::op::dot{}, + migraphx::make_op("dot"), s_m1, - s_m2, - s_m3); + s_m2); } { - migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; - migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}}, - migraphx::op::dot{}, + migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}}; + migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}}, + migraphx::make_op("dot"), s_m1, - s_m2, - s_m3); + s_m2); } { - migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}}; - migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; - migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3); + migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}}; + migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 7}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 7}}, + migraphx::make_op("dot"), + s_m1, + s_m2); } { - migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; - migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3); + migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; + migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}}; + throws_shape(migraphx::make_op("dot"), s_m1, s_m2); } { - migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}}; - migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; - migraphx::shape s_m3{migraphx::shape::float_type}; - throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3); + migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}}; + migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}}; + throws_shape(migraphx::make_op("dot"), s_m1, s_m2); } } -// quant_dot -TEST_CASE(quant_dot_2args) +TEST_CASE(multibroadcast) { { - migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; - migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}}; - expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}}, - migraphx::op::quant_dot{}, - s_m1, - s_m2); + std::vector lens{4, 2, 5, 3}; + migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}}, + migraphx::make_op("multibroadcast", {{"out_lens", lens}}), + input); } - { - migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}}; - migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}}; - expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}}, - migraphx::op::quant_dot{1, 0}, - s_m1, - s_m2); + std::vector lens{4, 2, 5, 3}; + migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}}, + migraphx::make_op("multibroadcast", {{"out_lens", lens}}), + input); } - { - migraphx::shape s_m1{migraphx::shape::int8_type, {2, 3}}; - migraphx::shape s_m2{migraphx::shape::int8_type, {3, 8}}; - throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2); + std::vector lens{4, 2, 5, 3}; + migraphx::shape input{migraphx::shape::float_type, {5, 1}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}}, + migraphx::make_op("multibroadcast", {{"out_lens", lens}}), + input); } - { - migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; - migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}}; - throws_shape(migraphx::op::quant_dot{}, s_m1, s_m2); + std::vector lens{4, 2, 5, 3}; + migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}}, + migraphx::make_op("multibroadcast", {{"out_lens", lens}}), + input); } -} - -TEST_CASE(quant_dot_3args) -{ { - migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; - migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}}; - migraphx::shape s_m3{migraphx::shape::int32_type, {2, 8}}; - expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}}, - migraphx::op::quant_dot{}, - s_m1, - s_m2, - s_m3); + std::vector lens{4, 2, 5, 3}; + migraphx::shape input{migraphx::shape::float_type, {3}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}}, + migraphx::make_op("multibroadcast", {{"out_lens", lens}}), + input); } - { - migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; - migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}}; - migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}}; - throws_shape(migraphx::op::quant_dot{1, 2}, s_m1, s_m2, s_m3); + std::vector lens{4, 4, 1, 3}; + migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}}, + migraphx::make_op("multibroadcast", {{"out_lens", lens}}), + input); + } + { + std::vector lens{4, 1, 1, 3}; + migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}}, + migraphx::make_op("multibroadcast", {{"out_lens", lens}}), + input); + } + { + std::vector lens{4, 1, 3}; + migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}}; + throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input); + } + { + std::vector lens{4, 1, 3}; + migraphx::shape input{migraphx::shape::float_type, {}}; + throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input); + } + { + std::vector lens{2, 3, 4, 5}; + migraphx::shape input{migraphx::shape::float_type, {3, 4}}; + throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input); + } + { + std::vector lens{2, 3, 4, 5}; + migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}}; + throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input); } } -TEST_CASE(rnn) +TEST_CASE(multinomial) { - { - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + migraphx::shape s{migraphx::shape::float_type, {2, 5}}; + int dtype = 0; - expect_shape( - migraphx::shape{migraphx::shape::float_type, - {seq_len, num_dirct, batch_size, hidden_size}}, - migraphx::op::rnn{ - hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, - in_shape, - w_shape, - r_shape, - b_shape, - ih_shape); - } + throws_shape(migraphx::make_op("multinomial", {{"dtype", dtype}}), s, s); +} - { - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - float clip = 0.0f; +TEST_CASE(pooling_shape) +{ + migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}}; + migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}}; + throws_shape(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::max}, + {"padding", {1}}, + {"stride", {0}}, + {"lengths", {1}}}), + input); + expect_shape(output, + migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::max}, + {"padding", {0, 0}}, + {"stride", {3, 3}}, + {"lengths", {1, 1}}}), + input); - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + migraphx::shape output1{migraphx::shape::float_type, {4, 3, 2, 2}}; + expect_shape(output1, + migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::max}, + {"padding", {0, 0}}, + {"stride", {3, 3}}, + {"lengths", {1, 1}}, + {"ceil_mode", true}}), + input); +} - expect_shape( - migraphx::shape{migraphx::shape::float_type, - {seq_len, num_dirct, batch_size, hidden_size}}, - migraphx::op::rnn{ - hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip}, - in_shape, - w_shape, - r_shape, - b_shape, - ih_shape); +TEST_CASE(prefix_scan_sum) +{ + { + migraphx::shape s{migraphx::shape::float_type, {1, 2, 3}}; + throws_shape( + migraphx::make_op("prefix_scan_sum", {{"axis", 3}, {"exclusive", 0}, {"reverse", 0}}), + s); } { - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - - expect_shape(migraphx::shape{migraphx::shape::float_type, - {seq_len, num_dirct, batch_size, hidden_size}}, - migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - in_shape, - w_shape, - r_shape, - b_shape, - ih_shape); + migraphx::shape s{migraphx::shape::float_type, {1, 2}}; + throws_shape( + migraphx::make_op("prefix_scan_sum", {{"axis", -3}, {"exclusive", 0}, {"reverse", 0}}), + s); } +} - { - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - float clip = 0.0f; +TEST_CASE(quant_convolution_shape) +{ + migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}}; + migraphx::shape input{migraphx::shape::int8_type, {4, 3, 3, 3}}; + migraphx::shape weights{migraphx::shape::int8_type, {4, 3, 3, 3}}; + expect_shape(output, migraphx::make_op("quant_convolution"), input, weights); + throws_shape(migraphx::make_op("quant_convolution"), input); + throws_shape(migraphx::make_op("quant_convolution", + {{"padding", {0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}), + input, + weights); + throws_shape(migraphx::make_op("quant_convolution", + {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}), + input, + weights); - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + migraphx::shape input2{migraphx::shape::int32_type, {3, 3}}; + migraphx::shape weights2{migraphx::shape::float_type, {3, 3}}; + throws_shape(migraphx::make_op("quant_convolution"), input2, weights2); + throws_shape(migraphx::make_op("quant_convolution"), input2, weights); - throws_shape(migraphx::op::rnn{hidden_size + 1, - {migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - in_shape, - w_shape, - r_shape, - b_shape, - ih_shape); - } + migraphx::shape input3{migraphx::shape::int32_type, {4, 3, 3, 3}}; + migraphx::shape weight3{migraphx::shape::float_type, {4, 3, 3, 3}}; + throws_shape(migraphx::make_op("quant_convolution"), input3, weights); + throws_shape(migraphx::make_op("quant_convolution"), input, weight3); + throws_shape(migraphx::make_op("quant_convolution"), input3, weight3); +} +// quant_dot +TEST_CASE(quant_dot_2args) +{ { - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - - throws_shape(migraphx::op::rnn{hidden_size, - {migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - in_shape, - w_shape, - r_shape, - b_shape, - ih_shape); + migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; + migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}}; + expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}}, + migraphx::make_op("quant_dot"), + s_m1, + s_m2); } { - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}}; + migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}}; + expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}}, + migraphx::make_op("quant_dot"), + s_m1, + s_m2); + } - throws_shape( - migraphx::op::rnn{ - hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, - in_shape, - w_shape, - r_shape, - b_shape, - ih_shape); + { + migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}}; + migraphx::shape s_m2{migraphx::shape::int8_type, {8, 8}}; + throws_shape(migraphx::make_op("quant_dot"), s_m1, s_m2); } } -TEST_CASE(gru) +template +void test_reduce_ops() { { - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - expect_shape( - migraphx::shape{migraphx::shape::float_type, - {seq_len, num_dirct, batch_size, hidden_size}}, - migraphx::op::gru{ - hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, - in_shape, - w_shape, - r_shape, - b_shape, - ih_shape); + migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input); } { - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - + migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; expect_shape( - migraphx::shape{migraphx::shape::float_type, - {seq_len, num_dirct, batch_size, hidden_size}}, - migraphx::op::gru{ - hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip}, - in_shape, - w_shape, - r_shape, - b_shape, - ih_shape); + migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input); } - { - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 2; - float clip = 0.0f; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - expect_shape(migraphx::shape{migraphx::shape::float_type, - {seq_len, num_dirct, batch_size, hidden_size}}, - migraphx::op::gru{hidden_size, - {migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - in_shape, - w_shape, - r_shape, - b_shape, - ih_shape); + migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input); } - { - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - float clip = 0.0f; - - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - - throws_shape(migraphx::op::gru{hidden_size + 1, - {migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - in_shape, - w_shape, - r_shape, - b_shape, - ih_shape); + migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input); } - { - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 1; - float clip = 0.0f; + migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input); + } + { + migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; + throws_shape(T{{4}}, input); + } +} - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; +TEST_CASE(reduce_mean) { test_reduce_ops(); } +TEST_CASE(reduce_sum) { test_reduce_ops(); } - throws_shape(migraphx::op::gru{hidden_size, - {migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - in_shape, - w_shape, - r_shape, - b_shape, - ih_shape); +TEST_CASE(reshape_shape) +{ + migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}}; + for(auto&& new_shape : + std::vector>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}}) + { + std::vector lens(new_shape.size()); + std::copy(new_shape.begin(), new_shape.end(), lens.begin()); + migraphx::shape output{migraphx::shape::float_type, lens}; + expect_shape(output, migraphx::make_op("reshape", {{"dims", new_shape}}), input); } + for(auto&& new_shape : + std::vector>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0, 0}, {3, 2, 0}}) { - std::size_t batch_size = 2; - std::size_t seq_len = 2; - std::size_t hidden_size = 4; - std::size_t input_size = 3; - std::size_t num_dirct = 2; - float clip = 0.0f; + throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input); + } - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + std::vector, migraphx::shape>> minus1_tests{ + {{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}}, + {{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}}, + {{2, -1, 0}, {migraphx::shape::float_type, {2, 12, 1}}}, + {{0, 0, -1}, {migraphx::shape::float_type, {24, 1, 1}}}, + {{2, 0, -1}, {migraphx::shape::float_type, {2, 1, 12}}}, + {{-1, 2, 3}, {migraphx::shape::float_type, {4, 2, 3}}}, + {{-1, 0, 3}, {migraphx::shape::float_type, {8, 1, 3}}}, + {{-1, 0, 0}, {migraphx::shape::float_type, {24, 1, 1}}}, + {{-1, 3, 0}, {migraphx::shape::float_type, {8, 3, 1}}}}; - throws_shape( - migraphx::op::gru{ - hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, - in_shape, - w_shape, - r_shape, - b_shape, - ih_shape); + for(auto& it : minus1_tests) + { + expect_shape(it.second, migraphx::make_op("reshape", {{"dims", it.first}}), input); } } -TEST_CASE(lstm) +TEST_CASE(rnn) { { std::size_t batch_size = 2; @@ -1089,19 +1163,26 @@ TEST_CASE(lstm) float clip = 0.0f; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; expect_shape( migraphx::shape{migraphx::shape::float_type, {seq_len, num_dirct, batch_size, hidden_size}}, - migraphx::op::lstm{ - hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), in_shape, w_shape, - r_shape); + r_shape, + b_shape, + ih_shape); } { @@ -1113,18 +1194,21 @@ TEST_CASE(lstm) float clip = 0.0f; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; expect_shape( migraphx::shape{migraphx::shape::float_type, {seq_len, num_dirct, batch_size, hidden_size}}, - migraphx::op::lstm{ - hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip}, + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), in_shape, w_shape, r_shape, @@ -1141,24 +1225,26 @@ TEST_CASE(lstm) float clip = 0.0f; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, - {seq_len, num_dirct, batch_size, hidden_size}}, - migraphx::op::lstm{hidden_size, - {migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - in_shape, - w_shape, - r_shape, - b_shape, - ih_shape); + expect_shape( + migraphx::shape{migraphx::shape::float_type, + {seq_len, num_dirct, batch_size, hidden_size}}, + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + in_shape, + w_shape, + r_shape, + b_shape, + ih_shape); } { @@ -1170,22 +1256,24 @@ TEST_CASE(lstm) float clip = 0.0f; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - throws_shape(migraphx::op::lstm{hidden_size + 1, - {migraphx::op::tanh{}}, - migraphx::op::rnn_direction::forward, - clip}, - in_shape, - w_shape, - r_shape, - b_shape, - ih_shape); + throws_shape( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size + 1}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + in_shape, + w_shape, + r_shape, + b_shape, + ih_shape); } { @@ -1197,22 +1285,24 @@ TEST_CASE(lstm) float clip = 0.0f; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; - throws_shape(migraphx::op::lstm{hidden_size, - {migraphx::op::tanh{}}, - migraphx::op::rnn_direction::bidirectional, - clip}, - in_shape, - w_shape, - r_shape, - b_shape, - ih_shape); + throws_shape( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + in_shape, + w_shape, + r_shape, + b_shape, + ih_shape); } { @@ -1224,16 +1314,19 @@ TEST_CASE(lstm) float clip = 0.0f; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, - {num_dirct, 3 * hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; throws_shape( - migraphx::op::lstm{ - hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), in_shape, w_shape, r_shape, @@ -1242,4 +1335,333 @@ TEST_CASE(lstm) } } +TEST_CASE(slice_shape) +{ + migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}}; + expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}), + input); + expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, + migraphx::make_op( + "slice", {{"axes", {0, 1, 2}}, {"starts", {0, 0, 1}}, {"ends", {2, 2, 3}}}), + input); + expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}}, + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {10}}}), + input); +} + +TEST_CASE(softmax) { test_softmax_variations(); } + +TEST_CASE(test_argmax) +{ + { + migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; + expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, + migraphx::make_op("argmax", {{"axis", 0}}), + input); + } + + { + migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; + expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, + migraphx::make_op("argmax", {{"axis", 1}}), + input); + } + + { + migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; + expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, + migraphx::make_op("argmax", {{"axis", 2}}), + input); + } + + { + migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; + expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, + migraphx::make_op("argmax", {{"axis", 3}}), + input); + } + + { + migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; + throws_shape(migraphx::make_op("argmax", {{"axis", 4}}), input); + } +} + +TEST_CASE(test_argmin) +{ + { + migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; + expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, + migraphx::make_op("argmin", {{"axis", 0}}), + input); + } + + { + migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; + expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, + migraphx::make_op("argmin", {{"axis", 1}}), + input); + } + + { + migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; + expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, + migraphx::make_op("argmin", {{"axis", 2}}), + input); + } + + { + migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; + expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, + migraphx::make_op("argmin", {{"axis", 3}}), + input); + } + + { + migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; + throws_shape(migraphx::make_op("argmin", {{"axis", 4}}), input); + } +} + +TEST_CASE(test_scalar) +{ + migraphx::shape s1{migraphx::shape::float_type, {1}, {1}}; + migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4, 5}, {0, 0, 0, 0}}; + expect_shape(s2, migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), s1); +} + +TEST_CASE(test_scalar_nelemnts) +{ + migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; + throws_shape(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), input); +} + +TEST_CASE(test_scatternd) +{ + { + // k > r + auto dtype = migraphx::shape::float_type; + auto itype = migraphx::shape::int64_type; + migraphx::shape ds{dtype, {8}}; + migraphx::shape is{itype, {4, 2}}; + migraphx::shape us{dtype, {4}}; + throws_shape(migraphx::make_op("scatternd_none"), ds, is, us); + } + + { + // update.lens != indices.lens[0:q-1] ++ data.lens[k:r-1] + auto dtype = migraphx::shape::float_type; + auto itype = migraphx::shape::int64_type; + migraphx::shape ds{dtype, {8}}; + migraphx::shape is{itype, {4, 1}}; + migraphx::shape us{dtype, {2, 2}}; + throws_shape(migraphx::make_op("scatternd_none"), ds, is, us); + } +} + +TEST_CASE(test_squeeze) +{ + migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}}; + expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {3}}}), s1); +} + +TEST_CASE(test_squeeze_all) +{ + migraphx::shape s1{migraphx::shape::float_type, {1}}; + migraphx::shape s2{migraphx::shape::float_type}; + expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1); +} + +TEST_CASE(test_squeeze_transpose) +{ + migraphx::shape s1{migraphx::shape::float_type, {4, 4, 1}, {4, 1, 4}}; + migraphx::shape s2{migraphx::shape::float_type, {4, 4}, {4, 1}}; + expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {2}}}), s1); +} + +TEST_CASE(test_squeeze_multibroadcast) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3, 1, 4}, {0, 1, 1, 0}}; + migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4}, {0, 1, 0}}; + expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {2}}}), s1); +} + +TEST_CASE(test_squeeze_slice) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3, 1, 4}, {108, 36, 6, 1}}; + migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4}, {108, 36, 1}}; + expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {2}}}), s1); +} + +TEST_CASE(test_squeeze_negative_axis) +{ + migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}}; + expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {-2}}}), s1); +} + +TEST_CASE(test_squeeze_wrong_axis) +{ + migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; + throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1); +} + +TEST_CASE(test_unsqueeze) +{ + migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; + expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); +} + +TEST_CASE(test_unsqueeze_negative_axis) +{ + migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; + expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1); +} + +TEST_CASE(test_unsqueeze_scalar) +{ + migraphx::shape s1{migraphx::shape::float_type, {1}, {0}}; + migraphx::shape s2{migraphx::shape::float_type, {1}, {1}}; + expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0}}}), s1); +} + +TEST_CASE(test_unsqueeze_scalar_tensor1) +{ + migraphx::shape s{migraphx::shape::float_type, {4, 3, 3}, {0, 0, 0}}; + throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s); +} + +TEST_CASE(test_unsqueeze_scalar_tensor2) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 1, 1}, {0, 0, 0}}; + throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s); +} + +TEST_CASE(test_unsqueeze_transpose) +{ + migraphx::shape s1{migraphx::shape::float_type, {4, 4, 3}, {12, 1, 4}}; + migraphx::shape s2{migraphx::shape::float_type, {4, 4, 1, 3}, {12, 1, 1, 4}}; + expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); +} + +TEST_CASE(test_unsqueeze_multibroadcast) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {0, 1, 0}}; + migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {0, 1, 1, 0}}; + expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); +} + +TEST_CASE(test_unsqueeze_slice) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}, {108, 36, 1}}; + migraphx::shape s2{migraphx::shape::float_type, {2, 3, 1, 4}, {108, 36, 36, 1}}; + expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); +} + +TEST_CASE(test_unsqueeze_axis_zero) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}}; + migraphx::shape s2{migraphx::shape::float_type, {1, 2, 3, 4}}; + expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0}}}), s1); +} + +TEST_CASE(test_unsqueeze_axis_last) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}}; + migraphx::shape s2{migraphx::shape::float_type, {2, 3, 4, 1}}; + expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-1}}}), s1); +} + +TEST_CASE(test_unsqueeze_multiple_axes_1) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}}; + migraphx::shape s2{migraphx::shape::float_type, {1, 2, 3, 4, 1}}; + expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0, -1}}}), s1); +} + +TEST_CASE(test_unsqueeze_multiple_axes_2) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4}}; + migraphx::shape s2{migraphx::shape::float_type, {1, 1, 2, 3, 4}}; + expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), s1); +} + +TEST_CASE(transpose_shape) +{ + migraphx::shape input{migraphx::shape::float_type, {2, 2}}; + migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}}; + expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1}}}), input); + expect_shape(output, migraphx::make_op("transpose", {{"permutation", {1, 0}}}), input); + throws_shape(migraphx::make_op("transpose", {{"permutation", {1, 2}}}), input); +} + +TEST_CASE(step_test) +{ + migraphx::shape s1{migraphx::shape::float_type, {1, 2, 4}}; + { + migraphx::shape s2{migraphx::shape::float_type, {1, 1, 2}, {8, 8, 3}}; + expect_shape(s2, migraphx::make_op("step", {{"axes", {1, 2}}, {"steps", {2, 3}}}), s1); + } + + { + migraphx::shape s{migraphx::shape::float_type, {1, 2, 4}}; + throws_shape(migraphx::make_op("step", {{"axes", {1, 2}}, {"steps", {1}}}), s1); + } + + { + migraphx::shape s{migraphx::shape::float_type, {1, 2, 4}}; + throws_shape(migraphx::make_op("step", {{"axes", {2, 3}}, {"steps", {2, 3}}}), s1); + } +} + +TEST_CASE(unary_scalar_input) +{ + migraphx::shape ss{migraphx::shape::half_type}; + expect_shape(ss, migraphx::make_op("sin"), ss); + + migraphx::shape s{migraphx::shape::float_type, {1}}; + expect_shape(s, migraphx::make_op("sin"), s); +} + +TEST_CASE(unary_broadcast_input) +{ + migraphx::shape ss{migraphx::shape::half_type, {2, 3}, {1, 0}}; + migraphx::shape s{migraphx::shape::half_type, {2, 3}}; + expect_shape(s, migraphx::make_op("sin"), ss); +} + +TEST_CASE(where_broadcast_input) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 2}, {3, 0}}; + migraphx::shape s2{migraphx::shape::float_type, {2, 2}}; + migraphx::shape s3{migraphx::shape::bool_type, {2, 2}}; + expect_shape(s2, migraphx::make_op("where"), s3, s1, s2); +} + +TEST_CASE(roialign_test) +{ + migraphx::shape sx{migraphx::shape::float_type, {3, 4, 5, 6}}; + migraphx::shape srois{migraphx::shape::float_type, {2, 4}}; + migraphx::shape sbi{migraphx::shape::int64_type, {2}}; + migraphx::shape sout{migraphx::shape::float_type, {2, 4, 1, 1}}; + + expect_shape(sout, migraphx::make_op("roialign"), sx, srois, sbi); + + migraphx::shape sbi1{migraphx::shape::int64_type, {2, 3}}; + throws_shape(migraphx::make_op("roialign"), sx, srois, sbi1); + + migraphx::shape sbi2{migraphx::shape::int64_type, {3}}; + throws_shape(migraphx::make_op("roialign"), sx, srois, sbi2); + + migraphx::shape srois1{migraphx::shape::float_type, {2, 4, 3}}; + throws_shape(migraphx::make_op("roialign"), sx, srois1, sbi); + + migraphx::shape srois2{migraphx::shape::float_type, {2, 3}}; + throws_shape(migraphx::make_op("roialign"), sx, srois2, sbi); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/operation.cpp b/test/operation.cpp index 556211282d05a023ff71127ac00a46734cab7cd6..fafa50710616f95faa46dac3b60ff0411aa689c3 100644 --- a/test/operation.cpp +++ b/test/operation.cpp @@ -46,6 +46,33 @@ struct simple_operation_no_print } }; +struct compilable_op +{ + std::string name() const { return "compilable"; } + migraphx::argument + compute(migraphx::context&, const migraphx::shape&, std::vector args) const + { + if(args.empty()) + return {}; + return args.front(); + } + + migraphx::shape compute_shape(std::vector inputs) const + { + if(inputs.empty()) + return {}; + return inputs.front(); + } + + int output_alias(const std::vector&) const { return 0; } + + migraphx::value + compile(migraphx::context&, const migraphx::shape&, const std::vector&) + { + return {{"compiled", true}}; + } +}; + TEST_CASE(operation_copy_test) { simple_operation s{}; @@ -57,6 +84,15 @@ TEST_CASE(operation_copy_test) EXPECT(op2 == op1); } +TEST_CASE(operation_copy_assign_test) +{ + simple_operation s{}; + migraphx::operation op; + op = s; + // cppcheck-suppress duplicateExpression + EXPECT(s == op); +} + TEST_CASE(operation_equal_test) { simple_operation s{}; @@ -164,4 +200,51 @@ TEST_CASE(check_run_finalize_throw) EXPECT(test::throws([&] { op.finalize(ctx, {}, {}); })); } +TEST_CASE(check_to_value1) +{ + migraphx::operation op = simple_operation{}; + auto v = op.to_value(); + EXPECT(v == migraphx::value{{"data", 1}}); +} + +TEST_CASE(check_to_value2) +{ + migraphx::operation op = simple_operation{}; + auto v = migraphx::to_value(op); + EXPECT(v == migraphx::value{{"name", "simple"}, {"operator", {{"data", 1}}}}); +} + +TEST_CASE(check_from_value1) +{ + migraphx::operation op1 = simple_operation{}; + migraphx::operation op2 = simple_operation{3}; + + op1.from_value({{"data", 3}}); + EXPECT(op1 == op2); +} + +TEST_CASE(check_from_value2) +{ + migraphx::operation op1 = migraphx::from_value({{"data", 3}}); + migraphx::operation op2 = simple_operation{3}; + + EXPECT(op1 == op2); +} + +TEST_CASE(compile) +{ + migraphx::operation op = compilable_op{}; + migraphx::context ctx{}; + auto v = op.compile(ctx, {}, {}); + EXPECT(v.at("compiled").to() == true); +} + +TEST_CASE(compile_non_compilable) +{ + migraphx::operation op = simple_operation{}; + migraphx::context ctx{}; + auto v = op.compile(ctx, {}, {}); + EXPECT(v.empty()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/operators.cpp b/test/operators.cpp new file mode 100755 index 0000000000000000000000000000000000000000..90af2577b99d76267067ee5d4f3b544c75d87e35 --- /dev/null +++ b/test/operators.cpp @@ -0,0 +1,151 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "test.hpp" + +TEST_CASE(load_op) +{ + for(const auto& name : migraphx::get_operators()) + { + auto op = migraphx::load_op(name); + CHECK(op.name() == name); + } +} + +TEST_CASE(make_op) +{ + for(const auto& name : migraphx::get_operators()) + { + auto op = migraphx::load_op(name); + CHECK(op == migraphx::make_op(name)); + } +} + +TEST_CASE(save_op) +{ + for(const auto& name : migraphx::get_operators()) + { + auto op1 = migraphx::load_op(name); + auto v = migraphx::to_value(op1); + auto op2 = migraphx::from_value(v); + CHECK(op1 == op2); + } +} + +TEST_CASE(make_op_from_value1) +{ + migraphx::operation x = migraphx::make_op( + "convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {2, 2}}}); + migraphx::operation y = migraphx::make_op( + "convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {2, 2}}}); + EXPECT(x == y); +} + +TEST_CASE(make_op_from_value2) +{ + migraphx::operation x = migraphx::make_op("convolution", {{"padding", {1, 1}}}); + migraphx::operation y = migraphx::make_op("convolution", {{"padding", {1, 1}}}); + EXPECT(x == y); +} + +TEST_CASE(make_rnn_op_from_value) +{ + migraphx::op::rnn_direction dirct = migraphx::op::rnn_direction::reverse; + migraphx::operation x = migraphx::make_op( + "rnn_var_sl_shift_output", {{"output_name", "hidden_states"}, {"direction", dirct}}); + migraphx::operation y = migraphx::make_op( + "rnn_var_sl_shift_output", + {{"output_name", "hidden_states"}, {"direction", migraphx::to_value(dirct)}}); + EXPECT(x == y); +} + +TEST_CASE(make_op_invalid_key) +{ + EXPECT(test::throws([] { migraphx::make_op("convolution", {{"paddings", {1, 1}}}); })); +} + +TEST_CASE(load_offset) +{ + migraphx::shape s{migraphx::shape::float_type, {4}}; + migraphx::shape bs{migraphx::shape::int8_type, {32}}; + auto op = migraphx::make_op("load", {{"offset", 4}, {"shape", migraphx::to_value(s)}}); + EXPECT(op.compute_shape({bs}) == s); + + migraphx::argument a{bs}; + EXPECT(op.compute(bs, {a}).data() == a.data() + 4); +} + +TEST_CASE(load_out_of_bounds) +{ + migraphx::shape s{migraphx::shape::float_type, {4}}; + migraphx::shape bs{migraphx::shape::int8_type, {16}}; + auto op = migraphx::make_op("load", {{"offset", 4}, {"shape", migraphx::to_value(s)}}); + + migraphx::argument a{bs}; + EXPECT(test::throws([&] { op.compute(bs, {a}); })); +} + +TEST_CASE(load_tuple) +{ + migraphx::shape s{{migraphx::shape{migraphx::shape::int8_type, {3}}, + migraphx::shape{migraphx::shape::float_type, {4}}}}; + migraphx::shape bs{migraphx::shape::int8_type, {32}}; + auto op = migraphx::make_op("load", {{"offset", 4}, {"shape", migraphx::to_value(s)}}); + EXPECT(op.compute_shape({bs}) == s); + + migraphx::argument a{bs}; + auto r = op.compute(bs, {a}); + EXPECT(r.get_sub_objects().size() == 2); + auto* start = a.data() + 4; + EXPECT(r.get_sub_objects()[0].data() == start + 16); + EXPECT(r.get_sub_objects()[1].data() == start); +} + +TEST_CASE(ops) +{ + auto names = migraphx::get_operators(); + EXPECT(names.size() > 1); +} + +TEST_CASE(rnn) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 1}}; + std::vector data1(2, 2.0f); + std::vector data2(2, 3.0f); + migraphx::argument a1(s, data1.data()); + migraphx::argument a2(s, data2.data()); + + auto op = migraphx::make_op("rnn"); + + EXPECT(test::throws([&] { op.compute(s, {a1, a2}); })); +} + +TEST_CASE(if_op) +{ + migraphx::shape s{migraphx::shape::bool_type, {1}}; + std::vector data = {1}; + migraphx::argument cond(s, data.data()); + migraphx::shape sd{migraphx::shape::float_type, {2, 1}}; + std::vector data1(2, 2.0f); + std::vector data2(2, 3.0f); + migraphx::argument a1(sd, data1.data()); + migraphx::argument a2(sd, data2.data()); + + migraphx::module m("name"); + auto l = m.add_literal(migraphx::literal(sd, data1)); + m.add_return({l}); + + auto op = migraphx::make_op("add"); + EXPECT(test::throws([&] { op.compute(s, {cond, a1, a2}, {&m, &m}, {}); })); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/output_alias.cpp b/test/output_alias.cpp index 464af92eb3bf999f3553e4acfd7ba664484c4d6d..9189b05b44effc61a693562fc07fe3fffaf5928c 100644 --- a/test/output_alias.cpp +++ b/test/output_alias.cpp @@ -6,8 +6,9 @@ TEST_CASE(simple_alias) { migraphx::program p; - auto l = p.add_literal(1); - auto p1 = p.add_instruction(pass_op{}, l); + auto* mm = p.get_main_module(); + auto l = mm->add_literal(1); + auto p1 = mm->add_instruction(pass_op{}, l); EXPECT(bool{migraphx::instruction::get_output_alias(l) == l}); EXPECT(bool{migraphx::instruction::get_output_alias(p1) == l}); } @@ -15,10 +16,11 @@ TEST_CASE(simple_alias) TEST_CASE(cascade_alias) { migraphx::program p; - auto l = p.add_literal(1); - auto p1 = p.add_instruction(pass_op{}, l); - auto p2 = p.add_instruction(pass_op{}, p1); - auto p3 = p.add_instruction(pass_op{}, p2); + auto* mm = p.get_main_module(); + auto l = mm->add_literal(1); + auto p1 = mm->add_instruction(pass_op{}, l); + auto p2 = mm->add_instruction(pass_op{}, p1); + auto p3 = mm->add_instruction(pass_op{}, p2); EXPECT(bool{migraphx::instruction::get_output_alias(l) == l}); EXPECT(bool{migraphx::instruction::get_output_alias(p1) == l}); EXPECT(bool{migraphx::instruction::get_output_alias(p2) == l}); @@ -28,9 +30,10 @@ TEST_CASE(cascade_alias) TEST_CASE(no_alias) { migraphx::program p; - auto x = p.add_literal(1); - auto y = p.add_literal(2); - auto sum = p.add_instruction(sum_op{}, x, y); + auto* mm = p.get_main_module(); + auto x = mm->add_literal(1); + auto y = mm->add_literal(2); + auto sum = mm->add_instruction(sum_op{}, x, y); EXPECT(bool{migraphx::instruction::get_output_alias(sum) == sum}); } diff --git a/test/perf_report.cpp b/test/perf_report.cpp index c16130523999ac0ec875c3ec0a2d595802f6ce4f..a82c80f4f5f855ff1ef349b6610d320371dcf37a 100644 --- a/test/perf_report.cpp +++ b/test/perf_report.cpp @@ -1,22 +1,25 @@ #include -#include -#include +#include #include +#include + #include "test.hpp" TEST_CASE(perf_report) { migraphx::program p; + auto* mm = p.get_main_module(); std::stringstream ss; - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(migraphx::op::add{}, one, two); - p.compile(migraphx::cpu::target{}); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(migraphx::make_op("add"), one, two); + p.compile(migraphx::ref::target{}); p.perf_report(ss, 2, {}); std::string output = ss.str(); EXPECT(migraphx::contains(output, "Summary:")); + EXPECT(migraphx::contains(output, "Batch size:")); EXPECT(migraphx::contains(output, "Rate:")); EXPECT(migraphx::contains(output, "Total time:")); EXPECT(migraphx::contains(output, "Total instructions time:")); diff --git a/test/print_graph_test.cpp b/test/print_graph_test.cpp index 6aef08d737a66f8a0dae2276fefc950a39c50280..fd1217bfffd1d206aadcfd5a40031ffb6640b741 100644 --- a/test/print_graph_test.cpp +++ b/test/print_graph_test.cpp @@ -8,12 +8,14 @@ migraphx::program create_program() { migraphx::program p; - auto x = p.add_parameter("x", {migraphx::shape::int64_type}); - auto y = p.add_parameter("y", {migraphx::shape::int64_type}); + auto* mm = p.get_main_module(); - auto sum = p.add_instruction(sum_op{}, x, y); - auto one = p.add_literal(1); - p.add_instruction(sum_op{}, sum, one); + auto x = mm->add_parameter("x", {migraphx::shape::int64_type}); + auto y = mm->add_parameter("y", {migraphx::shape::int64_type}); + + auto sum = mm->add_instruction(sum_op{}, x, y); + auto one = mm->add_literal(1); + mm->add_instruction(sum_op{}, sum, one); return p; } @@ -21,20 +23,23 @@ migraphx::program create_program() TEST_CASE(basic_graph_test) { migraphx::program p = create_program(); + std::stringstream ss; p.print_graph(ss); std::string test = ss.str(); + std::cout << "test = " << test << std::endl; + EXPECT(migraphx::contains(test, "digraph")); EXPECT(migraphx::contains(test, "rankdir=LR")); - EXPECT(migraphx::contains(test, "\"@0\"[label=\"@literal\"]")); + EXPECT(migraphx::contains(test, "\"main:@0\"[label=\"@literal\"]")); EXPECT(migraphx::contains(test, "\"y\"[label=\"@param:y\"]")); EXPECT(migraphx::contains(test, "\"x\"[label=\"@param:x\"]")); - EXPECT(migraphx::contains(test, "\"@1\"[label=\"sum\"]")); - EXPECT(migraphx::contains(test, "\"@2\"[label=\"sum\"]")); - EXPECT(migraphx::contains(test, "\"x\" -> \"@1\"")); - EXPECT(migraphx::contains(test, "\"y\" -> \"@1\"")); - EXPECT(migraphx::contains(test, "\"@1\" -> \"@2\"")); - EXPECT(migraphx::contains(test, "\"@0\" -> \"@2\"")); + EXPECT(migraphx::contains(test, "\"main:@3\"[label=\"sum\"]")); + EXPECT(migraphx::contains(test, "\"main:@4\"[label=\"sum\"]")); + EXPECT(migraphx::contains(test, "\"x\" -> \"main:@3\"")); + EXPECT(migraphx::contains(test, "\"y\" -> \"main:@3\"")); + EXPECT(migraphx::contains(test, "\"main:@3\" -> \"main:@4\"")); + EXPECT(migraphx::contains(test, "\"main:@0\" -> \"main:@4\"")); EXPECT(migraphx::contains(test, "[label=\"int64_type, {1}, {0}\"]")); } diff --git a/test/program_test.cpp b/test/program_test.cpp index 79f31ec81ba6429fb24a4a3023a930cb4769a283..6a7fc6600d7890fcaeb9df0f605f74de7f71e2a8 100644 --- a/test/program_test.cpp +++ b/test/program_test.cpp @@ -2,24 +2,25 @@ #include #include #include -#include -#include -#include -#include +#include #include +#include #include "test.hpp" +#include + #include migraphx::program create_program() { migraphx::program p; + auto* mm = p.get_main_module(); - auto x = p.add_parameter("x", {migraphx::shape::int64_type}); - auto y = p.add_parameter("y", {migraphx::shape::int64_type}); + auto x = mm->add_parameter("x", {migraphx::shape::int64_type}); + auto y = mm->add_parameter("y", {migraphx::shape::int64_type}); - auto sum = p.add_instruction(sum_op{}, x, y); - auto one = p.add_literal(1); - p.add_instruction(sum_op{}, sum, one); + auto sum = mm->add_instruction(sum_op{}, x, y); + auto one = mm->add_literal(1); + mm->add_instruction(sum_op{}, sum, one); return p; } @@ -28,6 +29,8 @@ TEST_CASE(program_equality) { migraphx::program x = create_program(); migraphx::program y = create_program(); + + EXPECT(x.size() == 1); EXPECT(x == y); } @@ -56,18 +59,54 @@ TEST_CASE(program_default_copy_construct) EXPECT(x == y); } +TEST_CASE(program_print) +{ + migraphx::program p = create_program(); + auto* mm = p.get_main_module(); + auto in1 = mm->end(); + + // print end instruction + p.debug_print(in1); + + // print instruction not in the program + auto p2 = p; + auto* mm2 = p2.get_main_module(); + auto in2 = mm2->begin(); + p.debug_print(in2); + + // print last instruction + auto in3 = std::prev(in1); + p.debug_print(in3); +} + +TEST_CASE(program_annotate) +{ + migraphx::program p1 = create_program(); + migraphx::program p2 = create_program(); + + std::stringstream ss1; + p1.annotate(ss1, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; }); + + std::stringstream ss2; + p2.annotate(ss2, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; }); + + EXPECT(ss1.str() == ss2.str()); +} + TEST_CASE(program_copy) { auto create_program_1 = [] { migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 4, 5}}; std::vector data(3 * 4 * 5); std::iota(data.begin(), data.end(), 1.0f); - auto l2 = p.add_literal(migraphx::literal(s, data)); - auto p1 = p.add_parameter("x", s); - auto po = p.add_outline(s); - auto sum = p.add_instruction(migraphx::op::add{}, l2, po); - p.add_instruction(migraphx::op::mul{}, sum, p1); + auto l2 = mm->add_literal(migraphx::literal(s, data)); + auto p1 = mm->add_parameter("x", s); + auto po = mm->add_outline(s); + auto sum = mm->add_instruction(migraphx::make_op("add"), l2, po); + mm->add_instruction(migraphx::make_op("mul"), sum, p1); return p; }; @@ -77,11 +116,13 @@ TEST_CASE(program_copy) migraphx::program p2{}; p2 = p1; - p2.compile(migraphx::cpu::target{}); + p2.compile(migraphx::ref::target{}); EXPECT(p1 != p2); - p1.compile(migraphx::cpu::target{}); + p1.compile(migraphx::ref::target{}); EXPECT(p1 == p2); + + EXPECT(p1.get_parameter_names() == p2.get_parameter_names()); } { @@ -89,7 +130,7 @@ TEST_CASE(program_copy) auto p2(p1); EXPECT(p1 == p2); - p1.compile(migraphx::cpu::target{}); + p1.compile(migraphx::ref::target{}); EXPECT(p1 != p2); p2 = p1; @@ -104,28 +145,30 @@ TEST_CASE(program_copy) p2 = p1; EXPECT(p1 == p2); - p1.compile(migraphx::cpu::target{}); - p2.compile(migraphx::cpu::target{}); + p1.compile(migraphx::ref::target{}); + p2.compile(migraphx::ref::target{}); EXPECT(p1 == p2); } { migraphx::program p1; + auto* mm1 = p1.get_main_module(); + migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; migraphx::shape s2{migraphx::shape::float_type, {3, 6}}; migraphx::shape s3{migraphx::shape::float_type, {2, 6}}; - auto para1 = p1.add_parameter("m1", s1); - auto para2 = p1.add_parameter("m2", s2); - auto para3 = p1.add_parameter("m3", s3); - p1.add_instruction(migraphx::op::dot{0.31f, 0.28f}, para1, para2, para3); - + auto para1 = mm1->add_parameter("m1", s1); + auto para2 = mm1->add_parameter("m2", s2); + auto para3 = mm1->add_parameter("m3", s3); + migraphx::add_apply_alpha_beta( + *mm1, {para1, para2, para3}, migraphx::make_op("dot"), 0.31f, 0.28f); migraphx::program p2{}; p2 = p1; EXPECT(p2 == p1); - p1.compile(migraphx::cpu::target{}); - p2.compile(migraphx::cpu::target{}); + p1.compile(migraphx::ref::target{}); + p2.compile(migraphx::ref::target{}); EXPECT(p2 == p1); } } diff --git a/test/propagate_constant_test.cpp b/test/propagate_constant_test.cpp index 8a9678cdfb8199d55ad7ea50f5d2b60e4a786215..30246e447470b7e638ad58c87e629ebabdcef774 100644 --- a/test/propagate_constant_test.cpp +++ b/test/propagate_constant_test.cpp @@ -1,111 +1,143 @@ #include #include #include -#include -#include -#include #include +#include + #include -void run_pass(migraphx::program& p) +void run_pass(migraphx::module& m) { - migraphx::run_passes(p, {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m, {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}}); } TEST_CASE(const_add) { - migraphx::program p1; - auto one = p1.add_literal(1); - auto two = p1.add_literal(2); - auto sum = p1.add_instruction(migraphx::op::add{}, one, two); - p1.add_instruction(pass_op{}, sum); - run_pass(p1); - - migraphx::program p2; - auto total = p2.add_literal(3); - p2.add_instruction(pass_op{}, total); - EXPECT(p1 == p2); + migraphx::module m1; + auto one = m1.add_literal(1); + auto two = m1.add_literal(2); + auto sum = m1.add_instruction(migraphx::make_op("add"), one, two); + m1.add_instruction(pass_op{}, sum); + run_pass(m1); + + migraphx::module m2; + auto total = m2.add_literal(3); + m2.add_instruction(pass_op{}, total); + EXPECT(m1 == m2); } TEST_CASE(const_add_parameter) { - migraphx::program p1; - auto one = p1.add_parameter("one", {migraphx::shape::int32_type, {1}}); - auto two = p1.add_literal(2); - auto sum = p1.add_instruction(migraphx::op::add{}, one, two); - p1.add_instruction(pass_op{}, sum); - run_pass(p1); - - migraphx::program p2; - auto total = p2.add_literal(3); - p2.add_instruction(pass_op{}, total); - EXPECT(p1 != p2); + migraphx::module m1; + auto one = m1.add_parameter("one", {migraphx::shape::int32_type, {1}}); + auto two = m1.add_literal(2); + auto sum = m1.add_instruction(migraphx::make_op("add"), one, two); + m1.add_instruction(pass_op{}, sum); + run_pass(m1); + + migraphx::module m2; + auto total = m2.add_literal(3); + m2.add_instruction(pass_op{}, total); + EXPECT(m1 != m2); } TEST_CASE(const_multiadd) { - migraphx::program p1; - auto one = p1.add_literal(1); - auto two = p1.add_literal(2); - auto sum1 = p1.add_instruction(migraphx::op::add{}, one, two); - auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two); - p1.add_instruction(pass_op{}, sum2); - run_pass(p1); - - migraphx::program p2; - auto total = p2.add_literal(5); - p2.add_instruction(pass_op{}, total); - EXPECT(p1 == p2); + migraphx::module m1; + auto one = m1.add_literal(1); + auto two = m1.add_literal(2); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, two); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), sum1, two); + m1.add_instruction(pass_op{}, sum2); + run_pass(m1); + + migraphx::module m2; + auto total = m2.add_literal(5); + m2.add_instruction(pass_op{}, total); + EXPECT(m1 == m2); } TEST_CASE(const_add_mul) { - migraphx::program p1; - auto one = p1.add_literal(1); - auto two = p1.add_literal(2); - auto mul = p1.add_instruction(migraphx::op::mul{}, two, two); - auto sum1 = p1.add_instruction(migraphx::op::add{}, one, mul); - auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two); - p1.add_instruction(pass_op{}, sum2); - run_pass(p1); - - migraphx::program p2; - auto total = p2.add_literal(7); - p2.add_instruction(pass_op{}, total); - EXPECT(p1 == p2); + migraphx::module m1; + auto one = m1.add_literal(1); + auto two = m1.add_literal(2); + auto mul = m1.add_instruction(migraphx::make_op("mul"), two, two); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, mul); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), sum1, two); + m1.add_instruction(pass_op{}, sum2); + run_pass(m1); + + migraphx::module m2; + auto total = m2.add_literal(7); + m2.add_instruction(pass_op{}, total); + EXPECT(m1 == m2); } TEST_CASE(const_add_scalar) { - migraphx::program p1; - auto one = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(1)); - auto two = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(2)); - auto sum = p1.add_instruction(migraphx::op::add{}, one, two); - p1.add_instruction(pass_op{}, sum); - run_pass(p1); - - migraphx::program p2; + migraphx::module m1; + auto one = m1.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}), + m1.add_literal(1)); + auto two = m1.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}), + m1.add_literal(2)); + auto sum = m1.add_instruction(migraphx::make_op("add"), one, two); + m1.add_instruction(pass_op{}, sum); + run_pass(m1); + + migraphx::module m2; auto total = - p2.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2, 2}}, {3, 3, 3, 3}}); - p2.add_instruction(pass_op{}, total); - EXPECT(p1 == p2); + m2.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2, 2}}, {3, 3, 3, 3}}); + m2.add_instruction(pass_op{}, total); + EXPECT(m1 == m2); } TEST_CASE(const_scalar) { - migraphx::program p1; + migraphx::module m1; { - auto one = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(1)); - p1.add_instruction(pass_op{}, one); + auto one = m1.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}), + m1.add_literal(1)); + m1.add_instruction(pass_op{}, one); } - run_pass(p1); + run_pass(m1); - migraphx::program p2; + migraphx::module m2; { - auto one = p2.add_instruction(migraphx::op::scalar{{2, 2}}, p2.add_literal(1)); - p2.add_instruction(pass_op{}, one); + auto one = m2.add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 2}}}), + m2.add_literal(1)); + m2.add_instruction(pass_op{}, one); + } + EXPECT(m1 == m2); +} + +TEST_CASE(const_dot) +{ + migraphx::module m1; + { + migraphx::shape s{migraphx::shape::float_type, {2, 2}}; + std::vector vec = {1.0f, 2.0f, 1.0f, 2.0f}; + + auto l = m1.add_literal(migraphx::literal(s, vec)); + auto dl = m1.add_instruction(migraphx::make_op("dot"), l, l); + auto x = m1.add_parameter("x", s); + auto r = m1.add_instruction(migraphx::make_op("add"), dl, x); + m1.add_return({r}); + } + + run_pass(m1); + + migraphx::module m2; + { + migraphx::shape s{migraphx::shape::float_type, {2, 2}}; + std::vector vec = {3.0f, 6.0f, 3.0f, 6.0f}; + + auto x = m2.add_parameter("x", s); + auto l = m2.add_literal(migraphx::literal(s, vec)); + auto r = m2.add_instruction(migraphx::make_op("add"), l, x); + m2.add_return({r}); } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/py/CMakeLists.txt b/test/py/CMakeLists.txt old mode 100644 new mode 100755 index f8d09284891d1a641373fc10a1a41623ed7db9bc..8057a007bf325b27ebf265ab438967162095c17d --- a/test/py/CMakeLists.txt +++ b/test/py/CMakeLists.txt @@ -1,25 +1,36 @@ -find_package(PythonInterp) +include(PythonModules) function(add_py_test NAME SCRIPT) - set (ENV_COMMAND ${CMAKE_COMMAND} -E env - "PYTHONPATH=$" - "PYTHONMALLOC=debug" - "MALLOC_CHECK_=3" - ) - add_test( - NAME test_py_${NAME} - COMMAND ${ENV_COMMAND} ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${SCRIPT} ${ARGN}) - add_custom_target(test_py_${NAME} - COMMAND ${ENV_COMMAND} ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${SCRIPT} ${ARGN} - COMMENT "${PYTHON_EXECUTABLE} ${SCRIPT}") + foreach(PYTHON_VERSION ${PYTHON_VERSIONS}) + set (ENV_COMMAND ${CMAKE_COMMAND} -E env + "PYTHONPATH=$" + "PYTHONMALLOC=debug" + "MALLOC_CHECK_=3" + ) + set(PYTHON_EXECUTABLE ${PYTHON_${PYTHON_VERSION}_EXECUTABLE}) + add_test( + NAME test_py_${PYTHON_VERSION}_${NAME} + COMMAND ${ENV_COMMAND} ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${SCRIPT} ${ARGN}) + add_custom_target(test_py_${PYTHON_VERSION}_${NAME} + COMMAND ${ENV_COMMAND} ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${SCRIPT} ${ARGN} + COMMENT "${PYTHON_EXECUTABLE} ${SCRIPT}") + + endforeach() endfunction() -add_dependencies(tests migraphx_py) -add_dependencies(check migraphx_py) +foreach(PYTHON_VERSION ${PYTHON_VERSIONS}) + add_dependencies(tests migraphx_py_${PYTHON_VERSION}) + add_dependencies(check migraphx_py_${PYTHON_VERSION}) +endforeach() -add_py_test(cpu test_cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) +add_py_test(ref test_cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) +add_py_test(save_load test_save_load.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) +add_py_test(op test_op.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) +add_py_test(shape test_shape.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) +add_py_test(module_construct test_module_construct.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) if(MIGRAPHX_ENABLE_GPU) add_py_test(gpu_offload test_gpu_offload.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(gpu test_gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(array test_array.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) +add_py_test(backend onnx_backend_test.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) endif() diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py new file mode 100755 index 0000000000000000000000000000000000000000..527cbf408eeac302d8b05df64026b8b52748db93 --- /dev/null +++ b/test/py/onnx_backend_test.py @@ -0,0 +1,340 @@ +import sys +if sys.version_info < (3, 0): + sys.exit() + +import argparse +import os +import unittest +import onnx +import onnx.backend.test +import numpy as np +from onnx_migraphx.backend import MIGraphXBackend as c2 +from packaging import version + +pytest_plugins = 'onnx.backend.test.report', + + +class MIGraphXBackendTest(onnx.backend.test.BackendTest): + def __init__(self, backend, parent_module=None): + super(MIGraphXBackendTest, self).__init__(backend, parent_module) + + @classmethod + def assert_similar_outputs(cls, ref_outputs, outputs, rtol, atol): + prog_string = c2.get_program() + np.testing.assert_equal(len(ref_outputs), + len(outputs), + err_msg=prog_string) + for i in range(len(outputs)): + np.testing.assert_equal(ref_outputs[i].dtype, + outputs[i].dtype, + err_msg=prog_string) + if ref_outputs[i].dtype == np.object: + np.testing.assert_array_equal(ref_outputs[i], + outputs[i], + err_msg=prog_string) + else: + np.testing.assert_allclose(ref_outputs[i], + outputs[i], + rtol=1e-3, + atol=1e-5, + err_msg=prog_string) + + +def disabled_tests_onnx_1_7_0(backend_test): + backend_test.exclude(r'test_logsoftmax_axis_0_cpu') + backend_test.exclude(r'test_logsoftmax_axis_1_cpu') + backend_test.exclude(r'test_logsoftmax_default_axis_cpu') + backend_test.exclude(r'test_softmax_axis_0_cpu') + backend_test.exclude(r'test_softmax_axis_1_cpu') + backend_test.exclude(r'test_softmax_default_axis_cpu') + + +def disabled_tests_onnx_1_8_1(backend_test): + backend_test.exclude(r'test_if_seq_cpu') + backend_test.exclude(r'test_if_seq_cpu') + backend_test.exclude(r'test_reduce_sum_default_axes_keepdims_example_cpu') + backend_test.exclude(r'test_reduce_sum_default_axes_keepdims_random_cpu') + backend_test.exclude(r'test_reduce_sum_do_not_keepdims_example_cpu') + backend_test.exclude(r'test_reduce_sum_do_not_keepdims_random_cpu') + backend_test.exclude(r'test_reduce_sum_empty_axes_input_noop_example_cpu') + backend_test.exclude(r'test_reduce_sum_empty_axes_input_noop_random_cpu') + backend_test.exclude(r'test_reduce_sum_keepdims_example_cpu') + backend_test.exclude(r'test_reduce_sum_keepdims_random_cpu') + backend_test.exclude(r'test_reduce_sum_negative_axes_keepdims_example_cpu') + backend_test.exclude(r'test_reduce_sum_negative_axes_keepdims_random_cpu') + backend_test.exclude(r'test_unsqueeze_axis_0_cpu') + backend_test.exclude(r'test_unsqueeze_axis_1_cpu') + backend_test.exclude(r'test_unsqueeze_axis_2_cpu') + backend_test.exclude(r'test_unsqueeze_negative_axes_cpu') + backend_test.exclude(r'test_unsqueeze_three_axes_cpu') + backend_test.exclude(r'test_unsqueeze_two_axes_cpu') + backend_test.exclude(r'test_unsqueeze_unsorted_axes_cpu') + + +def create_backend_test(testname=None, target_device=None): + if target_device is not None: + c2.set_device(target_device) + backend_test = MIGraphXBackendTest(c2, __name__) + + if testname: + backend_test.include(testname + '.*') + else: + # Include all of the nodes that we support. + # Onnx native node tests + backend_test.include(r'.*test_abs.*') + backend_test.include(r'.*test_acos.*') + backend_test.include(r'.*test_acosh.*') + backend_test.include(r'.*test_add.*') + backend_test.include(r'.*test_and.*') + backend_test.include(r'.*test_argmax.*') + backend_test.include(r'.*test_argmin.*') + backend_test.include(r'.*test_asin.*') + backend_test.include(r'.*test_asinh.*') + backend_test.include(r'.*test_atan.*') + backend_test.include(r'.*test_atanh.*') + backend_test.include(r'.*test_averagepool.*') + backend_test.include(r'.*test_AvgPool.*') + backend_test.include(r'.*test_BatchNorm.*eval.*') + backend_test.include(r'.*test_ceil.*') + backend_test.include(r'.*test_celu.*') + backend_test.include(r'.*test_clip.*') + backend_test.include(r'.*test_concat.*') + backend_test.include(r'.*test_constant.*') + backend_test.include(r'.*test_Conv[1-3]d*') + backend_test.include(r'.*test_cos.*') + backend_test.include(r'.*test_cosh.*') + backend_test.include(r'.*test_depthtospace.*') + backend_test.include(r'.*test_dequantizelinear') + backend_test.include(r'.*test_div.*') + backend_test.include(r'.*test_dropout.*') + backend_test.include(r'.*test_ELU*') + backend_test.include(r'.*test_elu.*') + backend_test.include(r'.*test_equal.*') + backend_test.include(r'.*test_Embedding*') + backend_test.include(r'.*test_exp.*') + backend_test.include(r'.*test_eyelike.*') + backend_test.include(r'.*test_flatten.*') + backend_test.include(r'.*test_floor.*') + backend_test.include(r'.*test_gather.*') + backend_test.include(r'.*test_gemm.*') + backend_test.include(r'.*test_globalaveragepool.*') + backend_test.include(r'.*test_globalmaxpool.*') + backend_test.include(r'.*test_greater.*') + backend_test.include(r'.*test_hardsigmoid.*') + backend_test.include(r'.*test_hardswish.*') + backend_test.include(r'.*test_identity.*') + backend_test.include(r'.*test_if.*') + backend_test.include(r'.*test_isnan.*') + backend_test.include(r'.*test_LeakyReLU*') + backend_test.include(r'.*test_leakyrelu.*') + backend_test.include(r'.*test_less.*') + backend_test.include(r'.*test_Linear.*') + backend_test.include(r'.*test_log.*') + backend_test.include(r'.*test_logsoftmax.*') + backend_test.include(r'.*test_LogSoftmax.*') + backend_test.include(r'.*test_log_softmax.*') + backend_test.include(r'.*test_lrn.*') + backend_test.include(r'.*test_matmul.*') + backend_test.include(r'.*test_max.*') + backend_test.include(r'.*test_MaxPool[1-9]d.*') + backend_test.include(r'.*test_mean.*') + backend_test.include(r'.*test_min.*') + backend_test.include(r'.*test_mul.*') + backend_test.include(r'.*test_multinomial.*') + backend_test.include(r'.*test_Multinomial.*') + backend_test.include(r'.*test_neg.*') + backend_test.include(r'.*test_not.*') + backend_test.include(r'.*test_operator_addmm.*') + backend_test.include(r'.*test_operator_basic.*') + backend_test.include(r'.*test_operator_chunk.*') + backend_test.include(r'.*test_operator_clip.*') + backend_test.include(r'.*test_operator_concat2.*') + backend_test.include(r'.*test_operator_conv_.*') + backend_test.include(r'.*test_operator_exp.*') + backend_test.include(r'.*test_operator_flatten.*') + backend_test.include(r'.*test_operator_index.*') + backend_test.include(r'.*test_operator_max_.*') + backend_test.include(r'.*test_operator_maxpool.*') + backend_test.include(r'.*test_operator_min.*') + backend_test.include(r'.*test_operator_mm.*') + backend_test.include(r'.*test_operator_non_float_params.*') + backend_test.include(r'.*test_operator_params.*') + backend_test.include(r'.*test_operator_permute2.*') + backend_test.include(r'.*test_operator_pow.*') + backend_test.include(r'.*test_operator_reduced_mean_.*') + backend_test.include(r'.*test_operator_reduced_mean_keepdim.*') + backend_test.include(r'.*test_operator_reduced_sum_.*') + backend_test.include(r'.*test_operator_reduced_sum_keepdim.*') + backend_test.include(r'.*test_operator_selu.*') + backend_test.include(r'.*test_operator_sqrt.*') + backend_test.include(r'.*test_operator_symbolic_override.*') + backend_test.include(r'.*test_operator_symbolic_override_nested.*') + backend_test.include(r'.*test_operator_view.*') + backend_test.include(r'.*test_or.*') + backend_test.include(r'.*test_pow.*') + backend_test.include(r'.*test_PoissonNLLLLoss_no_reduce*') + backend_test.include(r'.*test_quantizelinear') + backend_test.include(r'.*test_reciprocal.*') + backend_test.include(r'.*test_reduce.*') + backend_test.include(r'.*test_ReLU*') + backend_test.include(r'.*test_relu.*') + #backend_test.include(r'.*test_reversesequence.*') + backend_test.include(r'.*test_RoiAlign*') + backend_test.include(r'.*test_roialign.*') + backend_test.include(r'.*test_scatter.*') + backend_test.include(r'.*test_Scatter.*') + backend_test.include(r'.*test_selu.*') + backend_test.include(r'.*test_shape.*') + backend_test.include(r'.*test_Sigmoid*') + backend_test.include(r'.*test_sigmoid.*') + backend_test.include(r'.*test_sin.*') + backend_test.include(r'.*test_sinh.*') + backend_test.include(r'.*test_size.*') + backend_test.include(r'.*test_Softmax*') + backend_test.include(r'.*test_softmax.*') + backend_test.include(r'.*test_Softmin*') + backend_test.include(r'.*test_Softplus*') + backend_test.include(r'.*test_softplus.*') + backend_test.include(r'.*test_softsign.*') + backend_test.include(r'.*test_sqrt.*') + backend_test.include(r'.*test_squeeze_cuda') + backend_test.include(r'.*test_sub.*') + backend_test.include(r'.*test_sum.*') + backend_test.include(r'.*test_tan.*') + backend_test.include(r'.*test_Tanh*') + backend_test.include(r'.*test_tanh.*') + backend_test.include(r'.*test_thresholdedrelu.*') + backend_test.include(r'.*test_topk.*') + backend_test.include(r'.*test_Topk.*') + backend_test.include(r'.*test_transpose.*') + backend_test.include(r'.*test_unsqueeze.*') + backend_test.include(r'.*test_where*') + backend_test.include(r'.*test_where.*') + backend_test.include(r'.*test_xor.*') + backend_test.include(r'.*test_ZeroPad2d*') + + # # Onnx native model tests + backend_test.include(r'.*test_bvlc_alexnet.*') + backend_test.include(r'.*test_densenet121.*') + backend_test.include(r'.*test_inception_v1.*') + backend_test.include(r'.*test_inception_v2.*') + backend_test.include(r'.*test_resnet50.*') + backend_test.include(r'.*test_shufflenet.*') + backend_test.include(r'.*test_squeezenet.*') + backend_test.include(r'.*test_vgg19.*') + backend_test.include(r'.*test_zfnet512.*') + + # exclude unenabled ops get pulled in with wildcards + # test_constant_pad gets pulled in with the test_constant* wildcard. Explicitly disable padding tests for now. + # Operator MATMULINTEGER is not supported by TRT + backend_test.exclude(r'.*test_matmulinteger.*') + backend_test.exclude(r'.*test_maxunpool.*') + # Absolute diff failed because + # numpy compares the difference between actual and desired to atol + rtol * abs(desired) + + # failed test cases + backend_test.exclude( + r'test_argmax_keepdims_example_select_last_index_cpu') + backend_test.exclude( + r'test_argmax_negative_axis_keepdims_example_select_last_index_cpu' + ) + backend_test.exclude( + r'test_argmax_no_keepdims_example_select_last_index_cpu') + backend_test.exclude( + r'test_argmin_keepdims_example_select_last_index_cpu') + backend_test.exclude( + r'test_argmin_negative_axis_keepdims_example_select_last_index_cpu' + ) + backend_test.exclude( + r'test_argmin_no_keepdims_example_select_last_index_cpu') + backend_test.exclude(r'test_lrn_cpu') + backend_test.exclude(r'test_lrn_default_cpu') + backend_test.exclude(r'test_maxpool_2d_dilations_cpu') + backend_test.exclude(r'test_MaxPool2d_stride_padding_dilation_cpu') + backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu') + backend_test.exclude( + r'test_maxpool_with_argmax_2d_precomputed_pads_cpu') + backend_test.exclude( + r'test_maxpool_with_argmax_2d_precomputed_strides_cpu') + + # error cases + backend_test.exclude(r'test_constant_pad_cpu') + backend_test.exclude(r'test_constantofshape_float_ones_cpu') + backend_test.exclude(r'test_constantofshape_int_shape_zero_cpu') + backend_test.exclude(r'test_constantofshape_int_zeros_cpu') + backend_test.exclude(r'test_expand_dim_changed_cpu') + backend_test.exclude(r'test_expand_dim_unchanged_cpu') + backend_test.exclude(r'test_expand_shape_model1_cpu') + backend_test.exclude(r'test_expand_shape_model2_cpu') + backend_test.exclude(r'test_expand_shape_model3_cpu') + backend_test.exclude(r'test_expand_shape_model4_cpu') + backend_test.exclude(r'test_identity_sequence_cpu') + backend_test.exclude(r'test_maxpool_2d_uint8_cpu') + backend_test.exclude(r'test_negative_log_likelihood_loss_*') + + # all reduce ops have dynamic axes inputs + backend_test.exclude(r'test_softmax_cross_entropy_*') + backend_test.exclude(r'test_Embedding_cpu') + + # real model tests + backend_test.exclude(r'test_inception_v1_cpu') + backend_test.exclude(r'test_resnet50_cpu') + backend_test.exclude(r'test_squeezenet_cpu') + + # additional cases disabled for a specific onnx version + if version.parse(onnx.__version__) <= version.parse("1.7.0"): + disabled_tests_onnx_1_7_0(backend_test) + + if version.parse(onnx.__version__) >= version.parse("1.8.0"): + disabled_tests_onnx_1_8_1(backend_test) + + +# import all test cases at global scope to make +# them visible to python.unittest. + globals().update(backend_test.enable_report().test_cases) + + return backend_test + + +def parse_args(): + parser = argparse.ArgumentParser( + os.path.basename(__file__), + description='Run the ONNX backend tests using MIGraphX.') + + # Add an argument to match a single test name, by adding the name to the 'include' filter. + # Using -k with python unittest (https://docs.python.org/3/library/unittest.html#command-line-options) + # doesn't work as it filters on the test method name (Runner._add_model_test) rather than inidividual + # test case names. + parser.add_argument( + '-t', + '--test-name', + dest='testname', + type=str, + help= + "Only run tests that match this value. Matching is regex based, and '.*' is automatically appended" + ) + parser.add_argument('-d', + '--device', + dest='device', + type=str, + help="Specify the device to run test on") + + # parse just our args. python unittest has its own args and arg parsing, and that runs inside unittest.main() + args, left = parser.parse_known_args() + sys.argv = sys.argv[:1] + left + + if args.device is not None: + print("run on {} device....".format(args.device)) + else: + print("Default GPU device is used ....") + + return args + + +if __name__ == '__main__': + if sys.version_info < (3, 0): + sys.exit() + + args = parse_args() + backend_test = create_backend_test(args.testname, args.device) + unittest.main() diff --git a/test/py/test_array.py b/test/py/test_array.py index 31c064d20f7e9855f59c82b0f429159c27cc113a..13aec7e935458b7cbd5d71560b23ed2233d46b42 100644 --- a/test/py/test_array.py +++ b/test/py/test_array.py @@ -82,8 +82,8 @@ def test_output(): p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx") p.compile(migraphx.get_target("gpu")) - r1 = run(p) - r2 = run(p) + r1 = run(p)[-1] + r2 = run(p)[-1] assert_eq(r1, r2) assert_eq(r1.tolist(), r2.tolist()) diff --git a/test/py/test_cpu.py b/test/py/test_cpu.py old mode 100644 new mode 100755 index 5c204461ffcae0b754777bf04c9a340a2861ac25..1f737e08bc079f2048d654ebe032cd1be5425e48 --- a/test/py/test_cpu.py +++ b/test/py/test_cpu.py @@ -1,18 +1,66 @@ -import migraphx - -p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx") -print(p) -s1 = p.get_shape() -print("Compiling ...") -p.compile(migraphx.get_target("cpu")) -print(p) -s2 = p.get_shape() -assert s1 == s2 -params = {} - -for key, value in p.get_parameter_shapes().items(): - print("Parameter {} -> {}".format(key, value)) - params[key] = migraphx.generate_argument(value) - -r = p.run(params) -print(r) +import migraphx, array, sys + + +def test_conv_relu(): + p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx") + print(p) + s1 = p.get_output_shapes()[-1] + print("Compiling ...") + p.compile(migraphx.get_target("ref")) + print(p) + s2 = p.get_output_shapes()[-1] + assert s1 == s2 + params = {} + + for key, value in p.get_parameter_shapes().items(): + print("Parameter {} -> {}".format(key, value)) + params[key] = migraphx.generate_argument(value) + + r = p.run(params)[-1] + print(r) + + +def create_buffer(t, data, shape): + a = array.array(t, data) + if sys.version_info >= (3, 0): + m = memoryview(a.tobytes()) + return m.cast(t, shape) + else: + m = memoryview(a.tostring()) + return m + + +def test_add_scalar(): + p = migraphx.parse_onnx("add_scalar_test.onnx") + print(p) + s1 = p.get_output_shapes()[-1] + print("Compiling ...") + p.compile(migraphx.get_target("ref")) + print(p) + s2 = p.get_output_shapes()[-1] + assert s1 == s2 + + d0 = list(range(120)) + arg0 = create_buffer("B", d0, [2, 3, 4, 5]) + d1 = [1] + arg1 = create_buffer("B", d1, ()) + + params = {} + params["0"] = migraphx.argument(arg0) + params["1"] = migraphx.argument(arg1) + + r = p.run(params)[-1] + print(r) + + +def test_module(): + p = migraphx.parse_onnx("add_scalar_test.onnx") + mm = p.get_main_module() + print(p) + print(mm) + + +test_conv_relu() +test_module() +if sys.version_info >= (3, 0): + test_add_scalar() diff --git a/test/py/test_gpu.py b/test/py/test_gpu.py index f38ab2fa1bc208221085dd9c65d0d7342d58da78..63fc72771d4cfb920990665f700b7262c92b567d 100644 --- a/test/py/test_gpu.py +++ b/test/py/test_gpu.py @@ -1,15 +1,117 @@ +import sys import migraphx +try: + import numpy as np +except: + sys.exit() -p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx") -print(p) -print("Compiling ...") -p.compile(migraphx.get_target("gpu")) -print(p) -params = {} -for key, value in p.get_parameter_shapes().items(): - print("Parameter {} -> {}".format(key, value)) - params[key] = migraphx.generate_argument(value) +def test_conv_relu(): + p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx") + print(p) + print("Compiling ...") + p.compile(migraphx.get_target("gpu")) + print(p) + params = {} -r = p.run(params) -print(r) + for key, value in p.get_parameter_shapes().items(): + print("Parameter {} -> {}".format(key, value)) + params[key] = migraphx.generate_argument(value) + + r = p.run(params) + print(r) + + +def test_sub_uint64(): + p = migraphx.parse_onnx("implicit_sub_bcast_test.onnx") + print(p) + print("Compiling ...") + p.compile(migraphx.get_target("gpu")) + print(p) + params = {} + + shapes = p.get_parameter_shapes() + params["0"] = np.arange(120).reshape(shapes["0"].lens()).astype(np.uint64) + params["1"] = np.arange(20).reshape(shapes["1"].lens()).astype(np.uint64) + + r = p.run(params) + print(r) + + +def test_neg_int64(): + p = migraphx.parse_onnx("neg_test.onnx") + print(p) + print("Compiling ...") + p.compile(migraphx.get_target("gpu")) + print(p) + params = {} + + shapes = p.get_parameter_shapes() + params["0"] = np.arange(6).reshape(shapes["0"].lens()).astype(np.int64) + + r = p.run(params) + print(r) + + +def test_nonzero(): + p = migraphx.parse_onnx("nonzero_dynamic_test.onnx") + print(p) + print("Compiling ...") + p.compile(migraphx.get_target("gpu")) + print(p) + params = {} + + shapes = p.get_parameter_shapes() + params["data"] = np.array([1, 1, 0, 1]).reshape( + shapes["data"].lens()).astype(np.bool) + + r = p.run(params) + print(r) + + +def test_fp16_imagescaler(): + p = migraphx.parse_onnx("imagescaler_half_test.onnx") + print(p) + s1 = p.get_output_shapes()[-1] + print("Compiling ...") + p.compile(migraphx.get_target("gpu")) + print(p) + s2 = p.get_output_shapes()[-1] + assert s1 == s2 + + params = {} + shapes = p.get_parameter_shapes() + params["0"] = np.random.randn(768).reshape(shapes["0"].lens()).astype( + np.float16) + + r = p.run(params)[-1] + print(r) + + +def test_if_pl(): + p = migraphx.parse_onnx("if_pl_test.onnx") + print(p) + s1 = p.get_output_shapes()[-1] + print("Compiling ...") + p.compile(migraphx.get_target("gpu")) + print(p) + s2 = p.get_output_shapes()[-1] + assert s1 == s2 + + params = {} + shapes = p.get_parameter_shapes() + params["x"] = np.ones(6).reshape(shapes["x"].lens()).astype(np.float32) + params["y"] = np.array([2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0 + ]).reshape(shapes["y"].lens()).astype(np.float32) + params["cond"] = np.array([1]).reshape(()).astype(np.bool) + + r = p.run(params)[-1] + print(r) + + +test_conv_relu() +test_sub_uint64() +test_neg_int64() +test_fp16_imagescaler() +test_if_pl() +test_nonzero() diff --git a/test/py/test_gpu_offload.py b/test/py/test_gpu_offload.py index 3616a077ac38185651d5cd51f19f86c29e53909c..9388df250d028c68160d50df9c46a9431151304a 100644 --- a/test/py/test_gpu_offload.py +++ b/test/py/test_gpu_offload.py @@ -11,5 +11,5 @@ for key, value in p.get_parameter_shapes().items(): print("Parameter {} -> {}".format(key, value)) params[key] = migraphx.to_gpu(migraphx.generate_argument(value)) -r = migraphx.from_gpu(p.run(params)) +r = migraphx.from_gpu(p.run(params)[-1]) print(r) diff --git a/test/py/test_module_construct.py b/test/py/test_module_construct.py new file mode 100644 index 0000000000000000000000000000000000000000..1002062d018668b07a2c5d1c419456bdcb8385b0 --- /dev/null +++ b/test/py/test_module_construct.py @@ -0,0 +1,66 @@ +import migraphx, array, sys + + +def create_buffer(t, data, shape): + a = array.array(t, data) + m = memoryview(a.tobytes()) + return m.cast(t, shape) + + +def test_add_op(): + p = migraphx.program() + mm = p.get_main_module() + x = mm.add_literal(create_buffer('f', [1.0] * 9, (3, 3))) + y = mm.add_literal(create_buffer('f', [2.0] * 9, (3, 3))) + add_op = mm.add_instruction(migraphx.op("add"), [x, y]) + mm.add_return([add_op]) + p.compile(migraphx.get_target("ref")) + params = {} + output = p.run(params)[-1].tolist() + assert output == list([3.0] * 9) + + +def test_if_then_else(): + param_shape = migraphx.shape(lens=[3, 3], type="float") + cond_shape = migraphx.shape(type="bool", lens=[1], strides=[0]) + + def create_program(): + p = migraphx.program() + mm = p.get_main_module() + cond = mm.add_parameter("cond", cond_shape) + x = mm.add_parameter("x", param_shape) + y = mm.add_parameter("y", param_shape) + then_mod = p.create_module("If_0_if") + x_identity = then_mod.add_instruction(migraphx.op("identity"), [x]) + then_mod.add_return([x_identity]) + + else_mod = p.create_module("If_0_else") + y_identity = else_mod.add_instruction(migraphx.op("identity"), [y]) + else_mod.add_return([y_identity]) + + if_ins = mm.add_instruction(migraphx.op("if"), [cond], + [then_mod, else_mod]) + ret = mm.add_instruction(migraphx.op("get_tuple_elem", **{"index": 0}), + [if_ins]) + mm.add_return([ret]) + return p + + params = {} + params["x"] = migraphx.generate_argument(param_shape) + params["y"] = migraphx.generate_argument(param_shape) + + def run_prog(cond): + p = create_program() + p.compile(migraphx.get_target("ref")) + params["cond"] = migraphx.fill_argument(cond_shape, cond) + output = p.run(params)[-1] + return output + + assert run_prog(True) == params["x"] + assert run_prog(False) == params["y"] + + +if __name__ == "__main__": + if sys.version_info >= (3, 0): + test_add_op() + test_if_then_else() diff --git a/test/py/test_numpy.py b/test/py/test_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..a2cdbcc001452f36d72950ee5a28a083e6ad8381 --- /dev/null +++ b/test/py/test_numpy.py @@ -0,0 +1,22 @@ +import migraphx, sys +try: + import numpy as np +except: + sys.exit() + + +def test_add_op(): + p = migraphx.program() + mm = p.get_main_module() + x = mm.add_literal(np.ones((3, 3), dtype='float32')) + y = mm.add_literal(2 * np.ones((3, 3), dtype='float32')) + add_op = mm.add_instruction(migraphx.op("add"), [x, y]) + mm.add_return([add_op]) + p.compile(migraphx.get_target("ref")) + params = {} + output = p.run(params)[-1].tolist() + assert output == list(3 * np.ones((9), dtype='float32')) + + +if __name__ == "__main__": + test_add_op() diff --git a/test/py/test_op.py b/test/py/test_op.py new file mode 100755 index 0000000000000000000000000000000000000000..85bb757f58737eb17e82740e3d861c184d1349f8 --- /dev/null +++ b/test/py/test_op.py @@ -0,0 +1,19 @@ +import migraphx + + +def test_add_op(): + add_op = migraphx.op("add") + name = add_op.name() + + assert name == "add" + + +def test_reduce_mean(): + reduce_mean_op = migraphx.op("reduce_mean", **{"axes": [1, 2, 3, 4]}) + name = reduce_mean_op.name() + + assert name == "reduce_mean" + + +test_add_op() +test_reduce_mean() diff --git a/test/py/test_save_load.py b/test/py/test_save_load.py new file mode 100644 index 0000000000000000000000000000000000000000..dfac40c031be215b9507ebfe3368b492edc28508 --- /dev/null +++ b/test/py/test_save_load.py @@ -0,0 +1,22 @@ +import migraphx, tempfile + + +def test_conv_relu(format): + p1 = migraphx.parse_onnx("conv_relu_maxpool_test.onnx") + print(p1) + + s1 = p1.get_output_shapes()[-1] + + with tempfile.NamedTemporaryFile() as t: + migraphx.save(p1, t.name, format=format) + + p2 = migraphx.load(t.name, format=format) + print(p2) + s2 = p2.get_output_shapes()[-1] + + assert s1 == s2 + assert p1.sort() == p2.sort() + + +test_conv_relu('msgpack') +test_conv_relu('json') diff --git a/test/py/test_shape.py b/test/py/test_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..dd1c8f26c2b388b8c6168647201bad010994b843 --- /dev/null +++ b/test/py/test_shape.py @@ -0,0 +1,32 @@ +import migraphx + + +def test_create_shape(): + s = migraphx.shape(lens=[1, 64, 3, 3]) + assert s.standard() + assert s.packed() + assert s.lens() == [1, 64, 3, 3] + + +def test_create_shape_broadcast(): + s = migraphx.shape(lens=[1, 64, 3, 3], strides=[0, 1, 0, 0]) + assert s.broadcasted() + assert s.lens() == [1, 64, 3, 3] + assert s.strides() == [0, 1, 0, 0] + + +def test_create_shape_type(): + s = migraphx.shape(type='int64_t') + assert s.type_string() == 'int64_type' + assert s.type_size() == 8 + s = migraphx.shape(type='uint8_t') + assert s.type_string() == "uint8_type" + assert s.type_size() == 1 + s = migraphx.shape(type='float') + assert s.type_size() == 4 + + +if __name__ == "__main__": + test_create_shape() + test_create_shape_broadcast() + test_create_shape_type() diff --git a/test/quantization.cpp b/test/quantization.cpp index 6218977b90b91e4c66a10309a4268a6235ea1215..181dce5f9a6e3400a49db05b2290b4ed4cdcf423 100644 --- a/test/quantization.cpp +++ b/test/quantization.cpp @@ -4,37 +4,69 @@ #include #include #include -#include +#include #include +#include #include +#include +#include #include +#include +#include #include +#include #include #include +#include +#include +#include +#include +#include #include "test.hpp" #include +static void optimize_prog_int8(migraphx::program& prog) +{ + migraphx::run_passes(prog, + {migraphx::simplify_qdq{}, + migraphx::eliminate_common_subexpression{}, + migraphx::dead_code_elimination{}}); +} + TEST_CASE(param_add) { - auto create_program_float = [] { + auto create_program_float = [](bool add_return = false) { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - auto p1 = p.add_parameter("x", s); - auto p2 = p.add_parameter("y", s); - p.add_instruction(migraphx::op::add{}, p1, p2); + auto p1 = mm->add_parameter("x", s); + auto p2 = mm->add_parameter("y", s); + auto sum = mm->add_instruction(migraphx::make_op("add"), p1, p2); + if(add_return) + { + mm->add_return({sum}); + } return p; }; - auto create_program_half = [] { + auto create_program_half = [](bool add_return = false) { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - auto p1 = p.add_parameter("x", s); - auto hp1 = p.insert_instruction(std::next(p1), migraphx::op::convert{}, p1); - auto p2 = p.add_parameter("y", s); - auto hp2 = p.insert_instruction(std::next(p2), migraphx::op::convert{}, p2); - auto hs = p.add_instruction(migraphx::op::add{}, hp1, hp2); - p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hs); + auto p1 = mm->add_parameter("x", s); + auto p2 = mm->add_parameter("y", s); + auto hp1 = mm->add_instruction(migraphx::make_op("convert"), p1); + auto hp2 = mm->add_instruction(migraphx::make_op("convert"), p2); + auto hs = mm->add_instruction(migraphx::make_op("add"), hp1, hp2); + auto res = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), + hs); + if(add_return) + { + mm->add_return({res}); + } return p; }; @@ -54,71 +86,80 @@ TEST_CASE(param_add) migraphx::quantize_fp16(p1, {"add"}); EXPECT(p1 == p2); } + + { + auto p1 = create_program_float(true); + auto p2 = create_program_half(true); + + migraphx::quantize_fp16(p1); + EXPECT(p1 == p2); + } + + { + auto p1 = create_program_float(true); + auto p2 = create_program_half(true); + + migraphx::quantize_fp16(p1, {"add"}); + EXPECT(p1 == p2); + } } TEST_CASE(param_add_sub) { auto create_program_float = [] { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - auto p1 = p.add_parameter("x", s); - auto p2 = p.add_parameter("y", s); - auto sum = p.add_instruction(migraphx::op::add{}, p1, p2); - auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2); - p.add_instruction(migraphx::op::add{}, diff, p1); + auto p1 = mm->add_parameter("x", s); + auto p2 = mm->add_parameter("y", s); + auto sum = mm->add_instruction(migraphx::make_op("add"), p1, p2); + auto diff = mm->add_instruction(migraphx::make_op("sub"), sum, p2); + auto r = mm->add_instruction(migraphx::make_op("add"), diff, p1); + mm->add_return({r}); return p; }; auto create_program_half_add = [] { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - auto p1 = p.add_parameter("x", s); - auto hp1 = p.insert_instruction( - std::next(p1), migraphx::op::convert{migraphx::shape::half_type}, p1); - auto p2 = p.add_parameter("y", s); - auto hp2 = p.insert_instruction( - std::next(p2), migraphx::op::convert{migraphx::shape::half_type}, p2); - auto hsum = p.add_instruction(migraphx::op::add{}, hp1, hp2); - auto sum = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hsum); - auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2); - auto hdiff = p.add_instruction( - migraphx::op::convert{migraphx::op::convert{migraphx::shape::half_type}}, diff); - auto res = p.add_instruction(migraphx::op::add{}, hdiff, hp1); - p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, res); + auto p1 = mm->add_parameter("x", s); + auto p2 = mm->add_parameter("y", s); + auto hp1 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p1); + auto hp2 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p2); + auto hsum = mm->add_instruction(migraphx::make_op("add"), hp1, hp2); + auto sum = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), hsum); + auto diff = mm->add_instruction(migraphx::make_op("sub"), sum, p2); + auto hdiff = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), diff); + auto res = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1); + auto r = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), res); + mm->add_return({r}); return p; }; auto create_program_half_sub = [] { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - auto p1 = p.add_parameter("x", s); - auto p2 = p.add_parameter("y", s); - auto hp2 = p.insert_instruction( - std::next(p2), migraphx::op::convert{migraphx::shape::half_type}, p2); - auto sum = p.add_instruction(migraphx::op::add{}, p1, p2); - auto hsum = p.add_instruction(migraphx::op::convert{migraphx::shape::half_type}, sum); - auto hdiff = p.add_instruction(migraphx::op::sub{}, hsum, hp2); - auto diff = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hdiff); - p.add_instruction(migraphx::op::add{}, diff, p1); - - return p; - }; - - auto create_program_half_all = [] { - migraphx::program p; - migraphx::shape s{migraphx::shape::float_type, {2, 3}}; - auto p1 = p.add_parameter("x", s); - auto hp1 = p.insert_instruction( - std::next(p1), migraphx::op::convert{migraphx::shape::half_type}, p1); - auto p2 = p.add_parameter("y", s); - auto hp2 = p.insert_instruction( - std::next(p2), migraphx::op::convert{migraphx::shape::half_type}, p2); - auto hsum = p.add_instruction(migraphx::op::add{}, hp1, hp2); - auto hdiff = p.add_instruction(migraphx::op::sub{}, hsum, hp2); - auto hres = p.add_instruction(migraphx::op::add{}, hdiff, hp1); - p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hres); + auto p1 = mm->add_parameter("x", s); + auto p2 = mm->add_parameter("y", s); + auto sum = mm->add_instruction(migraphx::make_op("add"), p1, p2); + auto hsum = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), sum); + auto hp2 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p2); + auto hdiff = mm->add_instruction(migraphx::make_op("sub"), hsum, hp2); + auto diff = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), hdiff); + auto r = mm->add_instruction(migraphx::make_op("add"), diff, p1); + mm->add_return({r}); return p; }; @@ -136,17 +177,70 @@ TEST_CASE(param_add_sub) auto p2 = create_program_half_sub(); migraphx::quantize_fp16(p1, {"sub"}); + EXPECT(p1 == p2); } { - auto p1 = create_program_float(); - auto p2 = create_program_half_all(); + auto create_program_fp16 = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + auto p1 = mm->add_parameter("x", s); + auto p2 = mm->add_parameter("y", s); + auto hp1 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p1); + auto hp2 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p2); + auto hsum = mm->add_instruction(migraphx::make_op("add"), hp1, hp2); + auto sum = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), hsum); + auto hsum1 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), sum); + auto p3 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p2); + auto diff = mm->add_instruction(migraphx::make_op("sub"), hsum1, p3); + auto fdiff = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), diff); + auto hdiff1 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), fdiff); + auto p4 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p1); + auto res = mm->add_instruction(migraphx::make_op("add"), hdiff1, p4); + auto r = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), res); + mm->add_return({r}); + + return p; + }; + + auto create_program_quant_fp16 = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + auto p1 = mm->add_parameter("x", s); + auto p2 = mm->add_parameter("y", s); + auto hp1 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p1); + auto hp2 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), p2); + auto hsum = mm->add_instruction(migraphx::make_op("add"), hp1, hp2); + auto hdiff = mm->add_instruction(migraphx::make_op("sub"), hsum, hp2); + auto hres = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1); + auto r = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), hres); + mm->add_return({r}); + + return p; + }; + + auto p0 = create_program_float(); + migraphx::run_passes(p0, {migraphx::quantize_fp16_pass{{"all"}}}); + EXPECT(p0 == create_program_fp16()); + auto p1 = create_program_float(); migraphx::quantize_fp16(p1); - migraphx::run_passes(p1, {migraphx::dead_code_elimination{}}); - - EXPECT(p1 == p2); + EXPECT(p1 == create_program_quant_fp16()); } } @@ -154,25 +248,30 @@ TEST_CASE(literal_add) { auto create_program_float = [] { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::float_type, {2, 3}}; std::vector data(2 * 3); std::iota(data.begin(), data.end(), 1.0f); - auto l1 = p.add_literal(migraphx::literal(s, data)); - auto l2 = p.add_literal(migraphx::literal(s, data)); - p.add_instruction(migraphx::op::add{}, l1, l2); + auto l1 = mm->add_literal(migraphx::literal(s, data)); + auto l2 = mm->add_literal(migraphx::literal(s, data)); + mm->add_instruction(migraphx::make_op("add"), l1, l2); return p; }; auto create_program_half = [] { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::half_type, {2, 3}}; std::vector data(2 * 3); std::iota(data.begin(), data.end(), 1.0f); - auto l1 = p.add_literal(migraphx::literal(s, data)); - auto l2 = p.add_literal(migraphx::literal(s, data)); - auto hs = p.add_instruction(migraphx::op::add{}, l1, l2); - p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, hs); + auto l1 = mm->add_literal(migraphx::literal(s, data)); + auto l2 = mm->add_literal(migraphx::literal(s, data)); + auto hs = mm->add_instruction(migraphx::make_op("add"), l1, l2); + mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), + hs); return p; }; @@ -182,9 +281,9 @@ TEST_CASE(literal_add) auto p2 = create_program_half(); migraphx::quantize_fp16(p1, {"all"}); - migraphx::run_passes(p1, + migraphx::run_passes(*p1.get_main_module(), {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}}); - migraphx::run_passes(p2, + migraphx::run_passes(*p2.get_main_module(), {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}}); EXPECT(p1 == p2); @@ -195,665 +294,666 @@ TEST_CASE(literal_add) auto p2 = create_program_half(); migraphx::quantize_fp16(p1, {"add"}); - migraphx::run_passes(p1, + migraphx::run_passes(*p1.get_main_module(), {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}}); - migraphx::run_passes(p2, + migraphx::run_passes(*p2.get_main_module(), {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}}); EXPECT(p1 == p2); } } -TEST_CASE(op_capture) +TEST_CASE(fp16_subgraph) { - auto test_func = [&](std::size_t ins_index, const std::vector& args) { - (void)ins_index; - (void)args; + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sd{migraphx::shape::float_type, {1}}; + auto l1 = mm->add_literal(migraphx::literal(sd, {1})); + auto l2 = mm->add_literal(migraphx::literal(sd, {2})); + auto l3 = mm->add_literal(migraphx::literal(sd, {3})); + migraphx::shape sx{migraphx::shape::float_type, {1, 4}}; + migraphx::shape sy{migraphx::shape::float_type, {3, 4}}; + migraphx::shape sc{migraphx::shape::bool_type}; + auto cond = mm->add_parameter("cond", sc); + auto x = mm->add_parameter("x", sx); + auto y = mm->add_parameter("y", sy); + + auto* then_mod = p.create_module("If_6_if"); + auto m1 = then_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1); + auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1); + auto m2 = then_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2); + auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2); + auto mfp16 = then_mod->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), mul0); + then_mod->add_return({add0, mul0, mfp16}); + + auto* else_mod = p.create_module("If_6_else"); + auto me1 = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3); + auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1); + auto me2 = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3); + auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2); + auto afp16 = else_mod->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), add1); + else_mod->add_return({mul1, add1, afp16}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); + auto r16 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), ret); + mm->add_return({r0, r1, r16}); + + return p; + }; + + auto create_fp16_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sd{migraphx::shape::float_type, {1}}; + auto l1 = mm->add_literal(migraphx::literal(sd, {1})); + auto l2 = mm->add_literal(migraphx::literal(sd, {2})); + auto l3 = mm->add_literal(migraphx::literal(sd, {3})); + migraphx::shape sx{migraphx::shape::float_type, {1, 4}}; + migraphx::shape sy{migraphx::shape::float_type, {3, 4}}; + migraphx::shape sc{migraphx::shape::bool_type}; + auto cond = mm->add_parameter("cond", sc); + auto x = mm->add_parameter("x", sx); + auto y = mm->add_parameter("y", sy); + auto* then_mod = p.create_module("If_6_if"); + auto hl1 = then_mod->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l1); + auto mhl1 = then_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), hl1); + auto hx = then_mod->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), x); + auto ad = then_mod->add_instruction(migraphx::make_op("add"), hx, mhl1); + auto fad = then_mod->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), ad); + auto hl2 = then_mod->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l2); + auto mhl2 = then_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), hl2); + auto hy1 = then_mod->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), y); + auto mu = then_mod->add_instruction(migraphx::make_op("mul"), hy1, mhl2); + auto fmu = then_mod->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), mu); + then_mod->add_return({fad, fmu, mu}); + + auto* else_mod = p.create_module("If_6_else"); + auto hl3 = else_mod->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l3); + auto mhl3 = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), hl3); + auto hx2 = else_mod->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), x); + auto mu1 = else_mod->add_instruction(migraphx::make_op("mul"), hx2, mhl3); + auto fmu1 = else_mod->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), mu1); + auto mhl4 = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), hl3); + auto hy = else_mod->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), y); + auto ad1 = else_mod->add_instruction(migraphx::make_op("add"), hy, mhl4); + auto fad1 = else_mod->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), ad1); + else_mod->add_return({fmu1, fad1, ad1}); + + auto iff = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), iff); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), iff); + auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), iff); + mm->add_return({r0, r1, r2}); + + return p; }; + auto p1 = create_program(); + migraphx::quantize_fp16(p1); + + auto p2 = create_fp16_program(); + + EXPECT(p1 == p2); +} + +TEST_CASE(op_capture) +{ auto create_program_float = [] { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape s1{migraphx::shape::float_type, {3, 3}}; migraphx::shape s2{migraphx::shape::float_type, {3, 6}}; - auto p1 = p.add_parameter("x", s1); - auto p2 = p.add_parameter("y", s1); - auto pb = p.add_parameter("b", s2); - auto pc = p.add_parameter("c", s2); - auto pa = p.add_instruction(migraphx::op::add{}, p1, p2); - auto ps = p.add_instruction(migraphx::op::dot{}, pa, pb, pc); - p.add_instruction(migraphx::op::dot{}, pa, ps); + auto p1 = mm->add_parameter("x", s1); + auto p2 = mm->add_parameter("y", s1); + auto pb = mm->add_parameter("b", s2); + auto pc = mm->add_parameter("c", s2); + auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2); + auto ps = + migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f); + mm->add_instruction(migraphx::make_op("dot"), pa, ps); return p; }; auto create_program_op = [&] { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape s1{migraphx::shape::float_type, {3, 3}}; migraphx::shape s2{migraphx::shape::float_type, {3, 6}}; - auto p1 = p.add_parameter("x", s1); - auto p2 = p.add_parameter("y", s1); - auto pb = p.add_parameter("b", s2); - auto pc = p.add_parameter("c", s2); - auto pa = p.add_instruction(migraphx::op::add{}, p1, p2); - auto opb = p.insert_instruction(std::next(pb), migraphx::op::capture{1, test_func}, pb); - auto opc = p.insert_instruction(std::next(pc), migraphx::op::capture{2, test_func}, pc); - auto opa = p.add_instruction(migraphx::op::capture{0, test_func}, pa); - auto ps = p.add_instruction(migraphx::op::dot{}, opa, opb, opc); - auto ops = p.add_instruction(migraphx::op::capture{3, test_func}, ps); - p.add_instruction(migraphx::op::dot{}, opa, ops); + auto p1 = mm->add_parameter("x", s1); + auto p2 = mm->add_parameter("y", s1); + auto pb = mm->add_parameter("b", s2); + auto pc = mm->add_parameter("c", s2); + auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2); + auto opa = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 0}}), pa); + auto opb = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 1}}), pb); + auto ps = migraphx::add_apply_alpha_beta( + *mm, {opa, opb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f); + auto opm = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 2}}), pa); + auto ops = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 3}}), ps); + mm->add_instruction(migraphx::make_op("dot"), opm, ops); return p; }; { - auto p = create_program_float(); - auto op_capture_p = create_program_op(); - migraphx::target t = migraphx::cpu::target{}; - migraphx::capture_arguments(p, t, {"dot", "convolution"}); + auto p = create_program_float(); + auto op_capture_p = create_program_op(); + migraphx::target t = migraphx::ref::target{}; + std::size_t param_index = 0; + migraphx::run_passes( + p, {migraphx::capture_arguments_pass{{"dot", "convolution"}, {}, ¶m_index}}); EXPECT(p == op_capture_p); } } -TEST_CASE(dot_float) +TEST_CASE(op_capture_subgraph) { auto create_program = [] { migraphx::program p; - migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; - migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; - migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); - auto pc = p.add_parameter("c", sc); - - p.add_instruction(migraphx::op::dot{2.0f, 1.5f}, pa, pb, pc); + auto* mm = p.get_main_module(); + migraphx::shape sx{migraphx::shape::float_type, {2, 2, 4, 8}}; + migraphx::shape sy{migraphx::shape::float_type, {2, 2, 8, 6}}; + migraphx::shape sc{migraphx::shape::bool_type}; + auto cond = mm->add_parameter("cond", sc); + auto a = mm->add_parameter("a", sx); + auto b = mm->add_parameter("b", sy); + + migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}}; + migraphx::shape sw{migraphx::shape::float_type, {2, 2, 1, 1}}; + auto x = mm->add_parameter("x", sd); + auto w = mm->add_parameter("w", sw); + + auto* then_mod = p.create_module("If_6_if"); + auto out1 = then_mod->add_instruction(migraphx::make_op("dot"), a, b); + then_mod->add_return({out1}); + + auto* else_mod = p.create_module("If_6_else"); + auto out2 = else_mod->add_instruction(migraphx::make_op("convolution"), x, w); + else_mod->add_return({out2}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + mm->add_return({ret}); return p; }; - auto create_int8_quantized_prog = [] { + auto create_program_op = [&] { migraphx::program p; - migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; - migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; - migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); - auto pc = p.add_parameter("c", sc); - // quantize parameter a to int8 type, multiply the scale - std::vector vfa(sa.elements(), 0.1f); - auto fa = p.add_literal(migraphx::literal(sa, vfa)); - auto ma = p.add_instruction(migraphx::op::mul{}, fa, pa); - auto ra = p.add_instruction(migraphx::op::round{}, ma); - auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra); - auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca); - - // quantize parameter b to int8 type - auto insert_loc = std::next(pb); - std::vector vfb(sb.elements(), 0.1f); - auto fb = p.add_literal(migraphx::literal(sb, vfb)); - auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb); - auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb); - auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb); - auto qb = - p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb); - - auto qdot = p.add_instruction(migraphx::op::quant_dot{1, 0}, qa, qb); - auto fdot = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot); - std::vector v_alpha(fdot->get_shape().elements(), 200.0f); - auto new_alpha = p.add_literal(migraphx::literal(fdot->get_shape(), v_alpha)); - auto alpha_ab = p.add_instruction(migraphx::op::mul{}, new_alpha, fdot); - std::vector v_beta(pc->get_shape().elements(), 1.5f); - auto beta = p.add_literal(migraphx::literal(pc->get_shape(), v_beta)); - auto beta_c = p.add_instruction(migraphx::op::mul{}, beta, pc); - p.add_instruction(migraphx::op::add{}, alpha_ab, beta_c); + auto* mm = p.get_main_module(); + migraphx::shape sx{migraphx::shape::float_type, {2, 2, 4, 8}}; + migraphx::shape sy{migraphx::shape::float_type, {2, 2, 8, 6}}; + migraphx::shape sc{migraphx::shape::bool_type}; + auto cond = mm->add_parameter("cond", sc); + auto a = mm->add_parameter("a", sx); + auto b = mm->add_parameter("b", sy); + + migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}}; + migraphx::shape sw{migraphx::shape::float_type, {2, 2, 1, 1}}; + auto x = mm->add_parameter("x", sd); + auto w = mm->add_parameter("w", sw); + + auto* then_mod = p.create_module("If_6_if"); + auto ca = then_mod->add_instruction(migraphx::make_op("capture", {{"ins_index", 2}}), a); + auto cb = then_mod->add_instruction(migraphx::make_op("capture", {{"ins_index", 3}}), b); + auto out1 = then_mod->add_instruction(migraphx::make_op("dot"), ca, cb); + then_mod->add_return({out1}); + + auto* else_mod = p.create_module("If_6_else"); + auto cx = else_mod->add_instruction(migraphx::make_op("capture", {{"ins_index", 0}}), x); + auto cw = else_mod->add_instruction(migraphx::make_op("capture", {{"ins_index", 1}}), w); + auto out2 = else_mod->add_instruction(migraphx::make_op("convolution"), cx, cw); + else_mod->add_return({out2}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + mm->add_return({ret}); return p; }; - auto p = create_program(); - const std::vector>& quant_params{ - {0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; - migraphx::quantize_int8_impl(p, quant_params, {"dot"}); - migraphx::run_passes(p, {migraphx::dead_code_elimination{}}); - - auto qp = create_int8_quantized_prog(); + { + auto p = create_program(); + auto op_capture_p = create_program_op(); + migraphx::target t = migraphx::ref::target{}; + std::size_t param_index = 0; + migraphx::run_passes( + p, {migraphx::capture_arguments_pass{{"dot", "convolution"}, {}, ¶m_index}}); - EXPECT(p == qp); + EXPECT(p == op_capture_p); + } } -TEST_CASE(dot_double_2args) +TEST_CASE(dot_float) { auto create_program = [] { migraphx::program p; - migraphx::shape sa{migraphx::shape::double_type, {2, 16}}; - migraphx::shape sb{migraphx::shape::double_type, {16, 8}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); + auto* mm = p.get_main_module(); + migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; + migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; + migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; + auto pa = mm->add_parameter("a", sa); + auto pb = mm->add_parameter("b", sb); - p.add_instruction(migraphx::op::dot{2.0f, 1.5f}, pa, pb); + auto r = migraphx::add_apply_alpha_beta(*mm, {pa, pb}, migraphx::make_op("dot")); + mm->add_return({r}); return p; }; auto create_int8_quantized_prog = [] { migraphx::program p; - migraphx::shape sa{migraphx::shape::double_type, {2, 16}}; - migraphx::shape sb{migraphx::shape::double_type, {16, 8}}; - migraphx::shape sc{migraphx::shape::double_type, {2, 8}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); - // quantize parameter a to int8 type, multiply the scale - std::vector vfa(sa.elements(), 0.1f); - auto fpa = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa); - auto fa = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa)); - auto ma = p.add_instruction(migraphx::op::mul{}, fa, fpa); - auto ra = p.add_instruction(migraphx::op::round{}, ma); - auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra); - auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca); - - // quantize parameter b to int8 type - auto insert_loc = std::next(pb); - auto fpb = p.insert_instruction( - insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb); - std::vector vfb(sb.elements(), 0.1f); - auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb)); - auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, fpb); - auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb); - auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb); - auto qb = - p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb); - - auto qdot = p.add_instruction(migraphx::op::quant_dot{1, 0}, qa, qb); - auto fdot = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot); - std::vector v_alpha(fdot->get_shape().elements(), 200.0f); - auto new_alpha = p.add_literal(migraphx::literal(fdot->get_shape(), v_alpha)); - auto alpha_ab = p.add_instruction(migraphx::op::mul{}, new_alpha, fdot); - p.add_instruction(migraphx::op::convert{migraphx::shape::double_type}, alpha_ab); - - return p; - }; - - auto p = create_program(); - const std::vector>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; - migraphx::quantize_int8_impl(p, quant_params, {"dot"}); - auto qp = create_int8_quantized_prog(); - - EXPECT(p == qp); -} - -TEST_CASE(dot_large_alpha_beta_float) -{ - auto create_program = [] { - migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); - auto pc = p.add_parameter("c", sc); - - p.add_instruction(migraphx::op::dot{20.0f, 50.5f}, pa, pb, pc); + auto pa = mm->add_parameter("a", sa); + auto pb = mm->add_parameter("b", sb); + auto zp_a = mm->add_literal(static_cast(0)); + auto scale_a = mm->add_literal(10.0f); + scale_a = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a); + zp_a = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), + zp_a); + auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a); + auto dqa = mm->add_instruction(migraphx::make_op("dequantizelinear"), qa, scale_a, zp_a); + + auto zp_b = mm->add_literal(static_cast(0)); + auto scale_b = mm->add_literal(10.0f); + scale_b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b); + zp_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), + zp_b); + auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); + auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b); + + auto r = migraphx::add_apply_alpha_beta(*mm, {dqa, dqb}, migraphx::make_op("dot")); + mm->add_return({r}); return p; }; - auto create_int8_quantized_prog = [] { + auto create_int8_optimized_prog = [] { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); - auto pc = p.add_parameter("c", sc); - // quantize parameter a to int8 type, multiply the scale - std::vector vfa(sa.elements(), 0.1f); - auto fa = p.add_literal(migraphx::literal(sa, vfa)); - auto ma = p.add_instruction(migraphx::op::mul{}, fa, pa); - // add the shift - std::vector vsa(sa.elements(), 1.0f); - auto sfta = p.add_literal(migraphx::literal(sa, vsa)); - auto msa = p.add_instruction(migraphx::op::add{}, sfta, ma); - auto ra = p.add_instruction(migraphx::op::round{}, msa); - auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra); - auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca); - - // quantize parameter b to int8 type - auto insert_loc = std::next(pb); - std::vector vfb(sb.elements(), 0.1f); - auto fb = p.add_literal(migraphx::literal(sb, vfb)); - auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb); - auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb); - auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb); - auto qb = - p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb); - - // quantize parameter c to int32 type - auto qc = p.insert_instruction( - std::next(pc), migraphx::op::convert{migraphx::shape::int32_type}, pc); - - auto qdot = p.add_instruction(migraphx::op::quant_dot{2000, 51}, qa, qb, qc); - p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot); + auto pa = mm->add_parameter("a", sa); + auto pb = mm->add_parameter("b", sb); + auto zp = mm->add_literal(static_cast(0)); + auto scale = mm->add_literal(10.0f); + auto scale_a = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale); + auto zp_a = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp); + auto quant_a = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a); + auto scale_b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale); + auto zp_b = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp); + auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); + auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b); + std::vector vec(sc.elements(), 100.0f); + auto dc = mm->add_literal(100.0f); + auto mdc = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), dc); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, mdc); + mm->add_return({r}); return p; }; - auto p = create_program(); - const std::vector>& quant_params{ - {0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; - migraphx::quantize_int8_impl(p, quant_params, {"dot"}); + const std::vector> quant_params = { + {0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; + auto p = create_program(); + std::size_t param_index = 0; + migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, ¶m_index}}); + migraphx::run_passes( + p, + {migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}}); auto qp = create_int8_quantized_prog(); EXPECT(p == qp); + + optimize_prog_int8(p); + auto op = create_int8_optimized_prog(); + EXPECT(p == op); } -TEST_CASE(dot_large_alpha_beta_int32) +TEST_CASE(dot_double_2args) { auto create_program = [] { migraphx::program p; - migraphx::shape sa{migraphx::shape::int32_type, {2, 16}}; - migraphx::shape sb{migraphx::shape::int32_type, {16, 8}}; - migraphx::shape sc{migraphx::shape::int32_type, {2, 8}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); - auto pc = p.add_parameter("c", sc); - - p.add_instruction(migraphx::op::dot{20.0f, 50.0f}, pa, pb, pc); + auto* mm = p.get_main_module(); + migraphx::shape sa{migraphx::shape::double_type, {2, 16}}; + migraphx::shape sb{migraphx::shape::double_type, {16, 8}}; + auto pa = mm->add_parameter("a", sa); + auto pb = mm->add_parameter("b", sb); + auto r = migraphx::add_apply_alpha_beta(*mm, {pa, pb}, migraphx::make_op("dot")); + mm->add_return({r}); return p; }; auto create_int8_quantized_prog = [] { migraphx::program p; - migraphx::shape sa{migraphx::shape::int32_type, {2, 16}}; - migraphx::shape sb{migraphx::shape::int32_type, {16, 8}}; - migraphx::shape sc{migraphx::shape::int32_type, {2, 8}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); - auto pc = p.add_parameter("c", sc); - // quantize parameter a to int8 type, multiply the scale - std::vector vfa(sa.elements(), 0.1f); - auto fa = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa)); - auto conv_a = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa); - auto ma = p.add_instruction(migraphx::op::mul{}, fa, conv_a); - - // add the shift - std::vector vsa(sa.elements(), 1.0f); - auto sfta = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa)); - auto msa = p.add_instruction(migraphx::op::add{}, sfta, ma); - auto ra = p.add_instruction(migraphx::op::round{}, msa); - auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra); - auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca); - - // quantize parameter b to int8 type - auto insert_loc = std::next(pb); - std::vector vfb(sb.elements(), 0.1f); - auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb)); - auto conv_b = p.insert_instruction( - insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb); - auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, conv_b); - auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb); - auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb); - auto qb = - p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb); - - p.add_instruction(migraphx::op::quant_dot{2000, 50}, qa, qb, pc); - + auto* mm = p.get_main_module(); + migraphx::shape sa{migraphx::shape::double_type, {2, 16}}; + migraphx::shape sb{migraphx::shape::double_type, {16, 8}}; + auto pa = mm->add_parameter("a", sa); + auto pb = mm->add_parameter("b", sb); + + auto zp_a = mm->add_literal(static_cast(0)); + auto scale_a = mm->add_literal(10.0); + scale_a = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a); + zp_a = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), + zp_a); + auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a); + auto dqa = mm->add_instruction(migraphx::make_op("dequantizelinear"), qa, scale_a, zp_a); + auto zp_b = mm->add_literal(static_cast(0)); + auto scale_b = mm->add_literal(5.0); + scale_b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b); + zp_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), + zp_b); + auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); + auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b); + auto r = migraphx::add_apply_alpha_beta(*mm, {dqa, dqb}, migraphx::make_op("dot")); + mm->add_return({r}); return p; }; - auto p = create_program(); - const std::vector>& quant_params{ - {0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; - migraphx::quantize_int8_impl(p, quant_params, {"dot"}); - auto qp = create_int8_quantized_prog(); - - EXPECT(p == qp); -} - -TEST_CASE(dot_int32_one_arg) -{ - auto create_program = [] { + auto create_int8_optimized_prog = [] { migraphx::program p; - migraphx::shape s{migraphx::shape::int32_type, {16, 16}}; - auto pa = p.add_parameter("a", s); - - p.add_instruction(migraphx::op::dot{20.0f, 50.0f}, pa, pa); - - return p; - }; - - auto create_int8_quantized_prog = [] { - migraphx::program p; - migraphx::shape s{migraphx::shape::int32_type, {16, 16}}; - auto pa = p.add_parameter("a", s); - - // add the shift - auto fpa = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa); - std::vector vsa(s.elements(), 1.0f); - auto sfta = p.add_literal(migraphx::literal({migraphx::shape::float_type, s.lens()}, vsa)); - auto msa = p.add_instruction(migraphx::op::add{}, sfta, fpa); - auto ra = p.add_instruction(migraphx::op::round{}, msa); - auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra); - auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca); - - auto q_dot = p.add_instruction(migraphx::op::quant_dot{1, 0}, qa, qa); - auto f_dot = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, q_dot); - std::vector v_alpha(f_dot->get_shape().elements(), 20.0f); - auto new_alpha = p.add_literal(migraphx::literal{f_dot->get_shape(), v_alpha}); - auto alpha_ab = p.add_instruction(migraphx::op::mul{}, new_alpha, f_dot); - p.add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, alpha_ab); - + auto* mm = p.get_main_module(); + migraphx::shape sa{migraphx::shape::double_type, {2, 16}}; + migraphx::shape sb{migraphx::shape::double_type, {16, 8}}; + auto pa = mm->add_parameter("a", sa); + auto pb = mm->add_parameter("b", sb); + + auto scale_a = mm->add_literal(10.0); + auto zp = mm->add_literal(static_cast(0)); + scale_a = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a); + auto zp_a = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp); + auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a); + auto scale_b = mm->add_literal(5.0); + scale_b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b); + auto zp_b = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp); + auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); + auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb); + auto scale = mm->add_literal(50.0); + scale = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, scale); + mm->add_return({r}); return p; }; auto p = create_program(); - const std::vector>& quant_params{{1.0f, 1.0f}}; - migraphx::quantize_int8_impl(p, quant_params, {"dot"}); - auto qp = create_int8_quantized_prog(); - - EXPECT(p == qp); + const std::vector>& quant_params{{0.1f, 0.0f}, {0.2f, 0.0f}}; + std::size_t param_index = 0; + migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, ¶m_index}}); + migraphx::run_passes( + p, + {migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}}); + EXPECT(p == create_int8_quantized_prog()); + + optimize_prog_int8(p); + EXPECT(p == create_int8_optimized_prog()); } -TEST_CASE(dot_int32) +TEST_CASE(dot_half_1arg) { auto create_program = [] { migraphx::program p; - migraphx::shape sa{migraphx::shape::int32_type, {2, 16}}; - migraphx::shape sb{migraphx::shape::int32_type, {16, 8}}; - migraphx::shape sc{migraphx::shape::int32_type, {2, 8}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); - auto pc = p.add_parameter("c", sc); - - p.add_instruction(migraphx::op::dot{2.0f, 5.5f}, pa, pb, pc); + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::half_type, {9, 9}}; + auto x = mm->add_parameter("x", s); + auto r = mm->add_instruction(migraphx::make_op("dot"), x, x); + mm->add_return({r}); return p; }; auto create_int8_quantized_prog = [] { migraphx::program p; - migraphx::shape sa{migraphx::shape::int32_type, {2, 16}}; - migraphx::shape sb{migraphx::shape::int32_type, {16, 8}}; - migraphx::shape sc{migraphx::shape::int32_type, {2, 8}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); - auto pc = p.add_parameter("c", sc); - // quantize parameter a to int8 type, multiply the scale - std::vector vfa(sa.elements(), 0.1f); - auto fa = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa)); - auto conv_a = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa); - auto ma = p.add_instruction(migraphx::op::mul{}, fa, conv_a); - - // add the shift - std::vector vsa(sa.elements(), 1.0f); - auto sfta = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vsa)); - auto msa = p.add_instruction(migraphx::op::add{}, sfta, ma); - auto ra = p.add_instruction(migraphx::op::round{}, msa); - auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra); - auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca); - - // quantize parameter b to int8 type - auto insert_loc = std::next(pb); - std::vector vfb(sb.elements(), 0.1f); - auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb)); - auto conv_b = p.insert_instruction( - insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb); - auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, conv_b); - auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb); - auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb); - auto qb = - p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb); - - auto qdot = p.add_instruction(migraphx::op::quant_dot{1, 0}, qa, qb); - auto fr = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot); - std::vector v_alpha(fr->get_shape().elements(), 20.0f); - auto new_alpha = p.add_literal(migraphx::literal(fr->get_shape(), v_alpha)); - auto alpha_ab = p.add_instruction(migraphx::op::mul{}, new_alpha, fr); - auto fc = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pc); - std::vector v_beta(fc->get_shape().elements(), 5.5f); - auto beta = p.add_literal(migraphx::literal(fc->get_shape(), v_beta)); - auto beta_c = p.add_instruction(migraphx::op::mul{}, beta, fc); - auto f_res = p.add_instruction(migraphx::op::add{}, alpha_ab, beta_c); - p.add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, f_res); - - return p; - }; - - auto p = create_program(); - const std::vector>& quant_params{ - {0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; - migraphx::quantize_int8_impl(p, quant_params, {"dot"}); - auto qp = create_int8_quantized_prog(); - - EXPECT(p == qp); -} - -TEST_CASE(dot_float_convert) -{ - auto create_program = [] { - migraphx::program p; - migraphx::shape sa{migraphx::shape::int8_type, {2, 16}}; - migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); - - auto fpa = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa); - p.add_instruction(migraphx::op::dot{2.0f, 5.5f}, fpa, pb); - + auto* mm = p.get_main_module(); + migraphx::shape sa{migraphx::shape::half_type, {9, 9}}; + auto x = mm->add_parameter("x", sa); + + auto zp_a = mm->add_literal(static_cast(0)); + auto scale_a = mm->add_literal(migraphx::literal({sa.type()}, {10.0})); + scale_a = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a); + zp_a = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), + zp_a); + auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_a, zp_a); + auto dqa = mm->add_instruction(migraphx::make_op("dequantizelinear"), qa, scale_a, zp_a); + auto zp_b = mm->add_literal(static_cast(0)); + auto scale_b = mm->add_literal(migraphx::literal({sa.type()}, {10.0})); + scale_b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_b); + zp_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), + zp_b); + auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_b, zp_b); + auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b); + auto r = mm->add_instruction(migraphx::make_op("dot"), dqa, dqb); + mm->add_return({r}); return p; }; - auto create_int8_quantized_prog = [] { + auto create_int8_optimized_prog = [] { migraphx::program p; - migraphx::shape sa{migraphx::shape::int8_type, {2, 16}}; - migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); - - // quantize parameter b to int8 type - auto insert_loc = std::next(pb); - std::vector vfb(sb.elements(), 0.1f); - auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb)); - auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb); - auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb); - auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb); - auto qb = - p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb); - - auto qdot = p.add_instruction(migraphx::op::quant_dot{1, 0}, pa, qb); - auto fr = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot); - std::vector v_alpha(fr->get_shape().elements(), 10.0f); - auto new_alpha = p.add_literal(migraphx::literal(fr->get_shape(), v_alpha)); - p.add_instruction(migraphx::op::mul{}, new_alpha, fr); - + auto* mm = p.get_main_module(); + migraphx::shape sa{migraphx::shape::half_type, {9, 9}}; + auto x = mm->add_parameter("x", sa); + + auto zp = mm->add_literal(static_cast(0)); + auto scale = mm->add_literal(migraphx::literal({sa.type()}, {10.0})); + scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), + scale); + zp = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp); + auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp); + auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx); + auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0})); + dq_scale = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), + dq_scale); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, dq_scale); + mm->add_return({r}); return p; }; auto p = create_program(); - const std::vector>& quant_params{{0.1f, 1.0f}, {0.1f, 0.0f}}; - migraphx::quantize_int8_impl(p, quant_params, {"dot"}); - migraphx::run_passes(p, {migraphx::dead_code_elimination{}}); - auto qp = create_int8_quantized_prog(); - - EXPECT(p == qp); + const std::vector>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; + std::size_t param_index = 0; + migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, ¶m_index}}); + migraphx::run_passes( + p, + {migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}}); + EXPECT(p == create_int8_quantized_prog()); + + optimize_prog_int8(p); + EXPECT(p == create_int8_optimized_prog()); } TEST_CASE(conv_float) { auto create_program = [] { migraphx::program p; + auto* mm = p.get_main_module(); auto input = - p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto weights = - p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - p.add_instruction(migraphx::op::convolution{}, input, weights); + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto r = mm->add_instruction(migraphx::make_op("convolution"), input, weights); + mm->add_return({r}); return p; }; auto create_int8_quantized_prog = [] { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape sx{migraphx::shape::float_type, {4, 3, 3, 3}}; migraphx::shape sw{migraphx::shape::float_type, {4, 3, 3, 3}}; - auto px = p.add_parameter("x", sx); - auto pw = p.add_parameter("w", sw); - // quantize parameter a to int8 type, multiply the scale - std::vector vfx(sx.elements(), 0.1f); - auto fx = p.add_literal(migraphx::literal(sx, vfx)); - auto mx = p.add_instruction(migraphx::op::mul{}, fx, px); - auto rx = p.add_instruction(migraphx::op::round{}, mx); - auto cx = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, rx); - auto qx = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, cx); - - // quantize parameter b to int8 type - auto insert_loc = std::next(pw); - std::vector vfw(sw.elements(), 0.1f); - auto fw = p.add_literal(migraphx::literal(sw, vfw)); - auto mw = p.insert_instruction(insert_loc, migraphx::op::mul{}, fw, pw); - auto rw = p.insert_instruction(insert_loc, migraphx::op::round{}, mw); - auto cw = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rw); - auto qw = - p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cw); - - auto q_conv = p.add_instruction(migraphx::op::quant_convolution{}, qx, qw); - auto f_conv = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, q_conv); - std::vector v_adj(f_conv->get_shape().elements(), 100.0f); - auto adj = p.add_literal(migraphx::literal(f_conv->get_shape(), v_adj)); - p.add_instruction(migraphx::op::mul{}, adj, f_conv); + auto px = mm->add_parameter("x", sx); + auto pw = mm->add_parameter("w", sw); + + auto zp = mm->add_literal(static_cast(0)); + auto scale = mm->add_literal(10.0f); + scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), + scale); + zp = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp); + auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp); + auto quant_w = mm->add_instruction(migraphx::make_op("quantizelinear"), pw, scale, zp); + + auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w); + + migraphx::shape sc{migraphx::shape::float_type, {4, 4, 1, 1}}; + std::vector vec(sc.elements(), 100.0f); + migraphx::shape s_scale{migraphx::shape::float_type, sc.lens()}; + auto d_scale = mm->add_literal(100.0f); + d_scale = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale); + mm->add_return({r}); return p; }; auto p = create_program(); const std::vector>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; - migraphx::quantize_int8_impl(p, quant_params, {"convolution"}); + std::size_t param_index = 0; + migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, ¶m_index}}); + migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"convolution"}, quant_params}}); + optimize_prog_int8(p); auto qp = create_int8_quantized_prog(); EXPECT(p == qp); } -TEST_CASE(conv_int32) +TEST_CASE(conv_float_throw) { auto create_program = [] { migraphx::program p; + auto* mm = p.get_main_module(); auto input = - p.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {4, 3, 3, 3}}); + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto weights = - p.add_parameter("w", migraphx::shape{migraphx::shape::int32_type, {4, 3, 3, 3}}); - p.add_instruction(migraphx::op::convolution{}, input, weights); - - return p; - }; - - auto create_int8_quantized_prog = [] { - migraphx::program p; - migraphx::shape sx{migraphx::shape::int32_type, {4, 3, 3, 3}}; - migraphx::shape sw{migraphx::shape::int32_type, {4, 3, 3, 3}}; - auto px = p.add_parameter("x", sx); - auto pw = p.add_parameter("w", sw); - // quantize parameter a to int8 type, multiply the scale - auto fpx = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, px); - std::vector vfx(sx.elements(), 0.1f); - auto fx = p.add_literal(migraphx::literal(fpx->get_shape(), vfx)); - auto mx = p.add_instruction(migraphx::op::mul{}, fx, fpx); - auto rx = p.add_instruction(migraphx::op::round{}, mx); - auto cx = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, rx); - auto qx = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, cx); - - // quantize parameter b to int8 type - auto insert_loc = std::next(pw); - auto fpw = p.insert_instruction( - insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pw); - std::vector vfw(sw.elements(), 0.1f); - auto fw = p.add_literal(migraphx::literal(fpw->get_shape(), vfw)); - auto mw = p.insert_instruction(insert_loc, migraphx::op::mul{}, fw, fpw); - auto rw = p.insert_instruction(insert_loc, migraphx::op::round{}, mw); - auto cw = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rw); - auto qw = - p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cw); - - auto q_conv = p.add_instruction(migraphx::op::quant_convolution{}, qx, qw); - std::vector v_adj(q_conv->get_shape().elements(), 100.0f); - auto adj = p.add_literal(migraphx::literal(q_conv->get_shape(), v_adj)); - p.add_instruction(migraphx::op::mul{}, q_conv, adj); + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto r = mm->add_instruction(migraphx::make_op("convolution"), input, weights); + mm->add_return({r}); return p; }; auto p = create_program(); const std::vector>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; - migraphx::quantize_int8_impl(p, quant_params, {"convolution"}); - auto qp = create_int8_quantized_prog(); - - EXPECT(p == qp); + test::throws([&] { + migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"add"}, quant_params}}); + }); } TEST_CASE(conv_half) { auto create_program = [] { migraphx::program p; + auto* mm = p.get_main_module(); auto input = - p.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); + mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); auto weights = - p.add_parameter("w", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); - p.add_instruction(migraphx::op::convolution{}, input, weights); + mm->add_parameter("w", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); + auto r = mm->add_instruction(migraphx::make_op("convolution"), input, weights); + mm->add_return({r}); return p; }; auto create_int8_quantized_prog = [] { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape sx{migraphx::shape::half_type, {4, 3, 3, 3}}; migraphx::shape sw{migraphx::shape::half_type, {4, 3, 3, 3}}; - auto px = p.add_parameter("x", sx); - auto pw = p.add_parameter("w", sw); - // quantize parameter a to int8 type, multiply the scale - auto fpx = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, px); - std::vector vfx(sx.elements(), 0.1f); - auto fx = p.add_literal(migraphx::literal(fpx->get_shape(), vfx)); - auto mx = p.add_instruction(migraphx::op::mul{}, fx, fpx); - auto rx = p.add_instruction(migraphx::op::round{}, mx); - auto cx = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, rx); - auto qx = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, cx); - - // quantize parameter b to int8 type - auto insert_loc = std::next(pw); - auto fpw = p.insert_instruction( - insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pw); - std::vector vfw(sw.elements(), 0.1f); - auto fw = p.add_literal(migraphx::literal(fpw->get_shape(), vfw)); - auto mw = p.insert_instruction(insert_loc, migraphx::op::mul{}, fw, fpw); - auto rw = p.insert_instruction(insert_loc, migraphx::op::round{}, mw); - auto cw = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rw); - auto qw = - p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cw); - - auto q_conv = p.add_instruction(migraphx::op::quant_convolution{}, qx, qw); - auto f_conv = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, q_conv); - std::vector v_adj(f_conv->get_shape().elements(), 100.0f); - auto adj = p.add_literal(migraphx::literal(f_conv->get_shape(), v_adj)); - auto f_res = p.add_instruction(migraphx::op::mul{}, adj, f_conv); - p.add_instruction(migraphx::op::convert{migraphx::shape::half_type}, f_res); + auto px = mm->add_parameter("x", sx); + auto pw = mm->add_parameter("w", sw); + + auto zp = mm->add_literal(static_cast(0)); + auto scale = mm->add_literal(migraphx::literal({sx.type()}, {10.0})); + scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), + scale); + zp = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp); + auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp); + auto quant_w = mm->add_instruction(migraphx::make_op("quantizelinear"), pw, scale, zp); + + auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w); + auto d_scale = mm->add_literal(migraphx::literal({sx.type()}, {100.0})); + d_scale = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale); + mm->add_return({r}); return p; }; auto p = create_program(); const std::vector>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; - migraphx::quantize_int8_impl(p, quant_params, {"convolution"}); + std::size_t param_index = 0; + migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, ¶m_index}}); + migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"convolution"}, quant_params}}); + optimize_prog_int8(p); auto qp = create_int8_quantized_prog(); EXPECT(p == qp); } +template +auto get_hash(const T& x) +{ + return std::hash{}(x); +} + TEST_CASE(target_copy) { auto run_prog = [](migraphx::program p, const migraphx::target& t, - migraphx::program::parameter_map& m_in, + migraphx::parameter_map& m_in, std::vector& res) { p.compile(t); - migraphx::program::parameter_map m; + migraphx::parameter_map m; for(auto&& x : p.get_parameter_shapes()) { if(m_in.count(x.first) > 0) @@ -866,33 +966,34 @@ TEST_CASE(target_copy) } } - auto result = t.copy_from(p.eval(m)); + auto result = t.copy_from(p.eval(m).back()); result.visit([&](auto v) { res.assign(v.begin(), v.end()); }); }; auto create_program = [] { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::float_type, {3, 3}}; - auto p1 = p.add_parameter("x", s); - auto p2 = p.add_parameter("y", s); - p.add_instruction(migraphx::op::add{}, p1, p2); + auto p1 = mm->add_parameter("x", s); + auto p2 = mm->add_parameter("y", s); + mm->add_instruction(migraphx::make_op("add"), p1, p2); return p; }; { auto p = create_program(); - migraphx::program::parameter_map m; + migraphx::parameter_map m; migraphx::shape s{migraphx::shape::float_type, {3, 3}}; m["x"] = migraphx::generate_argument(s); - std::vector cpu_result; - migraphx::target cpu_t = migraphx::cpu::target{}; - run_prog(p, cpu_t, m, cpu_result); + std::vector ref_result; + migraphx::target ref_t = migraphx::ref::target{}; + run_prog(p, ref_t, m, ref_result); std::vector orig_result; - run_prog(p, cpu_t, m, orig_result); + run_prog(p, ref_t, m, orig_result); - EXPECT(migraphx::verify_range(cpu_result, orig_result)); + EXPECT(migraphx::verify_range(ref_result, orig_result)); } } @@ -900,17 +1001,17 @@ TEST_CASE(int8_quantization_dot) { auto run_prog = [](migraphx::program p, const migraphx::target& t, - migraphx::program::parameter_map& m_in, + migraphx::parameter_map& m_in, std::vector& res, bool b_quantize = false) { if(b_quantize) { - std::vector cali_data; + std::vector cali_data; cali_data.push_back(m_in); migraphx::quantize_int8(p, t, cali_data); } p.compile(t); - migraphx::program::parameter_map m; + migraphx::parameter_map m; for(auto&& x : p.get_parameter_shapes()) { if(m_in.count(x.first) > 0) @@ -923,38 +1024,40 @@ TEST_CASE(int8_quantization_dot) } } - auto result = t.copy_from(p.eval(m)); + auto result = t.copy_from(p.eval(m).back()); result.visit([&](auto v) { res.assign(v.begin(), v.end()); }); }; auto create_program = [] { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; - auto pa = p.add_parameter("a", sa); - auto pb = p.add_parameter("b", sb); - auto pc = p.add_parameter("c", sc); - p.add_instruction(migraphx::op::dot{}, pa, pb, pc); - + auto pa = mm->add_parameter("a", sa); + auto pb = mm->add_parameter("b", sb); + auto pc = mm->add_parameter("c", sc); + auto r = + migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f); + mm->add_return({r}); return p; }; { auto p = create_program(); - migraphx::program::parameter_map m; + migraphx::parameter_map m; migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; - migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; - m["a"] = migraphx::generate_argument(sa); - m["c"] = migraphx::generate_argument(sc); + migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; + m["a"] = migraphx::generate_argument(sa, get_hash(std::string("a"))); + m["b"] = migraphx::generate_argument(sb, get_hash(std::string("b"))); std::vector quant_result; - migraphx::target cpu_t = migraphx::cpu::target{}; - run_prog(p, cpu_t, m, quant_result, true); + migraphx::target ref_t = migraphx::ref::target{}; + run_prog(p, ref_t, m, quant_result, true); std::vector no_quant_result; - run_prog(p, cpu_t, m, no_quant_result); + run_prog(p, ref_t, m, no_quant_result); - EXPECT(migraphx::verify_range(quant_result, no_quant_result)); + EXPECT(migraphx::verify_range(quant_result, no_quant_result, 30000)); } } @@ -966,24 +1069,26 @@ TEST_CASE(int8_quantization_conv) bool b_quantize = false) { if(b_quantize) { - std::vector cali_data; + std::vector cali_data; migraphx::quantize_int8(p, t, cali_data); } p.compile(t); - migraphx::program::parameter_map m; + migraphx::parameter_map m; - auto result = t.copy_from(p.eval(m)); + auto result = t.copy_from(p.eval(m).back()); result.visit([&](auto v) { res.assign(v.begin(), v.end()); }); }; auto create_program = [] { migraphx::program p; + auto* mm = p.get_main_module(); migraphx::shape sx{migraphx::shape::float_type, {4, 2, 2, 2}}; migraphx::shape sw{migraphx::shape::float_type, {4, 2, 2, 2}}; std::vector v(sx.elements(), 0.5f); - auto input = p.add_literal(migraphx::literal(sx, v)); - auto weights = p.add_literal(migraphx::literal(sw, v)); - p.add_instruction(migraphx::op::convolution{}, input, weights); + auto input = mm->add_literal(migraphx::literal(sx, v)); + auto weights = mm->add_literal(migraphx::literal(sw, v)); + auto r = mm->add_instruction(migraphx::make_op("convolution"), input, weights); + mm->add_return({r}); return p; }; @@ -991,14 +1096,165 @@ TEST_CASE(int8_quantization_conv) { auto p = create_program(); std::vector quant_result; - migraphx::target cpu_t = migraphx::cpu::target{}; - run_prog(p, cpu_t, quant_result, true); + migraphx::target ref_t = migraphx::ref::target{}; + run_prog(p, ref_t, quant_result, true); std::vector no_quant_result; - run_prog(p, cpu_t, no_quant_result); + run_prog(p, ref_t, no_quant_result); EXPECT(migraphx::verify_range(quant_result, no_quant_result)); } } +TEST_CASE(int8_subgraph) +{ + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sx{migraphx::shape::float_type, {2, 2, 4, 8}}; + migraphx::shape sy{migraphx::shape::float_type, {2, 2, 8, 6}}; + migraphx::shape sc{migraphx::shape::bool_type}; + auto cond = mm->add_parameter("cond", sc); + auto a = mm->add_parameter("a", sx); + auto b = mm->add_parameter("b", sy); + + migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}}; + migraphx::shape sw{migraphx::shape::float_type, {2, 2, 1, 1}}; + auto x = mm->add_parameter("x", sd); + auto w = mm->add_parameter("w", sw); + + auto* then_mod = p.create_module("If_6_if"); + auto out1 = migraphx::add_apply_alpha_beta(*then_mod, {a, b}, migraphx::make_op("dot")); + then_mod->add_return({out1}); + + auto* else_mod = p.create_module("If_6_else"); + auto out2 = else_mod->add_instruction(migraphx::make_op("convolution"), x, w); + else_mod->add_return({out2}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + mm->add_return({ret}); + + return p; + }; + + auto create_int8_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape sx{migraphx::shape::float_type, {2, 2, 4, 8}}; + migraphx::shape sy{migraphx::shape::float_type, {2, 2, 8, 6}}; + migraphx::shape sout{migraphx::shape::float_type, {2, 2, 4, 6}}; + migraphx::shape sc{migraphx::shape::bool_type}; + auto cond = mm->add_parameter("cond", sc); + auto a = mm->add_parameter("a", sx); + auto b = mm->add_parameter("b", sy); + + // then submod + auto* then_mod = p.create_module("If_6_if"); + auto zp1 = then_mod->add_literal(static_cast(0)); + auto s1 = then_mod->add_literal(10.0f); + auto sa = then_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), s1); + auto zpa = then_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp1); + auto qa = then_mod->add_instruction(migraphx::make_op("quantizelinear"), a, sa, zpa); + auto sb = then_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), s1); + auto zpb = then_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1); + auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb); + auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb); + auto so = then_mod->add_literal(100.0f); + so = then_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so); + auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so); + then_mod->add_return({r}); + + migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}}; + migraphx::shape sw{migraphx::shape::float_type, {2, 2, 1, 1}}; + auto x = mm->add_parameter("x", sd); + auto w = mm->add_parameter("w", sw); + // else submod + auto* else_mod = p.create_module("If_6_else"); + auto sax = else_mod->add_literal(2.0f); + auto zp = else_mod->add_literal(static_cast(0)); + sax = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax); + auto zpx = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), zp); + auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx); + auto ssw = else_mod->add_literal(1.66667f); + ssw = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw); + auto zpw = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), zp); + auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw); + auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw); + auto so1 = else_mod->add_literal(3.33333f); + so1 = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so1); + auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1); + else_mod->add_return({r1}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + mm->add_return({ret}); + + return p; + }; + + auto p1 = create_program(); + const std::vector>& quant_params{ + {0.5f, 0.0f}, {0.6f, 0.0f}, {0.1f, 0.0f}, {0.1f, 0.0f}}; + std::size_t param_index = 0; + migraphx::run_passes( + p1, {migraphx::capture_arguments_pass{{"convolution", "dot"}, {}, ¶m_index}}); + migraphx::run_passes(p1, {migraphx::quantize_int8_pass{{"convolution", "dot"}, quant_params}}); + optimize_prog_int8(p1); + + auto p2 = create_int8_program(); + EXPECT(p1 == p2); +} + +TEST_CASE(test_op_capture) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s1{migraphx::shape::float_type, {3, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 6}}; + std::vector d1(s1.elements()); + std::vector d2(s2.elements()); + std::iota(d1.begin(), d1.end(), 0.0f); + std::iota(d2.begin(), d2.end(), 0.0f); + + auto p1 = mm->add_literal(s1, d1); + auto p2 = mm->add_literal(s1, d1); + auto pb = mm->add_literal(s2, d2); + auto pc = mm->add_literal(s2, d2); + auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2); + auto ps = + migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f); + mm->add_instruction(migraphx::make_op("dot"), pa, ps); + + auto calc = [](std::size_t, const std::vector&) {}; + + migraphx::program capture_p = p; + migraphx::target t = migraphx::ref::target{}; + std::size_t param_index = 0; + migraphx::run_passes(capture_p, + {migraphx::capture_arguments_pass{{"dot"}, calc, ¶m_index}}); + + p.compile(migraphx::ref::target{}); + capture_p.compile(migraphx::ref::target{}); + + auto cap_res = capture_p.eval({}).back(); + auto res = p.eval({}).back(); + + std::vector vec; + std::vector cap_vec; + cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); }); + res.visit([&](auto output) { vec.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(vec, cap_vec)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/reduce_dims.cpp b/test/reduce_dims.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c5328ac167eb10ff64ce22229c8ce7c3266ad25 --- /dev/null +++ b/test/reduce_dims.cpp @@ -0,0 +1,141 @@ +#include +#include +#include "test.hpp" + +migraphx::shape make_shape(std::vector lens) +{ + return {migraphx::shape::float_type, std::move(lens)}; +} + +migraphx::shape make_shape(std::vector lens, std::vector strides) +{ + return {migraphx::shape::float_type, std::move(lens), std::move(strides)}; +} + +TEST_CASE(same_standard) +{ + auto is = make_shape({64, 3, 7, 7}); + auto os = make_shape({64 * 3 * 7 * 7}); + std::vector ishapes = {is, is, is}; + std::vector eshapes = {os, os, os}; + auto rshapes = migraphx::reduce_dims(ishapes); + + EXPECT(eshapes == rshapes); +} + +TEST_CASE(same_broadcast1) +{ + auto is = make_shape({64, 3, 7, 7}); + auto os = make_shape({64, 3, 7 * 7}); + std::vector ishapes = {is, make_shape({64, 3, 7, 7}, {0, 1, 0, 0}), is}; + std::vector eshapes = {os, make_shape({64, 3, 7 * 7}, {0, 1, 0}), os}; + auto rshapes = migraphx::reduce_dims(ishapes); + + EXPECT(eshapes == rshapes); +} + +TEST_CASE(same_broadcast2) +{ + auto is = make_shape({64, 3, 8, 7, 7}); + auto os = make_shape({64, 8 * 3, 7 * 7}); + std::vector ishapes = {is, make_shape({64, 3, 8, 7, 7}, {0, 8, 1, 0, 0}), is}; + std::vector eshapes = {os, make_shape({64, 8 * 3, 7 * 7}, {0, 1, 0}), os}; + auto rshapes = migraphx::reduce_dims(ishapes); + + EXPECT(eshapes == rshapes); +} + +TEST_CASE(same_transposed) +{ + auto is = make_shape({64, 3, 7, 7}); + auto os = make_shape({64 * 3, 7, 7}); + std::vector ishapes = {is, migraphx::reorder_shape(is, {0, 1, 3, 2}), is}; + std::vector eshapes = {os, migraphx::reorder_shape(os, {0, 2, 1}), os}; + auto rshapes = migraphx::reduce_dims(ishapes); + + EXPECT(eshapes == rshapes); +} + +TEST_CASE(different_masked1) +{ + auto is = make_shape({64, 3, 7, 7}); + auto os = make_shape({64, 3, 7 * 7}); + std::vector ishapes = {is, make_shape({1, 3, 1, 1}), is}; + std::vector eshapes = {os, make_shape({1, 3, 1}), os}; + auto rshapes = migraphx::reduce_dims(ishapes); + + EXPECT(eshapes == rshapes); +} + +TEST_CASE(different_masked2) +{ + auto is = make_shape({64, 3, 7, 7}); + auto os = make_shape({64, 3, 7 * 7}); + std::vector ishapes = { + is, make_shape({1, 3, 1, 1}), make_shape({64, 1, 7, 7})}; + std::vector eshapes = {os, make_shape({1, 3, 1}), make_shape({64, 1, 7 * 7})}; + auto rshapes = migraphx::reduce_dims(ishapes); + + EXPECT(eshapes == rshapes); +} + +TEST_CASE(different_incompatible) +{ + auto is = make_shape({64, 3, 7, 7}); + std::vector ishapes = {is, make_shape({1, 3, 2, 1}), is}; + auto rshapes = migraphx::reduce_dims(ishapes); + + EXPECT(ishapes == rshapes); +} + +TEST_CASE(different_ranks) +{ + auto is = make_shape({64, 3, 7, 7}); + std::vector ishapes = {is, make_shape({1, 3}), is}; + auto rshapes = migraphx::reduce_dims(ishapes); + + EXPECT(ishapes == rshapes); +} + +TEST_CASE(transposed1) +{ + std::vector ishapes = { + make_shape({8, 28, 4, 56, 56}), + make_shape({8, 28, 4, 56, 56}, {351232, 3136, 87808, 56, 1})}; + std::vector eshapes = { + make_shape({8, 28, 4, 56 * 56}), make_shape({8, 28, 4, 56 * 56}, {351232, 3136, 87808, 1})}; + auto rshapes = migraphx::reduce_dims(ishapes); + + EXPECT(eshapes == rshapes); +} + +TEST_CASE(non_packed_empty1) +{ + std::vector ishapes = {make_shape({1, 12}, {589824, 64})}; + std::vector eshapes = {make_shape({12}, {64})}; + auto rshapes = migraphx::reduce_dims(ishapes); + EXPECT(eshapes == rshapes); +} + +TEST_CASE(non_packed_empty2) +{ + std::vector ishapes = {make_shape({12, 1}, {64, 589824})}; + std::vector eshapes = {make_shape({12}, {64})}; + auto rshapes = migraphx::reduce_dims(ishapes); + EXPECT(eshapes == rshapes); +} + +TEST_CASE(single_dim) +{ + std::vector ishapes = {make_shape({1}, {1})}; + auto rshapes = migraphx::reduce_dims(ishapes); + EXPECT(ishapes == rshapes); +} + +TEST_CASE(empty) +{ + auto rshapes = migraphx::reduce_dims({}); + EXPECT(rshapes.empty()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/ref_dev_examples.cpp b/test/ref_dev_examples.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d6d4a4a689097124e6ef16213c25f885d03ae5db --- /dev/null +++ b/test/ref_dev_examples.cpp @@ -0,0 +1,151 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "test.hpp" + +/*! + * Example MIGraphX programs for following the Contributor's Guide. + */ + +TEST_CASE(add_two_literals) +{ + /*! + * Simple MIGraphX program to add two literal values. + * Equivalent to adding two constant scalar values together. + */ + // create the program a get a pointer to the main module + migraphx::program p; + auto* mm = p.get_main_module(); + + // add two literals to the program + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + + // make the "add" operation between the two literals and add it to the program + mm->add_instruction(migraphx::make_op("add"), one, two); + + // compile the program on the reference device + p.compile(migraphx::ref::target{}); + + // evaulate the program and retreive the result + auto result = p.eval({}).back(); + std::cout << "add_two_literals: 1 + 2 = " << result << "\n"; + EXPECT(result.at() == 3); +} + +TEST_CASE(add_parameters) +{ + /*! + * Modified version of MIGraphX program seen in add_two_literals to accept a parameter. + * Equivalent to adding a constant scalar value with another scalar input. + */ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int32_type, {1}}; + + // add a "x" parameter with the shape s + auto x = mm->add_parameter("x", s); + auto two = mm->add_literal(2); + + // add the "add" instruction between the "x" parameter and "two" to the module + mm->add_instruction(migraphx::make_op("add"), x, two); + p.compile(migraphx::ref::target{}); + + // create a parameter_map object for passing a value to the "x" parameter + std::vector data = {4}; + migraphx::parameter_map params; + params["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(params).back(); + std::cout << "add_parameters: 4 + 2 = " << result << "\n"; + EXPECT(result.at() == 6); +} + +TEST_CASE(handling_tensors) +{ + /*! + * This example does a convolution operation over an input tensor using the given weighting + * tensor. This is meant to show an example of working with tensors in MIGraphX. The output + * tensor is compared against a precomputed solution tensor at the end of the program. + */ + migraphx::program p; + auto* mm = p.get_main_module(); + + // create shape objects for the input tensor and weights + migraphx::shape input_shape{migraphx::shape::float_type, {2, 3, 4, 4}}; + migraphx::shape weights_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; + + // create the parameters and add the "convolution" operation to the module + auto input = mm->add_parameter("X", input_shape); + auto weights = mm->add_parameter("W", weights_shape); + mm->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}), + input, + weights); + + p.compile(migraphx::ref::target{}); + + // Allocated buffers by the user + std::vector a = { + 2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712, + -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606, + 0.02345525, -0.33579525, 0.38901961, 1.05473483, -1.31188095, 1.8963089, -0.07265259, + 0.947339, 0.41949373, -0.70814759, 0.25892952, 1.07311416, 1.2571274, -0.62318051, + -0.19951548, -0.94232577, -0.29393643, 0.42292568, -0.80230367, 1.40909171, 0.63617158, + 0.13900366, 1.09253144, -0.15265895, 1.54781747, 0.72780299, 1.09189606, -0.38068101, + 0.97057933, -0.58958799, 1.56188643, 0.21474874, 0.58725154, -1.27097559, -0.03024297, + 1.09437096, -0.4897908, 0.34838957, -1.31042492, -1.69069934, 0.86956722, -0.40457946, + 0.46691212, 1.29273605, 0.26464137, 0.22073045, -1.02178168, 0.22163901, -1.84387338, + 0.75522131, -0.45775682, -0.42241111, -1.50944722, 1.07256448, -1.95876884, -0.28106022, + 0.3341668, 2.13129425, -1.14728117, -1.06555498, -0.298444, -0.88322699, -0.65866792, + -2.06007552, 0.01374334, 0.45612028, 0.52715492, 1.01914406, -1.72659791, 0.80650896, + 0.16860051, 2.24112225, -0.78620857, 0.36566174, -0.07020134, -0.47976932, -0.68230027, + -0.94711417, -0.54506505, 1.66504931, -0.71860826, 0.61132306}; + + std::vector c = { + -0.14601797, -0.13000923, 0.06521662, 0.06178288, -0.11083675, 0.10154136, 0.09990512, + 0.06030385, -0.11374587, -0.17523311, -0.14344215, 0.17802463, 0.06300922, -0.15325832, + 0.07066704, 0.05166031, 0.00615084, -0.02606523, 0.08083995, -0.17913306, 0.0624622, + 0.0735731, -0.04198661, -0.0164391, -0.06374192, 0.16569914, 0.10681538, 0.07370754, + 0.02802075, 0.00282027, 0.15104802, -0.11084409, -0.00197773, 0.07924436, 0.03528272, + 0.04765259, -0.15896152, 0.07917164, 0.12125669, -0.1154705, -0.11999125, 0.12749968, + -0.06269585, 0.18658121, -0.03944227, 0.0111798, -0.17731084, 0.11789055, -0.09982193, + 0.08142821, 0.0729029, 0.11303909, 0.12735154, 0.03885292}; + + // Solution vector + std::vector sol = {-0.20817225, + 0.87965256, + 0.14958936, + -1.24887264, + -0.06540672, + 0.20778663, + 0.40456355, + -0.99900877, + 0.4917807, + 0.1994698, + 0.64205718, + 0.37798831, + -0.25315839, + 0.44276932, + -0.16138598, + 0.79344082}; + + // Create the arguments in a parameter_map + migraphx::parameter_map params; + params["X"] = migraphx::argument(input_shape, a.data()); + params["W"] = migraphx::argument(weights_shape, c.data()); + + // Evaluate and confirm the result + auto result = p.eval(params).back(); + std::vector results_vector(64); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(results_vector, sol)); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/cpu_dot_op_test.cpp b/test/ref_dot_op_test.cpp similarity index 76% rename from test/cpu_dot_op_test.cpp rename to test/ref_dot_op_test.cpp index a63f084cc7fbef193bacbae60388b84a286ce43e..6660338bd43293d08a72e54d17ad51ddfe1f796f 100644 --- a/test/cpu_dot_op_test.cpp +++ b/test/ref_dot_op_test.cpp @@ -1,11 +1,13 @@ #include #include #include -#include #include -#include +#include #include #include +#include +#include + #include "test.hpp" #include @@ -13,6 +15,8 @@ template void matmul_test() { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885, 1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027, -0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632, @@ -45,12 +49,12 @@ void matmul_test() 2.16294914e+00, -1.48101497e-01}; migraphx::shape a_shape{migraphx::shape::get_type{}, {4, 5}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); migraphx::shape b_shape{migraphx::shape::get_type{}, {5, 3}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - p.add_instruction(migraphx::op::dot{}, al, bl); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + mm->add_instruction(migraphx::make_op("dot"), al, bl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector results_vector; result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(c, results_vector)); @@ -62,6 +66,8 @@ template void matmul_test_ex() { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885, 1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027, -0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632, @@ -94,12 +100,12 @@ void matmul_test_ex() 2.16294914e+00, -1.48101497e-01}; migraphx::shape a_shape{migraphx::shape::get_type{}, {1, 1, 4, 5}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); migraphx::shape b_shape{migraphx::shape::get_type{}, {1, 1, 5, 3}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - p.add_instruction(migraphx::op::dot{}, al, bl); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + mm->add_instruction(migraphx::make_op("dot"), al, bl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector results_vector; result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(c, results_vector)); @@ -110,6 +116,8 @@ TEST_CASE_REGISTER(matmul_test_ex) TEST_CASE(matmul_mutli_dim_2) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector m1 = {-0.76234141, 0.01368910, -0.86343423, @@ -129,12 +137,12 @@ TEST_CASE(matmul_mutli_dim_2) -0.69359678, -0.26334436, 1.56292796, -0.33629175, -1.72693469, 0.41435494, 1.52136843, -0.40699791, -1.59839430}; migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}}; - auto l1 = p.add_literal(migraphx::literal{m1_shape, m1}); - auto l2 = p.add_literal(migraphx::literal{m2_shape, m2}); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1}); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2}); - p.add_instruction(migraphx::op::dot{}, l1, l2); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + mm->add_instruction(migraphx::make_op("dot"), l1, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); @@ -161,6 +169,8 @@ TEST_CASE(matmul_mutli_dim_2) TEST_CASE(gemm_mutli_dim_2_beta0) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector m1 = {-0.76234141, 0.01368910, -0.86343423, @@ -197,14 +207,18 @@ TEST_CASE(gemm_mutli_dim_2_beta0) 0.40245487, 1.80182751}; migraphx::shape m3_shape{migraphx::shape::float_type, {2, 2, 4}}; - auto l1 = p.add_literal(migraphx::literal{m1_shape, m1}); - auto l2 = p.add_literal(migraphx::literal{m2_shape, m2}); - auto l3 = p.add_literal(migraphx::literal{m3_shape, m3}); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1}); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2}); + auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3}); float alpha = 1.0f; float beta = 0.0f; - p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + migraphx::add_apply_alpha_beta(*mm, + std::vector{l1, l2, l3}, + migraphx::make_op("dot"), + alpha, + beta); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); @@ -231,6 +245,8 @@ TEST_CASE(gemm_mutli_dim_2_beta0) TEST_CASE(gemm_beta_0) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector m1 = { -0.76234141, 0.01368910, -0.86343423, -0.99465282, 0.76133268, 0.96507140}; migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; @@ -257,15 +273,19 @@ TEST_CASE(gemm_beta_0) 1.00521735, -0.95536130, 2.27996211}; - auto l1 = p.add_literal(migraphx::literal{m1_shape, m1}); - auto l2 = p.add_literal(migraphx::literal{m2_shape, m2}); - auto l3 = p.add_literal(migraphx::literal{m3_shape, m3}); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1}); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2}); + auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3}); float alpha = 1.0f; float beta = 0.0f; - p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + migraphx::add_apply_alpha_beta(*mm, + std::vector{l1, l2, l3}, + migraphx::make_op("dot"), + alpha, + beta); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); @@ -284,6 +304,8 @@ TEST_CASE(gemm_beta_0) TEST_CASE(matmul_mutli_dim_2_3) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector m1 = { -1.93300070, 0.33902698, -0.45173527, -0.72283069, -0.17177134, 1.62199882, 0.87052847, 0.14989811, -0.88969184, -0.18131398, 0.72654339, -0.57123693, @@ -300,12 +322,12 @@ TEST_CASE(matmul_mutli_dim_2_3) -0.26877787, -1.90886366, 0.30622790, 0.59794535, 1.29795331, -0.37805803, -1.58167176, -1.26966832, 0.27435891, 0.89430347, 0.22854926, -0.50317658}; migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}}; - auto l1 = p.add_literal(migraphx::literal{m1_shape, m1}); - auto l2 = p.add_literal(migraphx::literal{m2_shape, m2}); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1}); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2}); - p.add_instruction(migraphx::op::dot{}, l1, l2); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + mm->add_instruction(migraphx::make_op("dot"), l1, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); @@ -321,6 +343,8 @@ TEST_CASE(matmul_mutli_dim_2_3) TEST_CASE(gemm_mutli_dim1_2_3) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector m1 = { 1.23636469, -0.47041261, -0.14375651, -0.48371852, 1.16479301, -0.89361055, -0.18569086, 1.10700457, -1.02632638, 0.82277012, 0.33525769, 0.52825145, @@ -344,18 +368,20 @@ TEST_CASE(gemm_mutli_dim1_2_3) 0.49759611, 0.10021662, 0.00592602, 0.90862000}; migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}}; - auto l1 = p.add_literal(migraphx::literal{m1_shape, m1}); - auto l2 = p.add_literal(migraphx::literal{m2_shape, m2}); - auto l3 = p.add_literal(migraphx::literal{m3_shape, m3}); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1}); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2}); + auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3}); float alpha = 0.35; float beta = 0.41; - auto m12_alpha = p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2); - auto l_beta = p.add_literal(beta); - auto b_beta = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape().lens()}, l_beta); - auto m3_beta = p.add_instruction(migraphx::op::mul{}, b_beta, l3); - p.add_instruction(migraphx::op::add{}, m3_beta, m12_alpha); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + auto m12_alpha = migraphx::add_apply_alpha_beta( + *mm, std::vector{l1, l2}, migraphx::make_op("dot"), alpha); + auto l_beta = mm->add_literal(beta); + auto b_beta = mm->add_instruction( + migraphx::make_op("scalar", {{"scalar_bcst_dims", m12_alpha->get_shape().lens()}}), l_beta); + auto m3_beta = mm->add_instruction(migraphx::make_op("mul"), b_beta, l3); + mm->add_instruction(migraphx::make_op("add"), m3_beta, m12_alpha); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); @@ -371,6 +397,8 @@ TEST_CASE(gemm_mutli_dim1_2_3) TEST_CASE(gemm_mutli_3args) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector m1 = { 1.23636469, -0.47041261, -0.14375651, -0.48371852, 1.16479301, -0.89361055, -0.18569086, 1.10700457, -1.02632638, 0.82277012, 0.33525769, 0.52825145, @@ -394,14 +422,18 @@ TEST_CASE(gemm_mutli_3args) 0.49759611, 0.10021662, 0.00592602, 0.90862000}; migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}}; - auto l1 = p.add_literal(migraphx::literal{m1_shape, m1}); - auto l2 = p.add_literal(migraphx::literal{m2_shape, m2}); - auto l3 = p.add_literal(migraphx::literal{m3_shape, m3}); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1}); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2}); + auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3}); float alpha = 0.35; float beta = 0.41; - p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + migraphx::add_apply_alpha_beta(*mm, + std::vector{l1, l2, l3}, + migraphx::make_op("dot"), + alpha, + beta); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); @@ -418,6 +450,8 @@ TEST_CASE(gemm_3args) { { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = {-0.86217194, -1.04129542, -0.64850364, @@ -453,12 +487,12 @@ TEST_CASE(gemm_3args) 2.11031439}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); migraphx::shape b_shape{migraphx::shape::float_type, {4, 3}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); migraphx::shape c_shape{migraphx::shape::float_type, {3, 3}}; - auto cl = p.add_literal(migraphx::literal{c_shape, c}); - p.add_instruction(migraphx::op::dot{}, al, bl, cl); + auto cl = mm->add_literal(migraphx::literal{c_shape, c}); + migraphx::add_apply_alpha_beta(*mm, {al, bl, cl}, migraphx::make_op("dot"), 1.0f, 1.0f); std::vector gold = {-1.60947, 0.703083, -5.46156, @@ -468,8 +502,8 @@ TEST_CASE(gemm_3args) -0.835966, 5.74736, 4.22063}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -480,6 +514,8 @@ TEST_CASE(matmul_vv_inner_product) { { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = {0.7481789, 0.02906279, 1.01193836, @@ -498,14 +534,14 @@ TEST_CASE(matmul_vv_inner_product) -0.2342857}; migraphx::shape a_shape{migraphx::shape::float_type, {8}}; migraphx::shape b_shape{migraphx::shape::float_type, {8}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al); - auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl); - p.add_instruction(migraphx::op::dot{}, ual, ubl); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al); + auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl); + mm->add_instruction(migraphx::make_op("dot"), ual, ubl); std::vector gold = {-1.43461}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -513,6 +549,8 @@ TEST_CASE(matmul_vv_inner_product) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = {0.7481789, 0.02906279, 1.01193836, @@ -531,15 +569,16 @@ TEST_CASE(matmul_vv_inner_product) -0.2342857}; migraphx::shape a_shape{migraphx::shape::float_type, {8}}; migraphx::shape b_shape{migraphx::shape::float_type, {8}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al); - auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al); + auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl); float alpha = 0.32f; - p.add_instruction(migraphx::op::dot{alpha}, ual, ubl); + migraphx::add_apply_alpha_beta( + *mm, std::vector{ual, ubl}, migraphx::make_op("dot"), alpha); std::vector gold = {-0.4590752}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -550,6 +589,8 @@ TEST_CASE(matmul_vm) { { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = {1.49530002, -0.07181969, 0.44593846, @@ -567,15 +608,15 @@ TEST_CASE(matmul_vm) 1.21119765, 1.23869861, 1.42169414, 0.86412382, 1.05898002, -0.31918307, 1.08546695, 1.50682711, -0.66083538, -0.32683929}; migraphx::shape a_shape{migraphx::shape::float_type, {8}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al); migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - p.add_instruction(migraphx::op::dot{}, ual, bl); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + mm->add_instruction(migraphx::make_op("dot"), ual, bl); std::vector gold = {-3.78111, -3.40007, -2.1972, -3.31448, -3.80326}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -583,6 +624,8 @@ TEST_CASE(matmul_vm) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = {1.49530002, -0.07181969, 0.44593846, @@ -600,16 +643,17 @@ TEST_CASE(matmul_vm) 1.21119765, 1.23869861, 1.42169414, 0.86412382, 1.05898002, -0.31918307, 1.08546695, 1.50682711, -0.66083538, -0.32683929}; migraphx::shape a_shape{migraphx::shape::float_type, {8}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al); migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); float alpha = 0.5f; - p.add_instruction(migraphx::op::dot{alpha}, ual, bl); + migraphx::add_apply_alpha_beta( + *mm, std::vector{ual, bl}, migraphx::make_op("dot"), alpha); std::vector gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -617,6 +661,8 @@ TEST_CASE(matmul_vm) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = { -1.7468318, -0.38900251, 1.00183915, 0.06016438, 0.08295905, 1.5830535}; std::vector b = { @@ -634,12 +680,13 @@ TEST_CASE(matmul_vm) -0.18205627, 0.29446203, -1.91360924, 0.46102174, 0.44977568, -0.48113321}; migraphx::shape a_shape{migraphx::shape::float_type, {6}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al); - auto bual = p.add_instruction(migraphx::op::multibroadcast{{3, 1, 6}}, ual); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al); + auto bual = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual); migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - p.add_instruction(migraphx::op::dot{}, bual, bl); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + mm->add_instruction(migraphx::make_op("dot"), bual, bl); std::vector gold = {1.22914, -1.17896, 2.28596, @@ -652,8 +699,8 @@ TEST_CASE(matmul_vm) 1.38484, -2.45019, -1.35064}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -661,6 +708,8 @@ TEST_CASE(matmul_vm) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = { -1.7468318, -0.38900251, 1.00183915, 0.06016438, 0.08295905, 1.5830535}; std::vector b = { @@ -678,12 +727,14 @@ TEST_CASE(matmul_vm) -0.18205627, 0.29446203, -1.91360924, 0.46102174, 0.44977568, -0.48113321}; migraphx::shape a_shape{migraphx::shape::float_type, {6}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al); - auto bual = p.add_instruction(migraphx::op::multibroadcast{{3, 1, 6}}, ual); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al); + auto bual = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual); migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - p.add_instruction(migraphx::op::dot{0.21f}, bual, bl); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + migraphx::add_apply_alpha_beta( + *mm, std::vector{bual, bl}, migraphx::make_op("dot"), 0.21f); std::vector gold = {0.25812, -0.247582, 0.480051, @@ -696,8 +747,8 @@ TEST_CASE(matmul_vm) 0.290817, -0.514539, -0.283635}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -708,6 +759,8 @@ TEST_CASE(matmul_mv) { { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = {0.1612524, 0.61266466, -0.19212896, @@ -727,14 +780,14 @@ TEST_CASE(matmul_mv) std::vector b = {0.14365572, 0.23401411, -0.8970094, -0.12526676, -1.04703286}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 5}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); migraphx::shape b_shape{migraphx::shape::float_type, {5}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl); - p.add_instruction(migraphx::op::dot{}, al, ubl); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl); + mm->add_instruction(migraphx::make_op("dot"), al, ubl); std::vector gold = {1.31982, 1.19022, -1.96062}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -742,6 +795,8 @@ TEST_CASE(matmul_mv) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = {0.1612524, 0.61266466, -0.19212896, @@ -761,15 +816,16 @@ TEST_CASE(matmul_mv) std::vector b = {0.14365572, 0.23401411, -0.8970094, -0.12526676, -1.04703286}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 5}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); migraphx::shape b_shape{migraphx::shape::float_type, {5}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl); float alpha = 0.3f; - p.add_instruction(migraphx::op::dot{alpha}, al, ubl); + migraphx::add_apply_alpha_beta( + *mm, std::vector{al, ubl}, migraphx::make_op("dot"), alpha); std::vector gold = {0.395946, 0.357067, -0.588187}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -777,6 +833,8 @@ TEST_CASE(matmul_mv) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = { 1.24593227, -0.84351316, 0.27882229, -0.42518484, -1.11391528, 0.59141834, 1.34198714, 2.25884063, -1.32093452, 0.44766336, -0.09306479, 0.47526699, @@ -791,12 +849,13 @@ TEST_CASE(matmul_mv) std::vector b = {0.05013914, 1.39932885, 2.56616476, 1.02225623, -0.03977829}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); migraphx::shape b_shape{migraphx::shape::float_type, {5}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl); - auto bubl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 1}}, ubl); - p.add_instruction(migraphx::op::dot{}, al, bubl); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl); + auto bubl = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 1}}}), ubl); + mm->add_instruction(migraphx::make_op("dot"), al, bubl); std::vector gold = {-0.792717, 6.33595, 2.61466, @@ -809,8 +868,8 @@ TEST_CASE(matmul_mv) 2.87146, 3.29447, 0.765651}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -821,6 +880,8 @@ TEST_CASE(matmul_mm1) { { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = { -0.49450006, -1.07431991, -0.02796692, -0.99631927, 0.20040449, -1.39709437, -0.15695328, 0.08208373, -0.09746386, 0.77923021, -0.1849151, 0.14419043, @@ -849,19 +910,20 @@ TEST_CASE(matmul_mm1) -1.62587164}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl); - p.add_instruction(migraphx::op::dot{}, al, bbl); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + auto bbl = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 3}}}), bl); + mm->add_instruction(migraphx::make_op("dot"), al, bbl); std::vector gold = {-0.386828, 0.187735, -0.22822, -0.148057, 2.015, -2.56938, -0.782212, 1.9459, 0.927426, -2.44907, 2.40531, 2.30232, 0.182745, -4.21937, 1.77551, 1.50775, -2.60888, -2.32484, -0.557691, 6.13527, -2.91743, 2.37836, -6.42584, 1.14979, 0.77227, 0.349659, 2.92759, 2.32384, -2.90664, 0.0527679, -0.547761, -0.155467, 0.964619, 2.09133, -4.44281, -1.3864}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -869,6 +931,8 @@ TEST_CASE(matmul_mm1) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = {-0.0309568, -1.57294749, -0.00768606, @@ -897,11 +961,12 @@ TEST_CASE(matmul_mm1) -0.14231862, -1.90915568, -0.06895489, 0.20160375, 0.01945916, 0.03586956}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - auto bal = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, al); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + auto bal = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 3, 4}}}), al); migraphx::shape b_shape{migraphx::shape::float_type, {2, 3, 4, 3}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - p.add_instruction(migraphx::op::dot{}, bal, bl); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + mm->add_instruction(migraphx::make_op("dot"), bal, bl); std::vector gold = { -1.61175, 3.11849, -0.703205, 0.331635, -0.00946922, 0.645626, 0.834069, 1.06409, 0.881037, 0.227628, -0.200308, -1.71836, 0.156255, 0.477222, 0.571363, -1.04543, @@ -910,8 +975,8 @@ TEST_CASE(matmul_mm1) -0.710558, 0.259424, -0.342345, -1.80522, -0.580476, 0.277368, -3.95582, 0.614823, -0.415107, 0.305138, 0.435993, -0.107089, -0.767885, -4.00837, 1.09921, -2.02129, 0.109717, 0.618422, 0.438342, 0.29602, 2.00928, 0.420871}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -922,6 +987,8 @@ TEST_CASE(matmul_mm2) { { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = { -0.49450006, -1.07431991, -0.02796692, -0.99631927, 0.20040449, -1.39709437, -0.15695328, 0.08208373, -0.09746386, 0.77923021, -0.1849151, 0.14419043, @@ -940,10 +1007,11 @@ TEST_CASE(matmul_mm2) 1.7746011, 0.24935804, 0.42830791, -0.13593643, 0.38749427, 1.39776254, -0.42911717, -1.3537624, -0.81999648, -0.1754485}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + auto bbl = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 3}}}), bl); std::vector gold = { 0.70574512, -2.80915314, -1.57644969, 1.75415381, -3.13303087, -1.00150259, -0.18675123, -0.23349122, -0.12357225, 0.82911538, 1.37473744, -1.11709934, @@ -951,9 +1019,9 @@ TEST_CASE(matmul_mm2) 4.81988916, -3.63687142, -0.19101717, -4.92522092, -1.76377022, -3.58095615, 1.83096922, 2.5512663, -1.07926588, -2.12749134, 0.33014536, -0.80393025, 0.60740202, 0.95217761, -1.06087445, -4.75868152, -3.6687713, -1.26539821}; - p.add_instruction(migraphx::op::dot{}, al, bbl); - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + mm->add_instruction(migraphx::make_op("dot"), al, bbl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -961,6 +1029,8 @@ TEST_CASE(matmul_mm2) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = {-0.19276159, -1.2568421, -0.321242, 1.21471077, -0.4927751, 0.69446894, -0.1786371, -1.00763473, -0.10279314, 3.02931355, 1.08359235, -0.35190132, -0.00639111, 0.78989113, 1.23538029, @@ -975,12 +1045,14 @@ TEST_CASE(matmul_mm2) 1.7746011, 0.24935804, 0.42830791, -0.13593643, 0.38749427, 1.39776254, -0.42911717, -1.3537624, -0.81999648, -0.1754485}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 5}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); - auto bal = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 3, 5}}, al); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + auto bal = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 5}}}), al); migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl); - p.add_instruction(migraphx::op::dot{}, bal, bbl); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + auto bbl = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 3}}}), bl); + mm->add_instruction(migraphx::make_op("dot"), bal, bbl); std::vector gold = { 1.64924590e+00, 2.84575831e+00, 1.07340773e+00, 2.19817080e-01, -1.87873283e+00, 1.91883003e+00, -2.89962196e-01, 2.76404142e+00, 1.50048102e+00, -6.29650347e-01, @@ -990,8 +1062,8 @@ TEST_CASE(matmul_mm2) 1.02442564e-01, -1.87659303e+00, -4.67302454e-01, 9.16189968e-01, -1.33537175e-01, 8.27398578e-01, 1.94406914e+00, -2.39250915e-01, -1.77062701e+00, -6.46239534e-01, -7.95202750e-01}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -999,6 +1071,8 @@ TEST_CASE(matmul_mm2) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = { -0.55248691, 0.70275958, 0.56967633, 0.88206033, -0.85088547, 0.05689149, -0.20084703, 0.18024434, 1.0730491, 0.15913531, 0.93621628, 0.35072771, @@ -1025,10 +1099,10 @@ TEST_CASE(matmul_mm2) -0.88370924, 0.95294025, -0.08208804, -0.95943892, 0.30280474, 1.1967013, -1.17700948, 0.29533973}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 4}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); migraphx::shape b_shape{migraphx::shape::float_type, {2, 2, 4, 5}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - p.add_instruction(migraphx::op::dot{}, al, bl); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + mm->add_instruction(migraphx::make_op("dot"), al, bl); std::vector gold = { 1.22136035, 1.3765651, 2.0611395, 1.70445494, 1.8189619, 0.2509717, 0.88815736, 1.13837946, 1.37006127, -0.53617378, 0.45759693, -0.503786, @@ -1040,8 +1114,8 @@ TEST_CASE(matmul_mm2) -0.61459168, -0.52561056, 0.3309648, -0.46185697, -1.60586695, -0.98590829, 0.63012062, -0.25606052, -0.69419352, -1.78299913, -0.38572706, 1.92249442, 0.3884186, -0.48153048, 0.84932351, 0.67234919, -1.07821322, -0.01208216}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1049,6 +1123,8 @@ TEST_CASE(matmul_mm2) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector a = { -0.55248691, 0.70275958, 0.56967633, 0.88206033, -0.85088547, 0.05689149, -0.20084703, 0.18024434, 1.0730491, 0.15913531, 0.93621628, 0.35072771, @@ -1069,11 +1145,12 @@ TEST_CASE(matmul_mm2) -0.36340925, -1.76152377, -0.96642674, -0.79231929, 0.11517073}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 4}}; - auto al = p.add_literal(migraphx::literal{a_shape, a}); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); migraphx::shape b_shape{migraphx::shape::float_type, {2, 4, 5}}; - auto bl = p.add_literal(migraphx::literal{b_shape, b}); - auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 4, 5}}, bl); - p.add_instruction(migraphx::op::dot{}, al, bbl); + auto bl = mm->add_literal(migraphx::literal{b_shape, b}); + auto bbl = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 4, 5}}}), bl); + mm->add_instruction(migraphx::make_op("dot"), al, bbl); std::vector gold = { -1.08585245, 0.39575611, 0.33947977, -0.86339678, 1.50710753, 0.05646156, -0.43180359, 0.19639674, -0.33742881, 0.98443538, -0.9021272, 1.25043704, @@ -1085,8 +1162,8 @@ TEST_CASE(matmul_mm2) 1.38307367, 0.42677257, 0.83759966, -0.34827442, -1.45067092, 2.09599671, 1.92882983, -0.30996324, 2.19736278, 2.32389426, 2.36741832, 1.62253915, 0.26698225, -0.00741609, -2.53680983, -0.0679954, 0.04499683, 0.85354276}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1097,6 +1174,8 @@ TEST_CASE(quant_dot_2args_multi4) { { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 8}}; std::vector data1(4 * 4); @@ -1104,16 +1183,16 @@ TEST_CASE(quant_dot_2args_multi4) std::iota(data1.begin(), data1.end(), 0); std::iota(data2.begin(), data2.end(), 0); - auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); - auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); - p.add_instruction(migraphx::op::quant_dot{}, l1, l2); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); + mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2); std::vector gold = {112, 118, 124, 130, 136, 142, 148, 154, 304, 326, 348, 370, 392, 414, 436, 458, 496, 534, 572, 610, 648, 686, 724, 762, 688, 742, 796, 850, 904, 958, 1012, 1066}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1121,6 +1200,8 @@ TEST_CASE(quant_dot_2args_multi4) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 8}}; std::vector data1(4 * 4); @@ -1128,17 +1209,18 @@ TEST_CASE(quant_dot_2args_multi4) std::iota(data1.begin(), data1.end(), 0); std::iota(data2.begin(), data2.end(), 0); - auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); - auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); - auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); - p.add_instruction(migraphx::op::quant_dot{}, tl1, l2); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); + auto tl1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); + mm->add_instruction(migraphx::make_op("quant_dot"), tl1, l2); std::vector gold = {448, 472, 496, 520, 544, 568, 592, 616, 496, 524, 552, 580, 608, 636, 664, 692, 544, 576, 608, 640, 672, 704, 736, 768, 592, 628, 664, 700, 736, 772, 808, 844}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1146,6 +1228,8 @@ TEST_CASE(quant_dot_2args_multi4) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 4}}; std::vector data1(4 * 4); @@ -1153,17 +1237,18 @@ TEST_CASE(quant_dot_2args_multi4) std::iota(data1.begin(), data1.end(), 0); std::iota(data2.begin(), data2.end(), 0); - auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); - auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); - auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2); - p.add_instruction(migraphx::op::quant_dot{}, l1, tl2); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); + auto tl2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); + mm->add_instruction(migraphx::make_op("quant_dot"), l1, tl2); std::vector gold = {14, 38, 62, 86, 110, 134, 158, 182, 38, 126, 214, 302, 390, 478, 566, 654, 62, 214, 366, 518, 670, 822, 974, 1126, 86, 302, 518, 734, 950, 1166, 1382, 1598}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1171,6 +1256,8 @@ TEST_CASE(quant_dot_2args_multi4) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 4}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 4}}; std::vector data1(4 * 4); @@ -1178,18 +1265,20 @@ TEST_CASE(quant_dot_2args_multi4) std::iota(data1.begin(), data1.end(), 0); std::iota(data2.begin(), data2.end(), 0); - auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); - auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); - auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); - auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2); - p.add_instruction(migraphx::op::quant_dot{}, tl1, tl2); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); + auto tl1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); + auto tl2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); + mm->add_instruction(migraphx::make_op("quant_dot"), tl1, tl2); std::vector gold = {56, 152, 248, 344, 440, 536, 632, 728, 62, 174, 286, 398, 510, 622, 734, 846, 68, 196, 324, 452, 580, 708, 836, 964, 74, 218, 362, 506, 650, 794, 938, 1082}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1200,6 +1289,8 @@ TEST_CASE(quant_dot_2args_general) { { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 4}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 5}}; std::vector data1(3 * 4); @@ -1207,15 +1298,15 @@ TEST_CASE(quant_dot_2args_general) std::iota(data1.begin(), data1.end(), 0); std::iota(data2.begin(), data2.end(), 0); - auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); - auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); - p.add_instruction(migraphx::op::quant_dot{}, l1, l2); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); + mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2); std::vector gold = { 70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1223,6 +1314,8 @@ TEST_CASE(quant_dot_2args_general) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 3}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 5}}; std::vector data1(4 * 3); @@ -1230,16 +1323,17 @@ TEST_CASE(quant_dot_2args_general) std::iota(data1.begin(), data1.end(), 0); std::iota(data2.begin(), data2.end(), 0); - auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); - auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); - auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); - p.add_instruction(migraphx::op::quant_dot{}, tl1, l2); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); + auto tl1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); + mm->add_instruction(migraphx::make_op("quant_dot"), tl1, l2); std::vector gold = { 210, 228, 246, 264, 282, 240, 262, 284, 306, 328, 270, 296, 322, 348, 374}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1247,6 +1341,8 @@ TEST_CASE(quant_dot_2args_general) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 4}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {5, 4}}; std::vector data1(3 * 4); @@ -1254,21 +1350,18 @@ TEST_CASE(quant_dot_2args_general) std::iota(data1.begin(), data1.end(), 0); std::iota(data2.begin(), data2.end(), 0); - auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); - auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); - auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2); - p.add_instruction( - migraphx::op::quant_dot{ - 2, - }, - l1, - tl2); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); + auto tl2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); + + migraphx::add_apply_alpha_beta(*mm, {l1, tl2}, migraphx::make_op("quant_dot"), 2); std::vector gold = { 28, 76, 124, 172, 220, 76, 252, 428, 604, 780, 124, 428, 732, 1036, 1340}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1276,6 +1369,8 @@ TEST_CASE(quant_dot_2args_general) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape m1_shape{migraphx::shape::int8_type, {4, 3}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {5, 4}}; std::vector data1(4 * 3); @@ -1283,17 +1378,19 @@ TEST_CASE(quant_dot_2args_general) std::iota(data1.begin(), data1.end(), 0); std::iota(data2.begin(), data2.end(), 0); - auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); - auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); - auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); - auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2); - p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); + auto tl1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); + auto tl2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); + migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), 3); std::vector gold = { 126, 342, 558, 774, 990, 144, 408, 672, 936, 1200, 162, 474, 786, 1098, 1410}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1304,6 +1401,8 @@ TEST_CASE(quant_dot_3args_general) { { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; @@ -1314,16 +1413,16 @@ TEST_CASE(quant_dot_3args_general) std::iota(data2.begin(), data2.end(), 0); std::iota(data3.begin(), data3.end(), 2); - auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); - auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); - auto l3 = p.add_literal(migraphx::literal{m3_shape, data3}); - p.add_instruction(migraphx::op::quant_dot{}, l1, l2, l3); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); + auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); + migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1); std::vector gold = { 982, 1011, 1040, 1069, 1098, 1127, 1156, 2557, 2650, 2743, 2836, 2929, 3022, 3115}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1331,6 +1430,8 @@ TEST_CASE(quant_dot_3args_general) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 4}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {4, 5}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 5}}; @@ -1341,16 +1442,15 @@ TEST_CASE(quant_dot_3args_general) std::iota(data2.begin(), data2.end(), 0); std::iota(data3.begin(), data3.end(), 0); - auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); - auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); - auto l3 = p.add_literal(migraphx::literal{m3_shape, data3}); - p.add_instruction(migraphx::op::quant_dot{1, 0}, l1, l2, l3); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); + mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2); std::vector gold = { 70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1358,6 +1458,8 @@ TEST_CASE(quant_dot_3args_general) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; @@ -1368,17 +1470,18 @@ TEST_CASE(quant_dot_3args_general) std::iota(data2.begin(), data2.end(), 0); std::iota(data3.begin(), data3.end(), 2); - auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); - auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); - auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); - auto l3 = p.add_literal(migraphx::literal{m3_shape, data3}); - p.add_instruction(migraphx::op::quant_dot{1, 3}, tl1, l2, l3); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); + auto tl1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); + auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); + migraphx::add_apply_alpha_beta(*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3); std::vector gold = { 1966, 2025, 2084, 2143, 2202, 2261, 2320, 2183, 2250, 2317, 2384, 2451, 2518, 2585}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1386,6 +1489,8 @@ TEST_CASE(quant_dot_3args_general) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; @@ -1396,17 +1501,18 @@ TEST_CASE(quant_dot_3args_general) std::iota(data2.begin(), data2.end(), 0); std::iota(data3.begin(), data3.end(), 2); - auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); - auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); - auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2); - auto l3 = p.add_literal(migraphx::literal{m3_shape, data3}); - p.add_instruction(migraphx::op::quant_dot{2, 3}, l1, tl2, l3); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); + auto tl2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); + auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); + migraphx::add_apply_alpha_beta(*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3); std::vector gold = { 286, 737, 1188, 1639, 2090, 2541, 2992, 755, 2230, 3705, 5180, 6655, 8130, 9605}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1414,6 +1520,8 @@ TEST_CASE(quant_dot_3args_general) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; @@ -1424,18 +1532,20 @@ TEST_CASE(quant_dot_3args_general) std::iota(data2.begin(), data2.end(), 0); std::iota(data3.begin(), data3.end(), 2); - auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); - auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); - auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); - auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2); - auto l3 = p.add_literal(migraphx::literal{m3_shape, data3}); - p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); + auto tl1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); + auto tl2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); + auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); + migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2); std::vector gold = { 844, 2190, 3536, 4882, 6228, 7574, 8920, 942, 2480, 4018, 5556, 7094, 8632, 10170}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1446,6 +1556,8 @@ TEST_CASE(quant_dot_3args_batch) { { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 2, 2, 4}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {2, 2, 4, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 2, 2, 7}}; @@ -1456,10 +1568,10 @@ TEST_CASE(quant_dot_3args_batch) std::iota(data2.begin(), data2.end(), 0); std::iota(data3.begin(), data3.end(), 2); - auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); - auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); - auto l3 = p.add_literal(migraphx::literal{m3_shape, data3}); - p.add_instruction(migraphx::op::quant_dot{1, 2}, l1, l2, l3); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); + auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); + migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 2); std::vector gold = { 102, 110, 118, 126, 134, 142, 150, 284, 308, 332, 356, 380, @@ -1468,8 +1580,8 @@ TEST_CASE(quant_dot_3args_batch) 5916, 6004, 6092, 6180, 6268, 6356, 9762, 9866, 9970, 10074, 10178, 10282, 10386, 11288, 11408, 11528, 11648, 11768, 11888, 12008}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); @@ -1477,6 +1589,8 @@ TEST_CASE(quant_dot_3args_batch) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 2, 4, 3}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {2, 2, 6, 4}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 2, 3, 6}}; @@ -1487,12 +1601,14 @@ TEST_CASE(quant_dot_3args_batch) std::iota(data2.begin(), data2.end(), 0); std::iota(data3.begin(), data3.end(), 2); - auto l1 = p.add_literal(migraphx::literal{m1_shape, data1}); - auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1); - auto l2 = p.add_literal(migraphx::literal{m2_shape, data2}); - auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2); - auto l3 = p.add_literal(migraphx::literal{m3_shape, data3}); - p.add_instruction(migraphx::op::quant_dot{2, 3}, tl1, tl2, l3); + auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); + auto tl1 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1); + auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); + auto tl2 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); + auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); + migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3); std::vector gold = { 90, 237, 384, 531, 678, 825, 120, 299, 478, 657, 836, 1015, @@ -1502,8 +1618,8 @@ TEST_CASE(quant_dot_3args_batch) 12258, 13237, 14216, 15195, 16174, 17153, 24012, 25311, 26610, 27909, 29208, 30507, 24618, 25949, 27280, 28611, 29942, 31273, 25224, 26587, 27950, 29313, 30676, 32039}; - p.compile(migraphx::cpu::target{}); - auto result = p.eval({}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); std::vector m; result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(m, gold)); diff --git a/test/ref_loop_test.cpp b/test/ref_loop_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d76b547cb1ea304a363be1440eb0421417d33e99 --- /dev/null +++ b/test/ref_loop_test.cpp @@ -0,0 +1,106 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "test.hpp" + +static auto run_prog(int64_t iter_num, bool cond, int64_t ini_val) +{ + migraphx::shape si{migraphx::shape::int64_type}; + migraphx::shape s{migraphx::shape::int64_type, {1}}; + migraphx::shape sc{migraphx::shape::bool_type}; + + auto create_program = [&]() { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto in_iter = mm->add_parameter("iter_num", si); + auto in_cond = mm->add_parameter("ccond", sc); + auto in_val = mm->add_parameter("val", s); + + auto* body = p.create_module("loop_module"); + auto iter = body->add_parameter("#loop_module_in_0", si); + body->add_parameter("#loop_module_in_1", sc); + auto in_v = body->add_parameter("#loop_module_in_2", s); + std::vector vd = {3}; + auto l = body->add_literal(migraphx::literal(si, vd)); + auto ad = body->add_instruction(migraphx::make_op("add"), iter, l); + auto val = body->add_instruction(migraphx::make_op("add"), in_v, ad); + auto eq = body->add_instruction(migraphx::make_op("equal"), iter, l); + auto beq = body->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), eq); + auto neq = body->add_instruction(migraphx::make_op("not"), beq); + body->add_return({neq, val, val}); + + auto rl = mm->add_instruction(migraphx::make_op("loop", {{"max_iterations", 10}}), + {in_iter, in_cond, in_val}, + {body}); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), rl); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), rl); + mm->add_return({r0, r1}); + + return p; + }; + + auto p = create_program(); + p.compile(migraphx::ref::target{}); + migraphx::parameter_map pp; + pp["iter_num"] = migraphx::argument(si, &iter_num); + pp["ccond"] = migraphx::argument(sc, &cond); + pp["val"] = migraphx::argument(s, &ini_val); + auto rets = p.eval(pp); + + std::vector> res; + for(auto& arg : rets) + { + std::vector vec; + arg.visit([&](auto v) { vec.assign(v.begin(), v.end()); }); + res.push_back(vec); + } + + return res; +} + +TEST_CASE(loop_test1) +{ + auto ress = run_prog(10, true, 1); + std::vector gold_last = {19}; + EXPECT(ress.front() == gold_last); + std::vector gold_concat = {4, 8, 13, 19, 0, 0, 0, 0, 0, 0}; + EXPECT(ress.back() == gold_concat); +} + +TEST_CASE(loop_test2) +{ + auto ress = run_prog(4, true, 1); + std::vector gold_last = {19}; + EXPECT(ress.front() == gold_last); + std::vector gold_concat = {4, 8, 13, 19, 0, 0, 0, 0, 0, 0}; + EXPECT(ress.back() == gold_concat); +} + +TEST_CASE(loop_test3) +{ + auto ress = run_prog(3, true, 1); + std::vector gold_last = {13}; + EXPECT(ress.front() == gold_last); + std::vector gold_concat = {4, 8, 13, 0, 0, 0, 0, 0, 0, 0}; + EXPECT(ress.back() == gold_concat); +} + +TEST_CASE(loop_test4) +{ + auto ress = run_prog(5, true, 2); + std::vector gold_last = {20}; + EXPECT(ress.front() == gold_last); + std::vector gold_concat = {5, 9, 14, 20, 0, 0, 0, 0, 0, 0}; + EXPECT(ress.back() == gold_concat); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/ref_ops_nonstd_shape_test.cpp b/test/ref_ops_nonstd_shape_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..11c272937e284d305ca9c3d0bc6cd164f4d63faa --- /dev/null +++ b/test/ref_ops_nonstd_shape_test.cpp @@ -0,0 +1,191 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "test.hpp" + +TEST_CASE(argmax_test_nonstd_shape) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto dl = mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 3, 4}})); + auto dl_trans = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl); + mm->add_instruction(migraphx::make_op("argmax", {{"axis", -3}}), dl_trans); + auto p_uncompiled = p; + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + auto res_gold = p_uncompiled.eval({}).back(); + std::vector result_vec; + result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); + std::vector res_gold_vec; + res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(result_vec, res_gold_vec)); +} + +TEST_CASE(argmin_test_nonstd_shape) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto dl = mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 3, 4}})); + auto dl_trans = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), dl); + mm->add_instruction(migraphx::make_op("argmin", {{"axis", -1}}), dl_trans); + auto p_uncompiled = p; + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + auto res_gold = p_uncompiled.eval({}).back(); + std::vector result_vec; + result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); + std::vector res_gold_vec; + res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(result_vec, res_gold_vec)); +} + +TEST_CASE(isnan_broadcast_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::float_type, {3}}; + migraphx::shape s1{migraphx::shape::float_type, {3, 2}}; + auto nan_val = std::numeric_limits::quiet_NaN(); + std::vector data0 = {1.2, 5.2, nan_val}; + auto l0 = mm->add_literal(migraphx::literal{s0, data0}); + auto l1 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", s1.lens()}}), l0); + mm->add_instruction(migraphx::make_op("isnan"), l1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector correct = {0, 0, 0, 0, 1, 1}; + EXPECT(migraphx::verify_range(results_vector, correct)); +} + +TEST_CASE(squeeze_transpose_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = + mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 1, 3, 1, 3}})); + auto l0_trans = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 3, 0, 4}}}), l0); + mm->add_instruction(migraphx::make_op("squeeze"), l0_trans); + auto p_uncompiled = p; + // contiguous is required to read the values in standard shaped order + auto* mm_uncompiled = p_uncompiled.get_main_module(); + mm_uncompiled->add_instruction(migraphx::make_op("contiguous"), + std::prev(mm_uncompiled->end())); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + auto expected_result = p_uncompiled.eval({}).back(); + EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 3}}); + EXPECT(result == expected_result); +} + +TEST_CASE(squeeze_multibroadcast_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = + mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 3, 1, 3}})); + auto l0_brcst = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {4, 1, 3, 4, 3}}}), l0); + mm->add_instruction(migraphx::make_op("squeeze"), l0_brcst); + auto p_uncompiled = p; + auto* mm_uncompiled = p_uncompiled.get_main_module(); + mm_uncompiled->add_instruction(migraphx::make_op("contiguous"), + std::prev(mm_uncompiled->end())); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + auto expected_result = p_uncompiled.eval({}).back(); + EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 3, 4, 3}}); + EXPECT(result == expected_result); +} + +TEST_CASE(squeeze_slice_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = + mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 3, 4, 3}})); + auto l0_slice = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {3}}}), l0); + mm->add_instruction(migraphx::make_op("squeeze"), l0_slice); + auto p_uncompiled = p; + auto* mm_uncompiled = p_uncompiled.get_main_module(); + mm_uncompiled->add_instruction(migraphx::make_op("contiguous"), + std::prev(mm_uncompiled->end())); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + auto expected_result = p_uncompiled.eval({}).back(); + EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 3}}); + EXPECT(result == expected_result); +} + +TEST_CASE(unsqueeze_transpose_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; + auto l0 = mm->add_literal(migraphx::generate_literal(s1)); + auto l0_trans = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), l0); + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l0_trans); + auto p_uncompiled = p; + auto* mm_uncompiled = p_uncompiled.get_main_module(); + mm_uncompiled->add_instruction(migraphx::make_op("contiguous"), + std::prev(mm_uncompiled->end())); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + auto expected_result = p_uncompiled.eval({}).back(); + EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 1, 3}}); + EXPECT(result == expected_result); +} + +TEST_CASE(unsqueeze_multibroadcast_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3}}; + auto l0 = mm->add_literal(migraphx::generate_literal(s1)); + auto l0_brcst = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 3, 3}}}), l0); + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l0_brcst); + auto p_uncompiled = p; + auto* mm_uncompiled = p_uncompiled.get_main_module(); + mm_uncompiled->add_instruction(migraphx::make_op("contiguous"), + std::prev(mm_uncompiled->end())); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + auto expected_result = p_uncompiled.eval({}).back(); + EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 4, 1, 3, 3}}); + EXPECT(result == expected_result); +} + +TEST_CASE(unsqueeze_slice_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4, 4}}; + auto l0 = mm->add_literal(migraphx::generate_literal(s1)); + auto l0_slice = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {2}}, {"ends", {3}}}), l0); + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l0_slice); + auto p_uncompiled = p; + auto* mm_uncompiled = p_uncompiled.get_main_module(); + mm_uncompiled->add_instruction(migraphx::make_op("contiguous"), + std::prev(mm_uncompiled->end())); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + auto expected_result = p_uncompiled.eval({}).back(); + EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {2, 1, 3, 4, 1}}); + EXPECT(result == expected_result); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/ref_ops_test.cpp b/test/ref_ops_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..547043ccb2361d56f63cc41d411ebcde633abc0f --- /dev/null +++ b/test/ref_ops_test.cpp @@ -0,0 +1,5517 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "test.hpp" +#include +#include + +float sigmoid(float x) { return 1 / (1 + expf(-x)); } + +float elu(float a, float x) { return x > 0 ? x : a * std::expm1(x); } + +TEST_CASE(abs_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 2}}; + auto l = mm->add_literal(migraphx::literal{s, {-1, 2, -3, 4}}); + mm->add_instruction(migraphx::make_op("abs"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1, 2, 3, 4}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(acos_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::double_type, {3}}; + std::vector data{-0.8f, 0.0f, 1.0f}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("acos"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acosf(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(acosh_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::double_type, {3}}; + std::vector data{1.1f, 1.2f, 2.0f}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("acosh"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acoshf(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(add_broadcast_test) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3}}; + std::vector a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + migraphx::shape b_shape{migraphx::shape::float_type, {2, 2}}; + std::vector b_data{0, -1, -2, -3}; + uint64_t axis = 0; + auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data}); + auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data}); + auto l3 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l1->get_shape().lens()}}), + l2); + mm->add_instruction(migraphx::make_op("add"), l1, l3); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + EXPECT(result.get_shape().packed()); + std::vector results_vector(12); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3}}; + std::vector a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + migraphx::shape b_shape{migraphx::shape::float_type, {2, 2, 1}}; + std::vector b_data{0, -1, -2, -3}; + auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data}); + auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data}); + auto l3 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3}}}), l1); + auto l4 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3}}}), l2); + mm->add_instruction(migraphx::make_op("add"), l3, l4); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + EXPECT(result.get_shape().packed()); + std::vector results_vector(12); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } +} + +TEST_CASE(add_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto l1 = mm->add_literal(migraphx::literal{s, {-1, 0, 1}}); + auto l2 = mm->add_literal(migraphx::literal{s, {1, 2, 3}}); + mm->add_instruction(migraphx::make_op("add"), l1, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0, 2, 4}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(argmax_test_0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, + -1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706, + 0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497}; + std::vector res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1}; + migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; + auto dl = mm->add_literal(migraphx::literal{data_shape, data}); + mm->add_instruction(migraphx::make_op("argmax", {{"axis", 0}}), dl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vec; + result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(result_vec, res_gold)); +} + +TEST_CASE(argmax_test_1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, + -1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706, + 0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497}; + std::vector res_gold = {0, 0, 2, 1, 2, 0, 0, 2}; + migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; + auto dl = mm->add_literal(migraphx::literal{data_shape, data}); + mm->add_instruction(migraphx::make_op("argmax", {{"axis", 1}}), dl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vec; + result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(result_vec, res_gold)); +} + +TEST_CASE(argmax_test_2) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, + -1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706, + 0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497}; + std::vector res_gold = {1, 3, 2, 2, 2, 3}; + migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; + auto dl = mm->add_literal(migraphx::literal{data_shape, data}); + mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), dl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vec; + result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(result_vec, res_gold)); +} + +TEST_CASE(argmax_test_neg_2) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, + -1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706, + 0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497}; + std::vector res_gold = {0, 0, 2, 1, 2, 0, 0, 2}; + migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; + auto dl = mm->add_literal(migraphx::literal{data_shape, data}); + mm->add_instruction(migraphx::make_op("argmax", {{"axis", -2}}), dl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vec; + result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(result_vec, res_gold)); +} + +TEST_CASE(argmin_test_0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, + -1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706, + 0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497}; + std::vector res_gold = {1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0}; + migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; + auto dl = mm->add_literal(migraphx::literal{data_shape, data}); + mm->add_instruction(migraphx::make_op("argmin", {{"axis", 0}}), dl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vec; + result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(result_vec, res_gold)); +} + +TEST_CASE(argmin_test_1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, + -1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706, + 0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497}; + std::vector res_gold = {2, 2, 0, 2, 0, 1, 2, 0}; + migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; + auto dl = mm->add_literal(migraphx::literal{data_shape, data}); + mm->add_instruction(migraphx::make_op("argmin", {{"axis", 1}}), dl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vec; + result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(result_vec, res_gold)); +} + +TEST_CASE(argmin_test_2) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, + -1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706, + 0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497}; + std::vector res_gold = {2, 1, 0, 3, 3, 2}; + migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; + auto dl = mm->add_literal(migraphx::literal{data_shape, data}); + mm->add_instruction(migraphx::make_op("argmin", {{"axis", 2}}), dl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vec; + result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(result_vec, res_gold)); +} + +TEST_CASE(argmin_test_neg_1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, + -1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706, + 0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497}; + std::vector res_gold = {2, 1, 0, 3, 3, 2}; + migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; + auto dl = mm->add_literal(migraphx::literal{data_shape, data}); + mm->add_instruction(migraphx::make_op("argmin", {{"axis", -1}}), dl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vec; + result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(result_vec, res_gold)); +} + +TEST_CASE(asin_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + std::vector data{-0.5f, 0.0f, 0.9f}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("asin"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return asinf(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(asinh_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + std::vector data{-0.5f, 0.0f, 0.9f}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("asinh"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return asinhf(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(atan_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::double_type, {3}}; + std::vector data{-1.0f, 0.0f, 1.0f}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("atan"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return atanf(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(atanh_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::double_type, {3}}; + std::vector data{0.4435683f, 0.6223626f, 0.316958f}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("atanh"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return atanhf(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(avgpool_test) +{ + // 1D case 1, input is 3D + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average}; + op.lengths = {2}; + op.padding = {0}; + op.stride = {1}; + + std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.25, 0.3, 0.25, 0.65, 0.7, 0.5, 0.4, 0.4, 0.35}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + // 1D case 2, stride 2 + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 4}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average}; + op.lengths = {2}; + op.padding = {1}; + op.stride = {2}; + + std::vector data{1.6321, + -2.4186, + 0.2239, + -1.4232, + 0.8158, + 0.4103, + -0.3149, + -0.1361, + -0.3442, + 2.007, + 0.4331, + 1.5295, + 0.9965, + 0.4766, + 1.0942, + -0.2915}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1.6321, + -1.0974, + -1.4232, + 0.8158, + 0.0477, + -0.1361, + -0.3442, + 1.22005, + 1.5295, + 0.9965, + 0.7854, + -0.2915}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + // 3D, input is 5D + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average}; + op.lengths = {2, 2, 2}; + op.padding = {0, 0, 0}; + op.stride = {1, 1, 1}; + + std::vector data{ + -0.179, -1.756, 0.651, 1.955, 1.87, -0.604, 0.247, 0.449, -0.137, 1.187, 1.593, + 0.424, 2.698, -0.104, -0.069, -1.293, 0.538, 1.291, 0.974, 1.096, 0.74, -0.669, + -1.08, -1.041, -1.407, 1.43, -0.211, -0.017, 0.532, 1.276, 0.627, 0.236, -0.396, + -0.204, 0.501, -0.599, -1.414, -0.615, -0.274, 0.168, -0.144, 0.5, 1.42, 1.082, + -0.952, -0.846, -1.244, 1.475, 1.246, 1.344, -1.722, -1.24, -0.851, 0.06, 0.507, + 0.762, -0.007, -1.484, 1.028, 0.317, 1.077, -1.289, 0.875, -0.417, -0.673, 1.715, + -0.307, 0.264, -0.973, 1.412, 2.561, -0.515, -0.201, 0.827, -1.231, 1.958, -0.552, + 0.036, -0.993, -0.859, -1.458, -0.575, 0.048, -0.779, -1.025, -1.135, 1.166, -0.131, + 0.726, 0.52, 0.467, -0.494, 0.675, 0.203, -0.63, -0.918, -0.5, -1.395, 1.39, + 1.705, 0.444, -0.835, -0.506, 0.101, 0.602, 0.543, 0.357, 1.042}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{ + 0.908, 0.250625, 0.795, 0.40425, 0.711875, 0.194875, 0.014125, 0.09425, + -0.078375, 0.139375, 0.46075, 0.0285, -0.188125, -0.085, 0.378125, -0.085375, + -0.04, 0.304125, 0.40775, 0.2835, 0.112375, -0.073375, 0.4355, -0.187, + -0.392625, -0.258375, -0.485875, -0.0345, 0.16125, -0.131875, -0.228375, 0.068625}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } +} + +TEST_CASE(batch_norm_1d_per_actv_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape x_shape{migraphx::shape::float_type, {2, 2, 4}}; + migraphx::shape c_shape(migraphx::shape::float_type, {2, 4}); + + std::vector x_data = {0.3547, + 0.477, + -1.8575, + 0.663, + -0.1881, + -0.5113, + -0.1803, + -0.5915, + -0.1552, + 0.9821, + 1.827, + 0.0558, + -0.0417, + -1.0693, + 1.9948, + -0.7448}; + std::vector scale_data = { + -0.3181, -0.3885, 1.655, 0.0704, -0.2565, -1.1761, -0.3751, 0.1057}; + std::vector bias_data = { + -1.2118, -2.1156, 0.0046, -0.1341, -0.2724, -1.0718, 0.5535, -0.889}; + std::vector mean_data = { + 0.0997, 0.7295, -0.0153, 0.3594, -0.1149, -0.7903, 0.9073, -0.6681}; + std::vector variance_data = { + 0.13, 0.1276, 6.7878, 0.1843, 0.0107, 0.1556, 2.3655, 0.0117}; + + auto x = mm->add_literal(migraphx::literal{x_shape, x_data}); + auto scale = mm->add_literal(migraphx::literal{c_shape, scale_data}); + auto bias = mm->add_literal(migraphx::literal{c_shape, bias_data}); + auto mean = mm->add_literal(migraphx::literal{c_shape, mean_data}); + auto variance = mm->add_literal(migraphx::literal{c_shape, variance_data}); + + mm->add_instruction( + migraphx::make_op( + "batch_norm_inference", + {{"epsilon", 1e-6}, + {"momentum", 0.9}, + {"bn_mode", migraphx::to_value(migraphx::op::batch_norm_inference::per_activation)}}), + x, + scale, + bias, + mean, + variance); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-1.43677, + -1.84098, + -1.16563, + -0.0843136, + -0.090896, + -1.90364, + 0.81875, + -0.81415, + -0.986915, + -2.39032, + 1.17489, + -0.183886, + -0.453904, + -0.239955, + 0.288275, + -0.963948}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(batch_norm_1d_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape x_shape{migraphx::shape::float_type, {2, 3, 4}}; + migraphx::shape c_shape(migraphx::shape::float_type, {3}); + + std::vector x_data = {0.7253, -0.6356, 0.4606, -0.8689, -1.1932, 0.4538, + -1.0018, -0.365, -0.214, -0.9553, -0.7672, 0.2331, + -0.8416, -0.6142, 0.0814, 0.2498, -0.6706, 1.4872, + 0.5112, -1.5212, -0.9126, 0.0735, 1.085, -0.3417}; + std::vector scale_data = {1.1, 1.2, 1.3}; + std::vector bias_data = {0.1, 0.2, 0.3}; + std::vector mean_data = {-0.1804, -0.2875, -0.2249}; + std::vector variance_data = {2.7914, 7.3424, 3.3287}; + + auto x = mm->add_literal(migraphx::literal{x_shape, x_data}); + auto scale = mm->add_literal(migraphx::literal{c_shape, scale_data}); + auto bias = mm->add_literal(migraphx::literal{c_shape, bias_data}); + auto mean = mm->add_literal(migraphx::literal{c_shape, mean_data}); + auto variance = mm->add_literal(migraphx::literal{c_shape, variance_data}); + + mm->add_instruction(migraphx::make_op("batch_norm_inference", {{"epsilon", 1e-5}}), + x, + scale, + bias, + mean, + variance); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0.696301, -0.199697, 0.522026, -0.353299, -0.201094, 0.528289, + -0.116332, 0.165679, 0.307767, -0.220435, -0.086407, 0.62634, + -0.335325, -0.185608, 0.272366, 0.383238, 0.0303421, 0.985936, + 0.553709, -0.346351, -0.190009, 0.51262, 1.23335, 0.216776}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(batch_norm_3d_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape x_shape{migraphx::shape::float_type, {2, 2, 2, 2, 2}}; + migraphx::shape c_shape(migraphx::shape::float_type, {2}); + + std::vector x_data = {-1.0833, 1.9681, 1.2075, -0.723, -0.4076, -0.8738, 0.5853, + -0.5357, 1.734, 0.7904, 0.6953, -0.468, -0.425, 0.6895, + 0.0096, 0.4205, -0.1749, 1.2821, 2.1453, -0.8538, 1.0687, + 0.0906, 0.0714, -1.3079, -0.6376, 1.3023, 0.945, 0.0927, + -0.7421, -1.4341, -1.0309, 1.5153}; + std::vector scale_data = {1.1, 1.3}; + std::vector bias_data = {0.1, 0.2}; + std::vector mean_data = {0.1537, 0.2161}; + std::vector variance_data = {18.0805, 13.3906}; + + auto x = mm->add_literal(migraphx::literal{x_shape, x_data}); + auto scale = mm->add_literal(migraphx::literal{c_shape, scale_data}); + auto bias = mm->add_literal(migraphx::literal{c_shape, bias_data}); + auto mean = mm->add_literal(migraphx::literal{c_shape, mean_data}); + auto variance = mm->add_literal(migraphx::literal{c_shape, variance_data}); + + mm->add_instruction(migraphx::make_op("batch_norm_inference"), x, scale, bias, mean, variance); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = { + -0.220005, 0.569376, 0.372612, -0.126798, -0.0452053, -0.165809, 0.211653, -0.0783441, + 0.739245, 0.404024, 0.370239, -0.0430317, -0.0277556, 0.368179, 0.126639, 0.272615, + 0.0149929, 0.391911, 0.615216, -0.160635, 0.336706, 0.0836764, 0.0787094, -0.278108, + -0.103283, 0.585881, 0.458947, 0.156161, -0.140408, -0.386246, -0.243006, 0.661551}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(batch_norm_inference_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + const size_t width = 2; + const size_t height = 2; + const size_t channels = 4; + const size_t batches = 2; + const float x_val = 8.0; + const float mean_val = 2.0; + const float variance_val = 4.0; + const float scale_val = 2.0f; + const float bias_val = 1.0f; + const float output_val = scale_val * (x_val - mean_val) / (std::sqrt(variance_val)) + bias_val; + + migraphx::shape s{migraphx::shape::float_type, {batches, channels, height, width}}; + migraphx::shape vars{migraphx::shape::float_type, {channels}}; + std::vector x_data(width * height * channels * batches); + std::vector scale_data(channels); + std::vector bias_data(channels); + std::vector mean_data(channels); + std::vector variance_data(channels); + + std::fill(x_data.begin(), x_data.end(), x_val); + std::fill(mean_data.begin(), mean_data.end(), mean_val); + std::fill(variance_data.begin(), variance_data.end(), variance_val); + std::fill(scale_data.begin(), scale_data.end(), scale_val); + std::fill(bias_data.begin(), bias_data.end(), bias_val); + + auto x = mm->add_literal(migraphx::literal{s, x_data}); + auto scale = mm->add_literal(migraphx::literal{vars, scale_data}); + auto bias = mm->add_literal(migraphx::literal{vars, bias_data}); + auto mean = mm->add_literal(migraphx::literal{vars, mean_data}); + auto variance = mm->add_literal(migraphx::literal{vars, variance_data}); + + mm->add_instruction(migraphx::make_op("batch_norm_inference"), x, scale, bias, mean, variance); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector result_vector(width * height * channels * batches); + std::vector gold(width * height * channels * batches); + std::fill(gold.begin(), gold.end(), output_val); + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(broadcast_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape a_shape{migraphx::shape::int32_type, {2, 2}}; + std::vector a_data{0, 0, 0, 0}; + migraphx::shape b_shape{migraphx::shape::int32_type, {2}}; + std::vector b_data{-2, -3}; + uint64_t axis = 0; + auto l1 = mm->add_literal(migraphx::literal{a_shape, a_data}); + auto l2 = mm->add_literal(migraphx::literal{b_shape, b_data}); + mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l1->get_shape().lens()}}), l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + auto output = result.get(); + EXPECT(output(0, 0) == -2); + EXPECT(output(0, 1) == -2); + EXPECT(output(1, 0) == -3); + EXPECT(output(1, 1) == -3); +} + +TEST_CASE(ceil_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {9}}; + std::vector data = {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("ceil"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::ceil(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(clip_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto l = mm->add_literal(migraphx::literal{s, {-1.0, 0.0, 10.0}}); + auto min_val = mm->add_literal(0.0f); + auto max_val = mm->add_literal(6.0f); + min_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), min_val); + max_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), max_val); + mm->add_instruction(migraphx::make_op("clip"), l, min_val, max_val); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0.0, 0.0, 6.0}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(concat_test) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + int axis = 1; + std::vector data0 = {0, 1, 5, 6}; + std::vector data1 = {2, 3, 4, 7, 8, 9}; + std::vector data2 = {10, 20}; + migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; + migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::int32_type, {2, 1}}; + auto l0 = mm->add_literal(migraphx::literal{s0, data0}); + auto l1 = mm->add_literal(migraphx::literal{s1, data1}); + auto l2 = mm->add_literal(migraphx::literal{s2, data2}); + mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector gold = {0, 1, 2, 3, 4, 10, 5, 6, 7, 8, 9, 20}; + std::vector results_vector(2 * 6); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, gold)); + EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector({2, 6}))); + EXPECT( + migraphx::verify_range(result.get_shape().strides(), std::vector({6, 1}))); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + int axis = -1; + std::vector data0 = {0, 1, 5, 6}; + std::vector data1 = {2, 3, 4, 7, 8, 9}; + std::vector data2 = {10, 20}; + migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; + migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::int32_type, {2, 1}}; + auto l0 = mm->add_literal(migraphx::literal{s0, data0}); + auto l1 = mm->add_literal(migraphx::literal{s1, data1}); + auto l2 = mm->add_literal(migraphx::literal{s2, data2}); + mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector gold = {0, 1, 2, 3, 4, 10, 5, 6, 7, 8, 9, 20}; + std::vector results_vector(2 * 6); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, gold)); + EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector({2, 6}))); + EXPECT( + migraphx::verify_range(result.get_shape().strides(), std::vector({6, 1}))); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + int axis = 0; + std::vector data0 = {0, 1, 2, 3}; + std::vector data1 = {4, 5, 6, 7, 8, 9}; + std::vector data2 = {10, 11}; + migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; + migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; + migraphx::shape s2{migraphx::shape::int32_type, {1, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s0, data0}); + auto l1 = mm->add_literal(migraphx::literal{s1, data1}); + auto l2 = mm->add_literal(migraphx::literal{s2, data2}); + mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + std::vector results_vector(6 * 2); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, gold)); + EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector({6, 2}))); + EXPECT( + migraphx::verify_range(result.get_shape().strides(), std::vector({2, 1}))); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + int axis = -2; + std::vector data0 = {0, 1, 2, 3}; + std::vector data1 = {4, 5, 6, 7, 8, 9}; + std::vector data2 = {10, 11}; + migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; + migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; + migraphx::shape s2{migraphx::shape::int32_type, {1, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s0, data0}); + auto l1 = mm->add_literal(migraphx::literal{s1, data1}); + auto l2 = mm->add_literal(migraphx::literal{s2, data2}); + mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + std::vector results_vector(6 * 2); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, gold)); + EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector({6, 2}))); + EXPECT( + migraphx::verify_range(result.get_shape().strides(), std::vector({2, 1}))); + } +} + +TEST_CASE(contiguous_test) +{ + migraphx::shape a_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}}; + std::vector data(12); + std::iota(data.begin(), data.end(), 0); + + migraphx::program p; + auto* mm = p.get_main_module(); + auto l = mm->add_literal(migraphx::literal{a_shape, data}); + mm->add_instruction(migraphx::make_op("contiguous"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector(12); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector new_lens = {1, 3, 2, 2}; + std::vector new_strides = {12, 1, 6, 3}; + EXPECT(migraphx::verify_range(results_vector, data)); +} + +TEST_CASE(conv2d_padding_stride_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector a = { + 2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712, + -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606, + 0.02345525, -0.33579525, 0.38901961, 1.05473483, -1.31188095, 1.8963089, -0.07265259, + 0.947339, 0.41949373, -0.70814759, 0.25892952, 1.07311416, 1.2571274, -0.62318051, + -0.19951548, -0.94232577, -0.29393643, 0.42292568, -0.80230367, 1.40909171, 0.63617158, + 0.13900366, 1.09253144, -0.15265895, 1.54781747, 0.72780299, 1.09189606, -0.38068101, + 0.97057933, -0.58958799, 1.56188643, 0.21474874, 0.58725154, -1.27097559, -0.03024297, + 1.09437096, -0.4897908, 0.34838957, -1.31042492, -1.69069934, 0.86956722, -0.40457946, + 0.46691212, 1.29273605, 0.26464137, 0.22073045, -1.02178168, 0.22163901, -1.84387338, + 0.75522131, -0.45775682, -0.42241111, -1.50944722, 1.07256448, -1.95876884, -0.28106022, + 0.3341668, 2.13129425, -1.14728117, -1.06555498, -0.298444, -0.88322699, -0.65866792, + -2.06007552, 0.01374334, 0.45612028, 0.52715492, 1.01914406, -1.72659791, 0.80650896, + 0.16860051, 2.24112225, -0.78620857, 0.36566174, -0.07020134, -0.47976932, -0.68230027, + -0.94711417, -0.54506505, 1.66504931, -0.71860826, 0.61132306}; + + std::vector c = { + -0.14601797, -0.13000923, 0.06521662, 0.06178288, -0.11083675, 0.10154136, 0.09990512, + 0.06030385, -0.11374587, -0.17523311, -0.14344215, 0.17802463, 0.06300922, -0.15325832, + 0.07066704, 0.05166031, 0.00615084, -0.02606523, 0.08083995, -0.17913306, 0.0624622, + 0.0735731, -0.04198661, -0.0164391, -0.06374192, 0.16569914, 0.10681538, 0.07370754, + 0.02802075, 0.00282027, 0.15104802, -0.11084409, -0.00197773, 0.07924436, 0.03528272, + 0.04765259, -0.15896152, 0.07917164, 0.12125669, -0.1154705, -0.11999125, 0.12749968, + -0.06269585, 0.18658121, -0.03944227, 0.0111798, -0.17731084, 0.11789055, -0.09982193, + 0.08142821, 0.0729029, 0.11303909, 0.12735154, 0.03885292}; + + std::vector s = {-0.20817225, + 0.87965256, + 0.14958936, + -1.24887264, + -0.06540672, + 0.20778663, + 0.40456355, + -0.99900877, + 0.4917807, + 0.1994698, + 0.64205718, + 0.37798831, + -0.25315839, + 0.44276932, + -0.16138598, + 0.79344082}; + + migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 4, 4}}; + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + + migraphx::shape c_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; + auto cl = mm->add_literal(migraphx::literal{c_shape, c}); + + mm->add_instruction( + migraphx::make_op("convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}), al, cl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector(16); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, s)); +} + +TEST_CASE(conv2d_padding_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector a = { + 2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712, + -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606, + 0.02345525, -0.33579525, 0.38901961, 1.05473483, -1.31188095, 1.8963089, -0.07265259, + 0.947339, 0.41949373, -0.70814759, 0.25892952, 1.07311416, 1.2571274, -0.62318051, + -0.19951548, -0.94232577, -0.29393643, 0.42292568, -0.80230367, 1.40909171, 0.63617158, + 0.13900366, 1.09253144, -0.15265895, 1.54781747, 0.72780299, 1.09189606, -0.38068101, + 0.97057933, -0.58958799, 1.56188643, 0.21474874, 0.58725154, -1.27097559, -0.03024297, + 1.09437096, -0.4897908, 0.34838957, -1.31042492, -1.69069934, 0.86956722, -0.40457946, + 0.46691212, 1.29273605, 0.26464137, 0.22073045, -1.02178168, 0.22163901, -1.84387338, + 0.75522131, -0.45775682, -0.42241111, -1.50944722, 1.07256448, -1.95876884, -0.28106022, + 0.3341668, 2.13129425, -1.14728117, -1.06555498, -0.298444, -0.88322699, -0.65866792, + -2.06007552, 0.01374334, 0.45612028, 0.52715492, 1.01914406, -1.72659791, 0.80650896, + 0.16860051, 2.24112225, -0.78620857, 0.36566174, -0.07020134, -0.47976932, -0.68230027, + -0.94711417, -0.54506505, 1.66504931, -0.71860826, 0.61132306}; + + std::vector c = { + -0.16115488, -0.09800646, -0.05412646, 0.10475694, 0.00555485, -0.12667653, 0.0458357, + -0.02656217, -0.16338061, 0.15037455, 0.0102711, 0.01303349, 0.05242859, 0.02034754, + 0.04751867, -0.17038961, -0.1434752, -0.10770349, 0.05676742, -0.15838449, 0.10128359, + -0.18958683, 0.11954515, 0.10758857, -0.01058291, -0.12797487, 0.08971019, 0.18793164, + -0.00881396, -0.06588994, -0.13321903, -0.03300409, 0.01439607, 0.07618178, -0.11556662, + 0.00764295, 0.12956454, -0.08937147, -0.12763587, 0.04674943, 0.05765297, 0.11336918, + 0.14747436, -0.06199479, -0.01166052, -0.12432006, -0.04494537, -0.17581205, 0.09475745, + 0.1149437, -0.1014564, 0.0274073, -0.01323579, -0.11092556}; + + std::vector s = { + -0.0201216, 0.40407312, -0.39005592, -0.0631946, 0.37963012, -0.64611685, 0.1349397, + -0.54113752, 0.28533003, 0.27667275, -0.16442731, -0.181494, 0.30564839, 0.58744538, + 0.32015014, 0.24969585, -0.27367792, -0.53308117, 0.41236052, 0.26136363, -0.01489828, + 0.57652152, -0.38506854, 0.119615, 0.0437076, 0.04779706, 0.57887721, 0.23126155, + 0.05695833, -0.68200272, 0.02063358, -0.10267162, 0.8062973, -0.38149622, -0.40134856, + -0.03353126, 0.38991132, -0.3478111, 0.03661491, 0.25783631, 0.62772679, -0.1961118, + 0.76423508, -0.36241418, -0.20994355, -0.12368261, -0.9406727, 0.02340185, -0.08793129, + -0.02471633, -0.58163726, -0.02211772, -0.42014724, 0.77525634, 0.504951, -0.20537445, + -0.20369984, -0.83037728, -1.40423918, -0.46160448, -0.22944322, 0.36074194, 0.49579027, + 0.46527559}; + + migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 4, 4}}; + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + + migraphx::shape c_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; + auto cl = mm->add_literal(migraphx::literal{c_shape, c}); + + mm->add_instruction( + migraphx::make_op("convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}), al, cl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector(64); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, s)); +} + +TEST_CASE(conv2d_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector a = { + 2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712, + -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606, + 0.02345525, -0.33579525, 0.38901961, 1.05473483, -1.31188095, 1.8963089, -0.07265259, + 0.947339, 0.41949373, -0.70814759, 0.25892952, 1.07311416, 1.2571274, -0.62318051, + -0.19951548, -0.94232577, -0.29393643, 0.42292568, -0.80230367, 1.40909171, 0.63617158, + 0.13900366, 1.09253144, -0.15265895, 1.54781747, 0.72780299, 1.09189606, -0.38068101, + 0.97057933, -0.58958799, 1.56188643, 0.21474874, 0.58725154, -1.27097559, -0.03024297, + 1.09437096, -0.4897908, 0.34838957, -1.31042492, -1.69069934, 0.86956722, -0.40457946, + 0.46691212, 1.29273605, 0.26464137, 0.22073045, -1.02178168, 0.22163901, -1.84387338, + 0.75522131, -0.45775682, -0.42241111, -1.50944722, 1.07256448, -1.95876884, -0.28106022, + 0.3341668, 2.13129425, -1.14728117, -1.06555498, -0.298444, -0.88322699, -0.65866792, + -2.06007552, 0.01374334, 0.45612028, 0.52715492, 1.01914406, -1.72659791, 0.80650896, + 0.16860051, 2.24112225, -0.78620857, 0.36566174, -0.07020134, -0.47976932, -0.68230027, + -0.94711417, -0.54506505, 1.66504931, -0.71860826, 0.61132306}; + + std::vector c = { + 2.82721668e-02, 6.44195229e-02, 1.53499246e-02, 1.72468081e-01, -6.33238107e-02, + 9.49496776e-02, 1.40258059e-01, -7.92879611e-02, -1.29301161e-01, 3.11307609e-03, + -1.90624535e-01, 1.13238767e-01, -2.80647576e-02, 3.12882811e-02, -3.52091640e-02, + 3.33581865e-02, 6.43158704e-02, 7.40238279e-02, -1.00106120e-01, -9.56912562e-02, + 1.44342467e-01, 9.40258950e-02, 6.36333972e-02, 1.66158378e-03, -8.91554281e-02, + 2.58734226e-02, 1.70919895e-02, 1.78214177e-01, 8.84564668e-02, 8.98126513e-02, + -1.63809001e-01, 1.37802169e-01, 1.66439757e-01, -1.45631135e-02, 1.88469887e-04, + 4.76950556e-02, -1.91969007e-01, -1.76233292e-01, -7.70473927e-02, 1.14828631e-01, + 1.76608220e-01, -1.50728196e-01, 1.99946314e-02, -5.88052124e-02, 1.31612435e-01, + 1.61106288e-02, -1.35080189e-01, 1.49512306e-01, 3.86456847e-02, 1.29330024e-01, + -3.22975963e-02, -5.60784787e-02, -5.41997552e-02, 4.78562862e-02}; + + std::vector s = {0.27039781, + 0.19105849, + -0.06339942, + -0.65087199, + 0.40867025, + 0.05063812, + -0.14907975, + 0.49018705, + -0.49197209, + 0.33236548, + -0.39374301, + 0.16012701, + 0.06574871, + 0.71606487, + -0.55201721, + -0.46427044}; + migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 4, 4}}; + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + + migraphx::shape c_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; + auto cl = mm->add_literal(migraphx::literal{c_shape, c}); + + mm->add_instruction(migraphx::make_op("convolution"), al, cl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector(16); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, s)); +} + +TEST_CASE(conv3d_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector a = { + 2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712, + -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606, + 0.02345525, -0.33579525, 0.38901961, 1.05473483, -1.31188095, 1.8963089, -0.07265259, + 0.947339, 0.41949373, -0.70814759, 0.25892952, 1.07311416, 1.2571274, -0.62318051, + -0.19951548, -0.94232577, -0.29393643, 0.42292568, -0.80230367, 1.40909171, 0.63617158, + 0.13900366, 1.09253144, -0.15265895, 1.54781747, 0.72780299, 1.09189606, -0.38068101, + 0.97057933, -0.58958799, 1.56188643, 0.21474874, 0.58725154, -1.27097559, -0.03024297, + 1.09437096, -0.4897908, 0.34838957, -1.31042492, -1.69069934, 0.86956722, -0.40457946, + 0.46691212, 1.29273605, 0.26464137, 0.22073045, -1.02178168, 0.22163901, -1.84387338, + 0.75522131, -0.45775682, -0.42241111, -1.50944722, 1.07256448, -1.95876884, -0.28106022, + 0.3341668, 2.13129425, -1.14728117, -1.06555498, -0.298444, -0.88322699, -0.65866792, + -2.06007552, 0.01374334, 0.45612028, 0.52715492, 1.01914406, -1.72659791, 0.80650896, + 0.16860051, 2.24112225, -0.78620857, 0.36566174, -0.07020134, -0.47976932, -0.68230027, + -0.94711417, -0.54506505, 1.66504931, -0.71860826, 0.61132306}; + + std::vector c = { + 2.82721668e-02, 6.44195229e-02, 1.53499246e-02, 1.72468081e-01, -6.33238107e-02, + 9.49496776e-02, 1.40258059e-01, -7.92879611e-02, -1.29301161e-01, 3.11307609e-03, + -1.90624535e-01, 1.13238767e-01, -2.80647576e-02, 3.12882811e-02, -3.52091640e-02, + 3.33581865e-02, 6.43158704e-02, 7.40238279e-02, -1.00106120e-01, -9.56912562e-02, + 1.44342467e-01, 9.40258950e-02, 6.36333972e-02, 1.66158378e-03, -8.91554281e-02, + 2.58734226e-02, 1.70919895e-02, 1.78214177e-01, 8.84564668e-02, 8.98126513e-02, + -1.63809001e-01, 1.37802169e-01, 1.66439757e-01, -1.45631135e-02, 1.88469887e-04, + 4.76950556e-02, -1.91969007e-01, -1.76233292e-01, -7.70473927e-02, 1.14828631e-01, + 1.76608220e-01, -1.50728196e-01, 1.99946314e-02, -5.88052124e-02, 1.31612435e-01, + 1.61106288e-02, -1.35080189e-01, 1.49512306e-01, 3.86456847e-02, 1.29330024e-01, + -3.22975963e-02, -5.60784787e-02, -5.41997552e-02, 4.78562862e-02}; + + std::vector s = {0.27039781, + 0.19105849, + -0.06339942, + -0.65087199, + 0.40867025, + 0.05063812, + -0.14907975, + 0.49018705, + -0.49197209, + 0.33236548, + -0.39374301, + 0.16012701, + 0.06574871, + 0.71606487, + -0.55201721, + -0.46427044}; + migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 4, 4, 1}}; + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + + migraphx::shape c_shape{migraphx::shape::float_type, {2, 3, 3, 3, 1}}; + auto cl = mm->add_literal(migraphx::literal{c_shape, c}); + + mm->add_instruction( + migraphx::make_op("convolution", + {{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}), + al, + cl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector(16); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, s)); +} + +TEST_CASE(cos_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + std::vector data{-1, 0, 1}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("cos"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return cosf(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(cosh_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 2}}; + std::vector data = {-1.0, 2.0, -3.0, 4.0}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("cosh"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return coshf(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(deconv_1d_test) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 1, 3}}; + std::vector x_data{0, 0.5, 1}; + std::vector w_data{0.5, 0.5, 0.5}; + + std::vector gold{0, 0.25, 0.75, 0.75, 0.5}; + + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal(migraphx::literal{s, x_data}); + auto w = mm->add_literal(migraphx::literal{s, w_data}); + + mm->add_instruction( + migraphx::make_op("deconvolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}), + x, + w); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(deconv_3d_test) +{ + migraphx::shape s_1{migraphx::shape::float_type, {1, 1, 1, 2, 3}}; + migraphx::shape s_2{migraphx::shape::float_type, {1, 1, 3, 2, 3}}; + std::vector x_data{0.8471, -0.4195, -2.2749, 1.2491, 0.1722, 0.3246}; + std::vector w_data{0.6478, + -0.1985, + 0.0633, + -0.3479, + 2.7056, + -0.1440, + -1.1229, + -0.7507, + -1.3151, + 0.8884, + -0.1859, + -0.3407, + -1.1544, + -1.5893, + 1.6265, + -1.4624, + 0.3812, + -1.5378}; + + std::vector gold{0.5488, -0.4399, -1.3369, 0.4251, -0.1439, 0.5145, 2.3015, -0.2104, + -6.1482, 0.3482, -0.4346, 3.3197, 0.1731, 0.8533, -0.0467, -0.9512, + -0.1649, 1.7553, 2.2594, 2.9917, -0.6500, -1.6612, -4.3680, 0.0957, + 0.3482, 1.1097, -0.0792, -0.1692, -0.1190, -0.1106, -0.9779, -0.8621, + 4.6707, 2.9332, -3.7001, -2.6808, -1.2476, 3.2475, -0.4578, 4.0263, + -1.8267, 0.2243, -2.3299, -0.1411, -0.4991}; + + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal(migraphx::literal{s_1, x_data}); + auto w = mm->add_literal(migraphx::literal{s_2, w_data}); + + mm->add_instruction( + migraphx::make_op("deconvolution", + {{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}), + x, + w); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(deconv_test) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 3}}; + std::vector x_data{0, 1, 2, 3, 4, 5, 6, 7, 8}; + std::vector w_data{1, 1, 1, 1, 1, 1, 1, 1, 1}; + + std::vector gold{0, 1, 3, 3, 2, 3, 8, 15, 12, 7, 9, 21, 36, + 27, 15, 9, 20, 33, 24, 13, 6, 13, 21, 15, 8}; + + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal(migraphx::literal{s, x_data}); + auto w = mm->add_literal(migraphx::literal{s, w_data}); + + mm->add_instruction(migraphx::make_op("deconvolution"), x, w); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(dequantizelinear) +{ + { /*uint8*/ + migraphx::shape xs{migraphx::shape::uint8_type, {1, 3, 3}}; + std::vector xv = {0, 1, 2, 5, 10, 50, 100, 150, 250}; + migraphx::shape ss{migraphx::shape::float_type, {1, 3, 3}}; + std::vector sv = {2, 2, 2, 2, 2, 2, 2, 2, 2}; + migraphx::shape zs{migraphx::shape::uint8_type, {1, 3, 3}}; + std::vector zv = {0, 0, 0, 0, 0, 0, 0, 0, 0}; + auto create_program = [&]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal(xs, xv); + auto s = mm->add_literal(ss, sv); + auto z = mm->add_literal(zs, zv); + mm->add_instruction(migraphx::make_op("dequantizelinear"), x, s, z); + return p; + }; + + migraphx::program p1 = create_program(); + p1.compile(migraphx::ref::target{}); + auto result = p1.eval({}).back(); + std::vector results_vector(9); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0, 2, 4, 10, 20, 100, 200, 300, 500}; + EXPECT(results_vector == gold); + } + + { /*int8*/ + migraphx::shape xs{migraphx::shape::int8_type, {1, 3, 3}}; + std::vector xv = {-128, -100, -50, -1, 0, 1, 50, 100, 127}; + migraphx::shape ss{migraphx::shape::float_type, {1, 3, 3}}; + std::vector sv = {2, 2, 2, 2, 2, 2, 2, 2, 2}; + auto create_program = [&]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal(xs, xv); + auto s = mm->add_literal(ss, sv); + mm->add_instruction(migraphx::make_op("dequantizelinear"), x, s); + return p; + }; + + migraphx::program p1 = create_program(); + p1.compile(migraphx::ref::target{}); + auto result = p1.eval({}).back(); + std::vector results_vector(9); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{-256, -200, -100, -2, 0, 2, 100, 200, 254}; + EXPECT(results_vector == gold); + } +} + +TEST_CASE(div_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + std::vector data1 = {-1.0f, 0.5f, 1.0f}; + std::vector data2 = {1.0f, 2.0f, 4.0f}; + auto l1 = mm->add_literal(migraphx::literal{s, data1}); + auto l2 = mm->add_literal(migraphx::literal{s, data2}); + mm->add_instruction(migraphx::make_op("div"), l1, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold(data1.size()); + std::transform(data1.begin(), data1.end(), data2.begin(), gold.begin(), std::divides()); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(elu_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 2}}; + auto l = mm->add_literal(migraphx::literal{s, {-1.0, 2.0, -3.0, 4.0}}); + float alpha = 0.5; + mm->add_instruction(migraphx::make_op("elu", {{"alpha", alpha}}), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{elu(alpha, -1), elu(alpha, 2), elu(alpha, -3), elu(alpha, 4)}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(equal_brcst_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::float_type, {3, 3}}; + auto l0 = + mm->add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}}); + migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; + auto l1 = mm->add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}}); + auto bl1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), l1); + auto eq = mm->add_instruction(migraphx::make_op("equal"), l0, bl1); + auto r = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + eq); + mm->add_return({r}); + + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {true, false, false, false, true, false, true, false, false}; + EXPECT(results_vector == gold); +} + +TEST_CASE(equal_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {9}}; + auto l0 = + mm->add_literal(migraphx::literal{s, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}}); + auto l1 = + mm->add_literal(migraphx::literal{s, {1.1, 1.6, -0.1, -1.2, -1.5, -0.7, 0.0, 2.3, -2.1}}); + auto eq = mm->add_instruction(migraphx::make_op("equal"), l0, l1); + auto r = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + eq); + mm->add_return({r}); + + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {true, false, false, false, true, false, true, false, false}; + EXPECT(results_vector == gold); +} + +TEST_CASE(erf_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {4}}; + std::vector data = {0.73785057, 1.58165966, -0.43597795, -0.01677432}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("erf"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return erff(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(exp_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data{-1, 0, 1}; + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("exp"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return expf(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(floor_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {9}}; + std::vector data = {1.1, 1.5, 0.6, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("floor"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return floor(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(fp16_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::half_type, {1}}; + migraphx::half a{1.5}; + migraphx::half b{2.5}; + migraphx::half c{4.0}; + auto l0 = mm->add_literal(migraphx::literal{s, {a}}); + auto l1 = mm->add_literal(migraphx::literal{s, {b}}); + mm->add_instruction(migraphx::make_op("add"), l0, l1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(1); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{c}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(fp32_fp16_test) +{ + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + std::vector data(2 * 3); + std::iota(data.begin(), data.end(), 1.0f); + auto l1 = mm->add_literal(migraphx::literal(s, data)); + auto l2 = mm->add_literal(migraphx::literal(s, data)); + mm->add_instruction(migraphx::make_op("add"), l1, l2); + return p; + }; + + auto test_case = [&](std::vector&& op_names) { + std::vector gold_res = {2.0, 4.0, 6.0, 8.0, 10.0, 12.0}; + auto p = create_program(); + migraphx::quantize_fp16(p, op_names); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res; + result.visit([&](auto output) { res.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(res, gold_res)); + }; + + test_case({"all"}); + test_case({"add"}); +} + +TEST_CASE(gather_non_std_test) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector data = {0.5f, 3.5f, 6.5f, 1.5f, 4.5f, 7.5f, 2.5f, 2.5f, 8.5f}; + migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + auto d = mm->add_literal(migraphx::literal{s, data}); + migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}}; + std::vector indices{-3, -3, -1, -1}; + auto ind = mm->add_literal(migraphx::literal{s_indices, indices}); + auto td = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d); + auto tind = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), ind); + + mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), td, tind); + auto result = p.eval({}).back(); + std::vector golden = { + 0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f, 0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f}; + std::vector res_data; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(res_data, golden)); + } +} + +TEST_CASE(gather_test) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector data(3 * 3); + std::iota(data.begin(), data.end(), 0.5); + migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + auto a0 = mm->add_literal(migraphx::literal{s, data}); + migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; + std::vector indices{0, 2}; + auto a1 = mm->add_literal(migraphx::literal{s_indices, indices}); + int axis = 0; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data(4 * 5); + std::vector golden = {0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(res_data, golden)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector data(3 * 3); + std::iota(data.begin(), data.end(), 0.5); + migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + auto a0 = mm->add_literal(migraphx::literal{s, data}); + migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; + std::vector indices{-3, -1}; + auto a1 = mm->add_literal(migraphx::literal{s_indices, indices}); + int axis = 0; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data(4 * 5); + std::vector golden = {0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(res_data, golden)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector data(3 * 3); + std::iota(data.begin(), data.end(), 0.5); + migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + auto a0 = mm->add_literal(migraphx::literal{s, data}); + migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; + std::vector indices{0, 2}; + auto a1 = mm->add_literal(migraphx::literal{s_indices, indices}); + int axis = 1; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data(4 * 5); + std::vector golden = {0.5f, 2.5f, 3.5f, 5.5f, 6.5f, 8.5f}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(res_data, golden)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector data(3 * 3); + std::iota(data.begin(), data.end(), 0.5); + migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + auto a0 = mm->add_literal(migraphx::literal{s, data}); + migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; + std::vector indices{0, 2}; + auto a1 = mm->add_literal(migraphx::literal{s_indices, indices}); + int axis = -1; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data(4 * 5); + std::vector golden = {0.5f, 2.5f, 3.5f, 5.5f, 6.5f, 8.5f}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(res_data, golden)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector data(3 * 3); + std::iota(data.begin(), data.end(), 0.5); + migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + auto a0 = mm->add_literal(migraphx::literal{s, data}); + // scalar index + migraphx::shape s_indices{migraphx::shape::int32_type}; + std::vector indices{0}; + auto a1 = mm->add_literal(migraphx::literal{s_indices, indices}); + int axis = -1; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data{}; + std::vector golden = {0.5f, 3.5f, 6.5f}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(res_data, golden)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector data(3 * 3); + std::iota(data.begin(), data.end(), 0.5); + migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + auto a0 = mm->add_literal(migraphx::literal{s, data}); + // scalar index + migraphx::shape s_indices{migraphx::shape::int32_type}; + std::vector indices{-3}; + auto a1 = mm->add_literal(migraphx::literal{s_indices, indices}); + int axis = -1; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data{}; + std::vector golden = {0.5f, 3.5f, 6.5f}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(res_data, golden)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + std::vector data(3); + std::iota(data.begin(), data.end(), 0.5); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto a0 = mm->add_literal(migraphx::literal{s, data}); + // scalar index + migraphx::shape s_indices{migraphx::shape::int32_type}; + std::vector indices{0}; + auto a1 = mm->add_literal(migraphx::literal{s_indices, indices}); + int axis = -1; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data{}; + std::vector golden = {0.5f}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(res_data, golden)); + } +} + +TEST_CASE(gathernd_test) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 2}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 2}}; + + std::vector data_vec(2 * 2); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec{0, 0, 1, 1}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + mm->add_instruction(migraphx::make_op("gathernd"), data, indices); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data{}; + std::vector gold{0, 3}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(res_data, gold)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 2}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 1}}; + + std::vector data_vec(2 * 2); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec{1, 0}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + mm->add_instruction(migraphx::make_op("gathernd"), data, indices); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data{}; + std::vector gold{2, 3, 0, 1}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(res_data, gold)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 2, 1}}; + + std::vector data_vec(2 * 3 * 1); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec{1, 0, 0, 1}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + mm->add_instruction(migraphx::make_op("gathernd"), data, indices); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data{}; + std::vector gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(res_data, gold)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 3, 2, 3}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 2, 2}}; + + std::vector data_vec(2 * 3 * 2 * 3); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec{0, 0, 0, 1, 0, 0, 0, 1}; + const int batch_dims = 1; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + mm->add_instruction( + migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), data, indices); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data{}; + std::vector gold{0, 1, 2, 3, 4, 5, 18, 19, 20, 21, 22, 23}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(res_data, gold)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1, 3}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 3, 2}}; + + std::vector data_vec(2 * 3 * 1 * 3); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec{0, 0, 0, 1, 0, 2, 0, 2, 0, 1, 0, 0}; + const int batch_dims = 2; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + mm->add_instruction( + migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), data, indices); + + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data{}; + std::vector gold{0, 4, 8, 11, 13, 15}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(res_data, gold)); + } + + { + // k > r - batch_dims + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1, 3}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 3, 3}}; + + std::vector data_vec(2 * 3 * 1 * 3); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec(2 * 3 * 3, 0); + const int batch_dims = 2; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + EXPECT(test::throws([&] { + mm->add_instruction( + migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), data, indices); + })); + } +} + +TEST_CASE(gathernd_negative_index_test) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 2}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 1, 1}}; + + std::vector data_vec(2 * 2); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec{-1, 0}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + mm->add_instruction(migraphx::make_op("gathernd"), data, indices); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector res_data{}; + std::vector gold{2, 3, 0, 1}; + result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); + + EXPECT(migraphx::verify_range(res_data, gold)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape ds{migraphx::shape::float_type, {2, 2}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 1, 1}}; + + std::vector data_vec(2 * 2); + std::iota(data_vec.begin(), data_vec.end(), 0); + std::vector indices_vec{-3, 0}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, indices_vec}); + + mm->add_instruction(migraphx::make_op("gathernd"), data, indices); + p.compile(migraphx::ref::target{}); + + EXPECT(test::throws([&] { p.eval({}); })); + } +} + +TEST_CASE(globalavgpool_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average}; + auto lens = s.lens(); + op.lengths = {lens[2], lens[3]}; + + std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.25, 0.575, 0.375}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(globallppool_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::lpnorm}; + auto lens = s.lens(); + op.lengths = {lens[2], lens[3]}; + op.lp_order = 2; + + std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.5477225575051662, 1.307669683062202, 0.9327379053088815}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(globalmaxpool_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max}; + auto lens = s.lens(); + op.lengths = {lens[2], lens[3]}; + + std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.4, 0.9, 0.7}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(greater_brcst_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::float_type, {3, 3}}; + auto l0 = + mm->add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}}); + migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; + auto l1 = mm->add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}}); + auto bl1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), l1); + auto gr = mm->add_instruction(migraphx::make_op("greater"), l0, bl1); + auto r = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + gr); + mm->add_return({r}); + + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {false, true, false, true, false, true, false, true, false}; + EXPECT(results_vector == gold); +} + +TEST_CASE(greater_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {9}}; + auto l0 = + mm->add_literal(migraphx::literal{s, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}}); + auto l1 = + mm->add_literal(migraphx::literal{s, {1.1, 1.6, -0.1, -1.2, -1.5, -0.7, 0.0, 2.3, -2.1}}); + auto gr = mm->add_instruction(migraphx::make_op("greater"), l0, l1); + auto r = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + gr); + mm->add_return({r}); + + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {false, false, true, true, false, true, false, false, true}; + EXPECT(results_vector == gold); +} + +TEST_CASE(identity_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 2}}; + std::vector data{1, 2, 3, 4}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("identity"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(std::equal(data.begin(), data.end(), results_vector.begin())); +} + +TEST_CASE(if_literal_test) +{ + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + auto cond = mm->add_parameter("cond", cond_s); + + migraphx::shape s{migraphx::shape::float_type, {5}}; + + auto* then_mod = p.create_module("If_0_if"); + std::vector data1 = {1, 2, 3, 4, 5}; + auto l1 = then_mod->add_literal(migraphx::literal(s, data1)); + then_mod->add_return({l1}); + + auto* else_mod = p.create_module("If_0_else"); + std::vector data2 = {5, 4, 3, 2, 1}; + auto l2 = else_mod->add_literal(migraphx::literal(s, data2)); + else_mod->add_return({l2}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + mm->add_return({r}); + + return p; + }; + + auto run_prog = [&](bool cond) { + auto p = create_program(); + p.compile(migraphx::ref::target()); + std::vector c_data = {static_cast(cond)}; + migraphx::shape cs{migraphx::shape::bool_type}; + migraphx::parameter_map m; + m["cond"] = migraphx::argument(cs, c_data.data()); + + auto res = p.eval(m).back(); + std::vector ret; + res.visit([&](auto v) { ret.assign(v.begin(), v.end()); }); + + return ret; + }; + + // then branch + { + std::vector gold_ret = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + auto ret = run_prog(true); + EXPECT(gold_ret == ret); + } + + // else branch + { + std::vector gold_ret = {5.0f, 4.0f, 3.0f, 2.0f, 1.0f}; + auto ret = run_prog(false); + EXPECT(gold_ret == ret); + } +} + +TEST_CASE(if_param_test) +{ + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + auto cond = mm->add_parameter("cond", cond_s); + migraphx::shape ds{migraphx::shape::float_type, {2, 3}}; + auto x = mm->add_parameter("x", ds); + auto y = mm->add_parameter("y", ds); + std::vector data2 = {-0.258047, 0.360394, 0.536804, -0.577762, 1.0217, 1.02442}; + auto l2 = mm->add_literal(migraphx::literal(ds, data2)); + auto sum = mm->add_instruction(migraphx::make_op("add"), x, l2); + + auto* then_mod = p.create_module("If_0_if"); + std::vector data1 = {0.384804, -1.77948, -0.453775, 0.477438, -1.06333, -1.12893}; + auto l1 = then_mod->add_literal(migraphx::literal(ds, data1)); + auto tx = then_mod->add_parameter("x", ds); + auto a1 = then_mod->add_instruction(migraphx::make_op("add"), tx, l1); + then_mod->add_return({a1}); + + auto* else_mod = p.create_module("If_0_else"); + auto ey = else_mod->add_parameter("y", ds); + auto a2 = else_mod->add_instruction(migraphx::make_op("mul"), ey, sum); + else_mod->add_return({a2}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond, x, y}, {then_mod, else_mod}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + mm->add_return({r}); + + return p; + }; + + auto run_prog = [&](bool cond) { + auto p = create_program(); + p.compile(migraphx::ref::target()); + std::vector c_data = {static_cast(cond)}; + migraphx::shape cs{migraphx::shape::bool_type}; + migraphx::parameter_map m; + m["cond"] = migraphx::argument(cs, c_data.data()); + migraphx::shape ds{migraphx::shape::float_type, {2, 3}}; + std::vector data_x(ds.elements(), 1); + m["x"] = migraphx::argument(ds, data_x.data()); + std::vector data_y(ds.elements(), 2); + m["y"] = migraphx::argument(ds, data_y.data()); + + auto res = p.eval(m).back(); + std::vector ret; + res.visit([&](auto v) { ret.assign(v.begin(), v.end()); }); + return ret; + }; + + // then branch + { + std::vector gold_ret = { + 1.384804, -0.77947998, 0.54622501, 1.477438, -0.063330054, -0.12892997}; + auto ret = run_prog(true); + EXPECT(gold_ret == ret); + } + + // else branch + { + std::vector gold_ret = { + 1.483906, 2.720788, 3.0736079, 0.84447598, 4.0433998, 4.04884}; + auto ret = run_prog(false); + EXPECT(gold_ret == ret); + } +} + +TEST_CASE(if_pl_test) +{ + auto create_program = [] { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + migraphx::shape s{migraphx::shape::float_type, {5}}; + auto cond = mm->add_parameter("cond", cond_s); + auto x = mm->add_parameter("x", s); + + auto* then_mod = p.create_module("If_0_if"); + std::vector data1 = {1, 2, 3, 4, 5}; + auto l1 = then_mod->add_literal(migraphx::literal(s, data1)); + then_mod->add_return({l1, x}); + + auto* else_mod = p.create_module("If_0_else"); + std::vector data2 = {5, 4, 3, 2, 1}; + auto l2 = else_mod->add_literal(migraphx::literal(s, data2)); + auto s2 = else_mod->add_instruction(migraphx::make_op("add"), x, l2); + else_mod->add_return({s2, l2}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto outline = mm->add_outline(s); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + mm->add_return({outline, r}); + + return p; + }; + + auto run_prog = [&](bool cond) { + auto p = create_program(); + p.compile(migraphx::ref::target()); + std::vector c_data = {static_cast(cond)}; + migraphx::shape cs{migraphx::shape::bool_type}; + migraphx::parameter_map m; + m["cond"] = migraphx::argument(cs, c_data.data()); + migraphx::shape ds{migraphx::shape::float_type, {5}}; + std::vector data(ds.elements(), 1); + m["x"] = migraphx::argument(ds, data.data()); + + auto res = p.eval(m).back(); + std::vector ret; + res.visit([&](auto v) { ret.assign(v.begin(), v.end()); }); + + return ret; + }; + + // then branch + { + std::vector gold_ret = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + auto ret = run_prog(true); + EXPECT(gold_ret == ret); + } + + // else branch + { + std::vector gold_ret = {6.0f, 5.0f, 4.0f, 3.0f, 2.0f}; + auto ret = run_prog(false); + EXPECT(gold_ret == ret); + } +} + +TEST_CASE(isnan_test) +{ + // float test + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + auto nan_val = std::numeric_limits::quiet_NaN(); + std::vector data0 = {1.2, 5.2, nan_val, nan_val, 0., 100.}; + auto l1 = mm->add_literal(migraphx::literal{s, data0}); + mm->add_instruction(migraphx::make_op("isnan"), l1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector correct = {0, 0, 1, 1, 0, 0}; + EXPECT(migraphx::verify_range(results_vector, correct)); + } + + // half test + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::half_type, {2, 3}}; + auto nan_val = std::numeric_limits::quiet_NaN(); + migraphx::half a{1.2}; + migraphx::half b{5.2}; + std::vector data0 = {a, b, nan_val, nan_val, b, a}; + auto l1 = mm->add_literal(migraphx::literal{s, data0}); + mm->add_instruction(migraphx::make_op("isnan"), l1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector correct = {0, 0, 1, 1, 0, 0}; + EXPECT(migraphx::verify_range(results_vector, correct)); + } +} + +TEST_CASE(im2col_3x3_no_pad_identity_test) +{ + std::size_t f[2] = {3, 3}; + std::size_t size[2] = {3, 3}; + std::vector padding{0, 0}; + std::vector stride{1, 1}; + std::vector dilation{1, 1}; + std::size_t channels = 1; + + std::vector weights(channels * f[0] * f[1]); + std::vector input(channels * size[0] * size[1]); + std::iota(input.begin(), input.end(), 0); + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}}; + migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; + auto l_image = mm->add_literal(migraphx::literal{s_image, input}); + auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights}); + mm->add_instruction( + migraphx::make_op("im2col", + {{"padding", padding}, {"stride", stride}, {"dilation", dilation}}), + l_image, + l_weights); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1; + std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; + std::vector results_vector(channels * f[0] * f[1] * col_height * col_width); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, input)); +} + +TEST_CASE(im2col_3x3_no_pad_test) +{ + std::size_t f[2] = {3, 3}; + std::size_t size[2] = {4, 4}; + std::vector padding{0, 0}; + std::vector stride{1, 1}; + std::vector dilation{1, 1}; + std::size_t channels = 1; + + std::vector weights(channels * f[0] * f[1]); + std::vector input(channels * size[0] * size[1]); + std::iota(input.begin(), input.end(), 0); + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}}; + migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; + auto l_image = mm->add_literal(migraphx::literal{s_image, input}); + auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights}); + mm->add_instruction( + migraphx::make_op("im2col", + {{"padding", padding}, {"stride", stride}, {"dilation", dilation}}), + l_image, + l_weights); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector correct = {0, 1, 2, 4, 5, 6, 8, 9, 10, 1, 2, 3, 5, 6, 7, 9, 10, 11, + 4, 5, 6, 8, 9, 10, 12, 13, 14, 5, 6, 7, 9, 10, 11, 13, 14, 15}; + + std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1; + std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; + std::vector results_vector(channels * f[0] * f[1] * col_height * col_width); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, correct)); +} + +TEST_CASE(im2col_3x3_stride_2_no_pad_test) +{ + std::size_t f[2] = {3, 3}; + std::size_t size[2] = {6, 6}; + std::vector padding{0, 0}; + std::vector stride{2, 2}; + std::vector dilation{1, 1}; + std::size_t channels = 1; + + std::vector weights(channels * f[0] * f[1]); + std::vector input(channels * size[0] * size[1]); + std::iota(input.begin(), input.end(), 0); + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}}; + migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; + auto l_image = mm->add_literal(migraphx::literal{s_image, input}); + auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights}); + mm->add_instruction( + migraphx::make_op("im2col", + {{"padding", padding}, {"stride", stride}, {"dilation", dilation}}), + l_image, + l_weights); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector correct = {0, 1, 2, 6, 7, 8, 12, 13, 14, 2, 3, 4, + 8, 9, 10, 14, 15, 16, 12, 13, 14, 18, 19, 20, + 24, 25, 26, 14, 15, 16, 20, 21, 22, 26, 27, 28}; + + std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1; + std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; + std::vector results_vector(channels * f[0] * f[1] * col_height * col_width); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, correct)); +} + +TEST_CASE(im2col_3x3_with_channels_identity_test) +{ + std::size_t f[2] = {3, 3}; + std::size_t size[2] = {3, 3}; + std::vector padding{0, 0}; + std::vector stride{1, 1}; + std::vector dilation{1, 1}; + std::size_t channels = 2; + + std::vector weights(channels * f[0] * f[1]); + std::vector input(channels * size[0] * size[1]); + std::iota(input.begin(), input.end(), 0); + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}}; + migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; + auto l_image = mm->add_literal(migraphx::literal{s_image, input}); + auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights}); + mm->add_instruction( + migraphx::make_op("im2col", + {{"padding", padding}, {"stride", stride}, {"dilation", dilation}}), + l_image, + l_weights); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1; + std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; + std::vector results_vector(channels * f[0] * f[1] * col_height * col_width); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, input)); +} + +TEST_CASE(im2col_3x3_with_padding_test) +{ + std::size_t f[2] = {3, 3}; + std::size_t size[2] = {2, 2}; + std::vector padding{1, 1}; + std::vector stride{1, 1}; + std::vector dilation{1, 1}; + std::size_t channels = 1; + + std::vector weights(channels * f[0] * f[1]); + std::vector input(channels * size[0] * size[1]); + std::iota(input.begin(), input.end(), 0); + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}}; + migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; + auto l_image = mm->add_literal(migraphx::literal{s_image, input}); + auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights}); + mm->add_instruction( + migraphx::make_op("im2col", + {{"padding", padding}, {"stride", stride}, {"dilation", dilation}}), + l_image, + l_weights); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector correct = {0, 0, 0, 0, 0, 1, 0, 2, 3, 0, 0, 0, 0, 1, 0, 2, 3, 0, + 0, 0, 1, 0, 2, 3, 0, 0, 0, 0, 1, 0, 2, 3, 0, 0, 0, 0}; + + std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1; + std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; + std::vector results_vector(channels * f[0] * f[1] * col_height * col_width); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, correct)); +} + +TEST_CASE(imagescaler_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {1, 3, 2, 2}}; + auto img = mm->add_literal(migraphx::literal{s, + {0.2, + 0.3, + 0.5, + 0.4, + + 0.7, + 0.8, + 0.1, + 0.9, + + 0.15, + 0.25, + 0.35, + 0.45}}); + auto scale_val = mm->add_literal(2.f); + auto scaled_tensor = mm->add_instruction( + migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), scale_val); + auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), img, scaled_tensor); + auto bias_vals = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); + auto bias_bcast = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), bias_vals); + mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(12); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0.41, + 0.61, + 1.01, + 0.81, + + 1.42, + 1.62, + 0.22, + 1.82, + + 0.33, + 0.53, + 0.73, + 0.93}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(leaky_relu_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto l = mm->add_literal(migraphx::literal{s, {-1.f, 0.f, 1.f}}); + mm->add_instruction(migraphx::make_op("leaky_relu", {{"alpha", 0.01}}), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-0.01f, 0.f, 1.f}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(less_brcst_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::float_type, {3, 3}}; + auto l0 = + mm->add_literal(migraphx::literal{s0, {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}}); + migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; + auto l1 = mm->add_literal(migraphx::literal{s1, {1.1, -1.5, 0.0}}); + auto bl1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), l1); + auto le = mm->add_instruction(migraphx::make_op("less"), l0, bl1); + auto r = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + le); + mm->add_return({r}); + + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {false, false, true, false, false, false, false, false, true}; + EXPECT(results_vector == gold); +} + +TEST_CASE(less_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {9}}; + std::vector data1 = {1.1, 1.5, 0.1, -1.1, -1.5, -0.6, 0.0, 2.0, -2.0}; + std::vector data2 = {1.1, 1.6, -0.1, -1.2, -1.5, -0.7, 0.0, 2.3, -2.1}; + auto l0 = mm->add_literal(migraphx::literal{s, data1}); + auto l1 = mm->add_literal(migraphx::literal{s, data2}); + auto le = mm->add_instruction(migraphx::make_op("less"), l0, l1); + auto r = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), + le); + mm->add_return({r}); + + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold(data1.size()); + std::transform( + data1.begin(), data1.end(), data2.begin(), gold.begin(), [](float n1, float n2) -> bool { + return n1 < n2; + }); + EXPECT(results_vector == gold); +} + +TEST_CASE(log_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + std::vector data = {1, 2, 3}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("log"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return logf(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(logical_and_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::bool_type, {4}}; + std::vector data1{true, false, true, false}; + std::vector data2{true, true, false, false}; + auto l1 = mm->add_literal(migraphx::literal{s, data1}); + auto l2 = mm->add_literal(migraphx::literal{s, data2}); + mm->add_instruction(migraphx::make_op("logical_and"), l1, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold(data2.size()); + std::transform( + data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool { + return n1 and n2; + }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(logical_or_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::bool_type, {4}}; + std::vector data1{true, false, true, false}; + std::vector data2{true, true, false, false}; + auto l1 = mm->add_literal(migraphx::literal{s, data1}); + auto l2 = mm->add_literal(migraphx::literal{s, data2}); + mm->add_instruction(migraphx::make_op("logical_or"), l1, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold(data1.size()); + std::transform( + data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool { + return n1 or n2; + }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(logical_xor_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::bool_type, {4}}; + std::vector data1{true, false, true, false}; + std::vector data2{true, true, false, false}; + auto l1 = mm->add_literal(migraphx::literal{s, data1}); + auto l2 = mm->add_literal(migraphx::literal{s, data2}); + mm->add_instruction(migraphx::make_op("logical_xor"), l1, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {false, true, true, false}; + std::transform( + data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool { + return n1 ^ n2; + }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(logsoftmax_test_axis_0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector a = { + 1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336, + -1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592, + -0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611, + -0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996, + 1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923, + -0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234, + -1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535, + -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; + + std::vector s = { + -0.135261, -2.843968, -0.659995, -0.488413, -1.051857, -2.812936, -0.250956, -0.353985, + -1.155980, -0.603651, -0.211969, -0.175371, -1.336552, -3.885010, -1.871544, -0.837083, + -0.887745, -0.433338, -1.158864, -4.911197, -1.147972, -0.666711, -0.996874, -0.981418, + -0.851145, -0.853988, -0.858112, -2.067420, -0.059956, -0.727436, -0.950881, -0.429689, + -0.061906, -1.505332, -1.210277, -0.377970, -0.791448, -1.655428, -1.827253, -0.304828, + -0.020762, -0.167101, -0.567346, -0.530319, -1.045094, -0.376648, -0.007391, -0.381670, + -0.720302, -0.460499, -0.469651, -0.556740, -0.554628, -0.551582}; + + migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + int axis = 0; + mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", axis}}), al); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, s)); +} + +TEST_CASE(logsoftmax_test_axis_1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector a = { + 1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336, + -1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592, + -0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611, + -0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996, + 1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923, + -0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234, + -1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535, + -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; + + std::vector s = { + -0.550468, -2.132973, -1.549746, -0.650533, -1.051529, -2.248570, -0.141017, -2.028357, + -1.947730, -1.511324, -0.166597, -0.379726, -1.965689, -1.172109, -1.475721, -2.700831, + -1.537011, -0.658754, -1.596017, -3.353137, -2.266743, -1.084197, -1.076214, -0.406712, + -2.743019, -0.425526, -1.079083, -2.139486, -1.270584, -1.024088, -1.154231, -3.201762, + -0.888957, -0.532855, -3.103583, -1.221339, -1.355980, -3.531678, -1.438510, -0.975194, + -0.080261, -1.162697, -1.568557, -1.398519, -1.322129, -0.470660, -0.370953, -0.907343, + -1.179017, -3.312239, -1.286363, -1.586076, -0.345100, -0.824173}; + + migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + int axis = 1; + mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", axis}}), al); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, s)); +} + +TEST_CASE(logsoftmax_test_axis_2) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector a = { + 1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336, + -1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592, + -0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611, + -0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996, + 1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923, + -0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234, + -1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535, + -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; + + std::vector s = { + -0.495957, -1.031212, -0.245531, -2.013726, -1.339125, -2.465619, -1.356652, -0.964037, + -2.019250, -0.214522, -0.289569, -0.234392, -2.086591, -2.684439, -2.851651, -2.674176, + -1.697424, -1.889155, -0.401029, -3.064586, -1.173030, -1.306912, -2.177020, -0.834262, + -2.818177, -0.174415, -1.361105, -1.024571, -0.106766, -1.167645, -1.072650, -2.576522, + -0.569261, -1.207483, -3.679894, -2.095913, -0.504264, -3.039291, -1.290559, -1.156812, + -0.126453, -0.551493, -2.506384, -2.646261, -1.905195, -0.206994, -0.191369, -0.959754, + -1.948685, -3.671233, -0.875521, -3.111952, -1.905644, -1.6076011}; + + migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + int axis = 2; + mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", axis}}), al); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, s)); +} + +TEST_CASE(logsoftmax_test_axis_3) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector a = { + 1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336, + -1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592, + -0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611, + -0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996, + 1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923, + -0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234, + -1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535, + -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; + + std::vector s = { + -0.336904, -3.475825, -1.366154, -0.279366, -2.208430, -2.010934, -0.225511, -2.436562, + -2.167785, -1.572415, -1.784104, -0.470789, -1.067459, -1.801948, -0.711023, -2.307197, + -1.467087, -0.400681, -0.426983, -3.740518, -1.127681, -1.078919, -2.599005, -0.534965, + -2.561400, -0.567617, -1.033025, -2.097713, -0.520463, -1.262245, -1.763230, -2.607658, + -0.281299, -0.814243, -2.627210, -0.724131, -0.655704, -2.123055, -1.018163, -2.480634, + -0.382599, -1.451479, -1.843102, -0.915303, -0.818078, -1.316929, -0.508875, -2.033541, + -1.487672, -2.417791, -0.378360, -2.568531, -0.569794, -1.028032}; + + migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + int axis = 3; + mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", axis}}), al); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, s)); +} + +TEST_CASE(lppool_test) +{ + // L1 norm test + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::lpnorm}; + op.lengths = {2}; + op.padding = {0}; + op.stride = {1}; + op.lp_order = 1; + + std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.5, 0.6, 0.5, 1.3, 1.4, 1.0, 0.8, 0.8, 0.7}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + // L2 norm test + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::lpnorm}; + op.lengths = {2}; + op.padding = {0}; + op.stride = {1}; + op.lp_order = 2; + + std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.36055512754639896, + 0.447213595499958, + 0.4123105625617661, + 0.9433981132056605, + 1.0295630140987, + 0.9055385138137417, + 0.7071067811865475, + 0.7071067811865475, + 0.6082762530298219}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } +} + +TEST_CASE(lrn_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {1, 5, 1, 1}}; + auto l = mm->add_literal(migraphx::literal{s, {-2.0f, 1.0f, 0.f, 1.0f, 2.0f}}); + mm->add_instruction( + migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1}, {"size", 5}}), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(5); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-2 / 1.000075, 1 / 1.00009, 0 / 1.000145, 1 / 1.00009, 2 / 1.000075}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(max_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto l0 = mm->add_literal(migraphx::literal{s, {1, 4, 3}}); + auto l1 = mm->add_literal(migraphx::literal{s, {2, 8, 6}}); + auto l2 = mm->add_literal(migraphx::literal{s, {7, 5, 9}}); + auto curr_max = mm->add_instruction(migraphx::make_op("max"), l0, l1); + mm->add_instruction(migraphx::make_op("max"), curr_max, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{7, 8, 9}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(maxpool_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector a = { + -2.1314404, -1.63041711, 1.54562736, 1.04625261, -1.42931843, -0.48703974, 0.4065806, + -0.1524526, 1.30775225, 0.45538983, -0.06631992, -1.75332725, 1.33493888, 0.47327688, + 0.36873096, 1.18358743, -0.34640595, 1.22098756, 0.01946825, -0.20238149, 0.43348005, + -0.67991608, -0.83041084, 0.93537551, 0.70241445, -0.5654031, -1.30899191, -0.26735824, + -0.52444768, 1.99097753, 1.86504853, -0.26506025, 0.26236168, 0.43763575, 0.95300823, + -1.02733946, -0.74655169, -0.5374338, -0.28901565, -0.59789604, 0.5310151, 0.99125904, + 0.40609556, -1.57175648, 0.22031412, 1.45862222, 0.53217483, 1.39087725, 1.00170159, + -0.87175864, -1.7204628, -1.72008383, -0.38656762, -0.01443311, 1.46645272, -1.39995027, + 0.22505587, -0.43461126, -0.05511411, -0.79950953, -0.01439556, 0.08795211, 1.18943918, + -0.84079367, -1.73383629, -0.55662078, -0.30626822, -0.67339015, 0.44179603, 0.54316711, + 0.40899998, -0.27831686, -1.11900508, -0.0881724, 0.35483059, 2.36277103, -0.04765317, + -0.36865309, 0.73814237, 1.47151589, 1.36546791, -0.32649881, -1.0517807, 2.24768877, + 0.68883753, 0.58646208, -0.91017133, -0.50462508, -0.4013325, -0.72348958, -0.47368807, + 0.35285577, -1.01817429, -0.5152272, 0.60321307, 0.43521205, -0.23733577, 0.66427642, + 0.82949388, 0.82443929, 0.71550399, 0.34561086, 0.68570769, -0.40718508, -1.20350206, + 0.15793853, -2.31013632, -0.07934658, -0.09348056, 0.36576006, 2.46601582, 0.11090943, + 0.9144392, 0.56759721, -0.22112127, -0.21955389, 0.72474903, -1.28448462, 1.53285873, + 0.37437943, 0.31409341, 1.95433736, 0.91620457, 0.86205518, 1.24365854, 0.19248386, + 0.22526583, 0.13462132, -0.27561715, -2.06446075, -0.02306402, -1.38278747, 1.1411345, + 1.31293464, -1.86041689, 1.06763375, -0.26541466, 1.4545635, 1.11430049, -0.66491818, + 0.87101674, 0.67768967, -1.02062869, -1.05031872, -2.2764678, -2.0200038, 0.37592548, + -0.26701379, -0.83388507, 0.19403623, 1.00968623, 0.11020003, 1.16736257, -1.1160326, + 0.47346735, 0.6126079, -0.19135755, 1.33624589, -0.29802522, -0.57873946, -1.06555879, + -0.20686582, 1.36892557, -0.19937795, 0.8649236, -1.40126073, 1.53441942, 0.34682792, + -1.31724346, -1.32898355, 2.40126371, 0.07845283, 1.35732043, -0.63678312, 0.39429256, + -1.36487007, -0.31026676, -0.44981545, -0.28994772, -0.14657612, -1.75206447, -0.70612341, + 1.20071781, -1.64647579, -0.7133292, 0.88494766, 0.52119428, -2.77387547, 2.07681108, + -0.90133125, 0.2847338, 0.6174528, -0.20616426, -0.64263535, -1.08496261, 0.54275119, + -0.88503587, 0.6629802, 1.47319221, -1.05829155, -0.97027361, -0.93187737, -1.39954746, + -0.52359426, -0.14743951, 1.51522756, 0.2078452, -1.28156149, -1.19363916, -0.78680223, + -0.89094824, 1.30212069, -0.77974445, -0.58411664, 0.48764706, -0.67132682}; + std::vector c = {1.33493888, 1.54562736, 1.22098756, 1.33493888, 1.18358743, 1.99097753, + 1.00170159, 1.45862222, 1.39087725, 1.46645272, 1.18943918, -0.01443311, + 1.47151589, 2.36277103, 2.24768877, 0.68883753, 0.82949388, 0.71550399, + 1.95433736, 2.46601582, 1.53285873, 1.95433736, 1.06763375, 1.4545635, + 1.33624589, 1.16736257, 0.6126079, 1.36892557, 2.40126371, 1.53441942, + 0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802}; + migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 6, 6}}; + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::max}, + {"padding", {0, 0}}, + {"stride", {2, 2}}, + {"lengths", {3, 2}}}), + al); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(36); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, c)); +} + +TEST_CASE(maxpool_test_1D_3D) +{ + // 1D case 1, input is 3D + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max}; + op.lengths = {2}; + op.padding = {0}; + op.stride = {1}; + + std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.3, 0.4, 0.4, 0.8, 0.9, 0.9, 0.7, 0.7, 0.6}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + // 1D case 2, input is 3D + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max}; + op.lengths = {2}; + op.padding = {0}; + op.stride = {2}; + + std::vector data{0.4975, -0.1226, -0.0405, -0.2861, -0.1227, -0.6186, -0.9618, + 0.6022, -0.1912, 1.1925, 0.5493, 0.1692, -0.8039, -1.0281, + 0.9907, 0.477, 1.5001, -1.1603, -1.361, 1.2556}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.4975, -0.0405, -0.6186, 0.6022, 0.5493, -0.8039, 1.5001, -1.1603}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + // 1D case 2, input is 3D, ceil mode + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max}; + op.lengths = {2}; + op.padding = {0}; + op.stride = {2}; + op.ceil_mode = true; + + std::vector data{0.4975, -0.1226, -0.0405, -0.2861, -0.1227, -0.6186, -0.9618, + 0.6022, -0.1912, 1.1925, 0.5493, 0.1692, -0.8039, -1.0281, + 0.9907, 0.477, 1.5001, -1.1603, -1.361, 1.2556}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.4975, + -0.0405, + -0.1227, + -0.6186, + 0.6022, + 1.1925, + 0.5493, + -0.8039, + 0.9907, + 1.5001, + -1.1603, + 1.2556}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + // 3D, input is 5D + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max}; + op.lengths = {2, 2, 2}; + op.padding = {0, 0, 0}; + op.stride = {2, 2, 2}; + + std::vector data{ + -2.8029, 0.5861, 0.7015, 0.1297, -1.44, -1.9472, 0.7812, 2.408, -0.3145, + 0.3405, -0.9146, 0.0624, 1.5064, -0.8345, 1.7977, 1.8949, 1.0073, -0.2102, + -0.042, -0.7146, 0.6227, -0.5263, -2.2598, 0.1713, 0.449, 0.5303, -0.8622, + -0.5691, 0.907, -0.0569, -1.5348, -0.4109, -0.1461, -0.5445, 0.4266, 0.2282, + 1.3655, -2.1519, 0.6068, -0.2001, -0.4702, 0.3864, 1.7083, 0.9096, 0.4286, + -1.8866, 0.7034, 0.0293, 1.4587, 0.7672, -2.8614, 0.8124, -0.053, 1.0449, + 0.845, -0.0131, 0.1139, -0.859, -1.2681, -0.6337, -0.4644, 0.1938, 0.2889, + 0.9035, 0.7118, -0.5767, 0.4577, -0.0549, 0.2237, 0.5756, 0.0677, -0.0223, + -0.329, 0.2364, 2.7666, -0.7417, -1.3196, -0.2655, 0.1698, -0.1777, -0.9427, + 2.6859, -0.7501, 0.5175, 1.0029, -2.6436, -0.4388, -1.2348, -0.1539, -0.6229, + -0.4136, 0.5085, 0.4136, -0.6439, -1.1953, -0.406, -0.0195, 0.1869, -0.8664, + 1.1364, 0.5041, 0.0647, 0.1941, -1.0819, -0.4629, -0.5107, 0.3612, -0.3583}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1.5064, 1.3655, 0.9035, 2.6859}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } +} + +TEST_CASE(min_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto l0 = mm->add_literal(migraphx::literal{s, {1, 4, 3}}); + auto l1 = mm->add_literal(migraphx::literal{s, {2, 8, 6}}); + auto l2 = mm->add_literal(migraphx::literal{s, {7, 5, 9}}); + auto curr_min = mm->add_instruction(migraphx::make_op("min"), l0, l1); + mm->add_instruction(migraphx::make_op("min"), curr_min, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1, 4, 3}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(mul_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + std::vector data1{-1, 0, 1}; + std::vector data2{1, 2, 3}; + auto l1 = mm->add_literal(migraphx::literal{s, {-1, 0, 1}}); + auto l2 = mm->add_literal(migraphx::literal{s, {1, 2, 3}}); + mm->add_instruction(migraphx::make_op("mul"), l1, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold(data1.size()); + std::transform( + data1.begin(), data1.end(), data2.begin(), gold.begin(), [](float n1, float n2) -> float { + return n1 * n2; + }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(multinomial_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + size_t sample_size = 100000; + float seed = 0.0f; + std::mt19937 gen(seed); + std::uniform_real_distribution<> dis(0.0, 1.0); + std::vector rand_samples(sample_size); + std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); + migraphx::shape rs{migraphx::shape::float_type, {1, sample_size}}; + auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples}); + + migraphx::shape s{migraphx::shape::float_type, {1, 5}}; + std::vector dist{15, 25, 15, 25, 20}; + std::vector data(5); + std::transform(dist.begin(), dist.end(), data.begin(), [&](auto d) { return std::log(d); }); + auto input = mm->add_literal(migraphx::literal(s, data)); + + auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input); + auto mb_maxes = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5}}}), maxes); + auto cdf = mm->add_instruction(migraphx::make_op("sub"), input, mb_maxes); + cdf = mm->add_instruction(migraphx::make_op("exp"), cdf); + cdf = mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf); + + mm->add_instruction(migraphx::make_op("multinomial"), cdf, rs_lit); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vec(sample_size); + result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); + + std::vector res_dist(5, 0); + for(auto& r : result_vec) + res_dist[r]++; + auto dist_sum = std::accumulate(dist.begin(), dist.end(), 0); + auto res_dist_sum = std::accumulate(res_dist.begin(), res_dist.end(), 0); + std::vector norm(5); + std::vector res_norm(5); + std::transform(dist.begin(), dist.end(), norm.begin(), [&](auto n) { + return static_cast(n) / dist_sum; + }); + std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) { + return static_cast(n) / res_dist_sum; + }); + EXPECT(migraphx::verify_range(norm, res_norm, 100000)); +} + +TEST_CASE(neg_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + std::vector data = {1.0f, 1.3f, -1.2f, 0.0f, -100.f, 200.f}; + auto input = mm->add_literal(migraphx::literal(s, data)); + auto ret = mm->add_instruction(migraphx::make_op("neg"), input); + mm->add_return({ret}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform(gold.begin(), gold.end(), gold.begin(), std::negate()); + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(nms_not_center_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape boxes_s{migraphx::shape::float_type, {1, 6, 4}}; + std::vector boxes_vec = {1.0, 1.0, 0.0, 0.0, 0.0, 0.1, 1.0, 1.1, + 0.0, 0.9, 1.0, -0.1, 0.0, 10.0, 1.0, 11.0, + 1.0, 10.1, 0.0, 11.1, 1.0, 101.0, 0.0, 100.0}; + + migraphx::shape scores_s{migraphx::shape::float_type, {1, 1, 6}}; + std::vector scores_vec = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3}; + + auto boxes_l = mm->add_literal(migraphx::literal(boxes_s, boxes_vec)); + auto scores_l = mm->add_literal(migraphx::literal(scores_s, scores_vec)); + auto max_out_l = mm->add_literal(int64_t{4}); + auto iou_threshold = mm->add_literal(0.5f); + auto score_threshold = mm->add_literal(0.0f); + + auto r = mm->add_instruction(migraphx::make_op("nonmaxsuppression"), + boxes_l, + scores_l, + max_out_l, + iou_threshold, + score_threshold); + mm->add_return({r}); + + p.compile(migraphx::ref::target{}); + auto output = p.eval({}).back(); + std::vector result; + output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); + std::vector gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + EXPECT(migraphx::verify_range(result, gold)); +} + +TEST_CASE(nms_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape boxes_s{migraphx::shape::float_type, {1, 6, 4}}; + std::vector boxes_vec = {0.5, 0.5, 1.0, 1.0, 0.5, 0.6, 1.0, 1.0, 0.5, 0.4, 1.0, 1.0, + 0.5, 10.5, 1.0, 1.0, 0.5, 10.6, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0}; + + migraphx::shape scores_s{migraphx::shape::float_type, {1, 1, 6}}; + std::vector scores_vec = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3}; + + auto boxes_l = mm->add_literal(migraphx::literal(boxes_s, boxes_vec)); + auto scores_l = mm->add_literal(migraphx::literal(scores_s, scores_vec)); + auto max_out_l = mm->add_literal(int64_t{4}); + auto iou_threshold = mm->add_literal(0.5f); + auto score_threshold = mm->add_literal(0.0f); + + auto r = mm->add_instruction(migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}), + boxes_l, + scores_l, + max_out_l, + iou_threshold, + score_threshold); + mm->add_return({r}); + + p.compile(migraphx::ref::target{}); + auto output = p.eval({}).back(); + std::vector result; + output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); + std::vector gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + EXPECT(migraphx::verify_range(result, gold)); +} + +TEST_CASE(nonzero_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}}; + std::vector data = { + 1.0f, 1.3f, 0.0f, -1.2f, 0.0f, -100.f, 200.f, 0.0f, 0.1f, 0.2f, 0.0f, 0.5f}; + auto input = mm->add_literal(migraphx::literal(s, data)); + auto ret = mm->add_instruction(migraphx::make_op("nonzero"), input); + mm->add_return({ret}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 1, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 2, 0, 0, 0, 0}; + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(not_test) +{ + // int32 + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int32_type, {4}}; + std::vector data{0, 8, 1, -32}; + auto l1 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("not"), l1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1, 0, 0, 0}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + // bool + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::bool_type, {4}}; + std::vector data{false, false, true, true}; + auto l1 = mm->add_literal(migraphx::literal{s, {0, 0, 1, 1}}); + mm->add_instruction(migraphx::make_op("not"), l1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold(data.size()); + std::transform(data.begin(), data.end(), gold.begin(), [](bool n) -> bool { return !n; }); + EXPECT(migraphx::verify_range(results_vector, gold)); + } +} + +TEST_CASE(pad_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s, {1, 2, 3, 4}}); + mm->add_instruction(migraphx::make_op("pad", {{"pads", {1, 1, 1, 1}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(16); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(pad_test_highest_half) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::half_type, {2, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s, {1, 2, 3, 4}}); + mm->add_instruction( + migraphx::make_op("pad", + {{"pads", {1, 1, 1, 1}}, {"value", std::numeric_limits::max()}}), + l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(16); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + const float x = std::numeric_limits::max(); + std::vector gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(pad_test_lowest_half) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::half_type, {2, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s, {1, 2, 3, 4}}); + mm->add_instruction( + migraphx::make_op( + "pad", {{"pads", {1, 1, 1, 1}}, {"value", std::numeric_limits::lowest()}}), + l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(16); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + const float x = std::numeric_limits::lowest(); + std::vector gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(pointwise_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto l1 = mm->add_literal(migraphx::literal{s, {-1, 0, 1}}); + auto l2 = mm->add_literal(migraphx::literal{s, {1, 2, 3}}); + auto* pm = p.create_module("pointwise"); + auto x1 = pm->add_parameter("x1", {migraphx::shape::float_type}); + auto x2 = pm->add_parameter("x2", {migraphx::shape::float_type}); + pm->add_instruction(migraphx::make_op("add"), x1, x2); + mm->add_instruction(migraphx::make_op("pointwise"), {l1, l2}, {pm}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0, 2, 4}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(pow_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + std::vector data = {1, 2, 3}; + auto b = mm->add_literal(migraphx::literal{s, data}); + auto e = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("pow"), b, e); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(prefix_scan_sum_1d) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {6}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", false}}), + l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1.0, 3.0, 6.0, 10.0, 15.0, 21.0}; + EXPECT(results_vector == gold); +} + +TEST_CASE(prefix_scan_sum_2d) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3}}; + auto l0 = mm->add_literal(input); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", false}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0}; + EXPECT(results_vector == gold); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3}}; + auto l0 = mm->add_literal(input); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1.0, 3.0, 6.0, 1.0, 3.0, 6.0, 1.0, 3.0, 6.0}; + EXPECT(results_vector == gold); + } +} + +TEST_CASE(prefix_scan_sum_3d) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}}; + auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}}; + auto l0 = mm->add_literal(input); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", false}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + 2.0, + 4.0, + 6.0, + 2.0, + 4.0, + 6.0, + 2.0, + 4.0, + 6.0}; + EXPECT(results_vector == gold); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}}; + auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}}; + auto l0 = mm->add_literal(input); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1.0, + 2.0, + 3.0, + 2.0, + 4.0, + 6.0, + 3.0, + 6.0, + 9.0, + 1.0, + 2.0, + 3.0, + 2.0, + 4.0, + 6.0, + 3.0, + 6.0, + 9.0}; + EXPECT(results_vector == gold); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}}; + auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}}; + auto l0 = mm->add_literal(input); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 2}, {"exclusive", false}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1.0, + 3.0, + 6.0, + 1.0, + 3.0, + 6.0, + 1.0, + 3.0, + 6.0, + 1.0, + 3.0, + 6.0, + 1.0, + 3.0, + 6.0, + 1.0, + 3.0, + 6.0}; + EXPECT(results_vector == gold); + } +} + +TEST_CASE(prefix_scan_sum_exclusive) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {8}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 1, 2, 3, 4}}; + auto l0 = mm->add_literal(input); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", true}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.0, 1.0, 3.0, 6.0, 10.0, 11.0, 13.0, 16.0}; + EXPECT(results_vector == gold); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}}; + auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}}; + auto l0 = mm->add_literal(input); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", true}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.0, + 0.0, + 0.0, + 1.0, + 2.0, + 3.0, + 2.0, + 4.0, + 6.0, + 0.0, + 0.0, + 0.0, + 1.0, + 2.0, + 3.0, + 2.0, + 4.0, + 6.0}; + EXPECT(results_vector == gold); + } +} + +TEST_CASE(prefix_scan_sum_exclusive_reverse) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {6}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6}}; + auto l0 = mm->add_literal(input); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", true}, {"reverse", true}}), + l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{20.0, 18.0, 15.0, 11.0, 6.0, 0.0}; + EXPECT(results_vector == gold); +} + +TEST_CASE(prefix_scan_sum_negative_axis) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}}; + auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}}; + auto l0 = mm->add_literal(input); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", -3}, {"exclusive", false}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + 1.0, + 2.0, + 3.0, + 2.0, + 4.0, + 6.0, + 2.0, + 4.0, + 6.0, + 2.0, + 4.0, + 6.0}; + EXPECT(results_vector == gold); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}}; + auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}}; + auto l0 = mm->add_literal(input); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", -2}, {"exclusive", false}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1.0, + 2.0, + 3.0, + 2.0, + 4.0, + 6.0, + 3.0, + 6.0, + 9.0, + 1.0, + 2.0, + 3.0, + 2.0, + 4.0, + 6.0, + 3.0, + 6.0, + 9.0}; + EXPECT(results_vector == gold); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3, 3}}; + auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}}; + auto l0 = mm->add_literal(input); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", -1}, {"exclusive", false}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1.0, + 3.0, + 6.0, + 1.0, + 3.0, + 6.0, + 1.0, + 3.0, + 6.0, + 1.0, + 3.0, + 6.0, + 1.0, + 3.0, + 6.0, + 1.0, + 3.0, + 6.0}; + EXPECT(results_vector == gold); + } +} + +TEST_CASE(prefix_scan_sum_reverse) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {8}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 1, 2, 3, 4}}; + auto l0 = mm->add_literal(input); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", + {{"axis", 0}, {"exclusive", false}, {"reverse", true}}), + l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{20.0, 19.0, 17.0, 14.0, 10.0, 9.0, 7.0, 4.0}; + EXPECT(results_vector == gold); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 1, 2, 3, 4}}; + auto l0 = mm->add_literal(input); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", + {{"axis", 0}, {"exclusive", false}, {"reverse", true}}), + l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{2.0, 4.0, 6.0, 8.0, 1.0, 2.0, 3.0, 4.0}; + EXPECT(results_vector == gold); + } +} + +TEST_CASE(prelu_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto x = mm->add_literal(migraphx::literal{s, {-1, 0, 2}}); + auto slope = mm->add_literal(migraphx::literal{s, {2, 1, 2}}); + mm->add_instruction(migraphx::make_op("prelu"), x, slope); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-2.0f, 0.0f, 2.0f}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(quant_conv2d_padding_stride_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + std::vector a(2 * 3 * 4 * 4); + std::iota(a.begin(), a.end(), 0); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + std::vector c(2 * 3 * 3 * 3); + std::iota(c.begin(), c.end(), 0); + auto cl = mm->add_literal(migraphx::literal{c_shape, c}); + mm->add_instruction( + migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}), al, cl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector s = {4521, + 7014, + 7830, + 11952, + 10515, + 16734, + 19737, + 30906, + 13161, + 19542, + 19494, + 28800, + 34707, + 52590, + 54729, + 82746}; + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, s)); +} + +TEST_CASE(quant_conv2d_padding_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + std::vector a(2 * 3 * 4 * 4); + std::iota(a.begin(), a.end(), 0); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + std::vector c(2 * 3 * 3 * 3); + std::iota(c.begin(), c.end(), 0); + auto cl = mm->add_literal(migraphx::literal{c_shape, c}); + mm->add_instruction( + migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}), al, cl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector s = { + 4521, 6753, 7014, 4635, 6858, 10197, 10548, 6939, 7830, 11601, 11952, 7839, 5007, + 7383, 7590, 4953, 10515, 15987, 16734, 11277, 16821, 25506, 26586, 17874, 19737, 29826, + 30906, 20718, 13593, 20505, 21198, 14187, 13161, 19281, 19542, 12699, 18522, 27045, 27396, + 17739, 19494, 28449, 28800, 18639, 11919, 17319, 17526, 11289, 34707, 51843, 52590, 34893, + 51813, 77346, 78426, 52002, 54729, 81666, 82746, 54846, 36057, 53769, 54462, 36075}; + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, s)); +} + +TEST_CASE(quant_conv2d_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + std::vector a(2 * 3 * 4 * 4); + std::iota(a.begin(), a.end(), 0); + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + + migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + std::vector c(2 * 3 * 3 * 3); + std::iota(c.begin(), c.end(), 0); + auto cl = mm->add_literal(migraphx::literal{c_shape, c}); + + mm->add_instruction(migraphx::make_op("quant_convolution"), al, cl); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + std::vector s = {10197, + 10548, + 11601, + 11952, + 25506, + 26586, + 29826, + 30906, + 27045, + 27396, + 28449, + 28800, + 77346, + 78426, + 81666, + 82746}; + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, s)); +} + +TEST_CASE(quantizelinear) +{ + { + migraphx::shape xs{migraphx::shape::float_type, {2, 3, 3}}; + std::vector xv = { + -300, 600, 129, -1000, 4, 3, -6, 600, 550, -300, 600, 129, -1000, 4, 3, -6, 600, 550}; + migraphx::shape ss{migraphx::shape::float_type, {2, 3, 3}}; + std::vector sv = {2, 2, 2, 4, 4, 4, 6, 6, 6, 2, 2, 2, 4, 4, 4, 6, 6, 6}; + migraphx::shape zs{migraphx::shape::int8_type, {2, 3, 3}}; + std::vector zv = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + auto create_program = [&]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal(xs, xv); + auto s = mm->add_literal(ss, sv); + auto z = mm->add_literal(zs, zv); + mm->add_instruction(migraphx::make_op("quantizelinear"), x, s, z); + return p; + }; + + migraphx::program p1 = create_program(); + p1.compile(migraphx::ref::target{}); + auto result = p1.eval({}).back(); + std::vector results_vector(18); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{ + -128, 127, 65, -128, 1, 1, -1, 100, 92, -128, 127, 65, -128, 1, 1, -1, 100, 92}; + EXPECT(results_vector == gold); + } + + { + migraphx::shape xs{migraphx::shape::float_type, {2, 3, 3}}; + std::vector xv = { + -300, 600, 129, -1000, 4, 3, -6, 600, 550, -300, 600, 129, -1000, 4, 3, -6, 600, 550}; + migraphx::shape ss{migraphx::shape::float_type, {2, 3, 3}}; + std::vector sv = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + auto create_program = [&]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal(xs, xv); + auto s = mm->add_literal(ss, sv); + mm->add_instruction(migraphx::make_op("quantizelinear"), x, s); + return p; + }; + + migraphx::program p1 = create_program(); + p1.compile(migraphx::ref::target{}); + auto result = p1.eval({}).back(); + std::vector results_vector(18); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0, 255, 65, 0, 2, 2, 0, 255, 255, 0, 255, 65, 0, 2, 2, 0, 255, 255}; + EXPECT(results_vector == gold); + } +} + +TEST_CASE(recip_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::double_type, {3}}; + std::vector data{-0.5f, 0.1f, 0.5f}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("recip"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-2.0f, 10.0f, 2.0f}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(reduce_max_axis0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{9, 10, 11, 12}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_max_axis01) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {0, 1}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{11, 12}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_max_axis02) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {0, 2}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{10, 12}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_mean_axis02) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 2}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{5.5, 7.5}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_mean_axis1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{2, 3, 6, 7, 10, 11}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_mean_axis12) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1, 2}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{2.5f, 6.5f, 10.5f}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_mean_axis2) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1.5f, 3.5f, 5.5f, 7.5f, 9.5f, 11.5f}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_mean_int) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int32_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1, 2}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{2, 6, 10}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_min_axis02) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_min", {{"axes", {0, 2}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1, 3}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_min_axis1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_min", {{"axes", {1}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1, 2, 5, 6, 9, 10}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_min_axis12) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_min", {{"axes", {1, 2}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1, 5, 9}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_prod_axis0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {4, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 3, 2, 3}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_prod", {{"axes", {0}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{6, 18, 12, 18}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_sum_axis0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{15, 18, 21, 24}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_sum_axis02) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{33, 45}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_sum_axis1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{4, 6, 12, 14, 20, 22}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_sum_axis12) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1, 2}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{10, 26, 42}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_sum_axis2) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; + auto l0 = mm->add_literal(input); + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{3, 7, 11, 15, 19, 23}; + EXPECT(results_vector == gold); +} + +TEST_CASE(relu_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto l = mm->add_literal(migraphx::literal{s, {-1.f, 0.f, 1.f}}); + mm->add_instruction(migraphx::make_op("relu"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0.f, 0.f, 1.f}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(reshape_test) +{ + migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}}; + std::vector data(24); + std::iota(data.begin(), data.end(), -3); + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto l = mm->add_literal(migraphx::literal{a_shape, data}); + std::vector new_shape = {8, 3, 1, 1}; + mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, data)); + } + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto l = mm->add_literal(migraphx::literal{a_shape, data}); + std::vector new_shape = {1, 3, 4, 2}; + mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, data)); + } + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto l = mm->add_literal(migraphx::literal{a_shape, data}); + std::vector new_shape = {1, 3, 4, 2}; + mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, data)); + } +} + +TEST_CASE(reverse_test_axis0) +{ + migraphx::shape in_shape{migraphx::shape::float_type, {2, 16}}; + std::vector data(32); + std::iota(data.begin(), data.end(), 1); + migraphx::program p; + auto* mm = p.get_main_module(); + auto l = mm->add_literal(migraphx::literal{in_shape, data}); + std::vector axes = {0}; + mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector target_data = data; + std::swap_ranges(target_data.begin(), target_data.begin() + 16, target_data.begin() + 16); + EXPECT(migraphx::verify_range(results_vector, target_data)); +} + +TEST_CASE(reverse_test_axis1) +{ + migraphx::shape in_shape{migraphx::shape::float_type, {2, 16}}; + std::vector data(32); + std::iota(data.begin(), data.end(), 1); + migraphx::program p; + auto* mm = p.get_main_module(); + auto l = mm->add_literal(migraphx::literal{in_shape, data}); + std::vector axes = {1}; + mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector target_data = data; + std::reverse(target_data.begin(), target_data.begin() + 16); + std::reverse(target_data.end() - 16, target_data.end()); + EXPECT(migraphx::verify_range(results_vector, target_data)); +} + +TEST_CASE(reverse_test_axis10) +{ + migraphx::shape in_shape{migraphx::shape::float_type, {2, 16}}; + std::vector data(32); + std::iota(data.begin(), data.end(), 1); + migraphx::program p; + auto* mm = p.get_main_module(); + auto l = mm->add_literal(migraphx::literal{in_shape, data}); + std::vector axes = {1, 0}; + mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector target_data = data; + std::reverse(target_data.begin(), target_data.begin() + 16); + std::reverse(target_data.end() - 16, target_data.end()); + std::swap_ranges(target_data.begin(), target_data.begin() + 16, target_data.begin() + 16); + EXPECT(migraphx::verify_range(results_vector, target_data)); +} + +TEST_CASE(roialign_out_of_bound_test) +{ + auto create_program = [](const std::string& trans_mode = "half_pixel") { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape x_s{migraphx::shape::float_type, {1, 1, 10, 10}}; + std::vector x_vec = { + 0.2764, 0.7150, 0.1958, 0.3416, 0.4638, 0.0259, 0.2963, 0.6518, 0.4856, 0.7250, + 0.9637, 0.0895, 0.2919, 0.6753, 0.0234, 0.6132, 0.8085, 0.5324, 0.8992, 0.4467, + 0.3265, 0.8479, 0.9698, 0.2471, 0.9336, 0.1878, 0.4766, 0.4308, 0.3400, 0.2162, + 0.0206, 0.1720, 0.2155, 0.4394, 0.0653, 0.3406, 0.7724, 0.3921, 0.2541, 0.5799, + 0.4062, 0.2194, 0.4473, 0.4687, 0.7109, 0.9327, 0.9815, 0.6320, 0.1728, 0.6119, + 0.3097, 0.1283, 0.4984, 0.5068, 0.4279, 0.0173, 0.4388, 0.0430, 0.4671, 0.7119, + 0.1011, 0.8477, 0.4726, 0.1777, 0.9923, 0.4042, 0.1869, 0.7795, 0.9946, 0.9689, + 0.1366, 0.3671, 0.7011, 0.6234, 0.9867, 0.5585, 0.6985, 0.5609, 0.8788, 0.9928, + 0.5697, 0.8511, 0.6711, 0.9406, 0.8751, 0.7496, 0.1650, 0.1049, 0.1559, 0.2514, + 0.7012, 0.4056, 0.7879, 0.3461, 0.0415, 0.2998, 0.5094, 0.3727, 0.5482, 0.0502}; + + migraphx::shape roi_s{migraphx::shape::float_type, {3, 4}}; + std::vector roi_vec = {0, 0, 9.99, 9.99, 0, 5, 4, 9, 5, 5, 9.9, 9.9}; + + migraphx::shape ind_s{migraphx::shape::int64_type, {3}}; + std::vector ind_vec = {0, 0, 0}; + + auto x = mm->add_literal(migraphx::literal(x_s, x_vec)); + auto roi = mm->add_literal(migraphx::literal(roi_s, roi_vec)); + auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec)); + auto r = + mm->add_instruction(migraphx::make_op("roialign", + {{"coordinate_transformation_mode", trans_mode}, + {"spatial_scale", 5.0}, + {"output_height", 1}, + {"output_width", 1}, + {"sampling_ratio", 1}}), + x, + roi, + ind); + mm->add_return({r}); + return p; + }; + + { + auto p = create_program("output_half_pixel"); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0.0f, 0.0f, 0.0f}; + + EXPECT(migraphx::verify_range(results_vector, gold)); + } +} + +TEST_CASE(roialign_test) +{ + auto create_program = [](const std::string& trans_mode = "half_pixel", + const migraphx::op::pooling_mode pooling_mode = + migraphx::op::pooling_mode::average, + int64_t sampling_ratio = 2) { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape x_s{migraphx::shape::float_type, {1, 1, 10, 10}}; + std::vector x_vec = { + 0.2764, 0.7150, 0.1958, 0.3416, 0.4638, 0.0259, 0.2963, 0.6518, 0.4856, 0.7250, + 0.9637, 0.0895, 0.2919, 0.6753, 0.0234, 0.6132, 0.8085, 0.5324, 0.8992, 0.4467, + 0.3265, 0.8479, 0.9698, 0.2471, 0.9336, 0.1878, 0.4766, 0.4308, 0.3400, 0.2162, + 0.0206, 0.1720, 0.2155, 0.4394, 0.0653, 0.3406, 0.7724, 0.3921, 0.2541, 0.5799, + 0.4062, 0.2194, 0.4473, 0.4687, 0.7109, 0.9327, 0.9815, 0.6320, 0.1728, 0.6119, + 0.3097, 0.1283, 0.4984, 0.5068, 0.4279, 0.0173, 0.4388, 0.0430, 0.4671, 0.7119, + 0.1011, 0.8477, 0.4726, 0.1777, 0.9923, 0.4042, 0.1869, 0.7795, 0.9946, 0.9689, + 0.1366, 0.3671, 0.7011, 0.6234, 0.9867, 0.5585, 0.6985, 0.5609, 0.8788, 0.9928, + 0.5697, 0.8511, 0.6711, 0.9406, 0.8751, 0.7496, 0.1650, 0.1049, 0.1559, 0.2514, + 0.7012, 0.4056, 0.7879, 0.3461, 0.0415, 0.2998, 0.5094, 0.3727, 0.5482, 0.0502}; + + migraphx::shape roi_s{migraphx::shape::float_type, {3, 4}}; + std::vector roi_vec = {0, 0, 9, 9, 0, 5, 4, 9, 5, 5, 9, 9}; + + migraphx::shape ind_s{migraphx::shape::int64_type, {3}}; + std::vector ind_vec = {0, 0, 0}; + + auto x = mm->add_literal(migraphx::literal(x_s, x_vec)); + auto roi = mm->add_literal(migraphx::literal(roi_s, roi_vec)); + auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec)); + auto r = + mm->add_instruction(migraphx::make_op("roialign", + {{"coordinate_transformation_mode", trans_mode}, + {"spatial_scale", 1.0}, + {"output_height", 5}, + {"output_width", 5}, + {"sampling_ratio", sampling_ratio}, + {"mode", pooling_mode}}), + x, + roi, + ind); + mm->add_return({r}); + return p; + }; + + { + auto p = create_program(); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = { + 0.466421425, 0.446552634, 0.340521216, 0.568848491, 0.606780827, 0.371379346, + 0.429571986, 0.383519977, 0.556241512, 0.351050019, 0.27680251, 0.488286227, + 0.522200167, 0.552770197, 0.417057365, 0.471240699, 0.4844096, 0.690457463, + 0.492039412, 0.877398551, 0.623889625, 0.712461948, 0.628926516, 0.335504025, + 0.349469036, 0.302179992, 0.43046391, 0.469585985, 0.39774403, 0.542259991, + 0.365552008, 0.704923987, 0.516481996, 0.317131996, 0.701444089, 0.291239977, + 0.505897999, 0.647610962, 0.623489916, 0.829879999, 0.591567993, 0.738860011, + 0.704825997, 0.837148011, 0.889315963, 0.622680008, 0.615276039, 0.709713995, + 0.615356028, 0.458524048, 0.238451958, 0.337952018, 0.371693879, 0.609999895, + 0.760059953, 0.376724035, 0.378532052, 0.71468991, 0.924308002, 0.972783983, + 0.574903965, 0.582623959, 0.570936024, 0.761904061, 0.876998067, 0.535508037, + 0.256580025, 0.214098021, 0.279604018, 0.360000014, 0.436488032, 0.350427985, + 0.288755983, 0.366139978, 0.234920025}; + + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + { + auto p = create_program("output_half_pixel"); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = { + 0.517783, 0.343411, 0.322905, 0.447362, 0.634375, 0.40308, 0.536647, 0.442791, + 0.486144, 0.402313, 0.251194, 0.400154, 0.515524, 0.695369, 0.346537, 0.33504, + 0.460099, 0.588069, 0.343863, 0.684932, 0.49319, 0.714058, 0.821744, 0.471935, + 0.403946, 0.306955, 0.218678, 0.33369, 0.488001, 0.486962, 0.18709, 0.49142, + 0.55611, 0.419167, 0.368608, 0.143278, 0.460835, 0.597125, 0.53096, 0.498207, + 0.278818, 0.438569, 0.6022, 0.700038, 0.752436, 0.577385, 0.702383, 0.725097, + 0.733754, 0.816304, 0.23933, 0.407514, 0.337893, 0.252521, 0.474335, 0.367075, + 0.270168, 0.41051, 0.64189, 0.830777, 0.55564, 0.454295, 0.55645, 0.75015, + 0.929997, 0.66257, 0.561664, 0.481275, 0.495449, 0.666306, 0.663573, 0.372107, + 0.205603, 0.192776, 0.247849}; + + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + { + auto p = create_program("output_half_pixel", migraphx::op::pooling_mode::max, 0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = { + 0.819145, 0.373103, 0.258302, 0.515419, 0.726104, 0.540536, 0.545512, 0.38511, + 0.376545, 0.274635, 0.22341, 0.184511, 0.230843, 0.404869, 0.29546, 0.540409, + 0.265838, 0.409324, 0.213915, 0.708654, 0.687264, 0.580821, 0.461283, 0.462879, + 0.709632, 0.27873, 0.083619, 0.22428, 0.313992, 0.410508, 0.0929099, 0.415373, + 0.296695, 0.231574, 0.136836, 0.0683, 0.296695, 0.211925, 0.245385, 0.28053, + 0.17091, 0.179879, 0.245385, 0.343539, 0.392742, 0.51273, 0.536193, 0.382995, + 0.422793, 0.761886, 0.0839429, 0.276444, 0.19746, 0.126117, 0.378351, 0.254646, + 0.092148, 0.272825, 0.381955, 0.626599, 0.251325, 0.244475, 0.194875, 0.272825, + 0.44757, 0.351855, 0.342265, 0.244475, 0.274841, 0.553644, 0.607176, 0.202392, + 0.07425, 0.066087, 0.126279}; + + EXPECT(migraphx::verify_range(results_vector, gold)); + } +} + +TEST_CASE(round_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {9}}; + auto l = + mm->add_literal(migraphx::literal{s, {1.1, 1.5, 1.6, -1.1, -1.5, -1.6, 0.0, 2.0, -2.0}}); + mm->add_instruction(migraphx::make_op("round"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {1.0, 2.0, 2.0, -1.0, -2.0, -2.0, 0.0, 2.0, -2.0}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(rsqrt_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto l = mm->add_literal(migraphx::literal{s, {4.0, 16.0, 64.0}}); + mm->add_instruction(migraphx::make_op("rsqrt"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0.5, 0.25, 0.125}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +// reduction_mode: "scatter_none", "scatter_add", "scatter_mul" +migraphx::program create_scatter_program(const std::string& reduction_mode, int axis) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sd{migraphx::shape::float_type, {3, 3}}; + std::vector vd(sd.elements(), 0.0f); + + migraphx::shape si{migraphx::shape::int32_type, {2, 3}}; + std::vector vi = {1, 0, 2, 0, 2, 1}; + + migraphx::shape su{migraphx::shape::float_type, {2, 3}}; + std::vector vu = {1.0, 1.1, 1.2, 2.0, 2.1, 2.2}; + + auto ld = mm->add_literal(migraphx::literal{sd, vd}); + auto li = mm->add_literal(migraphx::literal{si, vi}); + auto lu = mm->add_literal(migraphx::literal{su, vu}); + // scatter_none, formerly the scatter op + auto r = mm->add_instruction(migraphx::make_op(reduction_mode, {{"axis", axis}}), ld, li, lu); + mm->add_return({r}); + return p; +} + +TEST_CASE(scatter_ax0_test) +{ + // this tests what used to be the only scatter op, now changed to 3 sub-ops + // which have their own test case + { + migraphx::program p = create_scatter_program("scatter_none", 0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } +} + +TEST_CASE(scatter_ax_neg_test) +{ + { + migraphx::program p = create_scatter_program("scatter_none", -2); + + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } +} + +TEST_CASE(scatter_ax1_test) +{ + { + migraphx::program p = create_scatter_program("scatter_none", 1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {1.1, 1.0, 1.2, 2.0, 2.2, 2.1, 0.0, 0.0, 0.0}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } +} + +// similar to create_scatter_program but with different tensor values +// reduction_mode: "scatter_none", "scatter_add", "scatter_mul" +migraphx::program create_scatter_program2(const std::string& reduction_mode, int axis) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sd{migraphx::shape::float_type, {1, 5}}; + std::vector vd({1., 2., 3., 4., 5.}); + + migraphx::shape si{migraphx::shape::int32_type, {1, 2}}; + std::vector vi = {1, 3}; + + migraphx::shape su{migraphx::shape::float_type, {1, 2}}; + std::vector vu = {1.1, 2.1}; + + auto ld = mm->add_literal(migraphx::literal{sd, vd}); + auto li = mm->add_literal(migraphx::literal{si, vi}); + auto lu = mm->add_literal(migraphx::literal{su, vu}); + auto r = mm->add_instruction(migraphx::make_op(reduction_mode, {{"axis", axis}}), ld, li, lu); + mm->add_return({r}); + return p; +} +TEST_CASE(scatter_reduction1_test) +{ + { + // Test sub-ops for the three reduction values scatter_none, scatter_add, scatter_mul + migraphx::program p = create_scatter_program2("scatter_none", 1); + + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold_none = {1.0, 1.1, 3.0, 2.1, 5.0}; + EXPECT(migraphx::verify_range(results_vector, gold_none)); + } +} + +TEST_CASE(scatter_reduction2_test) +{ + { + migraphx::program p = create_scatter_program2("scatter_mul", 1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold_mul = {1.0, 2.2, 3.0, 8.4, 5.0}; + + EXPECT(migraphx::verify_range(results_vector, gold_mul)); + } +} +TEST_CASE(scatter_reduction3_test) +{ + { + migraphx::program p = create_scatter_program2("scatter_add", 1); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold_add = {1.0, 3.1, 3.0, 6.1, 5.0}; + + EXPECT(migraphx::verify_range(results_vector, gold_add)); + } +} + +TEST_CASE(scatter_reduction_3x3_test) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sd{migraphx::shape::float_type, {3, 3}}; + std::vector vd(sd.elements(), 3.0f); + + migraphx::shape si{migraphx::shape::int32_type, {2, 3}}; + std::vector vi = {1, 0, 2, 0, 2, 1}; + + migraphx::shape su{migraphx::shape::float_type, {2, 3}}; + std::vector vu = {1.0, 1.1, 1.2, 7.0, 7.1, 7.2}; + + auto ld = mm->add_literal(migraphx::literal{sd, vd}); + auto li = mm->add_literal(migraphx::literal{si, vi}); + auto lu = mm->add_literal(migraphx::literal{su, vu}); + auto r = mm->add_instruction(migraphx::make_op("scatter_add", {{"axis", 1}}), ld, li, lu); + mm->add_return({r}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold_a2 = {4.1, 4.0, 4.2, 10.0, 10.2, 10.1, 3.0, 3.0, 3.0}; + + EXPECT(migraphx::verify_range(results_vector, gold_a2)); + } +} + +// create a test scatter program with a 3x3 tensor; +// su and si are transposed from previous case +migraphx::program create_scatter_program_3x3(const std::string& reduction_mode, int axis) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sd{migraphx::shape::float_type, {3, 3}}; + std::vector vd(sd.elements(), 3.0f); + + migraphx::shape si{migraphx::shape::int32_type, {3, 2}}; + std::vector vi = {1, 0, 0, 2, 2, 1}; + + migraphx::shape su{migraphx::shape::float_type, {3, 2}}; + std::vector vu = {1.0, 7.0, 1.1, 7.1, 1.2, 7.2}; + + auto ld = mm->add_literal(migraphx::literal{sd, vd}); + auto li = mm->add_literal(migraphx::literal{si, vi}); + auto lu = mm->add_literal(migraphx::literal{su, vu}); + auto r = mm->add_instruction(migraphx::make_op(reduction_mode, {{"axis", axis}}), ld, li, lu); + mm->add_return({r}); + return p; +} + +TEST_CASE(scatter_reduction_3x3_xpose1_test) +{ + // test on vertical (0) axis. su and si are transposed from previous case + { + migraphx::program p = create_scatter_program_3x3("scatter_none", 0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold_none2 = {1.1, 7.0, 3.0, 1.0, 7.2, 3.0, 1.2, 7.1, 3.0}; + EXPECT(migraphx::verify_range(results_vector, gold_none2)); + } +} + +TEST_CASE(scatter_reduction_3x3_xpose2_test) +{ + // test on vertical (0) axis. + { + migraphx::program p = create_scatter_program_3x3("scatter_add", 0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold_a3 = {4.1, 10.0, 3.0, 4.0, 10.2, 3.0, 4.2, 10.1, 3.0}; + + EXPECT(migraphx::verify_range(results_vector, gold_a3)); + } +} + +TEST_CASE(scatter_reduction_3x3_xpose3_test) +{ + { + migraphx::program p = create_scatter_program_3x3("scatter_mul", 0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold_mul2 = {3.3, 21.0, 3.0, 3.0, 21.6, 3.0, 3.6, 21.3, 3.0}; + + EXPECT(migraphx::verify_range(results_vector, gold_mul2)); + } +} + +TEST_CASE(scatternd_shapes_test) +{ + { + // broadcasted input + migraphx::program p; + auto* mm = p.get_main_module(); + auto dtype = migraphx::shape::float_type; + auto itype = migraphx::shape::int64_type; + migraphx::shape is{itype, {4, 1}}; + migraphx::shape us{dtype, {4}}; + + std::vector ind_vec{4, 3, 1, 7}; + std::vector upd_vec{9, 10, 11, 12}; + + auto data = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {8}}}), + mm->add_literal(migraphx::literal{0.0f})); + auto indices = mm->add_literal(migraphx::literal{is, ind_vec}); + auto updates = mm->add_literal(migraphx::literal{us, upd_vec}); + auto scatternd = + mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates); + mm->add_return({scatternd}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0, 11, 0, 10, 9, 0, 0, 12}; + + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + { + // non-standard shape input + migraphx::program p; + auto* mm = p.get_main_module(); + auto dtype = migraphx::shape::float_type; + auto itype = migraphx::shape::int64_type; + migraphx::shape ds{dtype, {2, 2}}; + migraphx::shape is{itype, {2, 2}}; + migraphx::shape us{dtype, {2}}; + + std::vector data_vec{1, 2, 3, 4}; + std::vector ind_vec{0, 0, 0, 1}; + std::vector upd_vec{5, 6}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto td = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), data); + auto indices = mm->add_literal(migraphx::literal{is, ind_vec}); + auto updates = mm->add_literal(migraphx::literal{us, upd_vec}); + auto scatternd = + mm->add_instruction(migraphx::make_op("scatternd_none"), td, indices, updates); + mm->add_return({scatternd}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{5, 6, 2, 4}; + + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + { + // non-standard updates shape + migraphx::program p; + auto* mm = p.get_main_module(); + auto dtype = migraphx::shape::float_type; + auto itype = migraphx::shape::int64_type; + migraphx::shape ds{dtype, {2, 2, 2}}; + migraphx::shape is{itype, {2, 1, 3}}; + migraphx::shape us{dtype, {1, 2}}; + + std::vector data_vec{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector ind_vec{0, 0, 0, 1, 1, 1}; + std::vector upd_vec{9, 10}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, ind_vec}); + auto updates = mm->add_literal(migraphx::literal{us, upd_vec}); + auto tu = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), updates); + auto scatternd = + mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, tu); + mm->add_return({scatternd}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{9, 2, 3, 4, 5, 6, 7, 10}; + + EXPECT(migraphx::verify_range(results_vector, gold)); + } +} + +TEST_CASE(scatternd_test) +{ + { + // r=1, q=2, k=1 + migraphx::program p; + auto* mm = p.get_main_module(); + auto dtype = migraphx::shape::float_type; + auto itype = migraphx::shape::int64_type; + migraphx::shape ds{dtype, {8}}; + migraphx::shape is{itype, {4, 1}}; + migraphx::shape us{dtype, {4}}; + + std::vector data_vec{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector ind_vec{4, 3, 1, 7}; + std::vector upd_vec{9, 10, 11, 12}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, ind_vec}); + auto updates = mm->add_literal(migraphx::literal{us, upd_vec}); + auto scatternd = + mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates); + mm->add_return({scatternd}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1, 11, 3, 10, 9, 6, 7, 12}; + + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + { + // r=2, q=2, k=2 + migraphx::program p; + auto* mm = p.get_main_module(); + auto dtype = migraphx::shape::float_type; + auto itype = migraphx::shape::int64_type; + migraphx::shape ds{dtype, {2, 2}}; + migraphx::shape is{itype, {2, 2}}; + migraphx::shape us{dtype, {2}}; + + std::vector data_vec{1, 2, 3, 4}; + std::vector ind_vec{0, 0, 0, 1}; + std::vector upd_vec{5, 6}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, ind_vec}); + auto updates = mm->add_literal(migraphx::literal{us, upd_vec}); + auto scatternd = + mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates); + mm->add_return({scatternd}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{5, 6, 3, 4}; + + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + { + // r=3, q=3, k=3 + migraphx::program p; + auto* mm = p.get_main_module(); + auto dtype = migraphx::shape::float_type; + auto itype = migraphx::shape::int64_type; + migraphx::shape ds{dtype, {2, 2, 2}}; + migraphx::shape is{itype, {2, 1, 3}}; + migraphx::shape us{dtype, {2, 1}}; + + std::vector data_vec{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector ind_vec{0, 0, 0, 1, 1, 1}; + std::vector upd_vec{9, 10}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, ind_vec}); + auto updates = mm->add_literal(migraphx::literal{us, upd_vec}); + auto scatternd = + mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates); + mm->add_return({scatternd}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{9, 2, 3, 4, 5, 6, 7, 10}; + + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + { + // r=3, q=2, k=1 + migraphx::program p; + auto* mm = p.get_main_module(); + auto dtype = migraphx::shape::float_type; + auto itype = migraphx::shape::int64_type; + migraphx::shape ds{dtype, {4, 4, 4}}; + migraphx::shape is{itype, {2, 1}}; + migraphx::shape us{dtype, {2, 4, 4}}; + + std::vector data_vec{1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, + 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8, + 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8}; + std::vector ind_vec{0, 2}; + std::vector upd_vec{5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, + 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, ind_vec}); + auto updates = mm->add_literal(migraphx::literal{us, upd_vec}); + auto scatternd = + mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates); + mm->add_return({scatternd}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 1, 2, 3, 4, 5, 6, + 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, + 4, 4, 4, 4, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8}; + + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + { + // r=5, q=1, k=1 + migraphx::program p; + auto* mm = p.get_main_module(); + auto dtype = migraphx::shape::float_type; + auto itype = migraphx::shape::int64_type; + migraphx::shape ds{dtype, {2, 2, 2, 2, 2}}; + migraphx::shape is{itype, {1}}; + migraphx::shape us{dtype, {2, 2, 2, 2}}; + + std::vector data_vec(32, 1); + std::vector ind_vec{1}; + std::vector upd_vec(16, 0); + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, ind_vec}); + auto updates = mm->add_literal(migraphx::literal{us, upd_vec}); + auto scatternd = + mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates); + mm->add_return({scatternd}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold(32, 0); + std::copy(data_vec.begin(), data_vec.begin() + 16, gold.begin()); + + EXPECT(migraphx::verify_range(results_vector, gold)); + } +} + +TEST_CASE(scatternd_reduction_test) +{ + { + // reduction = add + migraphx::program p; + auto* mm = p.get_main_module(); + auto dtype = migraphx::shape::float_type; + auto itype = migraphx::shape::int64_type; + migraphx::shape ds{dtype, {8}}; + migraphx::shape is{itype, {8, 1}}; + migraphx::shape us{dtype, {8}}; + + std::vector data_vec{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector ind_vec{4, 3, 1, 7, 4, 3, 1, 7}; + std::vector upd_vec{9, 10, 11, 12, -8, -9, -10, -11}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, ind_vec}); + auto updates = mm->add_literal(migraphx::literal{us, upd_vec}); + auto scatternd = + mm->add_instruction(migraphx::make_op("scatternd_add"), data, indices, updates); + mm->add_return({scatternd}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1, 3, 3, 5, 6, 6, 7, 9}; + + EXPECT(migraphx::verify_range(results_vector, gold)); + } + + { + // reduction = mul + migraphx::program p; + auto* mm = p.get_main_module(); + auto dtype = migraphx::shape::float_type; + auto itype = migraphx::shape::int64_type; + migraphx::shape ds{dtype, {8}}; + migraphx::shape is{itype, {4, 1}}; + migraphx::shape us{dtype, {4}}; + + std::vector data_vec{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector ind_vec{4, 3, 1, 7}; + std::vector upd_vec{9, 10, 11, 12}; + + auto data = mm->add_literal(migraphx::literal{ds, data_vec}); + auto indices = mm->add_literal(migraphx::literal{is, ind_vec}); + auto updates = mm->add_literal(migraphx::literal{us, upd_vec}); + auto scatternd = + mm->add_instruction(migraphx::make_op("scatternd_mul"), data, indices, updates); + mm->add_return({scatternd}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{1, 22, 3, 40, 45, 6, 7, 96}; + + EXPECT(migraphx::verify_range(results_vector, gold)); + } +} + +TEST_CASE(sigmoid_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 2}}; + auto l = mm->add_literal(migraphx::literal{s, {-1, 2, -3, 4}}); + mm->add_instruction(migraphx::make_op("sigmoid"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{sigmoid(-1), sigmoid(2), sigmoid(-3), sigmoid(4)}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(sign_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {5}}; + auto l = mm->add_literal( + migraphx::literal{s, {1.02481645, 0.85643062, -0.03404123, -0.92791926, 0.0}}); + mm->add_instruction(migraphx::make_op("sign"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {1.0, 1.0, -1.0, -1.0, 0.0}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(sin_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + std::vector data = {-1, 0, 1}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("sin"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinf(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(sinh_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 2}}; + std::vector data{-1.0, 2.0, -3.0, 4.0}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("sinh"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinhf(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(slice_test) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data(2 * 2 * 3); + std::iota(data.begin(), data.end(), 0); + migraphx::shape s{migraphx::shape::int32_type, {2, 2, 3}}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}), l0); + migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}; + EXPECT(p.get_output_shapes().back() == s2); + p.compile(migraphx::ref::target{}); + migraphx::shape sresult{migraphx::shape::int32_type, {2, 2, 2}, {4, 2, 1}}; + auto result = p.eval({}).back(); + std::vector gold = {1, 2, 4, 5, 7, 8, 10, 11}; + std::vector results_vector(2 * 2 * 2); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, gold)); + EXPECT(result.get_shape() == sresult); + } + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data(2 * 2 * 3); + std::iota(data.begin(), data.end(), 0); + migraphx::shape s{migraphx::shape::int32_type, {2, 2, 3}}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction( + migraphx::make_op("slice", + {{"axes", {0, 1, 2}}, {"starts", {0, 0, 0}}, {"ends", {2, 2, 2}}}), + l0); + migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}; + EXPECT(p.get_output_shapes().back() == s2); + p.compile(migraphx::ref::target{}); + migraphx::shape sresult{migraphx::shape::int32_type, {2, 2, 2}, {4, 2, 1}}; + auto result = p.eval({}).back(); + std::vector gold = {0, 1, 3, 4, 6, 7, 9, 10}; + std::vector results_vector(2 * 2 * 2); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, gold)); + EXPECT(result.get_shape() == sresult); + } +} + +TEST_CASE(softmax_simple_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector a = {0.25, 0.75}; + std::vector s = {0.377541, 0.622459}; + migraphx::shape a_shape{migraphx::shape::float_type, {1, 2}}; + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), al); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(2); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, s)); +} + +TEST_CASE(softmax_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector a = { + -5.61869681e-01, 9.07827199e-01, 1.29255986e+00, 3.18533443e-02, -1.22183852e-03, + -2.83830553e-01, -1.03245842e+00, -9.28322077e-01, -8.82696748e-01, 1.11327164e-01, + -9.20038462e-01, 8.47388089e-01, 2.51734018e-01, 1.50563884e+00, 2.23056650e+00, + -6.17576987e-02, -1.00264274e-01, -6.10369384e-01, 1.17537189e+00, -2.51560897e-01, + -8.50333512e-01, -8.03578615e-01, -6.51194930e-01, -2.58137047e-01, 4.65528190e-01, + 3.23284641e-02, -1.54700470e+00, 1.38096774e+00, 5.39869189e-01, -7.56884992e-01, + 1.81503093e+00, -2.11269641e+00, 1.92466557e+00, 1.77230799e+00, 2.21660900e+00, + 1.56777036e+00, -2.08995026e-03, 3.50566894e-01, -1.15042710e+00, -1.18577778e+00, + 8.90633047e-01, -6.63949102e-02, 1.44661188e+00, 1.59215283e+00, -2.56262213e-01, + 9.39079225e-01, 4.07298543e-02, 3.86590779e-01, 6.09607756e-01, 8.22331488e-01, + -2.82126725e-01, -9.49052632e-01, -4.24012303e-01, -5.32990396e-01, -3.18386006e+00, + 3.27092171e-01, -1.33315325e+00, 3.62459183e-01, 3.74710828e-01, -1.30302286e+00, + 1.79680198e-01, -4.51832324e-01, 4.34282750e-01, -7.09520102e-01, 6.20333970e-01, + -1.28712380e+00, 2.04130828e-01, -7.70607769e-01, 1.61889160e+00, -1.50951004e+00, + -4.10505563e-01, -3.56566496e-02, -1.29747534e+00, -1.49967879e-01, 7.77626812e-01, + -8.28408226e-02, 2.73412596e-02, 5.79780899e-03, 9.87900198e-02, -7.95276761e-01, + -1.38536084e+00, -6.63573861e-01, 3.89783204e-01, -1.30670881e+00, -7.62425125e-01, + -4.04883057e-01, 6.24344349e-01, 3.68128955e-01, -1.01577950e+00, -3.06715906e-01, + 5.67961395e-01, 2.98198581e-01, -1.63613629e+00, -3.75131965e-01, -6.75393403e-01, + 2.59172034e+00, 6.75538957e-01, 9.07939598e-02, 1.92257717e-01, -1.21592450e+00, + -2.73682117e-01, 1.25232983e+00, -1.39969170e+00, -1.91483587e-01, 2.57732719e-01, + 3.10056299e-01, 1.41833842e+00, -1.81386679e-01, 3.92868072e-01, -8.14771175e-01, + 2.02392387e+00, -9.42091495e-02, -3.77683818e-01, 2.05638766e+00, 2.93796062e-01, + -6.02131486e-01, 2.70461679e-01, -8.92358482e-01, 1.04388881e+00, 2.66154885e-01}; + + std::vector s = { + 0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741, + 0.18750034, 0.21905553, 0.27000085, 0.0547399, 0.56318235, 0.47422904, 0.78964758, + 0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684, 0.18766427, 0.15753111, + 0.07844277, 0.05120674, 0.36648798, 0.14637007, 0.13152322, 0.01560997, 0.29065287, + 0.49196178, 0.10550152, 0.81890774, 0.06369215, 0.62972021, 0.74931765, 0.67285055, + 0.35034987, 0.28612873, 0.31931475, 0.04220394, 0.16093165, 0.22390974, 0.11915915, + 0.3115395, 0.35899726, 0.22190949, 0.57518375, 0.13888834, 0.7753762, 0.4642328, + 0.57055861, 0.21954368, 0.34515455, 0.09486015, 0.40631217, 0.01842281, 0.48770609, + 0.06652815, 0.36023033, 0.42343026, 0.24226256, 0.17348589, 0.44066274, 0.6865865, + 0.17296699, 0.46923906, 0.06921105, 0.3570261, 0.4125829, 0.73165393, 0.15302512, + 0.29499072, 0.33932695, 0.30852377, 0.40762195, 0.40170741, 0.36259529, 0.60848355, + 0.42618036, 0.31721094, 0.02960522, 0.28256637, 0.24389413, 0.2725659, 0.10663581, + 0.27622163, 0.28264219, 0.53652936, 0.09476089, 0.40890986, 0.34848392, 0.32572666, + 0.53076893, 0.11529481, 0.29117745, 0.14625968, 0.8756339, 0.49818122, 0.10656087, + 0.1813329, 0.17664003, 0.21410346, 0.80408043, 0.02315119, 0.27155462, 0.32804728, + 0.13268511, 0.61795473, 0.49703068, 0.41696799, 0.10175809, 0.71028161, 0.29929739, + 0.17377149, 0.76075399, 0.20071237, 0.32632929, 0.36892858, 0.09416146, 0.26656723, + 0.42914796}; + + migraphx::shape a_shape{migraphx::shape::float_type, {5, 3, 4, 2}}; + auto al = mm->add_literal(migraphx::literal{a_shape, a}); + mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), al); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(120); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(results_vector, s)); +} + +TEST_CASE(sqdiff_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto l1 = mm->add_literal(migraphx::literal{s, {-1, 0, 1}}); + auto l2 = mm->add_literal(migraphx::literal{s, {1, 2, 3}}); + mm->add_instruction(migraphx::make_op("sqdiff"), l1, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {4, 4, 4}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(sqrt_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {5}}; + std::vector data{1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("sqrt"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sqrtf(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(squeeze_test) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data(4 * 3 * 3); + migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; + auto l0 = mm->add_literal(migraphx::literal{s1, data}); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + EXPECT(result.get_shape() == s2); + } + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data(4 * 3 * 3); + migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}}; + auto l0 = mm->add_literal(migraphx::literal{s1, data}); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {3}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + EXPECT(result.get_shape() == s2); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data(4 * 3 * 3); + migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {4, 3, 3}}; + auto l0 = mm->add_literal(migraphx::literal{s1, data}); + mm->add_instruction(migraphx::make_op("squeeze"), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + EXPECT(result.get_shape() == s2); + } +} + +TEST_CASE(step_test) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data(2 * 4 * 6); + std::iota(data.begin(), data.end(), 2); + migraphx::shape s1{migraphx::shape::float_type, {2, 1, 4, 6}}; + auto l0 = mm->add_literal(migraphx::literal{s1, data}); + auto r = mm->add_instruction( + migraphx::make_op("step", {{"axes", {0, 2, 3}}, {"steps", {2, 2, 3}}}), l0); + mm->add_return({r}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + migraphx::shape s2{migraphx::shape::float_type, {1, 1, 2, 2}}; + EXPECT(result.get_shape() == s2); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data(2 * 4 * 6); + std::iota(data.begin(), data.end(), 2); + migraphx::shape s1{migraphx::shape::float_type, {2, 1, 4, 6}}; + auto l0 = mm->add_literal(migraphx::literal{s1, data}); + auto tl = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); + auto r = mm->add_instruction( + migraphx::make_op("step", {{"axes", {0, 1, 2}}, {"steps", {2, 2, 3}}}), tl); + mm->add_return({r}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + migraphx::shape s2{migraphx::shape::float_type, {1, 2, 2, 1}}; + EXPECT(result.get_shape() == s2); + } +} + +TEST_CASE(sub_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto l1 = mm->add_literal(migraphx::literal{s, {-1, 0, 1}}); + auto l2 = mm->add_literal(migraphx::literal{s, {1, 2, 3}}); + mm->add_instruction(migraphx::make_op("sub"), l1, l2); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {-2, -2, -2}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(tan_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + std::vector data{-1, 0, 1}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("tan"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanf(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(tanh_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 2}}; + std::vector data{-1.0, 2.0, -3.0, 4.0}; + auto l = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(migraphx::make_op("tanh"), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = data; + std::transform( + gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanhf(n); }); + EXPECT(migraphx::verify_range(results_vector, gold)); +} + +TEST_CASE(topk_test) +{ + auto create_program = [](int64_t k, int64_t axis, int largest) { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 5}}; + auto data = mm->add_parameter("data", s); + auto r = mm->add_instruction( + migraphx::make_op("topk", {{"axis", axis}, {"k", k}, {"largest", largest}}), data); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), r); + mm->add_return({r0, r1}); + + return p; + }; + + auto run_program = [&](int64_t k, int64_t axis, int largest) { + auto p = create_program(k, axis, largest); + p.compile(migraphx::ref::target{}); + std::vector data = { + 2.1, 2.3, 2.0, 2.5, 1.9, 3.3, 0.2, 4.5, 0.1, 0.8, 1.0, 4.5, 2.1, 0.8, 1.5}; + migraphx::shape s{migraphx::shape::float_type, {3, 5}}; + migraphx::parameter_map pp; + pp["data"] = migraphx::argument(s, data.data()); + auto rets = p.eval(pp); + std::vector ret_val; + rets.front().visit([&](auto v) { ret_val.assign(v.begin(), v.end()); }); + std::vector ret_ind; + rets.back().visit([&](auto v) { ret_ind.assign(v.begin(), v.end()); }); + + return std::make_pair(ret_val, ret_ind); + }; + + // case 1 + { + auto results = run_program(4, 1, 1); + std::vector gold_val = {2.5, 2.3, 2.1, 2, 4.5, 3.3, 0.8, 0.2, 4.5, 2.1, 1.5, 1}; + EXPECT(results.first == gold_val); + std::vector gold_ind = {3, 1, 0, 2, 2, 0, 4, 1, 1, 2, 4, 0}; + EXPECT(results.second == gold_ind); + } + + // case 2 + { + auto results = run_program(4, 1, 0); + std::vector gold_val = {1.9, 2, 2.1, 2.3, 0.1, 0.2, 0.8, 3.3, 0.8, 1, 1.5, 2.1}; + EXPECT(results.first == gold_val); + std::vector gold_ind = {4, 2, 0, 1, 3, 1, 4, 0, 3, 0, 4, 2}; + EXPECT(results.second == gold_ind); + } +} + +TEST_CASE(transpose_test) +{ + migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 2, 3}}; + std::vector data(12); + std::iota(data.begin(), data.end(), 0); + + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto l = mm->add_literal(migraphx::literal{a_shape, data}); + std::vector perm = {0, 3, 1, 2}; + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), l); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + + result.visit([&](auto output) { + std::vector new_lens = {1, 3, 2, 2}; + EXPECT(bool{output.get_shape().lens() == new_lens}); + }); + } + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto l = mm->add_literal(migraphx::literal{a_shape, data}); + std::vector perm = {0, 3, 1, 2}; + auto result = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), l); + mm->add_instruction(migraphx::make_op("contiguous"), result); + p.compile(migraphx::ref::target{}); + auto result2 = p.eval({}).back(); + + std::vector results_vector(12); + result2.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11}; + EXPECT(migraphx::verify_range(results_vector, gold)); + } +} + +TEST_CASE(unsqueeze_test) +{ + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data(4 * 3 * 3); + migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}}; + auto l0 = mm->add_literal(migraphx::literal{s1, data}); + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + EXPECT(result.get_shape() == s2); + } + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data(4 * 3 * 3); + migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}}; + auto l0 = mm->add_literal(migraphx::literal{s1, data}); + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l0); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + EXPECT(result.get_shape() == s2); + } +} + +TEST_CASE(where_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sb{migraphx::shape::bool_type, {3, 3}}; + migraphx::shape sx{migraphx::shape::float_type, {3, 3}}; + + std::vector b{true, true, true, false, false, false, true, false, true}; + std::vector x(9, 1.0); + std::vector y(9, 2.0); + + auto lb = mm->add_literal(migraphx::literal{sb, b}); + auto lx = mm->add_literal(migraphx::literal{sx, x}); + auto ly = mm->add_literal(migraphx::literal{sx, y}); + auto w = mm->add_instruction(migraphx::make_op("where"), lb, lx, ly); + mm->add_return({w}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vec; + result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); + std::vector gold(9); + for(int i = 0; i < gold.size(); ++i) + gold[i] = b[i] ? x[i] : y[i]; + + EXPECT(migraphx::verify_range(result_vec, gold)); +} + +TEST_CASE(where_broadcasted_inputs_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sb{migraphx::shape::bool_type, {3, 3}}; + + std::vector b{true, true, true, false, false, false, true, false, true}; + + auto lb = mm->add_literal(migraphx::literal{sb, b}); + auto lx = mm->add_literal(migraphx::literal(1.0f)); + auto ly = mm->add_literal(migraphx::literal(2.0f)); + auto mbx = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), lx); + auto mby = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), ly); + auto w = mm->add_instruction(migraphx::make_op("where"), lb, mbx, mby); + mm->add_return({w}); + p.compile(migraphx::ref::target{}); + auto result = p.eval({}).back(); + std::vector result_vec; + result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); + std::vector gold(9); + std::vector x(9, 1.0); + std::vector y(9, 2.0); + for(int i = 0; i < gold.size(); ++i) + gold[i] = b[i] ? x[i] : y[i]; + + EXPECT(migraphx::verify_range(result_vec, gold)); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/ref_rnn_ops_test.cpp b/test/ref_rnn_ops_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a234ec697038562cccec9e7cd2f0dd13f3600e70 --- /dev/null +++ b/test/ref_rnn_ops_test.cpp @@ -0,0 +1,4631 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "test.hpp" + +TEST_CASE(rnn_forward) +{ + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + std::vector w_data{0.4691, + 0.3185, + -0.2227, + 0.4423, + -0.0609, + -0.2803, + 0.1744, + 0.3146, + 0.4049, + -0.3973, + -0.0890, + -0.1636}; + + std::vector r_data{-0.0456, + 0.1061, + 0.1574, + -0.4928, + -0.4300, + -0.1909, + -0.0225, + -0.2668, + 0.1840, + -0.4453, + -0.4896, + 0.1302, + -0.0929, + 0.3545, + -0.4981, + 0.0616}; + + std::vector bias_data{ + -0.4938, 0.4355, -0.3186, 0.2094, 0.1037, -0.1071, 0.4504, -0.3990}; + std::vector ih_data(num_dirct * batch_size * hidden_size, 0); + std::vector input(seq_len * batch_size * input_size, 0); + input[0] = input[1] = 1.0; + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + float clip = 0.0f; + // concatenation of hidden states as program output + { + + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_return({hs, lho}); + p.compile(migraphx::ref::target{}); + + auto outputs = p.eval({}); + auto res_hs = outputs.front(); + auto res_lho = outputs.back(); + + std::vector hs_data; + std::vector lho_data; + res_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{0.37780784, + 0.61055139, + 0.55168478, + -0.5888475, + -0.37144644, + 0.31708236, + 0.13104209, + -0.18736027, + 0.03445704, + 0.19167931, + -0.3946827, + -0.30889652, + -0.22276389, + 0.44193283, + -0.16477929, + -0.11893477}; + + std::vector lho_data_gold{0.03445704, + 0.19167931, + -0.3946827, + -0.30889652, + -0.22276389, + 0.44193283, + -0.16477929, + -0.11893477}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + EXPECT(migraphx::verify_range(lho_data, lho_data_gold)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}}; + std::vector pad_data(pad_seq_s.elements(), 0.0f); + auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data}); + auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_p); + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data(batch_size, static_cast(seq_len)); + auto sql = mm->add_literal(seq_len_s, len_data); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + sql, + ih); + auto last_out = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + mm->add_return({out_hs, last_out}); + p.compile(migraphx::ref::target{}); + + auto outputs = p.eval({}); + + auto arg_hs = outputs.front(); + auto arg_last_output = outputs.back(); + std::vector last_output_data; + std::vector hs_data; + arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); }); + arg_last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); + + std::vector hs_data_gold{ + 0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236, 0.13104209, + -0.18736027, 0.03445704, 0.19167931, -0.3946827, -0.30889652, -0.22276389, 0.44193283, + -0.16477929, -0.11893477, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0}; + + std::vector last_output_data_gold{0.03445704, + 0.19167931, + -0.3946827, + -0.30889652, + -0.22276389, + 0.44193283, + -0.16477929, + -0.11893477}; + EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data{2, 1}; + auto sql = mm->add_literal(seq_len_s, len_data); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + sql, + ih); + auto last_out = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + mm->add_return({out_hs, last_out}); + p.compile(migraphx::ref::target{}); + + auto outputs = p.eval({}); + + auto arg_hs = outputs.front(); + auto arg_last_output = outputs.back(); + std::vector last_output_data; + std::vector hs_data; + arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); }); + arg_last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); + + std::vector hs_data_gold{0.377808, + 0.610551, + 0.551685, + -0.588848, + -0.371446, + 0.317082, + 0.131042, + -0.18736, + 0.034457, + 0.191679, + -0.394683, + -0.308897, + 0, + 0, + 0, + 0}; + std::vector last_output_data_gold{ + 0.034457, 0.191679, -0.394683, -0.308897, -0.371446, 0.317082, 0.131042, -0.18736}; + EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // 3 args + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + p.compile(migraphx::ref::target{}); + + auto last_output = p.eval({}).back(); + std::vector last_output_data; + last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); + + std::vector last_output_data_gold{ + 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.}; + EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); + } + + // seq_len = 1 + { + seq_len = 1; + std::vector input_1(seq_len * batch_size * input_size, 0); + input_1[0] = input_1[1] = 1.0; + migraphx::shape in_shape_1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape_1, input_1}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{0.37780784, + 0.61055139, + 0.55168478, + -0.5888475, + -0.37144644, + 0.31708236, + 0.13104209, + -0.18736027}; + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } +} + +TEST_CASE(rnn_reverse) +{ + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + std::vector w_data{-0.0296, + -0.1341, + 0.1761, + -0.2325, + -0.0717, + 0.1852, + 0.2720, + 0.1471, + -0.1097, + 0.3363, + -0.0587, + -0.2302}; + std::vector r_data{0.2528, + -0.2333, + 0.3973, + 0.1593, + -0.0388, + 0.1702, + 0.3829, + -0.0712, + -0.1668, + 0.3074, + -0.2854, + 0.4049, + -0.3737, + -0.1051, + 0.4482, + -0.2841}; + std::vector bias_data{-0.3188, 0.1341, -0.4446, 0.1389, 0.3117, 0.3664, 0.2352, 0.2552}; + std::vector input(seq_len * batch_size * input_size, 0); + input[0] = input[1] = 1.0; + std::vector ih_data(num_dirct * batch_size * hidden_size, 0); + float clip = 0.0f; + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + // concatenation of hidden states as program output + { + + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{-0.29385301, + 0.16796815, + 0.51075965, + 0.40258689, + -0.13818839, + 0.44124447, + 0.14365635, + 0.14803654, + -0.0070999, + 0.46251031, + -0.20639211, + 0.37488942, + -0.0070999, + 0.46251031, + -0.20639211, + 0.37488942}; + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // rnn last output as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + p.compile(migraphx::ref::target{}); + + auto last_output = p.eval({}).back(); + std::vector last_output_data; + last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); + + std::vector last_output_data_gold{-0.29385301, + 0.16796815, + 0.51075965, + 0.40258689, + -0.13818839, + 0.44124447, + 0.14365635, + 0.14803654}; + EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); + } + + // rnn hidden states and last hidden state output as program outputs + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}}; + std::vector pad_data(pad_seq_s.elements(), 0.0f); + auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data}); + auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_p); + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data(batch_size, static_cast(seq_len)); + auto sql = mm->add_literal(seq_len_s, len_data); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r, + bias, + sql, + ih); + + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + mm->add_return({out_hs, lho}); + p.compile(migraphx::ref::target{}); + + auto outputs = p.eval({}); + std::vector hs_data; + std::vector last_output_data; + auto arg_hs = outputs.front(); + arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); }); + + auto arg_lho = outputs.back(); + arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); + + std::vector hs_data_gold{ + -0.29385301, 0.16796815, 0.51075965, 0.40258689, -0.13818839, 0.44124447, 0.14365635, + 0.14803654, -0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031, + -0.20639211, 0.37488942, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0}; + + std::vector last_output_data_gold{-0.29385301, + 0.16796815, + 0.51075965, + 0.40258689, + -0.13818839, + 0.44124447, + 0.14365635, + 0.14803654}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); + } + + // rnn hidden states and last hidden state output as program outputs + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data{2, 1}; + auto sql = mm->add_literal(seq_len_s, len_data); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r, + bias, + sql, + ih); + + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + mm->add_return({out_hs, lho}); + p.compile(migraphx::ref::target{}); + + auto outputs = p.eval({}); + std::vector hs_data; + std::vector last_output_data; + auto arg_hs = outputs.front(); + arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); }); + + auto arg_lho = outputs.back(); + arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); + std::vector hs_data_gold{-0.293853, + 0.167968, + 0.51076, + 0.402587, + -0.0070999, + 0.46251, + -0.206392, + 0.374889, + -0.0070999, + 0.46251, + -0.206392, + 0.374889, + 0, + 0, + 0, + 0}; + std::vector last_output_data_gold{ + -0.293853, 0.167968, 0.51076, 0.402587, -0.0070999, 0.46251, -0.206392, 0.374889}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); + } +} + +TEST_CASE(rnn_bidirectional) +{ + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + std::vector w_data{0.4691, 0.3185, -0.2227, 0.4423, -0.0609, -0.2803, + 0.1744, 0.3146, 0.4049, -0.3973, -0.0890, -0.1636, + -0.0296, -0.1341, 0.1761, -0.2325, -0.0717, 0.1852, + 0.2720, 0.1471, -0.1097, 0.3363, -0.0587, -0.2302}; + + std::vector r_data{-0.0456, 0.1061, 0.1574, -0.4928, -0.4300, -0.1909, -0.0225, + -0.2668, 0.1840, -0.4453, -0.4896, 0.1302, -0.0929, 0.3545, + -0.4981, 0.0616, 0.2528, -0.2333, 0.3973, 0.1593, -0.0388, + 0.1702, 0.3829, -0.0712, -0.1668, 0.3074, -0.2854, 0.4049, + -0.3737, -0.1051, 0.4482, -0.2841}; + + std::vector bias_data{-0.4938, + 0.4355, + -0.3186, + 0.2094, + 0.1037, + -0.1071, + 0.4504, + -0.3990, + -0.3188, + 0.1341, + -0.4446, + 0.1389, + 0.3117, + 0.3664, + 0.2352, + 0.2552}; + + std::vector input(seq_len * batch_size * input_size, 0); + input[0] = input[1] = 1.0; + std::vector ih_data(num_dirct * batch_size * hidden_size, 0); + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + float clip = 0.0f; + // concatenation of hidden state and last hs output for program outputs + { + + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto out_hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + mm->add_return({out_hs, lho}); + p.compile(migraphx::ref::target{}); + + auto outputs = p.eval({}); + auto arg_hs = outputs.front(); + auto arg_lho = outputs.back(); + + std::vector hs_data; + arg_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + std::vector last_output_data; + arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); + + std::vector hs_data_gold{ + 0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236, + 0.13104209, -0.18736027, -0.29385301, 0.16796815, 0.51075965, 0.40258689, + -0.13818839, 0.44124447, 0.14365635, 0.14803654, 0.03445704, 0.19167931, + -0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477, + -0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031, + -0.20639211, 0.37488942}; + + std::vector last_output_data_gold{0.03445704, + 0.19167931, + -0.3946827, + -0.30889652, + -0.22276389, + 0.44193283, + -0.16477929, + -0.11893477, + -0.29385301, + 0.16796815, + 0.51075965, + 0.40258689, + -0.13818839, + 0.44124447, + 0.14365635, + 0.14803654}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); + } + + // last rnn output for program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data{1, 2}; + auto sql = mm->add_literal(seq_len_s, len_data); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + sql, + ih); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + mm->add_return({out_hs, lho}); + p.compile(migraphx::ref::target{}); + + auto outputs = p.eval({}); + auto arg_hs = outputs.front(); + auto arg_lho = outputs.back(); + + std::vector hs_data; + std::vector last_output_data; + arg_hs.visit([&](auto out) { hs_data.assign(out.begin(), out.end()); }); + arg_lho.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); + + std::vector hs_data_gold{ + 0.377808, 0.610551, 0.551685, -0.588848, -0.371446, 0.317082, 0.131042, -0.18736, + -0.169158, 0.193817, 0.206679, 0.586097, -0.138188, 0.441244, 0.143656, 0.148037, + 0, 0, 0, 0, -0.222764, 0.441933, -0.164779, -0.118935, + 0, 0, 0, 0, -0.0070999, 0.46251, -0.206392, 0.374889}; + std::vector last_output_data_gold{0.377808, + 0.610551, + 0.551685, + -0.588848, + -0.222764, + 0.441933, + -0.164779, + -0.118935, + -0.169158, + 0.193817, + 0.206679, + 0.586097, + -0.138188, + 0.441244, + 0.143656, + 0.148037}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); + } + + // 4 args + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias); + + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + p.compile(migraphx::ref::target{}); + + auto last_output = p.eval({}).back(); + std::vector last_output_data; + last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); + + std::vector last_output_data_gold{0.03445704, + 0.19167931, + -0.3946827, + -0.30889652, + -0.22276389, + 0.44193283, + -0.16477929, + -0.11893477, + -0.29385301, + 0.16796815, + 0.51075965, + 0.40258689, + -0.13818839, + 0.44124447, + 0.14365635, + 0.14803654}; + + EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); + } + + // 3 args + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + + mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r); + p.compile(migraphx::ref::target{}); + + auto last_output = p.eval({}).back(); + std::vector last_output_data; + last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); + + std::vector last_output_data_gold{ + 0.6570473, 0.36392266, 0.45342238, -0.45127486, 0., 0., 0., 0., + -0.16225325, -0.29515147, 0.39617197, 0.27068236, 0., 0., 0., 0., + 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0.}; + + EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); + } + + // concatenation of hidden state for program output + { + seq_len = 1; + std::vector input_1(seq_len * batch_size * input_size, 0); + input_1[0] = input_1[1] = 1.0; + migraphx::shape in_shape_1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape_1, input_1}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{0.37780784, + 0.61055139, + 0.55168478, + -0.5888475, + -0.37144644, + 0.31708236, + 0.13104209, + -0.18736027, + -0.16915828, + 0.1938169, + 0.20667936, + 0.58609703, + -0.0070999, + 0.46251031, + -0.20639211, + 0.37488942}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } +} + +TEST_CASE(gru_forward) +{ + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; + std::vector w_data{ + 0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418, + 0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640, + -0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498, + 0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331, + 0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905}; + + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; + std::vector r_data{ + 0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529, + -0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131, + 0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721, + -0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179, + -0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706, + -0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801, + 0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934, + 0.3645, -0.4310, -0.3480, 0.0702, -0.1558}; + + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + std::vector bias_data{ + 0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946, + -0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494, + 0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607}; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + std::vector input{-0.8432, + -0.9887, + 1.3041, + -2.6430, + -0.3306, + -0.8504, + -0.3933, + 0.5151, + -0.2951, + 0.0093, + -1.1948, + -0.1239, + 0.0373, + 1.3211, + 0.7854, + -0.4838, + -1.0536, + -0.2529}; + + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + std::vector ih_data{ + -0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; + float clip = 0.0f; + // concatenation of hidden states for output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + -0.27298412, 0.42363745, -0.09368783, 0.4823072, -0.02183238, -0.6873896, + 0.16144305, 0.31932795, 0.6104771, 0.79759157, -0.31791314, 0.5249062, + 0.08800987, 0.46404213, -0.11872687, -0.26210734, 0.34448293, -0.0176422, + 0.48523626, 0.60002893, -0.3969709, 0.43360898, 0.35775262, 0.23280787, + -0.52179873, -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // last output for output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto concat_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), concat_hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{-0.3969709, + 0.43360898, + 0.35775262, + 0.23280787, + -0.52179873, + -0.21944991, + 0.4535257, + -0.13735442, + 0.51757574, + 0.50380427}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // two rnn_last_hs_output operators after gru + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto concat_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), concat_hs); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), concat_hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{-0.3969709, + 0.43360898, + 0.35775262, + 0.23280787, + -0.52179873, + -0.21944991, + 0.4535257, + -0.13735442, + 0.51757574, + 0.50380427}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // last output for output, linear_before_reset = 0 + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto concat_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"linear_before_reset", 0}}), + seq, + w, + r, + bias, + und, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), concat_hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{-0.53291196, + 0.50160867, + 0.39010462, + 0.39292926, + -0.5960838, + -0.38451535, + 0.454239, + -0.10620412, + 0.6014447, + 0.43445644}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } +} + +TEST_CASE(gru_forward_args) +{ + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; + std::vector w_data{ + 0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418, + 0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640, + -0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498, + 0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331, + 0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905}; + + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; + std::vector r_data{ + 0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529, + -0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131, + 0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721, + -0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179, + -0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706, + -0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801, + 0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934, + 0.3645, -0.4310, -0.3480, 0.0702, -0.1558}; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + std::vector input{-0.8432, + -0.9887, + 1.3041, + -2.6430, + -0.3306, + -0.8504, + -0.3933, + 0.5151, + -0.2951, + 0.0093, + -1.1948, + -0.1239, + 0.0373, + 1.3211, + 0.7854, + -0.4838, + -1.0536, + -0.2529}; + + float clip = 0.0f; + + // 3 args + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r); + + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{-0.114674, -0.129581, -0.218156, -0.140788, -0.114242, + -0.346569, 0.321367, -0.0838253, 0.102097, 0.00232137, + -0.149055, 0.0590743, -0.0533094, -0.0446122, -0.112588, + 0.0153261, 0.168883, -0.326836, 0.0843562, 0.160872, + -0.232523, 0.00214573, 0.231693, -0.160475, -0.518952, + 0.0467166, 0.12327, -0.374162, 0.137778, 0.251976}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // 4 args (bias is used) + { + std::vector bias_data{ + 0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946, + -0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494, + 0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias); + + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{-0.273619, 0.0931375, -0.104717, 0.0203752, -0.0797887, + -0.493948, 0.472118, -0.0336318, 0.332706, 0.0182268, + -0.341684, 0.38063, 0.0589275, 0.2644, -0.115737, + -0.152324, 0.442277, -0.201626, 0.408909, 0.12905, + -0.416866, 0.377186, 0.32922, 0.162214, -0.519973, + -0.140072, 0.465076, -0.229563, 0.500164, 0.195166}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // 4 args (ih is used) + { + std::vector ih_data{ + -0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + und, + und, + ih); + + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{-0.0801064, 0.27025, -0.20704, 0.333579, -0.0452438, + -0.56265, 0.061061, 0.262172, 0.405193, 0.775226, + -0.100683, 0.258729, -0.0187297, 0.215815, -0.108936, + -0.0941018, 0.129665, -0.159421, 0.190636, 0.597412, + -0.197, 0.0885705, 0.269396, -0.0414511, -0.515137, + -0.03075, 0.158326, -0.296488, 0.177983, 0.519498}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } +} + +TEST_CASE(gru_forward_actv_funcs) +{ + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; + std::vector w_data{ + 0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418, + 0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640, + -0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498, + 0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331, + 0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905}; + + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; + std::vector r_data{ + 0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529, + -0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131, + 0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721, + -0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179, + -0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706, + -0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801, + 0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934, + 0.3645, -0.4310, -0.3480, 0.0702, -0.1558}; + + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + std::vector bias_data{ + 0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946, + -0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494, + 0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607}; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + std::vector input{-0.8432, + -0.9887, + 1.3041, + -2.6430, + -0.3306, + -0.8504, + -0.3933, + 0.5151, + -0.2951, + 0.0093, + -1.1948, + -0.1239, + 0.0373, + 1.3211, + 0.7854, + -0.4838, + -1.0536, + -0.2529}; + + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + std::vector ih_data{ + -0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; + float clip = 0.0f; + + // no activation function specified, so default is used. + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto concat_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), concat_hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{-0.3969709, + 0.43360898, + 0.35775262, + 0.23280787, + -0.52179873, + -0.21944991, + 0.4535257, + -0.13735442, + 0.51757574, + 0.50380427}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // 1 activation function (sigmoid) specified + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{0.26905832, 0.5669211, 0.20464146, 0.67195725, 0.24752215, + 0.11411376, 0.12353572, 0.4245067, 0.73908687, 0.8644615, + 0.34754312, 0.61424744, 0.36769435, 0.6499579, 0.3168031, + 0.3296533, 0.3055136, 0.42514813, 0.6851256, 0.7967266, + 0.35652235, 0.6033026, 0.52634895, 0.5815402, 0.3001663, + 0.39814138, 0.4354002, 0.4310627, 0.6708563, 0.7509278}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // 1 activation function (tanh) specified + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto concat_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), concat_hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{-0.49333298, + -0.06104589, + 0.5629142, + -0.97955984, + -0.9314696, + -0.03033514, + 0.5280315, + -0.27354342, + 0.65615714, + 0.53612584}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // seq length of 1 + { + migraphx::program p; + auto* mm = p.get_main_module(); + seq_len = 1; + migraphx::shape in_shape_one{migraphx::shape::float_type, + {seq_len, batch_size, input_size}}; + std::vector input_one{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504}; + auto seq = mm->add_literal(migraphx::literal{in_shape_one, input_one}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{-0.27298412, + 0.42363745, + -0.09368783, + 0.4823072, + -0.02183238, + -0.6873896, + 0.16144305, + 0.31932795, + 0.6104771, + 0.79759157}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } +} + +TEST_CASE(gru_reverse) +{ + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; + std::vector w_data{ + 0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418, + 0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640, + -0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498, + 0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331, + 0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905}; + + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; + std::vector r_data{ + 0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529, + -0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131, + 0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721, + -0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179, + -0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706, + -0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801, + 0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934, + 0.3645, -0.4310, -0.3480, 0.0702, -0.1558}; + + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + std::vector bias_data{ + 0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946, + -0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494, + 0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607}; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + std::vector input{-0.8432, + -0.9887, + 1.3041, + -2.6430, + -0.3306, + -0.8504, + -0.3933, + 0.5151, + -0.2951, + 0.0093, + -1.1948, + -0.1239, + 0.0373, + 1.3211, + 0.7854, + -0.4838, + -1.0536, + -0.2529}; + + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + std::vector ih_data{ + -0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; + float clip = 0.0f; + + // concatenation of hidden states and last hs output for outputs + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_return({lho, hs}); + p.compile(migraphx::ref::target{}); + auto outputs = p.eval({}); + + auto res_hs = outputs.back(); + auto res_lho = outputs.front(); + std::vector hs_data; + std::vector lho_data; + res_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{-0.263403, 0.317655, -0.00634162, 0.200443, -0.349125, + -0.600874, 0.542386, -0.0856531, 0.55703, 0.54711, + -0.276245, 0.521348, 0.302874, 0.394353, -0.334369, + -0.187861, 0.213553, -0.0708377, 0.545435, 0.654301, + -0.329512, 0.476095, 0.284044, 0.392077, -0.369226, + -0.3275, -0.027301, 0.143774, 0.655686, 0.782831}; + std::vector lho_data_gold{-0.263403, + 0.317655, + -0.00634162, + 0.200443, + -0.349125, + -0.600874, + 0.542386, + -0.0856531, + 0.55703, + 0.54711}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + EXPECT(migraphx::verify_range(lho_data, lho_data_gold)); + } + + // variable input sequence length + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data{1, 2}; + auto sql = mm->add_literal(seq_len_s, len_data); + + auto hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + sql, + ih); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_return({lho, hs}); + p.compile(migraphx::ref::target{}); + auto outputs = p.eval({}); + + auto res_hs = outputs.back(); + auto res_lho = outputs.front(); + std::vector hs_data; + std::vector lho_data; + res_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + -0.272984, 0.423637, -0.0936878, 0.482307, -0.0218324, -0.630874, 0.401448, 0.0488417, + 0.558397, 0.664423, 0, 0, 0, 0, 0, -0.238202, + -0.0752721, 0.0919409, 0.669654, 0.782363, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0}; + std::vector lho_data_gold{-0.272984, + 0.423637, + -0.0936878, + 0.482307, + -0.0218324, + -0.630874, + 0.401448, + 0.0488417, + 0.558397, + 0.664423}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + EXPECT(migraphx::verify_range(lho_data, lho_data_gold)); + } + + // last output for output, linear_before_reset = 0 + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto concat_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"linear_before_reset", 0}}), + seq, + w, + r, + bias, + und, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), concat_hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{-0.388654, + 0.384975, + 0.0179455, + 0.350101, + -0.456872, + -0.690085, + 0.534512, + -0.0558191, + 0.646604, + 0.463943}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // no activation function specified, so default is used. + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{-0.263403, 0.317655, -0.00634162, 0.200443, -0.349125, + -0.600874, 0.542386, -0.0856531, 0.55703, 0.54711, + -0.276245, 0.521348, 0.302874, 0.394353, -0.334369, + -0.187861, 0.213553, -0.0708377, 0.545435, 0.654301, + -0.329512, 0.476095, 0.284044, 0.392077, -0.369226, + -0.3275, -0.027301, 0.143774, 0.655686, 0.782831}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // seq length of 1 + { + migraphx::program p; + auto* mm = p.get_main_module(); + seq_len = 1; + migraphx::shape in_shape_one{migraphx::shape::float_type, + {seq_len, batch_size, input_size}}; + std::vector input_one{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504}; + auto seq = mm->add_literal(migraphx::literal{in_shape_one, input_one}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{-0.272984, + 0.423637, + -0.0936878, + 0.482307, + -0.0218324, + -0.68739, + 0.161443, + 0.319328, + 0.610477, + 0.797592}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } +} + +TEST_CASE(gru_bidirectional) +{ + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; + std::vector w_data{ + 0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418, + 0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109, + 0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732, + 0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294, + 0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778, + + -0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353, + 0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408, + -0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714, + -0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996, + -0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279}; + + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; + std::vector r_data{ + -0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063, + 0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194, + -0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082, + 0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609, + 0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339, + -0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534, + -0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305, + 0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440, + 0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074, + 0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677, + -0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618, + -0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997, + 0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027, + 0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955, + -0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599}; + + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + std::vector bias_data{ + -0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363, + -0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817, + -0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416, + 0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317, + 0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377, + 0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124}; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + std::vector input{-0.8432, + -0.9887, + 1.3041, + -2.6430, + -0.3306, + -0.8504, + -0.3933, + 0.5151, + -0.2951, + 0.0093, + -1.1948, + -0.1239, + 0.0373, + 1.3211, + 0.7854, + -0.4838, + -1.0536, + -0.2529}; + + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + std::vector ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, + 0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340, + 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; + + float clip = 0.0f; + + // concatenation of hidden states and last hs output for outputs + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_return({hs, lho}); + p.compile(migraphx::ref::target{}); + auto outputs = p.eval({}); + auto hs_concat = outputs.front(); + auto res_lho = outputs.back(); + std::vector hs_data; + std::vector lho_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + 0.0352243, 0.0146756, 0.00570925, 0.152446, 0.208683, 0.214342, -0.0454273, + -0.135177, -0.0800739, 0.903659, 0.0248217, 0.435231, -0.144448, 0.101531, + -0.111305, 0.381317, 0.468983, 0.230557, 0.348021, 0.180229, -0.0930435, + 0.174108, -0.063834, 0.0909285, 0.22759, -0.221983, -0.139656, -0.0938906, + -0.247681, 0.69647, -0.159396, 0.299061, -0.116652, 0.238649, 0.109945, + 0.192866, 0.307073, 0.191113, 0.658287, -0.0340374, -0.0959787, 0.0794681, + 0.241526, 0.321104, 0.00693533, -0.311839, -0.12802, -0.16643, -0.393849, + 0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665, + 0.079043, 0.322652, 0.752701, 0.243775}; + std::vector lho_data_gold{-0.0959787, 0.0794681, 0.241526, 0.321104, 0.00693533, + -0.311839, -0.12802, -0.16643, -0.393849, 0.648851, + 0.0248217, 0.435231, -0.144448, 0.101531, -0.111305, + 0.381317, 0.468983, 0.230557, 0.348021, 0.180229}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + EXPECT(migraphx::verify_range(lho_data, lho_data_gold)); + } + + // same input sequence length, but shorter than max squence length + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}}; + std::vector pad_data(pad_seq_s.elements(), 0.0f); + auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data}); + auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_p); + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data(batch_size, static_cast(seq_len)); + auto sql = mm->add_literal(seq_len_s, len_data); + + auto concat_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + sql, + ih); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), concat_hs); + mm->add_return({concat_hs, lho}); + p.compile(migraphx::ref::target{}); + auto outputs = p.eval({}); + auto hs_concat = outputs.front(); + auto res_lho = outputs.back(); + std::vector hs_data; + std::vector lho_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + 0.0352244, 0.0146756, 0.00570924, 0.152446, 0.208683, 0.214342, -0.0454273, + -0.135177, -0.0800739, 0.903659, 0.0248217, 0.435231, -0.144448, 0.101531, + -0.111305, 0.381317, 0.468983, 0.230557, 0.348021, 0.180229, -0.0930435, + 0.174108, -0.063834, 0.0909285, 0.22759, -0.221983, -0.139656, -0.0938906, + -0.247681, 0.69647, -0.159396, 0.299061, -0.116652, 0.238649, 0.109945, + 0.192866, 0.307073, 0.191113, 0.658287, -0.0340374, -0.0959787, 0.0794681, + 0.241526, 0.321104, 0.00693531, -0.311839, -0.12802, -0.16643, -0.393849, + 0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665, + 0.079043, 0.322652, 0.752701, 0.243775, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0}; + std::vector lho_data_gold{-0.0959787, 0.0794681, 0.241526, 0.321104, 0.00693531, + -0.311839, -0.12802, -0.16643, -0.393849, 0.648851, + 0.0248217, 0.435231, -0.144448, 0.101531, -0.111305, + 0.381317, 0.468983, 0.230557, 0.348021, 0.180229}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + EXPECT(migraphx::verify_range(lho_data, lho_data_gold)); + } + + // variable input sequence lengths + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data{1, 2}; + auto sql = mm->add_literal(seq_len_s, len_data); + + auto concat_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + sql, + ih); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), concat_hs); + mm->add_return({concat_hs, lho}); + p.compile(migraphx::ref::target{}); + auto outputs = p.eval({}); + auto hs_concat = outputs.front(); + auto res_lho = outputs.back(); + std::vector hs_data; + std::vector lho_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + 0.0352244, 0.0146756, 0.00570924, 0.152446, 0.208683, 0.214342, -0.0454273, + -0.135177, -0.0800739, 0.903659, -0.0271321, 0.624762, -0.117084, 0.509115, + -0.0175078, 0.182457, 0.304506, 0.313825, 0.397697, 0.300873, 0, + 0, 0, 0, 0, -0.221983, -0.139656, -0.0938906, + -0.247681, 0.69647, 0, 0, 0, 0, 0, + -0.059911, 0.0552807, 0.306764, 0.794409, 0.194492, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0}; + std::vector lho_data_gold{0.0352244, 0.0146756, 0.00570924, 0.152446, 0.208683, + -0.221983, -0.139656, -0.0938906, -0.247681, 0.69647, + -0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078, + 0.182457, 0.304506, 0.313825, 0.397697, 0.300873}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + EXPECT(migraphx::verify_range(lho_data, lho_data_gold)); + } + + // last output for output, linear_before_reset = 0 + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto concat_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"linear_before_reset", 0}}), + seq, + w, + r, + bias, + und, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), concat_hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + -0.09280921, 0.18506107, 0.32247013, 0.17034212, -0.00115255, -0.29865006, -0.04513004, + -0.10688055, -0.4767866, 0.6317833, 0.00286336, 0.53692746, -0.00617076, 0.04564289, + -0.18030001, 0.39584228, 0.53879917, 0.384983, 0.2759448, 0.11611474}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } +} + +TEST_CASE(gru_bidirectional_args) +{ + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; + std::vector w_data{ + 0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418, + 0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109, + 0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732, + 0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294, + 0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778, + + -0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353, + 0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408, + -0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714, + -0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996, + -0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279}; + + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; + std::vector r_data{ + -0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063, + 0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194, + -0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082, + 0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609, + 0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339, + -0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534, + -0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305, + 0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440, + 0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074, + 0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677, + -0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618, + -0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997, + 0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027, + 0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955, + -0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599}; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + std::vector input{-0.8432, + -0.9887, + 1.3041, + -2.6430, + -0.3306, + -0.8504, + -0.3933, + 0.5151, + -0.2951, + 0.0093, + -1.1948, + -0.1239, + 0.0373, + 1.3211, + 0.7854, + -0.4838, + -1.0536, + -0.2529}; + + float clip = 0.0f; + + // 3 args + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"linear_before_reset", 0}}), + seq, + w, + r); + + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + 0.0863793, -0.227845, 0.0283059, -0.258645, 0.14187, 0.43541, 0.190748, + -0.530196, -0.440444, 0.293767, 0.0402142, 0.0788687, -0.013, -0.233298, + -0.0739615, 0.467104, 0.446285, 0.306097, 0.125636, 0.272524, 0.0949838, + 0.0522264, -0.0872712, -0.084203, 0.140013, 0.12739, -0.0111171, -0.431119, + -0.468382, 0.388067, -0.109174, -0.119064, -0.0242958, -0.180555, 0.118983, + 0.341578, 0.275472, 0.0853083, 0.332205, -0.0498387, 0.140338, 0.0319435, + 0.247019, 0.275848, -0.158223, 0.0495464, -0.0681034, -0.418158, -0.523234, + 0.469122, -0.306578, -0.221095, -0.106449, -0.248934, -0.00682121, 0.288407, + 0.198708, 0.0695644, 0.211621, 0.00246037}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // 4 args (bias is used) + { + std::vector bias_data{ + -0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, + 0.2363, -0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, + 0.1517, 0.1817, -0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, + 0.2187, -0.2990, -0.0416, 0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, + -0.2990, 0.2167, 0.1395, 0.2317, 0.1318, 0.1909, -0.3615, 0.1953, -0.2582, + -0.2217, 0.3723, 0.1458, 0.2630, -0.0377, 0.1754, 0.0800, -0.3964, -0.3247, + 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias); + + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + -0.156667, -0.248473, 0.0255282, -0.24566, 0.211589, 0.192707, 0.253025, + -0.515283, -0.414174, 0.227127, 0.124773, 0.284532, -0.203929, -0.120517, + -0.2794, 0.547635, 0.518549, 0.0447674, 0.258461, 0.0502881, -0.219516, + 0.0927382, -0.0760062, -0.0906231, 0.237615, -0.215638, 0.0128074, -0.425813, + -0.433378, 0.375383, -0.0381738, 0.117793, -0.180851, -0.0841245, -0.116649, + 0.419469, 0.393515, -0.076395, 0.427436, -0.264071, -0.185829, 0.0483585, + 0.242955, 0.25233, 0.0148512, -0.304127, -0.0616653, -0.411568, -0.491748, + 0.476508, -0.313413, -0.0361821, -0.173037, -0.235731, -0.163113, 0.349008, + 0.248674, -0.0295413, 0.291437, -0.165005}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // 4 args (ih is used) + { + std::vector ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, + 0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340, + 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + und, + und, + ih); + + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + 0.248571, 0.0982155, 0.00808877, 0.0986508, 0.0969705, 0.434692, -0.141696, + -0.164271, -0.121157, 0.863222, -0.0718357, 0.137711, 0.109221, -0.00207995, + 0.0331223, 0.262705, 0.346587, 0.457158, 0.240744, 0.404261, 0.222779, + 0.179757, -0.0845316, 0.0690347, 0.10204, 0.100155, -0.190286, -0.122062, + -0.274379, 0.547281, -0.226753, -0.0397069, 0.120404, 0.171299, 0.259989, + 0.0864604, 0.111322, 0.331784, 0.604653, 0.181017, 0.237426, 0.0911999, + 0.233106, 0.32996, -0.17175, 0.0190231, -0.154805, -0.205631, -0.405354, + 0.519054, -0.380409, -0.0350301, -0.00633752, 0.403791, 0.181883, -0.0977917, + -0.0339407, 0.413089, 0.721238, 0.431879}; + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } +} + +TEST_CASE(gru_bidirectional_actv_funcs) +{ + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; + std::vector w_data{ + 0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418, + 0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109, + 0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732, + 0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294, + 0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778, + + -0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353, + 0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408, + -0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714, + -0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996, + -0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279}; + + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; + std::vector r_data{ + -0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063, + 0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194, + -0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082, + 0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609, + 0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339, + -0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534, + -0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305, + 0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440, + 0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074, + 0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677, + -0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618, + -0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997, + 0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027, + 0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955, + -0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599}; + + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + std::vector bias_data{ + -0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363, + -0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817, + -0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416, + 0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317, + 0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377, + 0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124}; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + std::vector input{-0.8432, + -0.9887, + 1.3041, + -2.6430, + -0.3306, + -0.8504, + -0.3933, + 0.5151, + -0.2951, + 0.0093, + -1.1948, + -0.1239, + 0.0373, + 1.3211, + 0.7854, + -0.4838, + -1.0536, + -0.2529}; + + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + std::vector ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, + 0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340, + 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; + + float clip = 0.0f; + + // no activation function specified, so default is used. + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto concat_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), concat_hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{-0.0959787, 0.0794681, 0.241526, 0.321104, 0.00693533, + -0.311839, -0.12802, -0.16643, -0.393849, 0.648851, + 0.0248217, 0.435231, -0.144448, 0.101531, -0.111305, + 0.381317, 0.468983, 0.230557, 0.348021, 0.180229}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // 1 activation function (sigmoid) specified + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"linear_before_reset", 0}}), + seq, + w, + r, + bias, + und, + ih); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + 0.325495, 0.469214, 0.164517, 0.585327, 0.328398, 0.457928, 0.065011, 0.35986, + 0.545029, 0.859425, 0.427923, 0.667133, 0.41591, 0.540971, 0.365475, 0.482058, + 0.565495, 0.556993, 0.607649, 0.543627, 0.428915, 0.537405, 0.306046, 0.518399, + 0.403561, 0.410694, 0.301163, 0.407397, 0.471334, 0.726446, 0.309389, 0.612072, + 0.360619, 0.590861, 0.366545, 0.367001, 0.433829, 0.501275, 0.72481, 0.512745, + 0.463795, 0.539649, 0.487682, 0.554471, 0.395916, 0.430744, 0.415923, 0.424275, + 0.409655, 0.698256, 0.126883, 0.554374, 0.216137, 0.671491, 0.263833, 0.0678646, + 0.132732, 0.477083, 0.802206, 0.626802}; + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // 1 activation function (tanh) specified + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + 0.0919632, -0.398302, -0.0267752, -0.326771, 0.401983, 0.949841, 0.557779, + -0.745259, -1.52726, 0.946066, 0.330446, 0.301982, -0.443763, -0.0655817, + -0.326473, 0.861394, 0.560799, -0.101768, 0.145142, 0.128956, -0.329758, + 0.458253, -0.339208, 0.289109, 0.36728, -1.09574, -0.181394, -0.575781, + -0.823083, 0.804262, -0.0965933, 0.20405, -0.430215, 0.00884668, 0.0716857, + 0.844222, 0.516472, -0.191571, 0.596968, -0.545405, -0.336693, -0.0280516, + 0.339058, 1.00367, 0.12655, -0.0984504, -0.174945, -0.5365, 0.183188, + 0.66716, -0.704461, -0.393346, -0.627123, 0.210395, 0.0563026, 0.31419, + 0.759629, 0.000258222, 0.350835, -0.682684}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // 3 activation functions specified + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto concat_hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), concat_hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{0.351019, 0.474363, 0.570719, 0.717703, 0.468843, + 1.15142, 0.457633, 0.300962, 0.361245, 0.666199, + 0.330446, 0.301982, -0.443763, -0.0655817, -0.326473, + 0.861394, 0.560799, -0.101768, 0.145142, 0.128956}; + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // 4 activation functions all specified + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + 0.0352243, 0.0146756, 0.00570925, 0.152446, 0.208683, 0.214342, -0.0454273, + -0.135177, -0.0800739, 0.903659, 0.0248217, 0.435231, -0.144448, 0.101531, + -0.111305, 0.381317, 0.468983, 0.230557, 0.348021, 0.180229, -0.0930435, + 0.174108, -0.063834, 0.0909285, 0.22759, -0.221983, -0.139656, -0.0938906, + -0.247681, 0.69647, -0.159396, 0.299061, -0.116652, 0.238649, 0.109945, + 0.192866, 0.307073, 0.191113, 0.658287, -0.0340374, -0.0959787, 0.0794681, + 0.241526, 0.321104, 0.00693533, -0.311839, -0.12802, -0.16643, -0.393849, + 0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665, + 0.079043, 0.322652, 0.752701, 0.243775}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } +} + +TEST_CASE(gru_bidirectional_seq_1) +{ + std::size_t batch_size = 2; + std::size_t seq_len = 1; + std::size_t hidden_size = 5; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; + std::vector w_data{ + 0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418, + 0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109, + 0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732, + 0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294, + 0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778, + + -0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353, + 0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408, + -0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714, + -0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996, + -0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279}; + + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; + std::vector r_data{ + -0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063, + 0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194, + -0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082, + 0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609, + 0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339, + -0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534, + -0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305, + 0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440, + 0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074, + 0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677, + -0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618, + -0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997, + 0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027, + 0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955, + -0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599}; + + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + std::vector bias_data{ + -0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363, + -0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817, + -0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416, + 0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317, + 0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377, + 0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124}; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + std::vector input{-0.8432, + -0.9887, + 1.3041, + -2.6430, + -0.3306, + -0.8504, + -0.3933, + 0.5151, + -0.2951, + 0.0093, + -1.1948, + -0.1239, + 0.0373, + 1.3211, + 0.7854, + -0.4838, + -1.0536, + -0.2529}; + + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + std::vector ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, + 0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340, + 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; + + float clip = 0.0f; + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape_one{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + std::vector input_one{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504}; + auto seq = mm->add_literal(migraphx::literal{in_shape_one, input_one}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"linear_before_reset", 1}}), + seq, + w, + r, + bias, + und, + ih); + + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{0.0352243, 0.0146756, 0.00570925, 0.152446, 0.208683, + 0.214342, -0.0454273, -0.135177, -0.0800739, 0.903659, + -0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078, + -0.144492, -0.0115366, 0.409153, 0.487015, 0.550755}; + + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); +} + +TEST_CASE(lstm_forward) +{ + std::size_t batch_size = 3; + std::size_t seq_len = 4; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + std::vector w_data{ + 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, + 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, + -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, + -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, + -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285}; + + std::vector r_data{ + 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, + -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, + -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, + -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, + 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, + 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, + -0.2169, -0.1344, 0.3468, -0.2260}; + + std::vector bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, + -0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182, + 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807, + 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, + -0.3025, 0.3637, -0.3181, -0.4655}; + + std::vector input_data{ + -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, + -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, + 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, + 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; + + std::vector ih_data{1.9104, + -1.9004, + 0.3337, + 0.5741, + 0.5671, + 0.0458, + 0.4514, + -0.8968, + -0.9201, + 0.1962, + 0.5771, + -0.5332}; + + std::vector ic_data{0.9569, + -0.5981, + 1.1312, + 1.0945, + 1.1055, + -0.1212, + -0.9097, + 0.7831, + -1.6991, + -1.9498, + -1.2567, + -0.4114}; + + std::vector pph_data{1.84369764, + 0.68413646, + -0.44892886, + -1.50904413, + 0.3860796, + -0.52186625, + 1.08474445, + -1.80867321, + 1.32594529, + 0.4336262, + -0.83699064, + 0.49162736}; + + float clip = 0.0f; + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + + // forward, hidden state concatenation as output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + und); + p.compile(migraphx::ref::target{}); + + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + 0.0417273, -0.272355, 0.206765, 0.223879, 0.138193, -0.0322939, -0.0891815, + 0.15773, 0.19139, -0.127708, -0.409371, -0.136186, 0.0742487, -0.0800085, + 0.259897, 0.0670196, 0.184266, 0.0610048, -0.138041, 0.0963885, 0.0213755, + -0.146027, -0.0324509, -0.0620429, -0.00532985, 0.0440265, 0.29654, -0.0463156, + 0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434, + 0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113, + 0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607}; + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // forward, last_output as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + p.compile(migraphx::ref::target{}); + + auto last_hs = p.eval({}).back(); + std::vector output_data; + last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + + std::vector output_data_gold{-0.0847427, + 0.0874114, + 0.304256, + -0.0585745, + -0.0223018, + 0.131113, + 0.135643, + -0.0566208, + 0.142701, + 0.0342236, + -0.198664, + 0.0702607}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // forward, last_cell_output as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + und); + mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + p.compile(migraphx::ref::target{}); + + auto last_hs = p.eval({}).back(); + std::vector output_data; + last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + + std::vector output_data_gold{-0.111454, + 0.247794, + 0.471087, + -0.220574, + -0.048196, + 0.263184, + 0.283258, + -0.14882, + 0.605585, + 0.078598, + -0.64457, + 0.119811}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } +} + +TEST_CASE(lstm_forward_more) +{ + std::size_t batch_size = 3; + std::size_t seq_len = 4; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + std::vector w_data{ + 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, + 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, + -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, + -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, + -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285}; + + std::vector r_data{ + 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, + -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, + -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, + -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, + 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, + 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, + -0.2169, -0.1344, 0.3468, -0.2260}; + + std::vector bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, + -0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182, + 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807, + 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, + -0.3025, 0.3637, -0.3181, -0.4655}; + + std::vector input_data{ + -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, + -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, + 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, + 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; + + std::vector ih_data{1.9104, + -1.9004, + 0.3337, + 0.5741, + 0.5671, + 0.0458, + 0.4514, + -0.8968, + -0.9201, + 0.1962, + 0.5771, + -0.5332}; + + std::vector ic_data{0.9569, + -0.5981, + 1.1312, + 1.0945, + 1.1055, + -0.1212, + -0.9097, + 0.7831, + -1.6991, + -1.9498, + -1.2567, + -0.4114}; + + std::vector pph_data{1.84369764, + 0.68413646, + -0.44892886, + -1.50904413, + 0.3860796, + -0.52186625, + 1.08474445, + -1.80867321, + 1.32594529, + 0.4336262, + -0.83699064, + 0.49162736}; + + float clip = 0.0f; + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + + // forward, 3 args + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + p.compile(migraphx::ref::target{}); + + auto last_hs = p.eval({}).back(); + std::vector output_data; + last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + + std::vector output_data_gold{ + -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361, + 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.0786602, -0.0613048, + 0.179592, -0.071286, 0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, + -0.0552699, 0.0252681, -0.0562072, -0.102509, -0.0372696, 0.252296, -0.144544, + 0.00496085, 0.0662588, -0.048577, -0.187329, 0.0855831, -0.0171894, -0.140202, + 0.0828391, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, + 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // forward, 8 args + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + p.compile(migraphx::ref::target{}); + + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + 0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337, + 0.104168, 0.022618, -0.121195, -0.4065, -0.252054, 0.186991, -0.0624168, + 0.205513, 0.0836373, 0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, + -0.0890598, -0.135266, -0.0413375, 0.0459032, 0.0414126, 0.272303, 0.0393149, + 0.218258, 0.0944405, 0.0431211, -0.132394, 0.103489, 0.0142918, -0.123408, + 0.0401075, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, + 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723}; + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } + + // forward, last_output as program output, sequence length shorter + // than max_seq_len + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}}; + std::vector pad_data(pad_seq_s.elements(), 0.0f); + auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data}); + auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_p); + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data(batch_size, static_cast(seq_len)); + auto sql = mm->add_literal(seq_len_s, len_data); + + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + sql, + ih, + ic, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + p.compile(migraphx::ref::target{}); + + auto last_hs = p.eval({}).back(); + std::vector output_data; + last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + + std::vector output_data_gold{-0.0847427, + 0.0874114, + 0.304256, + -0.0585745, + -0.0223018, + 0.131113, + 0.135643, + -0.0566208, + 0.142701, + 0.0342236, + -0.198664, + 0.0702607}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // seq_len = 1 + { + seq_len = 1; + migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + std::vector input_data1{ + -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + p.compile(migraphx::ref::target{}); + + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{0.079753, + -0.289854, + 0.160043, + 0.115056, + 0.294074, + -0.0319677, + -0.0955337, + 0.104168, + 0.022618, + -0.121195, + -0.4065, + -0.252054}; + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + } +} + +TEST_CASE(lstm_reverse) +{ + std::size_t batch_size = 3; + std::size_t seq_len = 4; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + std::vector w_data{ + -0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, + -0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, + -0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, + 0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, + -0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; + + std::vector r_data{ + -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707, + 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430, + -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365, + 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360, + 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291, + -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910, + 0.3987, -0.1687, -0.0032, -0.1038}; + + std::vector bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, + -0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, + 0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470, + -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055, + -0.4386, 0.4208, 0.0717, 0.3789}; + + std::vector input_data{ + -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, + -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, + 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, + 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; + + std::vector ih_data{1.5289, + 1.0986, + 0.6091, + 1.6462, + 0.8720, + 0.5349, + -0.1962, + -1.7416, + -0.9912, + 1.2831, + 1.0896, + -0.6959}; + + std::vector ic_data{-0.8323, + 0.3998, + 0.1831, + 0.5938, + 2.7096, + -0.1790, + 0.0022, + -0.8040, + 0.1578, + 0.0567, + 0.8069, + -0.5141}; + + std::vector pph_data{-0.8271, + -0.5683, + 0.4562, + -1.2545, + 1.2729, + -0.4082, + -0.4392, + -0.9406, + 0.7794, + 1.8194, + -0.5811, + 0.2166}; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + float clip = 0.0f; + // reverse, concatenation of hidden states as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, + 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549, + 0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456, + 0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485, + 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353, + 0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, + 0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // reverse, sequence lengths are the same, but less than max_seq_lens + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input_data}); + + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + + migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}}; + std::vector pad_data(pad_seq_s.elements(), 0.0f); + auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data}); + auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_p); + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data(batch_size, static_cast(seq_len)); + auto sql = mm->add_literal(seq_len_s, len_data); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + sql, + ih, + ic, + pph); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, + 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549, + 0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456, + 0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485, + 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353, + 0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, + 0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // variable sequence lengths + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data{3, 2, 1}; + auto sql = mm->add_literal(seq_len_s, len_data); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + sql, + ih, + ic, + pph); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.126517, 0.0359124, 0.107453, -0.0617278, 0.911307, 0.11468, 0.114449, + 0.0196755, -0.102969, 0.295872, 0.515859, 0.246501, -0.168327, 0.00023761, + 0.167567, -0.0621982, 0.96657, 0.0755112, 0.0620917, -0.264845, 0, + 0, 0, 0, -0.204545, 0.0146403, 0.210057, 0.0296268, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // reverse, 3 args, last cell output as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{-0.443077, + -0.325425, + -0.249367, + -0.270812, + 0.122913, + 0.118537, + 0.0370199, + -0.0164687, + -0.00754759, + 0.141613, + 0.348002, + 0.667298}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // reverse, 3 args, 0 actv function + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{-0.443077, + -0.325425, + -0.249367, + -0.270812, + 0.122913, + 0.118537, + 0.0370199, + -0.0164687, + -0.00754759, + 0.141613, + 0.348002, + 0.667298}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } +} + +// lstm activation function test +TEST_CASE(lstm_reverse_actv) +{ + std::size_t batch_size = 3; + std::size_t seq_len = 4; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + std::vector w_data{ + -0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, + -0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, + -0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, + 0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, + -0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; + + std::vector r_data{ + -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707, + 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430, + -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365, + 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360, + 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291, + -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910, + 0.3987, -0.1687, -0.0032, -0.1038}; + + std::vector bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, + -0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, + 0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470, + -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055, + -0.4386, 0.4208, 0.0717, 0.3789}; + + std::vector input_data{ + -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, + -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, + 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, + 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; + + std::vector ih_data{1.5289, + 1.0986, + 0.6091, + 1.6462, + 0.8720, + 0.5349, + -0.1962, + -1.7416, + -0.9912, + 1.2831, + 1.0896, + -0.6959}; + + std::vector ic_data{-0.8323, + 0.3998, + 0.1831, + 0.5938, + 2.7096, + -0.1790, + 0.0022, + -0.8040, + 0.1578, + 0.0567, + 0.8069, + -0.5141}; + + std::vector pph_data{-0.8271, + -0.5683, + 0.4562, + -1.2545, + 1.2729, + -0.4082, + -0.4392, + -0.9406, + 0.7794, + 1.8194, + -0.5811, + 0.2166}; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + float clip = 0.0f; + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + 0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934, + 0.405483, 0.445586, 0.515814, 0.473186, 0.301937, 0.264893, 0.254353, 0.269231, + 0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213, + 0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216, + 0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634, + 0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // reverse, 3 args, 2 actv functions + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{ + migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{-0.132123, + -0.37531, + -0.12943, + -0.00798307, + -0.133882, + -0.0251383, + 0.0486486, + -0.0220606, + 0.292495, + 0.233866, + 0.48646, + 0.481844}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // reverse, 3 args, seq_len = 1, concatenation of hidden states as program output + { + seq_len = 1; + std::vector input_data1{ + -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; + migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1}); + + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{-0.104351, + -0.0471426, + -0.0905753, + 0.01506, + 0.059797, + 0.104239, + -0.0266768, + 0.0727547, + -0.146298, + 0.070535, + 0.327809, + 0.407388}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } +} + +TEST_CASE(lstm_bidirectional) +{ + std::size_t batch_size = 3; + std::size_t seq_len = 4; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + std::vector w_data{ + 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, + 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, + -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, + -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, + -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715, + -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351, + 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734, + -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346, + 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729, + 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; + + std::vector r_data{ + 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, + -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, + -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, + -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, + 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, + 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, + -0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, + 0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, + 0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, + 0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, + 0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, + -0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, + 0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038}; + + std::vector bias_data{ + 0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274, + -0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, + 0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637, + -0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823, + 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474, + -0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055, + -0.4386, 0.4208, 0.0717, 0.3789}; + + std::vector input_data{ + -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, + -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, + 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, + 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; + + std::vector ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458, + 0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332, + 1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349, + -0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959}; + + std::vector ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212, + -0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114, + -0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790, + 0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141}; + + std::vector pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796, + -0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262, + -0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562, + -1.2545, 1.2729, -0.4082, -0.4392, -0.9406, + 0.7794, 1.8194, -0.5811, 0.2166}; + float clip = 0.0f; + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + + // concatenation of hidden states as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + 0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337, + 0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.120174, 0.043157, + 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, + 0.32421, 0.344048, 0.271694, 0.186991, -0.0624168, 0.205513, 0.0836373, + 0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266, + -0.0413375, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685, + 0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, 0.0459032, + 0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394, + 0.103489, 0.0142918, -0.123408, 0.0401075, -0.182201, -0.0232277, 0.235501, + -0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, + 0.146353, 0.0789186, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, + 0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723, + -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508, + -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // last hidden state as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549, + 0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188, + 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // last cell output as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934, + 0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334, + 1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // 3 args, concatenation of hidden states as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361, + 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647, + -0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328, + 0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286, + 0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681, + -0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636, + 0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509, + -0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329, + 0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065, + -0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432, + 0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, + -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, + -0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348, + -0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // sequence length is 1, contenation of hidden state as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + seq_len = 1; + migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + std::vector input_data1{ + -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; + auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, + -0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, + -0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239, + -0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } +} + +TEST_CASE(lstm_bidirectional_var_seq_lens) +{ + std::size_t batch_size = 3; + std::size_t seq_len = 4; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + std::vector w_data{ + 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, + 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, + -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, + -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, + -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715, + -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351, + 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734, + -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346, + 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729, + 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; + + std::vector r_data{ + 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, + -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, + -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, + -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, + 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, + 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, + -0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, + 0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, + 0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, + 0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, + 0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, + -0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, + 0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038}; + + std::vector bias_data{ + 0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274, + -0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, + 0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637, + -0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823, + 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474, + -0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055, + -0.4386, 0.4208, 0.0717, 0.3789}; + + std::vector input_data{ + -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, + -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, + 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, + 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; + + std::vector ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458, + 0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332, + 1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349, + -0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959}; + + std::vector ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212, + -0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114, + -0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790, + 0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141}; + + std::vector pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796, + -0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262, + -0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562, + -1.2545, 1.2729, -0.4082, -0.4392, -0.9406, + 0.7794, 1.8194, -0.5811, 0.2166}; + + float clip = 0.0f; + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + + // concatenation of hidden states as program output + { + std::vector sl_data{1, 2, 3}; + migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}}; + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data}); + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + sql, + ih, + ic, + pph); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto lco = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); + mm->add_return({out_hs, lho, lco}); + p.compile(migraphx::ref::target{}); + + auto outputs = p.eval({}); + auto arg_hs = outputs.front(); + auto arg_lho = outputs.at(1); + auto arg_lco = outputs.at(2); + + std::vector output_data; + std::vector last_output_data; + std::vector last_cell_data; + + arg_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + arg_lho.visit([&](auto output) { last_output_data.assign(output.begin(), output.end()); }); + arg_lco.visit([&](auto output) { last_cell_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + 0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337, + 0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.141643, 0.0451978, + 0.140804, 0.0745128, 0.911307, 0.11468, 0.114449, 0.0196755, -0.262807, + 0.275286, 0.358395, 0.266267, 0, 0, 0, 0, + 0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266, + -0.0413375, 0, 0, 0, 0, 0.96657, 0.0755112, + 0.0620917, -0.264845, -0.128254, 0.125398, 0.0665142, -0.163651, 0, + 0, 0, 0, 0, 0, 0, 0, + 0.103489, 0.0142918, -0.123408, 0.0401075, 0, 0, 0, + 0, 0, 0, 0, 0, -0.0644683, 0.371512, + 0.212431, -0.116131, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0}; + std::vector last_output_data_gold{ + 0.079753, -0.289854, 0.160043, 0.115056, 0.421857, 0.0459771, -0.144955, 0.0720673, + 0.103489, 0.0142918, -0.123408, 0.0401075, -0.141643, 0.0451978, 0.140804, 0.0745128, + 0.911307, 0.11468, 0.114449, 0.0196755, -0.262807, 0.275286, 0.358395, 0.266267}; + std::vector last_cell_data_gold{ + 0.600582, -0.601197, 0.353558, 0.789097, 0.737121, 0.134902, -0.303595, 0.241948, + 0.391174, 0.0308845, -0.561745, 0.0730323, -0.326822, 0.301121, 0.219523, 0.415242, + 2.08242, 0.442513, 0.187127, 0.0577626, -0.611307, 0.55454, 0.4364, 0.509436}; + + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); + EXPECT(migraphx::verify_range(last_cell_data, last_cell_data_gold)); + } + + // last cell output as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}}; + std::vector pad_data(pad_seq_s.elements(), 0.0f); + auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data}); + auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_p); + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data(batch_size, static_cast(seq_len)); + auto sql = mm->add_literal(seq_len_s, len_data); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + sql, + ih, + ic, + pph); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + auto lco = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + mm->add_return({hs, lho, lco}); + p.compile(migraphx::ref::target{}); + auto outputs = p.eval({}); + auto res_hs = outputs.at(0); + auto res_lho = outputs.at(1); + auto res_lco = outputs.at(2); + std::vector hs_data; + std::vector lho_data; + std::vector lco_data; + res_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); }); + res_lco.visit([&](auto output) { lco_data.assign(output.begin(), output.end()); }); + std::vector hs_data_gold{ + 0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337, + 0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.120174, 0.043157, + 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, + 0.32421, 0.344048, 0.271694, 0.186991, -0.0624168, 0.205513, 0.0836373, + 0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266, + -0.0413375, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685, + 0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, 0.0459033, + 0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394, + 0.103489, 0.0142918, -0.123408, 0.0401075, -0.182201, -0.0232277, 0.235501, + -0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, + 0.146353, 0.0789186, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, + 0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723, + -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508, + -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0}; + std::vector lho_data_gold{ + -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549, + 0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188, + 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694}; + std::vector lco_data_gold{ + -0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934, + 0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334, + 1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713}; + EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); + EXPECT(migraphx::verify_range(lho_data, lho_data_gold)); + EXPECT(migraphx::verify_range(lco_data, lco_data_gold)); + } +} + +TEST_CASE(lstm_bidirectional_actv_func) +{ + std::size_t batch_size = 3; + std::size_t seq_len = 4; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + std::vector w_data{ + 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, + 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, + -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, + -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, + -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715, + -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351, + 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734, + -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346, + 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729, + 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; + + std::vector r_data{ + 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, + -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, + -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, + -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, + 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, + 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, + -0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, + 0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, + 0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, + 0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, + 0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, + -0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, + 0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038}; + + std::vector input_data{ + -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, + -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, + 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, + 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; + + float clip = 0.0f; + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; + // 3 args, 0 actv func + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361, + 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647, + -0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328, + 0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286, + 0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681, + -0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636, + 0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509, + -0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329, + 0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065, + -0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432, + 0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, + -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, + -0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348, + -0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // 3 args, 1 actv func + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + 0.227861, 0.328562, 0.277867, 0.272945, 0.204389, 0.296123, 0.223834, 0.311113, + 0.424666, 0.173974, 0.40628, 0.286631, 0.246078, 0.199709, 0.303753, 0.301178, + 0.264634, 0.304661, 0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186, + 0.339438, 0.29655, 0.331832, 0.242338, 0.409384, 0.236272, 0.306045, 0.26269, + 0.261246, 0.334357, 0.23622, 0.245288, 0.301937, 0.264893, 0.254353, 0.269231, + 0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213, + 0.374123, 0.283167, 0.377129, 0.245726, 0.444712, 0.203168, 0.411446, 0.269965, + 0.172792, 0.296224, 0.17319, 0.352547, 0.310306, 0.262902, 0.276964, 0.295002, + 0.373802, 0.366785, 0.419791, 0.393216, 0.262827, 0.371441, 0.369022, 0.298262, + 0.450186, 0.263538, 0.402895, 0.216177, 0.267257, 0.342535, 0.257797, 0.268563, + 0.193043, 0.275645, 0.167678, 0.350889, 0.334143, 0.309444, 0.174822, 0.251634, + 0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // 3 args, 2 actv func + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, + 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, + -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, + 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // 3 args, 4 actv func + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, + 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, + 0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, + 0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // 3 args, 5 actv func + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, + 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, + -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, + 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } + + // 3 args, 6 actv func + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + p.compile(migraphx::ref::target{}); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, + 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, + -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, + 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805}; + EXPECT(migraphx::verify_range(output_data, output_data_gold)); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/replace_allocate.cpp b/test/replace_allocate.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a3fd8ffe0e32883a39580867f01e084c99ced21d --- /dev/null +++ b/test/replace_allocate.cpp @@ -0,0 +1,195 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct allocate_no_out : migraphx::auto_register_op +{ + migraphx::shape s{}; + + template + static auto reflect(Self& self, F f) + { + return migraphx::pack(f(self.s, "shape")); + } + + std::string name() const { return "allocate_no_out"; } + migraphx::shape compute_shape(const std::vector& inputs) const + { + migraphx::check_shapes{inputs, *this}.has(0); + return s; + } + migraphx::argument compute(migraphx::context&, + const migraphx::shape& output_shape, + const std::vector&) const + { + return {output_shape}; + } +}; + +struct allocate_with_out : migraphx::auto_register_op +{ + migraphx::shape s{}; + + template + static auto reflect(Self& self, F f) + { + return migraphx::pack(f(self.s, "shape")); + } + + std::string name() const { return "allocate_with_out"; } + migraphx::shape compute_shape(const std::vector& inputs) const + { + migraphx::check_shapes{inputs, *this}.has(0); + return s; + } + migraphx::argument compute(migraphx::context&, + const migraphx::shape& output_shape, + const std::vector&) const + { + return {output_shape}; + } +}; + +// allocation model that has no out params +struct allocation_no_out_model +{ + std::string name() const { return "allocate_no_out"; } + migraphx::operation allocate(const migraphx::shape& s) const + { + return migraphx::make_op(name(), {{"shape", to_value(s)}}); + } + migraphx::operation preallocate(const migraphx::shape&, const std::string&) const { return {}; } + std::string copy() const { return {}; } + bool needs_out_params() const { return false; } +}; + +// allocation model with out params +struct allocation_with_out_model +{ + std::string name() const { return "allocate_with_out"; } + migraphx::operation allocate(const migraphx::shape& s) const + { + return migraphx::make_op(name(), {{"shape", to_value(s)}}); + } + migraphx::operation preallocate(const migraphx::shape&, const std::string&) const { return {}; } + std::string copy() const { return {}; } + bool needs_out_params() const { return true; } +}; + +void run_pass(migraphx::module& m, migraphx::allocation_model model, bool offload_copy = false) +{ + migraphx::run_passes(m, + {migraphx::replace_allocate{std::move(model), offload_copy}, + migraphx::dead_code_elimination{}}); +} + +void run_pass(migraphx::program& p, migraphx::allocation_model model, bool offload_copy = false) +{ + migraphx::run_passes(p, + {migraphx::replace_allocate{std::move(model), offload_copy}, + migraphx::dead_code_elimination{}}); +} + +migraphx::module create_simple_program() +{ + migraphx::module m; + migraphx::shape s{migraphx::shape::float_type, {5}}; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto alloc = + m.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + m.add_instruction(pass_op{}, alloc, x, y); + return m; +} + +TEST_CASE(allocate_no_out) +{ + migraphx::module m = create_simple_program(); + run_pass(m, allocation_no_out_model{}); + + EXPECT(std::any_of(m.begin(), m.end(), [](const migraphx::instruction& ins) { + return migraphx::contains(ins.name(), "allocate_no_out"); + })); +} + +TEST_CASE(allocate_with_out_param) +{ + migraphx::module m = create_simple_program(); + run_pass(m, allocation_with_out_model{}); + + EXPECT(std::none_of(m.begin(), m.end(), [](const migraphx::instruction& ins) { + return migraphx::contains(ins.name(), "allocate"); + })); +} + +TEST_CASE(allocate_with_out_return) +{ + migraphx::module m = create_simple_program(); + m.add_return({std::prev(m.end())}); + run_pass(m, allocation_with_out_model{}); + + EXPECT(std::none_of(m.begin(), m.end(), [](const migraphx::instruction& ins) { + return migraphx::contains(ins.name(), "allocate"); + })); +} + +TEST_CASE(allocate_with_out_no_params) +{ + migraphx::module m; + migraphx::shape s{migraphx::shape::float_type, {5}}; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto z = m.add_parameter("z", s); + auto alloc = + m.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + auto pass1 = m.add_instruction(pass_op{}, alloc, x, y); + auto alloc2 = + m.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + m.add_instruction(pass_op{}, alloc2, z, pass1); + run_pass(m, allocation_with_out_model{}); + + EXPECT(std::any_of(m.begin(), m.end(), [](const migraphx::instruction& ins) { + return migraphx::contains(ins.name(), "allocate_with_out"); + })); +} + +TEST_CASE(if_allocate) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + auto cond = mm->add_parameter("cond", cond_s); + migraphx::shape s{migraphx::shape::float_type, {5}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + + auto* then_mod = p.create_module("If_0_if"); + auto alloc = then_mod->add_instruction( + migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + auto a1 = then_mod->add_instruction(pass_op{}, alloc, x); + then_mod->add_return({a1}); + + auto* else_mod = p.create_module("If_0_else"); + auto alloc1 = else_mod->add_instruction( + migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); + auto a2 = else_mod->add_instruction(pass_op{}, alloc1, y); + else_mod->add_return({a2}); + + mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + + run_pass(p, allocation_with_out_model{}); + EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { + return migraphx::contains(ins.name(), "allocate_with_out"); + })); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/rewrite_batchnorm_test.cpp b/test/rewrite_batchnorm_test.cpp index 8c2f1502cd2b1d38a28cc5326029403b2afd71b0..f3aedab7c66da32a832f7d7f80f71cb922533d73 100644 --- a/test/rewrite_batchnorm_test.cpp +++ b/test/rewrite_batchnorm_test.cpp @@ -1,13 +1,17 @@ #include #include -#include +#include #include #include -#include +#include #include #include #include #include +#include + +#include + #include bool is_batch_norm(migraphx::instruction& ins) { return ins.name() == "batch_norm_inference"; } @@ -42,27 +46,34 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test) auto create_program = [&]() { migraphx::program p; - auto x = p.add_literal(xs, xdata); - auto w = p.add_literal(ws, wdata); - auto conv = - p.add_instruction(migraphx::op::convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, x, w); - auto scale = p.add_literal(migraphx::literal{vars, {3.0f}}); - auto bias = p.add_literal(migraphx::literal{vars, {8.1f}}); - auto mean = p.add_literal(migraphx::literal{vars, {4.0f}}); - auto variance = p.add_literal(migraphx::literal{vars, {37.11f}}); - p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance); + + auto* mm = p.get_main_module(); + auto x = mm->add_literal(xs, xdata); + auto w = mm->add_literal(ws, wdata); + auto conv = mm->add_instruction( + migraphx::make_op("convolution", + {{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}), + x, + w); + auto scale = mm->add_literal(migraphx::literal{vars, {3.0f}}); + auto bias = mm->add_literal(migraphx::literal{vars, {8.1f}}); + auto mean = mm->add_literal(migraphx::literal{vars, {4.0f}}); + auto variance = mm->add_literal(migraphx::literal{vars, {37.11f}}); + mm->add_instruction( + migraphx::make_op("batch_norm_inference"), conv, scale, bias, mean, variance); return p; }; migraphx::program p1 = create_program(); migraphx::program p2 = create_program(); + migraphx::rewrite_batchnorm opt; - opt.apply(p2); - p1.compile(migraphx::cpu::target{}); - p2.compile(migraphx::cpu::target{}); + opt.apply(*p2.get_main_module()); + p1.compile(migraphx::ref::target{}); + p2.compile(migraphx::ref::target{}); - auto result1 = p1.eval({}); - auto result2 = p2.eval({}); + auto result1 = p1.eval({}).back(); + auto result2 = p2.eval({}).back(); std::vector results_vector1; std::vector results_vector2; @@ -79,24 +90,26 @@ TEST_CASE(non_literal) migraphx::shape vars{migraphx::shape::float_type, {4}}; auto create_program = [&]() { migraphx::program p; - - auto x = p.add_parameter("x", xs); - auto w = p.add_parameter("w", ws); - auto conv = p.add_instruction(migraphx::op::convolution{}, x, w); - auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); - auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); - auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); - auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); - p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance); + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", xs); + auto w = mm->add_parameter("w", ws); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), x, w); + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); + auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); + mm->add_instruction( + migraphx::make_op("batch_norm_inference"), conv, scale, bias, mean, variance); return p; }; migraphx::program p1 = create_program(); migraphx::program p2 = create_program(); + migraphx::rewrite_batchnorm opt; - opt.apply(p2); - EXPECT(any_of(p1, &is_batch_norm)); - EXPECT(none_of(p2, &is_batch_norm)); + opt.apply(*p2.get_main_module()); + EXPECT(any_of(*p1.get_main_module(), &is_batch_norm)); + EXPECT(none_of(*p2.get_main_module(), &is_batch_norm)); } TEST_CASE(as_literal) @@ -107,30 +120,110 @@ TEST_CASE(as_literal) migraphx::shape vars{migraphx::shape::float_type, {4}}; auto create_program = [&]() { migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal(migraphx::generate_literal(xs, 1)); + auto w = mm->add_literal(migraphx::generate_literal(ws, 1)); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), x, w); + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); + auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); + mm->add_instruction( + migraphx::make_op("batch_norm_inference"), conv, scale, bias, mean, variance); + return p; + }; + + migraphx::program p1 = create_program(); + migraphx::program p2 = create_program(); + migraphx::rewrite_batchnorm opt; + opt.apply(*p2.get_main_module()); + EXPECT(any_of(*p1.get_main_module(), &is_batch_norm)); + EXPECT(none_of(*p2.get_main_module(), &is_batch_norm)); + + p1.compile(migraphx::ref::target{}); + p2.compile(migraphx::ref::target{}); - auto x = p.add_literal(migraphx::generate_literal(xs, 1)); - auto w = p.add_literal(migraphx::generate_literal(ws, 1)); - auto conv = p.add_instruction(migraphx::op::convolution{}, x, w); - auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); - auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); - auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); - auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); - p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance); + auto result1 = p1.eval({}).back(); + auto result2 = p2.eval({}).back(); + visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); }); +} + +TEST_CASE(as_literal_1d) +{ + migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8}}; + migraphx::shape ws{migraphx::shape::float_type, {4, 3, 1}}; + migraphx::shape vars{migraphx::shape::float_type, {4}}; + auto create_program = [&]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal(migraphx::generate_literal(xs, 1)); + auto w = mm->add_literal(migraphx::generate_literal(ws, 1)); + auto conv = mm->add_instruction( + migraphx::make_op("convolution", + {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}), + x, + w); + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); + auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); + mm->add_instruction( + migraphx::make_op("batch_norm_inference"), conv, scale, bias, mean, variance); return p; }; migraphx::program p1 = create_program(); migraphx::program p2 = create_program(); migraphx::rewrite_batchnorm opt; - opt.apply(p2); - EXPECT(any_of(p1, &is_batch_norm)); - EXPECT(none_of(p2, &is_batch_norm)); + opt.apply(*p2.get_main_module()); + EXPECT(any_of(*p1.get_main_module(), &is_batch_norm)); + EXPECT(none_of(*p2.get_main_module(), &is_batch_norm)); - p1.compile(migraphx::cpu::target{}); - p2.compile(migraphx::cpu::target{}); + p1.compile(migraphx::ref::target{}); + p2.compile(migraphx::ref::target{}); - auto result1 = p1.eval({}); - auto result2 = p2.eval({}); + auto result1 = p1.eval({}).back(); + auto result2 = p2.eval({}).back(); + visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); }); +} + +TEST_CASE(as_literal_3d) +{ + migraphx::shape xs{migraphx::shape::float_type, {1, 3, 2, 4, 8}}; + migraphx::shape ws{migraphx::shape::float_type, {4, 3, 1, 1, 1}}; + migraphx::shape vars{migraphx::shape::float_type, {4}}; + auto create_program = [&]() { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::op::convolution conv_op; + conv_op.padding = {0, 0, 0}; + conv_op.stride = {1, 1, 1}; + conv_op.dilation = {1, 1, 1}; + + auto x = mm->add_literal(migraphx::generate_literal(xs, 1)); + auto w = mm->add_literal(migraphx::generate_literal(ws, 1)); + auto conv = mm->add_instruction(conv_op, x, w); + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); + auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); + mm->add_instruction( + migraphx::make_op("batch_norm_inference"), conv, scale, bias, mean, variance); + return p; + }; + + migraphx::program p1 = create_program(); + migraphx::program p2 = create_program(); + migraphx::rewrite_batchnorm opt; + opt.apply(*p2.get_main_module()); + EXPECT(any_of(*p1.get_main_module(), &is_batch_norm)); + EXPECT(none_of(*p2.get_main_module(), &is_batch_norm)); + + p1.compile(migraphx::ref::target{}); + p2.compile(migraphx::ref::target{}); + + auto result1 = p1.eval({}).back(); + auto result2 = p2.eval({}).back(); visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); }); } @@ -142,33 +235,82 @@ TEST_CASE(literal_reshape) auto create_program = [&]() { migraphx::program p; - auto reshape = [&](auto ins) { - return p.add_instruction(migraphx::op::reshape{{1, 4, 1, 1}}, ins); - }; - - auto x = p.add_literal(migraphx::generate_literal(xs, 1)); - auto w = p.add_literal(migraphx::generate_literal(ws, 1)); - auto conv = p.add_instruction(migraphx::op::convolution{}, x, w); - auto scale = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)))); - auto bias = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)))); - auto mean = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)))); - auto variance = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)))); - p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance); + auto* mm = p.get_main_module(); + auto x = mm->add_literal(migraphx::generate_literal(xs, 1)); + auto w = mm->add_literal(migraphx::generate_literal(ws, 1)); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), x, w); + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); + auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); + mm->add_instruction( + migraphx::make_op("batch_norm_inference"), conv, scale, bias, mean, variance); + return p; + }; + + migraphx::program p1 = create_program(); + migraphx::program p2 = create_program(); + migraphx::rewrite_batchnorm opt; + opt.apply(*p2.get_main_module()); + EXPECT(any_of(*p1.get_main_module(), &is_batch_norm)); + EXPECT(none_of(*p2.get_main_module(), &is_batch_norm)); + + p1.compile(migraphx::ref::target{}); + p2.compile(migraphx::ref::target{}); + + auto result1 = p1.eval({}).back(); + auto result2 = p2.eval({}).back(); + visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); }); +} + +TEST_CASE(literal_reshape_per_actv) +{ + migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 7, 4}}; + migraphx::shape ws{migraphx::shape::float_type, {4, 3, 1, 1, 1}}; + migraphx::shape vars{migraphx::shape::float_type, {4, 8, 7, 4}}; + + auto create_program = [&]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal(migraphx::generate_literal(xs, 1)); + auto w = mm->add_literal(migraphx::generate_literal(ws, 1)); + auto conv = mm->add_instruction( + migraphx::make_op( + "convolution", + {{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}), + x, + w); + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); + auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); + mm->add_instruction( + migraphx::make_op( + "batch_norm_inference", + {{"epsilon", 1.0e-5}, + {"momentum", 0.88}, + {"bn_mode", + migraphx::to_value(migraphx::op::batch_norm_inference::per_activation)}}), + conv, + scale, + bias, + mean, + variance); return p; }; migraphx::program p1 = create_program(); migraphx::program p2 = create_program(); migraphx::rewrite_batchnorm opt; - opt.apply(p2); - EXPECT(any_of(p1, &is_batch_norm)); - EXPECT(none_of(p2, &is_batch_norm)); + opt.apply(*p2.get_main_module()); + EXPECT(any_of(*p1.get_main_module(), &is_batch_norm)); + EXPECT(none_of(*p2.get_main_module(), &is_batch_norm)); - p1.compile(migraphx::cpu::target{}); - p2.compile(migraphx::cpu::target{}); + p1.compile(migraphx::ref::target{}); + p2.compile(migraphx::ref::target{}); - auto result1 = p1.eval({}); - auto result2 = p2.eval({}); + auto result1 = p1.eval({}).back(); + auto result2 = p2.eval({}).back(); visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); }); } diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..856a2ecbe40be75e80082f7c51359989c19349d6 --- /dev/null +++ b/test/rewrite_pooling_test.cpp @@ -0,0 +1,188 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +bool is_pooling(migraphx::instruction& ins) { return ins.name() == "pooling"; } +static void opt_pooling(migraphx::module& m) +{ + migraphx::rewrite_pooling rp; + migraphx::dead_code_elimination dce; + rp.apply(m); + dce.apply(m); +} + +TEST_CASE(rewrite_pooling_test) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}}; + auto pooling_program = [&](const migraphx::op::pooling_mode mode) { + migraphx::module m; + auto input = m.add_parameter("x", s); + auto ret = m.add_instruction(migraphx::make_op("pooling", + {{"mode", mode}, + {"padding", {0, 0, 0}}, + {"stride", {1, 1, 1}}, + {"lengths", {3, 4, 5}}}), + input); + m.add_return({ret}); + return m; + }; + + auto opt_program = [&](const migraphx::operation& reduce_op) { + migraphx::module m; + auto input = m.add_parameter("x", s); + auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), input); + auto rdm = m.add_instruction(reduce_op, rsp); + auto ret = + m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 1, 1, 1}}}), rdm); + m.add_return({ret}); + return m; + }; + + auto test_rewrite = [&](const migraphx::op::pooling_mode mode, const migraphx::operation& op) { + migraphx::module m1 = pooling_program(mode); + migraphx::module m2 = opt_program(op); + opt_pooling(m1); + EXPECT(m1 == m2); + }; + + test_rewrite(migraphx::op::pooling_mode::average, + migraphx::make_op("reduce_mean", {{"axes", {1}}})); + test_rewrite(migraphx::op::pooling_mode::max, migraphx::make_op("reduce_max", {{"axes", {1}}})); +} + +TEST_CASE(rewrite_avepooling_na1_test) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}}; + auto pooling_program = [&]() { + migraphx::module m; + + auto input = m.add_parameter("x", s); + auto ret = + m.add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {0, 1, 0}}, + {"stride", {1, 1, 1}}, + {"lengths", {3, 4, 5}}}), + input); + m.add_return({ret}); + return m; + }; + + migraphx::module m1 = pooling_program(); + migraphx::module m2 = m1; + + opt_pooling(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(rewrite_avepooling_na2_test) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}}; + auto pooling_program = [&]() { + migraphx::module m; + + auto input = m.add_parameter("x", s); + auto ret = + m.add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {0, 0, 0}}, + {"stride", {1, 2, 1}}, + {"lengths", {3, 4, 5}}}), + input); + m.add_return({ret}); + return m; + }; + + migraphx::module m1 = pooling_program(); + migraphx::module m2 = m1; + + opt_pooling(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(rewrite_avepooling_na3_test) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}}; + auto pooling_program = [&]() { + migraphx::module m; + + auto input = m.add_parameter("x", s); + auto ret = m.add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::max}, + {"padding", {0, 0, 0}}, + {"stride", {1, 1, 1}}, + {"lengths", {3, 3, 5}}}), + input); + m.add_return({ret}); + return m; + }; + + migraphx::module m1 = pooling_program(); + migraphx::module m2 = m1; + + opt_pooling(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(literal_rewrite_pooling_test) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}}; + std::vector data(s.elements()); + std::iota(data.begin(), data.end(), 1.0f); + + auto pooling_program = [&](const migraphx::op::pooling_mode mode) { + migraphx::program p; + + auto* mm = p.get_main_module(); + auto input = mm->add_literal(migraphx::literal(s, data)); + auto ret = mm->add_instruction(migraphx::make_op("pooling", + {{"mode", mode}, + {"padding", {0, 0, 0}}, + {"stride", {1, 1, 1}}, + {"lengths", {3, 4, 5}}}), + input); + mm->add_return({ret}); + return p; + }; + + auto opt_program = [&](const migraphx::operation& op) { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_literal(migraphx::literal(s, data)); + auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), input); + auto rdm = mm->add_instruction(op, rsp); + auto ret = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 1, 1, 1}}}), rdm); + mm->add_return({ret}); + + return p; + }; + + auto test_rewrite_pooling = [&](const migraphx::op::pooling_mode mode, + const migraphx::operation& op) { + migraphx::program p1 = pooling_program(mode); + migraphx::program p2 = opt_program(op); + p1.compile(migraphx::ref::target{}); + p2.compile(migraphx::ref::target{}); + auto result1 = p1.eval({}).back(); + auto result2 = p2.eval({}).back(); + visit_all(result1, + result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); }); + }; + + test_rewrite_pooling(migraphx::op::pooling_mode::max, + migraphx::make_op("reduce_max", {{"axes", {1}}})); + test_rewrite_pooling(migraphx::op::pooling_mode::average, + migraphx::make_op("reduce_mean", {{"axes", {1}}})); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/rewrite_quantization_test.cpp b/test/rewrite_quantization_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c39bce5f7ee153a7fcaddb67b885235dbc2c1428 --- /dev/null +++ b/test/rewrite_quantization_test.cpp @@ -0,0 +1,72 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; } +bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; } + +TEST_CASE(quantizelinear) +{ + + migraphx::shape xs{migraphx::shape::float_type, {1, 3, 3}}; + std::vector xv = {-300, 200, 129, 1, 2, 3, 500, 1000, 50}; + migraphx::shape ss{migraphx::shape::float_type, {1, 3, 3}}; + std::vector sv = {2, 2, 2, 2, 2, 2, 2, 2, 2}; + auto create_program = [&]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal(xs, xv); + auto s = mm->add_literal(ss, sv); + mm->add_instruction(migraphx::make_op("quantizelinear"), x, s); + return p; + }; + + migraphx::program p1 = create_program(); + migraphx::program p2 = create_program(); + + migraphx::rewrite_quantization opt; + opt.apply(*p2.get_main_module()); + EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear)); + EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear)); +} + +TEST_CASE(dequantizelinear) +{ + + migraphx::shape xs{migraphx::shape::float_type, {1, 3, 3}}; + std::vector xv = {0, 1, 2, 5, 10, 50, 100, 150, 250}; + migraphx::shape ss{migraphx::shape::float_type, {1, 3, 3}}; + std::vector sv = {2, 2, 2, 2, 2, 2, 2, 2, 2}; + migraphx::shape zs{migraphx::shape::uint8_type, {1, 3, 3}}; + std::vector zv = {0, 0, 0, 0, 0, 0, 0, 0, 0}; + auto create_program = [&]() { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_literal(xs, xv); + auto s = mm->add_literal(ss, sv); + auto z = mm->add_literal(zs, zv); + mm->add_instruction(migraphx::make_op("dequantizelinear"), x, s, z); + return p; + }; + + migraphx::program p1 = create_program(); + migraphx::program p2 = create_program(); + + migraphx::rewrite_quantization opt; + opt.apply(*p2.get_main_module()); + EXPECT(any_of(*p1.get_main_module(), &is_dequantizelinear)); + EXPECT(none_of(*p2.get_main_module(), &is_dequantizelinear)); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/run_loop_test.cpp b/test/run_loop_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aca460c29fe2c7121ca684a16373958509be7c0f --- /dev/null +++ b/test/run_loop_test.cpp @@ -0,0 +1,246 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "test.hpp" + +struct copy_op +{ + std::string name() const { return "copy"; } + + migraphx::shape compute_shape(std::vector inputs) const + { + return inputs.front(); + } + + migraphx::argument + compute(migraphx::context&, const migraphx::shape&, std::vector args) const + { + visit_all(args[0], args[1])([&](auto input, auto output) { + std::copy(input.begin(), input.end(), output.begin()); + }); + + return args[1]; + } + + int output_alias(const std::vector&) const { return 0; } +}; + +struct test_loop_op +{ + int64_t max_iterations = 10; + + template + static auto reflect(Self& self, F f) + { + return migraphx::pack(f(self.max_iterations, "max_iterations")); + } + + std::string name() const { return "test_loop_op"; } + + migraphx::shape compute_shape(const std::vector& inputs, + std::vector mods) const + { + migraphx::check_shapes{inputs, *this}.standard(); + if(mods.size() != 1) + { + MIGRAPHX_THROW("LOOP: operator should have one submodule."); + } + + const auto& mod = mods.front(); + auto mod_out_shapes = mod->get_output_shapes(); + auto dep_param_num = inputs.size() - 2; + + // first item of the mod output shapes is condition used in loop, + // which is not needed to compute output shape + mod_out_shapes.erase(mod_out_shapes.begin()); + std::vector ins_out_shapes(mod_out_shapes.begin(), + mod_out_shapes.begin() + dep_param_num); + mod_out_shapes.erase(mod_out_shapes.begin(), mod_out_shapes.begin() + dep_param_num); + for(const auto& out_s : mod_out_shapes) + { + auto lens = out_s.lens(); + lens.insert(lens.begin(), max_iterations); + ins_out_shapes.push_back({out_s.type(), lens}); + } + + return {ins_out_shapes}; + } + + struct test_loop : public migraphx::op::loop::ref_loop + { + test_loop(int64_t iter_num) { max_iterations = iter_num; } + + std::unordered_map get_output_params(const migraphx::module& m) const + { + auto get_output_index = [](const std::string& name) { + std::string out_prefix = "#output_"; + auto loc = name.find(out_prefix); + if(loc != std::string::npos) + { + int index = std::stoi(name.substr(loc + out_prefix.size())); + return index; + } + + return -1; + }; + + const auto& param_names = m.get_parameter_names(); + std::unordered_map result; + for(const auto& name : param_names) + { + auto index = get_output_index(name); + if(index == -1) + continue; + result[name] = index; + } + + return result; + } + }; + + migraphx::argument + compute(migraphx::context& ctx, + const migraphx::shape& out_shape, + const std::vector& args, + const std::vector& mods, + const std::function( + migraphx::module_ref&, const std::unordered_map&)>& + run) const + { + // wrap up the arguments vector, so ref and gpu impl are the same + auto cpy_args = args; + bool in_cond = args.at(1).at(); + bool cond = in_cond; + int64_t iter = 0; + // insert iter and cond used in the loop + auto s_cond = args.at(1).get_shape(); + auto s_iter = args.at(0).get_shape(); + cpy_args.push_back({s_iter, &iter}); + cpy_args.push_back({s_cond, &cond}); + cpy_args.insert(cpy_args.end(), args.begin() + 2, args.end()); + // add cond and mod outputs to the argument list + cpy_args.push_back(migraphx::argument(s_cond)); + cpy_args.push_back(migraphx::argument(out_shape)); + // run loop + return run_loop(test_loop{max_iterations}, ctx, cpy_args, mods, run); + } +}; + +static auto create_program(int64_t max_loop_iterations = 10) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape si{migraphx::shape::int64_type}; + migraphx::shape s{migraphx::shape::int64_type, {1}}; + migraphx::shape sc{migraphx::shape::bool_type}; + + auto in_iter = mm->add_parameter("iter_num", si); + auto in_cond = mm->add_parameter("ccond", sc); + auto in_val = mm->add_parameter("val", s); + + auto* body = p.create_module("loop_module"); + auto iter = body->add_parameter("#loop_module_in_0", si); + body->add_parameter("#loop_module_in_1", sc); + auto in_v = body->add_parameter("#loop_module_in_2", s); + std::vector vd = {3}; + auto l = body->add_literal(migraphx::literal(si, vd)); + auto ad = body->add_instruction(migraphx::make_op("add"), iter, l); + auto val = body->add_instruction(migraphx::make_op("add"), in_v, ad); + auto eq = body->add_instruction(migraphx::make_op("equal"), iter, l); + auto beq = body->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), eq); + auto neq = body->add_instruction(migraphx::make_op("not"), beq); + std::string out_param_prefix = "loop_module:#output_"; + auto out0 = body->add_parameter(out_param_prefix + std::to_string(0), neq->get_shape()); + auto r_neq = body->add_instruction(copy_op{}, neq, out0); + auto out2 = body->add_parameter(out_param_prefix + std::to_string(2), val->get_shape()); + auto r_val = body->add_instruction(copy_op{}, val, out2); + body->add_return({r_neq, r_val, r_val}); + + auto rl = + mm->add_instruction(test_loop_op{max_loop_iterations}, {in_iter, in_cond, in_val}, {body}); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), rl); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), rl); + mm->add_return({r0, r1}); + + return p; +}; + +static auto run_prog(migraphx::program p, int64_t iter_num, bool cond, int64_t ini_val) +{ + migraphx::shape si{migraphx::shape::int64_type}; + migraphx::shape s{migraphx::shape::int64_type, {1}}; + migraphx::shape sc{migraphx::shape::bool_type}; + + p.compile(migraphx::ref::target{}); + migraphx::parameter_map pp; + pp["iter_num"] = migraphx::argument(si, &iter_num); + pp["ccond"] = migraphx::argument(sc, &cond); + pp["val"] = migraphx::argument(s, &ini_val); + auto rets = p.eval(pp); + + std::vector> res; + for(auto& arg : rets) + { + std::vector vec; + arg.visit([&](auto v) { vec.assign(v.begin(), v.end()); }); + res.push_back(vec); + } + + return res; +} + +TEST_CASE(loop_test1) +{ + auto p = create_program(); + auto ress = run_prog(p, 10, true, 1); + std::vector gold_last = {19}; + EXPECT(ress.front() == gold_last); + std::vector gold_concat = {4, 8, 13, 19, 0, 0, 0, 0, 0, 0}; + EXPECT(ress.back() == gold_concat); +} + +TEST_CASE(loop_test2) +{ + auto p = create_program(12); + auto ress = run_prog(p, 4, true, 1); + std::vector gold_last = {19}; + EXPECT(ress.front() == gold_last); + std::vector gold_concat = {4, 8, 13, 19, 0, 0, 0, 0, 0, 0, 0, 0}; + EXPECT(ress.back() == gold_concat); +} + +TEST_CASE(loop_test3) +{ + auto p = create_program(3); + auto ress = run_prog(p, 3, true, 1); + std::vector gold_last = {13}; + EXPECT(ress.front() == gold_last); + std::vector gold_concat = {4, 8, 13}; + EXPECT(ress.back() == gold_concat); +} + +TEST_CASE(loop_test4) +{ + auto p = create_program(20); + auto ress = run_prog(p, 5, true, 2); + std::vector gold_last = {20}; + EXPECT(ress.front() == gold_last); + std::vector gold_concat = {5, 9, 14, 20, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + EXPECT(ress.back() == gold_concat); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/schedule_test.cpp b/test/schedule_test.cpp old mode 100644 new mode 100755 index 904a80d6d752f08264147891a4257efc2a56271b..cb5731d6748429c756e9c609ed67a577b8822680 --- a/test/schedule_test.cpp +++ b/test/schedule_test.cpp @@ -1,12 +1,13 @@ #include #include -#include #include #include #include #include #include #include +#include + #include struct unary_op @@ -31,7 +32,7 @@ struct unary_op struct nary_op { - std::string comment = ""; + std::string comment; template static auto reflect(Self& self, F f) { @@ -56,7 +57,7 @@ struct nary_op struct stream_free_op { - std::string comment = ""; + std::string comment; template static auto reflect(Self& self, F f) { @@ -112,21 +113,21 @@ struct schedule_model_test std::shared_ptr wait2stream = std::make_shared(); std::shared_ptr ins2wait_for = std::make_shared(); std::size_t concurrency() const { return 4; } - void sched(migraphx::program&, migraphx::instruction_ref ins, std::size_t n) const + void sched(migraphx::module&, migraphx::instruction_ref ins, std::size_t n) const { (*ins2stream)[ins] = n; } - void wait(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_id) const + void wait(migraphx::module& m, migraphx::instruction_ref ins, std::size_t wait_id) const { if(ins2wait_for->count(ins) == 0) { auto event = wait_event{}; - p.insert_instruction(ins, event); + m.insert_instruction(ins, event); (*ins2wait_for)[ins] = event.wait_for; } (*ins2wait_for)[ins]->push_back(wait2stream->at(wait_id)); } - void record(migraphx::program&, migraphx::instruction_ref ins, std::size_t wait_id) const + void record(migraphx::module&, migraphx::instruction_ref ins, std::size_t wait_id) const { (*wait2stream)[wait_id] = ins2stream->at(ins); } @@ -141,19 +142,17 @@ struct schedule_model_test } }; -bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y) +bool check_conflicts(migraphx::module& m, migraphx::instruction_ref x, migraphx::instruction_ref y) { - for(auto ins : migraphx::iterator_for(p)) - { + return migraphx::any_of(migraphx::iterator_for(m), [&](auto ins) { if(ins->name() != "identity") - continue; + return false; if(not migraphx::contains(ins->inputs(), x)) - continue; + return false; if(not migraphx::contains(ins->inputs(), y)) - continue; + return false; return true; - } - return false; + }); } struct scheduler @@ -171,11 +170,11 @@ struct scheduler return result; } - void run_pass(migraphx::program& p) { migraphx::run_passes(p, {migraphx::schedule{model}}); } + void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::schedule{model}}); } bool has_stream(migraphx::instruction_ref ins) { return model.ins2stream->count(ins) > 0; } - void check_conflicts(migraphx::program& p, + void check_conflicts(migraphx::module& m, std::vector> conflicts, bool result = true) { @@ -190,7 +189,7 @@ struct scheduler if(this->has_stream(ins1) and this->has_stream(ins2) and this->get_stream(ins1) == this->get_stream(ins2)) continue; - CHECK(::check_conflicts(p, ins1, ins2) == result); + CHECK(::check_conflicts(m, ins1, ins2) == result); } } }); @@ -238,12 +237,12 @@ std::vector get_wait_for(migraphx::instruction_ref ins) template std::vector -chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input) +chain(migraphx::module& m, std::size_t n, T x, migraphx::instruction_ref input) { std::vector result; for(std::size_t i = 0; i < n; i++) { - result.push_back(p.add_instruction(x, input)); + result.push_back(m.add_instruction(x, input)); input = result.back(); } return result; @@ -251,105 +250,111 @@ chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input) TEST_CASE(single_entry) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto onep1 = p.add_instruction(unary_op{}, one); - auto onep2 = p.add_instruction(unary_op{}, one); - auto binary = p.add_instruction(nary_op{}, onep1, onep2); - t.run_pass(p); + migraphx::module m; + + auto one = m.add_literal(1); + auto onem1 = m.add_instruction(unary_op{}, one); + auto onem2 = m.add_instruction(unary_op{}, one); + auto binary = m.add_instruction(nary_op{}, onem1, onem2); + t.run_pass(m); EXPECT(not t.has_stream(one)); - EXPECT(t.get_stream(onep1) != t.get_stream(onep2)); + EXPECT(t.get_stream(onem1) != t.get_stream(onem2)); EXPECT(t.get_stream(binary) == 0); EXPECT(get_wait_for(binary) == - get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)})); - EXPECT(check_conflicts(p, onep1, onep2)); + get_wait_for(t.get_stream(binary), {t.get_stream(onem1), t.get_stream(onem2)})); + EXPECT(check_conflicts(m, onem1, onem2)); } TEST_CASE(stream_free) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto onep1 = p.add_instruction(stream_free_op{}, one); - auto onep2 = p.add_instruction(stream_free_op{}, one); - auto binary = p.add_instruction(nary_op{}, onep1, onep2); - t.run_pass(p); + migraphx::module m; + + auto one = m.add_literal(1); + auto onem1 = m.add_instruction(stream_free_op{}, one); + auto onem2 = m.add_instruction(stream_free_op{}, one); + auto binary = m.add_instruction(nary_op{}, onem1, onem2); + t.run_pass(m); EXPECT(not t.has_stream(one)); - EXPECT(not t.has_stream(onep1)); - EXPECT(not t.has_stream(onep2)); + EXPECT(not t.has_stream(onem1)); + EXPECT(not t.has_stream(onem2)); EXPECT(not t.has_stream(binary)); } TEST_CASE(zero_record) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto onep1 = p.add_instruction(unary_op{}, one); - auto onep2 = p.add_instruction(unary_op{}, one); - auto onei1 = p.add_instruction(migraphx::op::identity{}, onep1); - auto onei2 = p.add_instruction(migraphx::op::identity{}, onep2); - auto binary = p.add_instruction(nary_op{}, onei1, onei2); - t.run_pass(p); + migraphx::module m; + + auto one = m.add_literal(1); + auto onem1 = m.add_instruction(unary_op{}, one); + auto onem2 = m.add_instruction(unary_op{}, one); + auto onei1 = m.add_instruction(migraphx::make_op("identity"), onem1); + auto onei2 = m.add_instruction(migraphx::make_op("identity"), onem2); + auto binary = m.add_instruction(nary_op{}, onei1, onei2); + t.run_pass(m); EXPECT(not t.has_stream(one)); - EXPECT(t.get_stream(onep1) != t.get_stream(onep2)); + EXPECT(t.get_stream(onem1) != t.get_stream(onem2)); EXPECT(t.has_stream(binary)); EXPECT(get_wait_for(binary) == - get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)})); - EXPECT(check_conflicts(p, onep1, onep2)); - t.check_conflicts(p, {{onep1, onei1}, {onep2, onei2}}); + get_wait_for(t.get_stream(binary), {t.get_stream(onem1), t.get_stream(onem2)})); + EXPECT(check_conflicts(m, onem1, onem2)); + t.check_conflicts(m, {{onem1, onei1}, {onem2, onei2}}); } TEST_CASE(zero_merge1) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto onep1 = p.add_instruction(unary_op{}, one); - auto onep2 = p.add_instruction(unary_op{}, one); - auto binary = p.add_instruction(migraphx::op::identity{}, onep1, onep2); - t.run_pass(p); + migraphx::module m; + + auto one = m.add_literal(1); + auto onem1 = m.add_instruction(unary_op{}, one); + auto onem2 = m.add_instruction(unary_op{}, one); + auto binary = m.add_instruction(migraphx::make_op("identity"), onem1, onem2); + t.run_pass(m); EXPECT(not t.has_stream(one)); - EXPECT(t.get_stream(onep1) != t.get_stream(onep2)); + EXPECT(t.get_stream(onem1) != t.get_stream(onem2)); // No stream assignment EXPECT(not t.has_stream(binary)); // There is no wait EXPECT(get_wait_for(binary).empty()); - EXPECT(check_conflicts(p, onep1, onep2)); + EXPECT(check_conflicts(m, onem1, onem2)); } TEST_CASE(zero_merge2) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto onep1 = p.add_instruction(unary_op{}, one); - auto onep2 = p.add_instruction(unary_op{}, one); - auto binary = p.add_instruction(migraphx::op::identity{}, - p.add_instruction(migraphx::op::identity{}, onep1), - p.add_instruction(migraphx::op::identity{}, onep2)); - t.run_pass(p); + migraphx::module m; + + auto one = m.add_literal(1); + auto onem1 = m.add_instruction(unary_op{}, one); + auto onem2 = m.add_instruction(unary_op{}, one); + auto binary = m.add_instruction(migraphx::make_op("identity"), + m.add_instruction(migraphx::make_op("identity"), onem1), + m.add_instruction(migraphx::make_op("identity"), onem2)); + t.run_pass(m); EXPECT(not t.has_stream(one)); - EXPECT(t.get_stream(onep1) != t.get_stream(onep2)); + EXPECT(t.get_stream(onem1) != t.get_stream(onem2)); // No stream assignment EXPECT(not t.has_stream(binary)); // There is no wait EXPECT(get_wait_for(binary).empty()); - EXPECT(check_conflicts(p, onep1, onep2)); + EXPECT(check_conflicts(m, onem1, onem2)); } TEST_CASE(zero_merge3) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto onep1 = p.add_instruction(unary_op{}, one); - auto onep2 = p.add_instruction(unary_op{}, one); - auto id = p.add_instruction(migraphx::op::identity{}, onep1, onep2); - auto final = p.add_instruction(unary_op{}, id); - t.run_pass(p); + migraphx::module m; + + auto one = m.add_literal(1); + auto onem1 = m.add_instruction(unary_op{}, one); + auto onem2 = m.add_instruction(unary_op{}, one); + auto id = m.add_instruction(migraphx::make_op("identity"), onem1, onem2); + auto final = m.add_instruction(unary_op{}, id); + t.run_pass(m); EXPECT(not t.has_stream(one)); - EXPECT(t.get_stream(onep1) != t.get_stream(onep2)); + EXPECT(t.get_stream(onem1) != t.get_stream(onem2)); // No stream assignment EXPECT(not t.has_stream(id)); // There is no wait @@ -357,24 +362,25 @@ TEST_CASE(zero_merge3) // Stream assignment for final op EXPECT(t.get_stream(final) == 0); EXPECT(get_wait_for(final) == - get_wait_for(t.get_stream(final), {t.get_stream(onep1), t.get_stream(onep2)})); - EXPECT(check_conflicts(p, onep1, onep2)); + get_wait_for(t.get_stream(final), {t.get_stream(onem1), t.get_stream(onem2)})); + EXPECT(check_conflicts(m, onem1, onem2)); } TEST_CASE(zero_merge4) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto onep1 = p.add_instruction(unary_op{}, one); - auto onep2 = p.add_instruction(unary_op{}, one); - auto id = p.add_instruction(migraphx::op::identity{}, - p.add_instruction(migraphx::op::identity{}, onep1), - p.add_instruction(migraphx::op::identity{}, onep2)); - auto final = p.add_instruction(unary_op{}, id); - t.run_pass(p); + migraphx::module m; + + auto one = m.add_literal(1); + auto onem1 = m.add_instruction(unary_op{}, one); + auto onem2 = m.add_instruction(unary_op{}, one); + auto id = m.add_instruction(migraphx::make_op("identity"), + m.add_instruction(migraphx::make_op("identity"), onem1), + m.add_instruction(migraphx::make_op("identity"), onem2)); + auto final = m.add_instruction(unary_op{}, id); + t.run_pass(m); EXPECT(not t.has_stream(one)); - EXPECT(t.get_stream(onep1) != t.get_stream(onep2)); + EXPECT(t.get_stream(onem1) != t.get_stream(onem2)); // No stream assignment EXPECT(not t.has_stream(id)); // There is no wait @@ -382,38 +388,40 @@ TEST_CASE(zero_merge4) // Stream assignment for final op EXPECT(t.get_stream(final) == 0); EXPECT(get_wait_for(final) == - get_wait_for(t.get_stream(final), {t.get_stream(onep1), t.get_stream(onep2)})); - EXPECT(check_conflicts(p, onep1, onep2)); + get_wait_for(t.get_stream(final), {t.get_stream(onem1), t.get_stream(onem2)})); + EXPECT(check_conflicts(m, onem1, onem2)); } TEST_CASE(double_entry) { scheduler t{}; - migraphx::program p; - auto one = p.add_instruction(stream_free_op{}, p.add_literal(1)); - auto two = p.add_instruction(stream_free_op{}, p.add_literal(2)); - auto onep = p.add_instruction(unary_op{}, one); - auto twop = p.add_instruction(unary_op{}, two); - auto binary = p.add_instruction(nary_op{}, onep, twop); - t.run_pass(p); + migraphx::module m; + + auto one = m.add_instruction(stream_free_op{}, m.add_literal(1)); + auto two = m.add_instruction(stream_free_op{}, m.add_literal(2)); + auto onep = m.add_instruction(unary_op{}, one); + auto twop = m.add_instruction(unary_op{}, two); + auto binary = m.add_instruction(nary_op{}, onep, twop); + t.run_pass(m); EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(two)); EXPECT(t.get_stream(onep) != t.get_stream(twop)); EXPECT(t.get_stream(binary) == 0); EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)})); - t.check_conflicts(p, {{onep, one}, {twop, two}}); + t.check_conflicts(m, {{onep, one}, {twop, two}}); } TEST_CASE(two_branches) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto c1 = chain(p, 2, unary_op{}, one); - auto i1 = p.add_instruction(unary_op{}, one); - auto binary = p.add_instruction(nary_op{}, i1, c1.back()); - t.run_pass(p); + migraphx::module m; + + auto one = m.add_literal(1); + auto c1 = chain(m, 2, unary_op{}, one); + auto i1 = m.add_instruction(unary_op{}, one); + auto binary = m.add_instruction(nary_op{}, i1, c1.back()); + t.run_pass(m); EXPECT(not t.has_stream(one)); EXPECT(t.get_stream(i1) == 1); for(auto ins : c1) @@ -421,20 +429,21 @@ TEST_CASE(two_branches) EXPECT(t.get_stream(binary) == 0); EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)})); - t.check_conflicts(p, {c1, {i1}}); + t.check_conflicts(m, {c1, {i1}}); } TEST_CASE(four_branches) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto c1 = chain(p, 4, unary_op{}, one); - auto c2 = chain(p, 3, unary_op{}, one); - auto c3 = chain(p, 2, unary_op{}, one); - auto i1 = p.add_instruction(unary_op{}, one); - auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back()); - t.run_pass(p); + migraphx::module m; + + auto one = m.add_literal(1); + auto c1 = chain(m, 4, unary_op{}, one); + auto c2 = chain(m, 3, unary_op{}, one); + auto c3 = chain(m, 2, unary_op{}, one); + auto i1 = m.add_instruction(unary_op{}, one); + auto binary = m.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back()); + t.run_pass(m); EXPECT(not t.has_stream(one)); EXPECT(t.get_stream(i1) == 3); for(auto ins : c1) @@ -449,21 +458,22 @@ TEST_CASE(four_branches) t.get_stream(c2.back()), t.get_stream(c3.back()), t.get_stream(i1)})); - t.check_conflicts(p, {c1, c2, c3, {i1}}); + t.check_conflicts(m, {c1, c2, c3, {i1}}); } TEST_CASE(five_branches) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto c1 = chain(p, 5, unary_op{}, one); - auto c2 = chain(p, 4, unary_op{}, one); - auto c3 = chain(p, 3, unary_op{}, one); - auto c4 = chain(p, 2, unary_op{}, one); - auto i1 = p.add_instruction(unary_op{}, one); - auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back(), c4.back()); - t.run_pass(p); + migraphx::module m; + + auto one = m.add_literal(1); + auto c1 = chain(m, 5, unary_op{}, one); + auto c2 = chain(m, 4, unary_op{}, one); + auto c3 = chain(m, 3, unary_op{}, one); + auto c4 = chain(m, 2, unary_op{}, one); + auto i1 = m.add_instruction(unary_op{}, one); + auto binary = m.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back(), c4.back()); + t.run_pass(m); EXPECT(not t.has_stream(one)); EXPECT(t.get_stream(i1) == 3); for(auto ins : c1) @@ -480,50 +490,52 @@ TEST_CASE(five_branches) t.get_stream(c2.back()), t.get_stream(c3.back()), t.get_stream(i1)})); - t.check_conflicts(p, {c1, c2, c3, c4}); - t.check_conflicts(p, {c1, c2, c3, {i1}}); + t.check_conflicts(m, {c1, c2, c3, c4}); + t.check_conflicts(m, {c1, c2, c3, {i1}}); } TEST_CASE(four_branches_eq) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto onep1 = p.add_instruction(unary_op{}, one); - auto onep2 = p.add_instruction(unary_op{}, one); - auto onep3 = p.add_instruction(unary_op{}, one); - auto onep4 = p.add_instruction(unary_op{}, one); - auto binary = p.add_instruction(nary_op{}, onep1, onep2, onep3, onep4); - t.run_pass(p); + migraphx::module m; + + auto one = m.add_literal(1); + auto onem1 = m.add_instruction(unary_op{}, one); + auto onem2 = m.add_instruction(unary_op{}, one); + auto onep3 = m.add_instruction(unary_op{}, one); + auto onep4 = m.add_instruction(unary_op{}, one); + auto binary = m.add_instruction(nary_op{}, onem1, onem2, onep3, onep4); + t.run_pass(m); EXPECT(not t.has_stream(one)); EXPECT( sorted( - {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}) == + {t.get_stream(onem1), t.get_stream(onem2), t.get_stream(onep3), t.get_stream(onep4)}) == unique( - {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)})); + {t.get_stream(onem1), t.get_stream(onem2), t.get_stream(onep3), t.get_stream(onep4)})); EXPECT(t.get_stream(binary) == 0); EXPECT( get_wait_for(binary) == get_wait_for( t.get_stream(binary), - {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)})); - t.check_conflicts(p, {{onep1}, {onep2}, {onep3}, {onep4}}); + {t.get_stream(onem1), t.get_stream(onem2), t.get_stream(onep3), t.get_stream(onep4)})); + t.check_conflicts(m, {{onem1}, {onem2}, {onep3}, {onep4}}); } TEST_CASE(seq_merge) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto c1 = chain(p, 2, unary_op{}, one); - auto i1 = p.add_instruction(unary_op{}, one); - auto binary1 = p.add_instruction(nary_op{}, i1, c1.back()); + migraphx::module m; + + auto one = m.add_literal(1); + auto c1 = chain(m, 2, unary_op{}, one); + auto i1 = m.add_instruction(unary_op{}, one); + auto binary1 = m.add_instruction(nary_op{}, i1, c1.back()); - auto c2 = chain(p, 2, unary_op{}, binary1); - auto i2 = p.add_instruction(unary_op{}, binary1); - auto binary2 = p.add_instruction(nary_op{}, i2, c2.back()); + auto c2 = chain(m, 2, unary_op{}, binary1); + auto i2 = m.add_instruction(unary_op{}, binary1); + auto binary2 = m.add_instruction(nary_op{}, i2, c2.back()); - t.run_pass(p); + t.run_pass(m); EXPECT(not t.has_stream(one)); EXPECT(t.get_stream(i1) != t.get_stream(c1.back())); @@ -532,7 +544,7 @@ TEST_CASE(seq_merge) EXPECT(t.get_stream(binary1) == t.get_stream(c1.back())); EXPECT(get_wait_for(binary1) == get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)})); - t.check_conflicts(p, {c1, {i1}}); + t.check_conflicts(m, {c1, {i1}}); EXPECT(t.get_stream(i2) != t.get_stream(c2.back())); for(auto ins : c2) @@ -540,27 +552,28 @@ TEST_CASE(seq_merge) EXPECT(t.get_stream(binary2) == 0); EXPECT(get_wait_for(binary2) == get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)})); - t.check_conflicts(p, {c2, {i2}}); + t.check_conflicts(m, {c2, {i2}}); } TEST_CASE(par_merge) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto start1 = p.add_instruction(unary_op{}, one); - auto c1 = chain(p, 3, unary_op{}, start1); - auto i1 = p.add_instruction(unary_op{}, start1); - auto binary1 = p.add_instruction(nary_op{}, i1, c1.back()); + migraphx::module m; + + auto one = m.add_literal(1); + auto start1 = m.add_instruction(unary_op{}, one); + auto c1 = chain(m, 3, unary_op{}, start1); + auto i1 = m.add_instruction(unary_op{}, start1); + auto binary1 = m.add_instruction(nary_op{}, i1, c1.back()); - auto start2 = p.add_instruction(unary_op{}, one); - auto c2 = chain(p, 2, unary_op{}, start2); - auto i2 = p.add_instruction(unary_op{}, start2); - auto binary2 = p.add_instruction(nary_op{}, i2, c2.back()); + auto start2 = m.add_instruction(unary_op{}, one); + auto c2 = chain(m, 2, unary_op{}, start2); + auto i2 = m.add_instruction(unary_op{}, start2); + auto binary2 = m.add_instruction(nary_op{}, i2, c2.back()); - auto binary3 = p.add_instruction(nary_op{}, binary1, binary2); + auto binary3 = m.add_instruction(nary_op{}, binary1, binary2); - t.run_pass(p); + t.run_pass(m); EXPECT(not t.has_stream(one)); EXPECT(t.get_stream(binary3) == 0); @@ -570,7 +583,7 @@ TEST_CASE(par_merge) EXPECT(t.get_stream(binary1) == 0); EXPECT(get_wait_for(binary1) == get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)})); - t.check_conflicts(p, {c1, {i1}}); + t.check_conflicts(m, {c1, {i1}}); for(auto ins : c2) EXPECT(t.get_stream(ins) == t.get_stream(binary2)); @@ -578,33 +591,34 @@ TEST_CASE(par_merge) EXPECT(t.get_stream(binary2) != t.get_stream(i2)); EXPECT(get_wait_for(binary2) == get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)})); - t.check_conflicts(p, {c2, {i2}}); + t.check_conflicts(m, {c2, {i2}}); - EXPECT(check_conflicts(p, binary1, binary2)); - t.check_conflicts(p, {c1, {i1}, c2, {i2}}); + EXPECT(check_conflicts(m, binary1, binary2)); + t.check_conflicts(m, {c1, {i1}, c2, {i2}}); } TEST_CASE(inner_par_merge) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto start1 = p.add_instruction(unary_op{}, one); - auto c1 = chain(p, 3, unary_op{}, start1); - auto i1 = p.add_instruction(unary_op{}, start1); - auto binary1 = p.add_instruction(nary_op{}, i1, c1.back()); + migraphx::module m; + + auto one = m.add_literal(1); + auto start1 = m.add_instruction(unary_op{}, one); + auto c1 = chain(m, 3, unary_op{}, start1); + auto i1 = m.add_instruction(unary_op{}, start1); + auto binary1 = m.add_instruction(nary_op{}, i1, c1.back()); - auto start2 = p.add_instruction(unary_op{}, one); - auto c2 = chain(p, 2, unary_op{}, start2); - auto i2 = p.add_instruction(unary_op{}, start2); - auto binary2 = p.add_instruction(nary_op{}, i2, c2.back()); + auto start2 = m.add_instruction(unary_op{}, one); + auto c2 = chain(m, 2, unary_op{}, start2); + auto i2 = m.add_instruction(unary_op{}, start2); + auto binary2 = m.add_instruction(nary_op{}, i2, c2.back()); - auto outer1 = p.add_instruction(unary_op{}, one); - auto outer2 = p.add_instruction(unary_op{}, one); + auto outer1 = m.add_instruction(unary_op{}, one); + auto outer2 = m.add_instruction(unary_op{}, one); - auto output = p.add_instruction(nary_op{}, binary1, binary2, outer1, outer2); + auto output = m.add_instruction(nary_op{}, binary1, binary2, outer1, outer2); - t.run_pass(p); + t.run_pass(m); EXPECT(not t.has_stream(one)); EXPECT(t.get_stream(output) == 0); EXPECT(get_wait_for(output) == get_wait_for(t.get_stream(output), @@ -623,7 +637,7 @@ TEST_CASE(inner_par_merge) EXPECT(t.get_stream(binary1) == 0); EXPECT(get_wait_for(binary1) == get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)})); - t.check_conflicts(p, {c1, {i1}}); + t.check_conflicts(m, {c1, {i1}}); for(auto ins : c2) EXPECT(t.get_stream(ins) == t.get_stream(binary2)); @@ -631,31 +645,32 @@ TEST_CASE(inner_par_merge) EXPECT(t.get_stream(binary2) != t.get_stream(i2)); EXPECT(get_wait_for(binary2) == get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)})); - t.check_conflicts(p, {c2, {i2}}); + t.check_conflicts(m, {c2, {i2}}); - EXPECT(check_conflicts(p, binary1, binary2)); - t.check_conflicts(p, {c1, {i1}, c2, {i2}, {outer1}, {outer2}}); + EXPECT(check_conflicts(m, binary1, binary2)); + t.check_conflicts(m, {c1, {i1}, c2, {i2}, {outer1}, {outer2}}); } TEST_CASE(par_merge_multi_entry) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto start1 = p.add_instruction(unary_op{}, one); - auto c1 = chain(p, 3, unary_op{}, start1); - auto i1 = p.add_instruction(unary_op{}, start1); - auto binary1 = p.add_instruction(nary_op{}, i1, c1.back()); + migraphx::module m; + + auto one = m.add_literal(1); + auto start1 = m.add_instruction(unary_op{}, one); + auto c1 = chain(m, 3, unary_op{}, start1); + auto i1 = m.add_instruction(unary_op{}, start1); + auto binary1 = m.add_instruction(nary_op{}, i1, c1.back()); - auto two = p.add_literal(1); - auto start2 = p.add_instruction(unary_op{}, two); - auto c2 = chain(p, 2, unary_op{}, start2); - auto i2 = p.add_instruction(unary_op{}, start2); - auto binary2 = p.add_instruction(nary_op{}, i2, c2.back()); + auto two = m.add_literal(1); + auto start2 = m.add_instruction(unary_op{}, two); + auto c2 = chain(m, 2, unary_op{}, start2); + auto i2 = m.add_instruction(unary_op{}, start2); + auto binary2 = m.add_instruction(nary_op{}, i2, c2.back()); - auto binary3 = p.add_instruction(nary_op{}, binary1, binary2); + auto binary3 = m.add_instruction(nary_op{}, binary1, binary2); - t.run_pass(p); + t.run_pass(m); EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(two)); EXPECT(t.get_stream(binary3) == 0); @@ -666,7 +681,7 @@ TEST_CASE(par_merge_multi_entry) EXPECT(t.get_stream(binary1) == 0); EXPECT(get_wait_for(binary1) == get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)})); - t.check_conflicts(p, {c1, {i1}}); + t.check_conflicts(m, {c1, {i1}}); for(auto ins : c2) EXPECT(t.get_stream(ins) == t.get_stream(binary2)); @@ -674,23 +689,24 @@ TEST_CASE(par_merge_multi_entry) EXPECT(t.get_stream(binary2) != t.get_stream(i2)); EXPECT(get_wait_for(binary2) == get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)})); - t.check_conflicts(p, {c2, {i2}}); + t.check_conflicts(m, {c2, {i2}}); - EXPECT(check_conflicts(p, binary1, binary2)); - t.check_conflicts(p, {c1, {i1}, c2, {i2}}); + EXPECT(check_conflicts(m, binary1, binary2)); + t.check_conflicts(m, {c1, {i1}, c2, {i2}}); } TEST_CASE(inner_split1) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto c1 = chain(p, 2, unary_op{}, one); - auto i1 = p.add_instruction(unary_op{}, one); - auto s1 = p.add_instruction(unary_op{}, c1); - auto s2 = p.add_instruction(unary_op{}, c1); - auto output = p.add_instruction(nary_op{}, i1, s1, s2); - t.run_pass(p); + migraphx::module m; + + auto one = m.add_literal(1); + auto c1 = chain(m, 2, unary_op{}, one); + auto i1 = m.add_instruction(unary_op{}, one); + auto s1 = m.add_instruction(unary_op{}, c1); + auto s2 = m.add_instruction(unary_op{}, c1); + auto output = m.add_instruction(nary_op{}, i1, s1, s2); + t.run_pass(m); EXPECT(not t.has_stream(one)); EXPECT(t.get_stream(i1) != t.get_stream(s1)); EXPECT(t.get_stream(i1) != t.get_stream(s2)); @@ -704,20 +720,21 @@ TEST_CASE(inner_split1) get_wait_for(t.get_stream(output), {t.get_stream(i1), t.get_stream(s1), t.get_stream(s2)})); // Either s1 or s2 has a wait depending on the sort order but not both EXPECT(get_wait_for(s1).empty() xor get_wait_for(s2).empty()); - t.check_conflicts(p, {c1, {i1}, {s1}, {s2}}); + t.check_conflicts(m, {c1, {i1}, {s1}, {s2}}); } TEST_CASE(inner_split2) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto c1 = chain(p, 2, unary_op{}, one); - auto i1 = p.add_instruction(unary_op{}, one); - auto s1 = chain(p, 3, unary_op{}, c1.back()); - auto s2 = chain(p, 4, unary_op{}, c1.back()); - auto output = p.add_instruction(nary_op{}, i1, s1.back(), s2.back()); - t.run_pass(p); + migraphx::module m; + + auto one = m.add_literal(1); + auto c1 = chain(m, 2, unary_op{}, one); + auto i1 = m.add_instruction(unary_op{}, one); + auto s1 = chain(m, 3, unary_op{}, c1.back()); + auto s2 = chain(m, 4, unary_op{}, c1.back()); + auto output = m.add_instruction(nary_op{}, i1, s1.back(), s2.back()); + t.run_pass(m); EXPECT(not t.has_stream(one)); EXPECT(t.get_stream(i1) != t.get_stream(s1.back())); EXPECT(t.get_stream(i1) != t.get_stream(s2.back())); @@ -730,20 +747,21 @@ TEST_CASE(inner_split2) get_wait_for(t.get_stream(output), {t.get_stream(i1), t.get_stream(s1.back()), t.get_stream(s2.back())})); EXPECT(get_wait_for(s1.front()) == get_wait_for({t.get_stream(c1.back())})); - t.check_conflicts(p, {c1, {i1}, s1, s2}); + t.check_conflicts(m, {c1, {i1}, s1, s2}); } TEST_CASE(inception_resnet) { scheduler t{}; - migraphx::program p; - auto one = p.add_literal(1); - auto input = p.add_instruction(unary_op{}, one); - auto c1 = chain(p, 2, unary_op{}, input); - auto i1 = p.add_instruction(unary_op{}, input); - auto binary = p.add_instruction(nary_op{}, i1, c1.back()); - auto output = p.add_instruction(nary_op{}, binary, input); - t.run_pass(p); + migraphx::module m; + + auto one = m.add_literal(1); + auto input = m.add_instruction(unary_op{}, one); + auto c1 = chain(m, 2, unary_op{}, input); + auto i1 = m.add_instruction(unary_op{}, input); + auto binary = m.add_instruction(nary_op{}, i1, c1.back()); + auto output = m.add_instruction(nary_op{}, binary, input); + t.run_pass(m); EXPECT(not t.has_stream(one)); EXPECT(t.get_stream(i1) != 0); for(auto ins : c1) @@ -753,105 +771,130 @@ TEST_CASE(inception_resnet) get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)})); EXPECT(t.get_stream(output) == 0); EXPECT(get_wait_for(output).empty()); - t.check_conflicts(p, {c1, {i1}}); + t.check_conflicts(m, {c1, {i1}}); } -TEST_CASE(inception1) +TEST_CASE(dominate_conflicts) { scheduler t{}; - migraphx::program p; + migraphx::module m; + + auto one = m.add_literal(1); + auto onep1 = m.add_instruction(unary_op{}, one); + auto onep2 = m.add_instruction(unary_op{}, one); + auto binary1 = m.add_instruction(nary_op{}, onep1, onep2); + auto onep3 = m.add_instruction(unary_op{}, binary1); + auto onep4 = m.add_instruction(unary_op{}, binary1); + auto binary2 = m.add_instruction(nary_op{}, onep3, onep4); + t.run_pass(m); + + EXPECT(t.get_stream(onep1) != t.get_stream(onep2)); + EXPECT(t.get_stream(onep3) != t.get_stream(onep4)); + EXPECT(get_wait_for(binary1) == + get_wait_for(t.get_stream(binary1), {t.get_stream(onep1), t.get_stream(onep2)})); + t.check_conflicts(m, {{onep1}, {onep2}}); + t.check_conflicts(m, {{onep3}, {onep4}}); + + t.check_conflicts(m, {{onep1, onep2}, {onep3, onep4}}, false); + t.check_conflicts(m, {{binary1}, {binary2}}, false); +} - auto i1 = p.add_literal(0); - auto i2 = p.add_literal(1); - auto i3 = p.add_literal(1); - auto i4 = p.add_literal(2); - auto i7 = p.add_instruction(nary_op{"i7"}, i1, i4, i3, i2); - auto i8 = p.add_literal(2); - auto i9 = p.add_instruction(migraphx::op::identity{}, i8); - auto i10 = p.add_literal(1); - auto i11 = p.add_instruction(nary_op{"i11"}, i7, i9, i10); - auto i12 = p.add_literal(2); - auto i13 = p.add_instruction(migraphx::op::identity{}, i12); - auto i14 = p.add_literal(1); - auto i15 = p.add_literal(1); - auto i16 = p.add_literal(2); - auto i17 = p.add_instruction(nary_op{"i17"}, i11, i16, i15, i13, i14); - auto i18 = p.add_literal(2); - auto i19 = p.add_instruction(migraphx::op::identity{}, i18); - auto i20 = p.add_literal(1); - auto i21 = p.add_literal(1); - auto i22 = p.add_literal(2); - auto i23 = p.add_instruction(nary_op{"i23"}, i17, i22, i21, i19, i20); - auto i24 = p.add_literal(1); - auto i25 = p.add_instruction(nary_op{"i25"}, i23, i24); - auto i26 = p.add_literal(2); - auto i27 = p.add_instruction(migraphx::op::identity{}, i26); - auto i28 = p.add_literal(1); - auto i29 = p.add_literal(1); - auto i30 = p.add_literal(2); - auto i31 = p.add_instruction(nary_op{"i31"}, i25, i30, i29, i27, i28); - auto i32 = p.add_literal(2); - auto i33 = p.add_instruction(migraphx::op::identity{}, i32); - auto i34 = p.add_literal(1); - auto i35 = p.add_literal(1); - auto i36 = p.add_literal(2); - auto i37 = p.add_instruction(nary_op{"i37"}, i31, i36, i35, i33, i34); - auto i38 = p.add_literal(1); - auto i39 = p.add_instruction(nary_op{"i39"}, i37, i38); - auto i41 = p.add_literal(2); - auto i42 = p.add_instruction(migraphx::op::identity{}, i41); - auto i43 = p.add_literal(1); - auto i44 = p.add_literal(1); - auto i45 = p.add_literal(2); - auto i48 = p.add_instruction(nary_op{"i48"}, i39, i45, i44, i42, i43); - auto i49 = p.add_literal(2); - auto i50 = p.add_instruction(migraphx::op::identity{}, i49); - auto i51 = p.add_literal(1); - auto i52 = p.add_literal(1); - auto i53 = p.add_literal(2); - auto i54 = p.add_instruction(nary_op{"i54"}, i48, i53, i52, i50, i51); - auto i55 = p.add_literal(1); - auto i56 = p.add_instruction(migraphx::op::identity{}, i55); - auto i57 = p.add_literal(2); - auto i58 = p.add_instruction(migraphx::op::identity{}, i57); - auto i59 = p.add_literal(1); - auto i60 = p.add_literal(2); - auto i61 = p.add_instruction(nary_op{"i61"}, i54, i60, i59, i58, i56); - auto i62 = p.add_literal(2); - auto i63 = p.add_instruction(migraphx::op::identity{}, i62); - auto i64 = p.add_literal(1); - auto i65 = p.add_literal(1); - auto i66 = p.add_literal(2); - auto i69 = p.add_instruction(nary_op{"i69"}, i39, i66, i65, i63, i64); - auto i70 = p.add_instruction(migraphx::op::identity{}, i55); - auto i71 = p.add_literal(2); - auto i72 = p.add_instruction(migraphx::op::identity{}, i71); - auto i73 = p.add_literal(1); - auto i74 = p.add_literal(2); - auto i75 = p.add_instruction(nary_op{"i75"}, i69, i74, i73, i72, i70); - auto i77 = p.add_literal(1); - auto i80 = p.add_instruction(nary_op{"i80"}, i39, i77); - auto i81 = p.add_instruction(migraphx::op::identity{}, i55); - auto i82 = p.add_literal(2); - auto i83 = p.add_instruction(migraphx::op::identity{}, i82); - auto i84 = p.add_literal(1); - auto i85 = p.add_literal(2); - auto i86 = p.add_instruction(nary_op{"i86"}, i80, i85, i84, i83, i81); - auto i88 = p.add_instruction(migraphx::op::identity{}, i55); - auto i89 = p.add_literal(2); - auto i90 = p.add_instruction(migraphx::op::identity{}, i89); - auto i91 = p.add_literal(1); - auto i92 = p.add_literal(2); - auto i94 = p.add_instruction(nary_op{"i94"}, i39, i92, i91, i90, i88); - auto i96 = p.add_instruction(migraphx::op::identity{}, i55, i94, i75, i61, i86); - auto i97 = p.add_literal(2); - auto i98 = p.add_instruction(migraphx::op::identity{}, i97); - auto i99 = p.add_literal(3); - auto i100 = p.add_literal(1); - auto i101 = p.add_literal(2); - auto output = p.add_instruction(nary_op{"output"}, i96, i101, i100, i98, i99); - - t.run_pass(p); +TEST_CASE(inception1) +{ + scheduler t{}; + migraphx::module m; + + auto i1 = m.add_literal(0); + auto i2 = m.add_literal(1); + auto i3 = m.add_literal(1); + auto i4 = m.add_literal(2); + auto i7 = m.add_instruction(nary_op{"i7"}, i1, i4, i3, i2); + auto i8 = m.add_literal(2); + auto i9 = m.add_instruction(migraphx::make_op("identity"), i8); + auto i10 = m.add_literal(1); + auto i11 = m.add_instruction(nary_op{"i11"}, i7, i9, i10); + auto i12 = m.add_literal(2); + auto i13 = m.add_instruction(migraphx::make_op("identity"), i12); + auto i14 = m.add_literal(1); + auto i15 = m.add_literal(1); + auto i16 = m.add_literal(2); + auto i17 = m.add_instruction(nary_op{"i17"}, i11, i16, i15, i13, i14); + auto i18 = m.add_literal(2); + auto i19 = m.add_instruction(migraphx::make_op("identity"), i18); + auto i20 = m.add_literal(1); + auto i21 = m.add_literal(1); + auto i22 = m.add_literal(2); + auto i23 = m.add_instruction(nary_op{"i23"}, i17, i22, i21, i19, i20); + auto i24 = m.add_literal(1); + auto i25 = m.add_instruction(nary_op{"i25"}, i23, i24); + auto i26 = m.add_literal(2); + auto i27 = m.add_instruction(migraphx::make_op("identity"), i26); + auto i28 = m.add_literal(1); + auto i29 = m.add_literal(1); + auto i30 = m.add_literal(2); + auto i31 = m.add_instruction(nary_op{"i31"}, i25, i30, i29, i27, i28); + auto i32 = m.add_literal(2); + auto i33 = m.add_instruction(migraphx::make_op("identity"), i32); + auto i34 = m.add_literal(1); + auto i35 = m.add_literal(1); + auto i36 = m.add_literal(2); + auto i37 = m.add_instruction(nary_op{"i37"}, i31, i36, i35, i33, i34); + auto i38 = m.add_literal(1); + auto i39 = m.add_instruction(nary_op{"i39"}, i37, i38); + auto i41 = m.add_literal(2); + auto i42 = m.add_instruction(migraphx::make_op("identity"), i41); + auto i43 = m.add_literal(1); + auto i44 = m.add_literal(1); + auto i45 = m.add_literal(2); + auto i48 = m.add_instruction(nary_op{"i48"}, i39, i45, i44, i42, i43); + auto i49 = m.add_literal(2); + auto i50 = m.add_instruction(migraphx::make_op("identity"), i49); + auto i51 = m.add_literal(1); + auto i52 = m.add_literal(1); + auto i53 = m.add_literal(2); + auto i54 = m.add_instruction(nary_op{"i54"}, i48, i53, i52, i50, i51); + auto i55 = m.add_literal(1); + auto i56 = m.add_instruction(migraphx::make_op("identity"), i55); + auto i57 = m.add_literal(2); + auto i58 = m.add_instruction(migraphx::make_op("identity"), i57); + auto i59 = m.add_literal(1); + auto i60 = m.add_literal(2); + auto i61 = m.add_instruction(nary_op{"i61"}, i54, i60, i59, i58, i56); + auto i62 = m.add_literal(2); + auto i63 = m.add_instruction(migraphx::make_op("identity"), i62); + auto i64 = m.add_literal(1); + auto i65 = m.add_literal(1); + auto i66 = m.add_literal(2); + auto i69 = m.add_instruction(nary_op{"i69"}, i39, i66, i65, i63, i64); + auto i70 = m.add_instruction(migraphx::make_op("identity"), i55); + auto i71 = m.add_literal(2); + auto i72 = m.add_instruction(migraphx::make_op("identity"), i71); + auto i73 = m.add_literal(1); + auto i74 = m.add_literal(2); + auto i75 = m.add_instruction(nary_op{"i75"}, i69, i74, i73, i72, i70); + auto i77 = m.add_literal(1); + auto i80 = m.add_instruction(nary_op{"i80"}, i39, i77); + auto i81 = m.add_instruction(migraphx::make_op("identity"), i55); + auto i82 = m.add_literal(2); + auto i83 = m.add_instruction(migraphx::make_op("identity"), i82); + auto i84 = m.add_literal(1); + auto i85 = m.add_literal(2); + auto i86 = m.add_instruction(nary_op{"i86"}, i80, i85, i84, i83, i81); + auto i88 = m.add_instruction(migraphx::make_op("identity"), i55); + auto i89 = m.add_literal(2); + auto i90 = m.add_instruction(migraphx::make_op("identity"), i89); + auto i91 = m.add_literal(1); + auto i92 = m.add_literal(2); + auto i94 = m.add_instruction(nary_op{"i94"}, i39, i92, i91, i90, i88); + auto i96 = m.add_instruction(migraphx::make_op("identity"), i55, i94, i75, i61, i86); + auto i97 = m.add_literal(2); + auto i98 = m.add_instruction(migraphx::make_op("identity"), i97); + auto i99 = m.add_literal(3); + auto i100 = m.add_literal(1); + auto i101 = m.add_literal(2); + auto output = m.add_instruction(nary_op{"output"}, i96, i101, i100, i98, i99); + + t.run_pass(m); EXPECT(t.get_streams({i7, i11, i17, i23, i25, i31, i37, i39}) == t.get_streams({i7, i7, i7, i7, i7, i7, i7, i7})); @@ -874,7 +917,68 @@ TEST_CASE(inception1) get_wait_for(t.get_stream(output), {t.get_stream(i94), t.get_stream(i75), t.get_stream(i61), t.get_stream(i86)})); - t.check_conflicts(p, {{i80, i86}, {i69, i75}, {i48, i54, i61}, {i94}}); + t.check_conflicts(m, {{i80, i86}, {i69, i75}, {i48, i54, i61}, {i94}}); +} + +TEST_CASE(if_pl_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + migraphx::shape xs{migraphx::shape::float_type, {2, 3}}; + migraphx::shape ys{migraphx::shape::float_type, {3, 3}}; + std::vector datax = {1, 2, 3, 4, 5, 6}; + std::vector datay = {8, 7, 6, 5, 4, 3, 2, 1, 0}; + + auto lx = mm->add_literal(migraphx::literal(xs, datax)); + auto ly = mm->add_literal(migraphx::literal(ys, datay)); + auto cond = mm->add_parameter("cond", cond_s); + auto x = mm->add_parameter("x", xs); + auto y = mm->add_parameter("y", ys); + + auto* then_mod = p.create_module("If_5_if"); + auto l1 = then_mod->add_literal(migraphx::literal(ys, datay)); + auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x, lx); + then_mod->add_return({a1, l1}); + + auto* else_mod = p.create_module("If_5_else"); + auto l2 = else_mod->add_literal(migraphx::literal(xs, datax)); + auto a2 = else_mod->add_instruction(migraphx::make_op("mul"), y, ly); + else_mod->add_return({l2, a2}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r2 = mm->add_return({ret}); + + scheduler t{}; + auto sub_modules = p.get_modules(); + std::reverse(sub_modules.begin(), sub_modules.end()); + + for(const auto& smod : sub_modules) + { + t.run_pass(*smod); + } + + EXPECT(t.has_stream(ret) == false); + EXPECT(t.has_stream(r2) == false); +} + +TEST_CASE(unused_param_test) +{ + migraphx::module mm; + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + + auto x = mm.add_parameter("x", s); + auto y = mm.add_parameter("y", s); + auto z = mm.add_parameter("z", s); + + auto r = mm.add_instruction(migraphx::make_op("add"), x, y); + mm.add_return({r}); + + scheduler t{}; + t.run_pass(mm); + + EXPECT(t.has_stream(z) == false); + EXPECT(t.has_stream(r) == false); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/serialize_program.cpp b/test/serialize_program.cpp new file mode 100644 index 0000000000000000000000000000000000000000..99bd1abaf6e46c97421a21e803965625252de1dc --- /dev/null +++ b/test/serialize_program.cpp @@ -0,0 +1,118 @@ +#include +#include +#include +#include "test.hpp" +#include + +#include + +migraphx::program create_program() +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", {migraphx::shape::int32_type}); + auto two = mm->add_literal(2); + auto add = mm->add_instruction(migraphx::make_op("add"), x, two); + mm->add_return({add}); + return p; +} + +TEST_CASE(as_value) +{ + migraphx::program p1 = create_program(); + migraphx::program p2; + p2.from_value(p1.to_value()); + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(as_msgpack) +{ + migraphx::file_options options; + options.format = "msgpack"; + migraphx::program p1 = create_program(); + std::vector buffer = migraphx::save_buffer(p1, options); + migraphx::program p2 = migraphx::load_buffer(buffer, options); + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(as_json) +{ + migraphx::file_options options; + options.format = "json"; + migraphx::program p1 = create_program(); + std::vector buffer = migraphx::save_buffer(p1, options); + migraphx::program p2 = migraphx::load_buffer(buffer, options); + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(as_file) +{ + std::string filename = "migraphx_program.mxr"; + migraphx::program p1 = create_program(); + migraphx::save(p1, filename); + migraphx::program p2 = migraphx::load(filename); + std::remove(filename.c_str()); + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(compiled) +{ + migraphx::program p1 = create_program(); + p1.compile(migraphx::ref::target{}); + std::vector buffer = migraphx::save_buffer(p1); + migraphx::program p2 = migraphx::load_buffer(buffer); + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(unknown_format) +{ + migraphx::file_options options; + options.format = "???"; + + EXPECT(test::throws([&] { migraphx::save_buffer(create_program(), options); })); + EXPECT(test::throws([&] { migraphx::load_buffer(std::vector{}, options); })); +} + +TEST_CASE(program_with_module) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sd{migraphx::shape::float_type, {2, 3}}; + auto x = mm->add_parameter("x", sd); + + std::vector one(sd.elements(), 1); + std::vector two(sd.elements(), 2); + + auto* then_smod = p.create_module("then_smod"); + auto l1 = then_smod->add_literal(migraphx::literal{sd, one}); + auto r1 = then_smod->add_instruction(migraphx::make_op("add"), x, l1); + then_smod->add_return({r1}); + + auto* else_smod = p.create_module("else_smod"); + auto l2 = else_smod->add_literal(migraphx::literal{sd, two}); + auto r2 = else_smod->add_instruction(migraphx::make_op("mul"), x, l2); + else_smod->add_return({r2}); + + migraphx::shape s_cond{migraphx::shape::bool_type, {1}}; + auto cond = mm->add_parameter("cond", s_cond); + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_smod, else_smod}); + mm->add_return({ret}); + + migraphx::program p1 = p; + auto v = p.to_value(); + auto v1 = p1.to_value(); + EXPECT(v == v1); + + std::stringstream ss; + p.print_cpp(ss); + std::stringstream ss1; + p1.print_cpp(ss1); + EXPECT(ss.str() == ss1.str()); + + migraphx::program p2; + p2.from_value(v); + EXPECT(p1.sort() == p2.sort()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/serialize_test.cpp b/test/serialize_test.cpp new file mode 100755 index 0000000000000000000000000000000000000000..b199816f8bcf77b87ab259572cd7ca7d07add077 --- /dev/null +++ b/test/serialize_test.cpp @@ -0,0 +1,116 @@ +#include +#include +#include + +#include + +struct empty_type +{ +}; +struct reflectable_type +{ + enum simple_enum + { + simple1, + simple2, + simple3 + }; + enum class class_enum + { + class1, + class2, + class3 + }; + std::vector ints = {}; + std::string name = ""; + float fvalue = 0.0; + empty_type et{}; + simple_enum se = simple1; + class_enum ce = class_enum::class1; + + struct nested_type + { + int value; + template + static auto reflect(Self& self, F f) + { + return migraphx::pack(f(self.value, "value")); + } + }; + std::vector nested_types = {}; + + template + static auto reflect(Self& self, F f) + { + return migraphx::pack(f(self.ints, "ints"), + f(self.name, "name"), + f(self.fvalue, "fvalue"), + f(self.et, "et"), + f(self.se, "se"), + f(self.ce, "ce"), + f(self.nested_types, "nested_types")); + } +}; + +TEST_CASE(serialize_reflectable_type) +{ + reflectable_type t1{{1, 2}, + "hello", + 1.0, + {}, + reflectable_type::simple1, + reflectable_type::class_enum::class2, + {{1}, {2}}}; + migraphx::value v1 = migraphx::to_value(t1); + reflectable_type t2 = migraphx::from_value(v1); + migraphx::value v2 = migraphx::to_value(t2); + migraphx::value v3 = migraphx::to_value(reflectable_type{}); + + EXPECT(v1 == v2); + EXPECT(v1 != v3); + EXPECT(v2 != v3); +} + +TEST_CASE(serialize_empty_array) +{ + std::vector ints = {}; + migraphx::value v = migraphx::to_value(ints); + EXPECT(v.is_array()); + EXPECT(v.empty()); + v.push_back(1); + EXPECT(v.size() == 1); + EXPECT(v.front().to() == 1); +} + +struct empty_struct +{ + template + static auto reflect(Self&, F) + { + return migraphx::pack(); + } +}; + +TEST_CASE(serialize_empty_struct) +{ + empty_struct es{}; + migraphx::value v = migraphx::to_value(es); + EXPECT(v.is_object()); + EXPECT(v.empty()); + v["a"] = 1; + EXPECT(v.size() == 1); + EXPECT(v.at("a").to() == 1); +} + +TEST_CASE(from_value_binary) +{ + std::vector data(10); + std::iota(data.begin(), data.end(), 0); + + migraphx::value v = migraphx::value::binary{data}; + + auto out = migraphx::from_value(v); + EXPECT(out == data); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/shape_test.cpp b/test/shape_test.cpp old mode 100644 new mode 100755 index b0cae0a62d56ec5eaf0f16489ba7563637518e96..c3f0d319b4d121e30f3e20085af3ec3e73b5a250 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -1,5 +1,9 @@ #include +#include +#include +#include +#include #include #include #include @@ -47,6 +51,15 @@ TEST_CASE(test_shape_packed) EXPECT(not s.broadcasted()); } +TEST_CASE(test_shape_non_packed_single_dim) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 64, 35, 35}, {156800, 1225, 35, 1}}; + EXPECT(s.standard()); + EXPECT(s.packed()); + EXPECT(not s.transposed()); + EXPECT(not s.broadcasted()); +} + TEST_CASE(test_shape_transposed1) { migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 2}}; @@ -92,6 +105,33 @@ TEST_CASE(test_shape_overlap3) EXPECT(not s.broadcasted()); } +TEST_CASE(test_shape_scalar1) +{ + migraphx::shape s{migraphx::shape::float_type}; + EXPECT(s.standard()); + EXPECT(s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_shape_scalar2) +{ + migraphx::shape s{migraphx::shape::float_type, {1}, {0}}; + EXPECT(s.standard()); + EXPECT(s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + +TEST_CASE(test_shape_scalar_broadcast) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 2, 3, 3}, {0, 0, 0, 0}}; + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(not s.transposed()); + EXPECT(s.broadcasted()); +} + TEST_CASE(test_shape_broadcasted) { migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 0}}; @@ -145,6 +185,53 @@ TEST_CASE(test_shape_default_copy) EXPECT(!(s1 != s2)); } +TEST_CASE(test_shape_normalize_standard1) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 1}}; + EXPECT(s.standard()); + auto n = s.normalize_standard(); + EXPECT(n == s); +} + +TEST_CASE(test_shape_normalize_standard2) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 64, 35, 35}, {156800, 1225, 35, 1}}; + EXPECT(s.standard()); + auto n = s.normalize_standard(); + EXPECT(n.standard()); + EXPECT(n != s); + EXPECT(n.lens() == s.lens()); + EXPECT(n.type() == s.type()); +} + +TEST_CASE(test_shape_normalize_standard3) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 2}}; + EXPECT(not s.standard()); + auto n = s.normalize_standard(); + EXPECT(n == s); +} + +TEST_CASE(test_shape_normalize_scalar1) +{ + migraphx::shape s{migraphx::shape::float_type}; + EXPECT(s.standard()); + EXPECT(s.scalar()); + auto n = s.normalize_standard(); + EXPECT(n != s); + EXPECT(n.standard()); + EXPECT(not n.scalar()); +} + +TEST_CASE(test_shape_normalize_scalar2) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 2}, {0, 0}}; + EXPECT(not s.standard()); + EXPECT(s.scalar()); + auto n = s.normalize_standard(); + EXPECT(n == s); +} + TEST_CASE(test_shape4) { migraphx::shape s{migraphx::shape::float_type, {100, 32, 8, 8}}; @@ -287,4 +374,249 @@ TEST_CASE(test_shape4_nonpacked) EXPECT(s.index(s.elements() - 1) == 469273); } +TEST_CASE(test_serialize) +{ + migraphx::shape s1{migraphx::shape::float_type, {100, 32, 8, 8}}; + auto v1 = migraphx::to_value(s1); + migraphx::shape s2{migraphx::shape::uint64_type, {2, 2}}; + auto v2 = migraphx::to_value(s2); + EXPECT(v1 != v2); + + auto s3 = migraphx::from_value(v1); + EXPECT(s3 == s1); + auto s4 = migraphx::from_value(v2); + EXPECT(s4 == s2); + EXPECT(s3 != s4); +} + +TEST_CASE(tuple) +{ + migraphx::shape s{{migraphx::shape{migraphx::shape::float_type}, + migraphx::shape{migraphx::shape::int8_type}}}; + EXPECT(s.type() == migraphx::shape::tuple_type); + EXPECT(s.bytes() == 4 + 1); + EXPECT(s.type_size() == 0); + EXPECT(s.type_string() == "tuple_type"); + EXPECT(s.lens().empty()); + EXPECT(s.strides().empty()); + EXPECT(not s.standard()); + EXPECT(not s.packed()); + EXPECT(not s.broadcasted()); + EXPECT(not s.transposed()); + EXPECT(not s.scalar()); + EXPECT(s.sub_shapes().size() == 2); + EXPECT(s.sub_shapes()[0].type() == migraphx::shape::float_type); + EXPECT(s.sub_shapes()[0].elements() == 1); + EXPECT(s.sub_shapes()[1].type() == migraphx::shape::int8_type); + EXPECT(s.sub_shapes()[1].elements() == 1); + EXPECT(test::throws([&] { s.visit_type([](auto) {}); })); +} + +TEST_CASE(tuple_copy) +{ + migraphx::shape s1{{migraphx::shape{migraphx::shape::float_type}, + migraphx::shape{migraphx::shape::int8_type}}}; + migraphx::shape s2{{migraphx::shape{migraphx::shape::float_type}, + migraphx::shape{migraphx::shape::int8_type}}}; + EXPECT(s1 == s2); + auto s3 = s1; + EXPECT(s3 == s1); + EXPECT(s3 == s2); + migraphx::shape s4{{migraphx::shape{migraphx::shape::int8_type}, + migraphx::shape{migraphx::shape::float_type}}}; + EXPECT(s4 != s1); + EXPECT(s4 != s2); + EXPECT(s4 != s3); +} + +TEST_CASE(tuple_print) +{ + migraphx::shape s{{migraphx::shape{migraphx::shape::float_type}, + migraphx::shape{migraphx::shape::int8_type}}}; + std::string x = migraphx::to_string(s); + EXPECT(x.front() == '['); + EXPECT(x.back() == ']'); + EXPECT(migraphx::contains(x, "float")); + EXPECT(migraphx::contains(x, "int8")); +} + +TEST_CASE(tuple_serialize) +{ + migraphx::shape s1{{migraphx::shape{migraphx::shape::float_type}, + migraphx::shape{migraphx::shape::int8_type}}}; + migraphx::shape s2{{migraphx::shape{migraphx::shape::int8_type}, + migraphx::shape{migraphx::shape::float_type}}}; + auto v1 = migraphx::to_value(s1); + auto v2 = migraphx::to_value(s2); + EXPECT(v1 != v2); + + auto s3 = migraphx::from_value(v1); + EXPECT(s3 == s1); + auto s4 = migraphx::from_value(v2); + EXPECT(s4 == s2); + EXPECT(s3 != s4); +} + +TEST_CASE(test_with_lens1) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 2}, {1, 2}}; + auto s2 = s1.with_lens({4, 3}); + EXPECT(s2.transposed()); + migraphx::shape s3{migraphx::shape::float_type, {4, 3}, {1, 4}}; + EXPECT(s2 == s3); +} + +TEST_CASE(test_with_lens2) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 2}, {2, 1}}; + auto s2 = s1.with_lens({3, 4}); + EXPECT(s2.standard()); + migraphx::shape s3{migraphx::shape::float_type, {3, 4}}; + EXPECT(s2 == s3); +} + +TEST_CASE(test_with_lens_ambigous1) +{ + migraphx::shape s1{migraphx::shape::float_type, {64, 1, 24, 24}}; + auto s2 = s1.with_lens({64, 3, 24, 24}); + EXPECT(not s2.transposed()); + migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}}; + EXPECT(s2 == s3); +} + +TEST_CASE(test_with_lens_ambigous2) +{ + auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 1}}, {0, 3, 1, 2}); + auto s2 = s1.with_lens({64, 3, 24, 24}); + EXPECT(s2.transposed()); + migraphx::shape s3 = + migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2}); + EXPECT(s2 == s3); +} + +TEST_CASE(test_with_lens_ambigous3) +{ + migraphx::shape s1{migraphx::shape::float_type, {64, 3, 1, 1}}; + auto s2 = s1.with_lens({64, 3, 24, 24}); + EXPECT(not s2.transposed()); + migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}}; + EXPECT(s2 == s3); +} + +TEST_CASE(test_with_lens_ambigous4) +{ + auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {64, 1, 1, 3}}, {0, 3, 1, 2}); + auto s2 = s1.with_lens({64, 3, 24, 24}); + EXPECT(s2.transposed()); + migraphx::shape s3 = + migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2}); + EXPECT(s2 == s3); +} + +TEST_CASE(test_with_lens_ambigous5) +{ + migraphx::shape s1{migraphx::shape::float_type, {1, 5, 24, 24}}; + auto s2 = s1.with_lens({64, 3, 24, 24}); + EXPECT(not s2.transposed()); + migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}}; + EXPECT(s2 == s3); +} + +TEST_CASE(test_with_lens_ambigous6) +{ + auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {1, 24, 24, 5}}, {0, 3, 1, 2}); + auto s2 = s1.with_lens({64, 3, 24, 24}); + EXPECT(s2.transposed()); + migraphx::shape s3 = + migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2}); + EXPECT(s2 == s3); +} + +TEST_CASE(test_with_lens_ambigous7) +{ + auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {1, 1, 1, 3}}, {0, 3, 1, 2}); + auto s2 = s1.with_lens({64, 3, 24, 24}); + EXPECT(s2.transposed()); + migraphx::shape s3 = + migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2}); + EXPECT(s2 == s3); +} + +TEST_CASE(test_with_lens_ambigous8) +{ + migraphx::shape s1{migraphx::shape::float_type, {1, 1, 24, 24}}; + auto s2 = s1.with_lens({64, 3, 24, 24}); + EXPECT(not s2.transposed()); + migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}}; + EXPECT(s2 == s3); +} + +TEST_CASE(test_with_lens_ambigous9) +{ + auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {1, 24, 24, 1}}, {0, 3, 1, 2}); + auto s2 = s1.with_lens({64, 3, 24, 24}); + EXPECT(s2.transposed()); + migraphx::shape s3 = + migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2}); + EXPECT(s2 == s3); +} + +TEST_CASE(test_with_lens_ambigous10) +{ + migraphx::shape s1{migraphx::shape::float_type, {3, 2, 4, 1}}; + auto s2 = s1.with_lens({3, 2, 4, 1}); + EXPECT(not s2.transposed()); + migraphx::shape s3{migraphx::shape::float_type, {3, 2, 4, 1}}; + EXPECT(s2 == s3); +} + +TEST_CASE(test_with_lens_ambigous11) +{ + migraphx::shape s1{migraphx::shape::float_type, {64, 1, 1, 1}}; + auto s2 = s1.with_lens({64, 3, 24, 24}); + EXPECT(s1.standard()); + EXPECT(s2.standard()); + migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}}; + EXPECT(s2 == s3); +} + +TEST_CASE(test_with_lens_ambigous12) +{ + migraphx::shape s1{migraphx::shape::float_type, {1, 64, 1, 1}}; + auto s2 = s1.with_lens({64, 3, 24, 24}); + EXPECT(s1.standard()); + EXPECT(s2.standard()); + migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}}; + EXPECT(s2 == s3); +} + +TEST_CASE(test_with_lens_ambigous13) +{ + auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {1, 1, 1, 3}}, {0, 3, 1, 2}); + auto s2 = s1.with_lens({64, 3, 24, 24}); + EXPECT(s2.transposed()); + migraphx::shape s3 = + migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2}); + EXPECT(s2 == s3); +} + +TEST_CASE(cpp_type_name) +{ + EXPECT(migraphx::shape::cpp_type(migraphx::shape::int8_type) == "int8_t"); + EXPECT(migraphx::shape::cpp_type(migraphx::shape::float_type) == "float"); + EXPECT(migraphx::shape::cpp_type(migraphx::shape::half_type) == "half"); + EXPECT(test::throws([&] { migraphx::shape::cpp_type(migraphx::shape::tuple_type); })); +} + +TEST_CASE(test_with_type) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 0}}; + EXPECT(s.type() == migraphx::shape::float_type); + auto new_s = s.with_type(migraphx::shape::half_type); + EXPECT(s.type() == migraphx::shape::float_type); + EXPECT(s.type() != new_s.type()); + EXPECT(s.lens() == new_s.lens()); + EXPECT(s.strides() == new_s.strides()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index 0d2e58462537e2c10076917bb19c731960cad2ca..f168baa20af53c7a64e0c2098e4a5aacb16c2bf9 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -6,96 +6,98 @@ #include #include #include +#include + #include -void run_pass(migraphx::program& p) +void run_pass(migraphx::module& m) { - migraphx::run_passes(p, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}}); } TEST_CASE(simplify_add1) { - migraphx::program p1; + migraphx::module m1; { - auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}}); - auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}}); - auto one = p1.add_literal(1); - auto two = p1.add_literal(2); - auto sum1 = p1.add_instruction(migraphx::op::add{}, x, one); - auto sum2 = p1.add_instruction(migraphx::op::add{}, y, two); - auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); - p1.add_instruction(pass_op{}, sum3); + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1}}); + auto one = m1.add_literal(1); + auto two = m1.add_literal(2); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, one); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, two); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); + m1.add_instruction(pass_op{}, sum3); } - run_pass(p1); + run_pass(m1); - migraphx::program p2; + migraphx::module m2; { - auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}}); - auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}}); - auto one = p2.add_literal(1); - auto two = p2.add_literal(2); - auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two); - auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y); - auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1); - p2.add_instruction(pass_op{}, sum3); + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1}}); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, two); + auto sum2 = m2.add_instruction(migraphx::make_op("add"), x, y); + auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum2, sum1); + m2.add_instruction(pass_op{}, sum3); } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } TEST_CASE(simplify_add2) { - migraphx::program p1; + migraphx::module m1; { - auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}}); - auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}}); - auto one = p1.add_literal(1); - auto two = p1.add_literal(2); - auto sum1 = p1.add_instruction(migraphx::op::add{}, one, x); - auto sum2 = p1.add_instruction(migraphx::op::add{}, two, y); - auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); - p1.add_instruction(pass_op{}, sum3); + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1}}); + auto one = m1.add_literal(1); + auto two = m1.add_literal(2); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, x); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), two, y); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); + m1.add_instruction(pass_op{}, sum3); } - run_pass(p1); + run_pass(m1); - migraphx::program p2; + migraphx::module m2; { - auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}}); - auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}}); - auto one = p2.add_literal(1); - auto two = p2.add_literal(2); - auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two); - auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y); - auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1); - p2.add_instruction(pass_op{}, sum3); + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1}}); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, two); + auto sum2 = m2.add_instruction(migraphx::make_op("add"), x, y); + auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum2, sum1); + m2.add_instruction(pass_op{}, sum3); } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } TEST_CASE(simplify_add3) { - migraphx::program p1; + migraphx::module m1; { - auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}}); - auto one = p1.add_literal(1); - auto two = p1.add_literal(2); - auto sum1 = p1.add_instruction(migraphx::op::add{}, one, x); - auto sum2 = p1.add_instruction(migraphx::op::add{}, one, two); - auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); - p1.add_instruction(pass_op{}, sum3); + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto one = m1.add_literal(1); + auto two = m1.add_literal(2); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, x); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), one, two); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); + m1.add_instruction(pass_op{}, sum3); } - run_pass(p1); + run_pass(m1); - migraphx::program p2; + migraphx::module m2; { - auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}}); - auto one = p2.add_literal(1); - auto two = p2.add_literal(2); - auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two); - auto sum2 = p2.add_instruction(migraphx::op::add{}, one, sum1); - auto sum3 = p2.add_instruction(migraphx::op::add{}, x, sum2); - p2.add_instruction(pass_op{}, sum3); + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, two); + auto sum2 = m2.add_instruction(migraphx::make_op("add"), one, sum1); + auto sum3 = m2.add_instruction(migraphx::make_op("add"), x, sum2); + m2.add_instruction(pass_op{}, sum3); } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } TEST_CASE(simplify_add_broadcast1) @@ -103,34 +105,34 @@ TEST_CASE(simplify_add_broadcast1) migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::op::broadcast b{1, {1, 2, 3, 3}}; - migraphx::program p1; + migraphx::module m1; { - auto x = p1.add_parameter("x", outer); - auto y = p1.add_parameter("y", outer); - auto one = p1.add_literal({inner, {1, 1}}); - auto oneb = p1.add_instruction(b, one); - auto two = p1.add_literal({inner, {2, 2}}); - auto twob = p1.add_instruction(b, two); - auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb); - auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob); - auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); - p1.add_instruction(pass_op{}, sum3); + auto x = m1.add_parameter("x", outer); + auto y = m1.add_parameter("y", outer); + auto one = m1.add_literal({inner, {1, 1}}); + auto oneb = m1.add_instruction(b, one); + auto two = m1.add_literal({inner, {2, 2}}); + auto twob = m1.add_instruction(b, two); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, twob); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); + m1.add_instruction(pass_op{}, sum3); } - run_pass(p1); + run_pass(m1); - migraphx::program p2; + migraphx::module m2; { - auto x = p2.add_parameter("x", outer); - auto y = p2.add_parameter("y", outer); - auto one = p2.add_literal({inner, {1, 1}}); - auto two = p2.add_literal({inner, {2, 2}}); - auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two); - auto sum1b = p2.add_instruction(b, sum1); - auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y); - auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1b); - p2.add_instruction(pass_op{}, sum3); + auto x = m2.add_parameter("x", outer); + auto y = m2.add_parameter("y", outer); + auto one = m2.add_literal({inner, {1, 1}}); + auto two = m2.add_literal({inner, {2, 2}}); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, two); + auto sum1b = m2.add_instruction(b, sum1); + auto sum2 = m2.add_instruction(migraphx::make_op("add"), x, y); + auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum2, sum1b); + m2.add_instruction(pass_op{}, sum3); } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } TEST_CASE(simplify_add_broadcast2) @@ -139,246 +141,2029 @@ TEST_CASE(simplify_add_broadcast2) migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::op::broadcast b{1, {1, 2, 3, 3}}; auto create_program = [&] { - migraphx::program p; - auto x = p.add_parameter("x", outer); - auto y = p.add_parameter("y", outer); - auto one = p.add_literal({inner, {1, 1}}); - auto oneb = p.add_instruction(b, one); - auto two = p.add_literal({outer, {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}}); - auto sum1 = p.add_instruction(migraphx::op::add{}, x, y); - auto sum2 = p.add_instruction(migraphx::op::add{}, oneb, two); - auto sum3 = p.add_instruction(migraphx::op::add{}, sum2, sum1); - p.add_instruction(pass_op{}, sum3); - return p; + migraphx::module m; + auto x = m.add_parameter("x", outer); + auto y = m.add_parameter("y", outer); + auto one = m.add_literal({inner, {1, 1}}); + auto oneb = m.add_instruction(b, one); + auto two = m.add_literal({outer, {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}}); + auto sum1 = m.add_instruction(migraphx::make_op("add"), x, y); + auto sum2 = m.add_instruction(migraphx::make_op("add"), oneb, two); + auto sum3 = m.add_instruction(migraphx::make_op("add"), sum2, sum1); + m.add_instruction(pass_op{}, sum3); + return m; }; - migraphx::program p1 = create_program(); - run_pass(p1); + migraphx::module m1 = create_program(); + run_pass(m1); - migraphx::program p2 = create_program(); - EXPECT(p1 == p2); + migraphx::module m2 = create_program(); + EXPECT(m1 == m2); } // TODO: Add test case // TEST_CASE(simplify_add4) void simplify_add4() { - migraphx::program p1; + migraphx::module m1; { - auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}}); - auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}}); - auto one = p1.add_literal(1); - auto two = p1.add_literal(2); - auto sum1 = p1.add_instruction(migraphx::op::add{}, one, x); - auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, y); - auto sum3 = p1.add_instruction(migraphx::op::add{}, sum2, two); - p1.add_instruction(pass_op{}, sum3); + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1}}); + auto one = m1.add_literal(1); + auto two = m1.add_literal(2); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), one, x); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), sum1, y); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum2, two); + m1.add_instruction(pass_op{}, sum3); } - run_pass(p1); + run_pass(m1); - migraphx::program p2; + migraphx::module m2; { - auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}}); - auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}}); - auto one = p2.add_literal(1); - auto two = p2.add_literal(2); - auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two); - auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y); - auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1); - p2.add_instruction(pass_op{}, sum3); + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1}}); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), one, two); + auto sum2 = m2.add_instruction(migraphx::make_op("add"), x, y); + auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum2, sum1); + m2.add_instruction(pass_op{}, sum3); } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } TEST_CASE(simplify_mul_conv1) { - migraphx::program p; - auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}}); + migraphx::module m; + auto x = m.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}}); auto w = - p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}})); - auto conv = p.add_instruction(migraphx::op::convolution{{1, 1}, {2, 2}, {1, 1}}, x, w); - auto a = p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}})); - auto b = p.add_instruction(migraphx::op::broadcast{1, {1, 256, 14, 14}}, a); - auto mul = p.add_instruction(migraphx::op::mul{}, conv, b); - p.add_instruction(pass_op{}, mul); + m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}})); + auto conv = m.add_instruction( + migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + w); + auto a = m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}})); + auto b = m.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 256, 14, 14}}}), a); + auto mul = m.add_instruction(migraphx::make_op("mul"), conv, b); + m.add_instruction(pass_op{}, mul); EXPECT(conv->outputs().front()->name() == "mul"); - run_pass(p); + run_pass(m); auto new_conv = - std::find_if(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }); + std::find_if(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; }); EXPECT(new_conv->outputs().front()->name() != "mul"); } +TEST_CASE(simplify_mul_slice_conv1) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}}); + auto w = m1.add_literal( + migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}})); + auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv); + auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); + auto b = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a); + auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv); + auto add = m1.add_instruction(migraphx::make_op("add"), mul, slice2); + m1.add_instruction(pass_op{}, add); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}}); + auto w = m2.add_literal( + migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}})); + auto wslice1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {384}}}), w); + auto a = m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); + auto b = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {384, 1024, 1, 1}}}), a); + auto mul = m2.add_instruction(migraphx::make_op("mul"), b, wslice1); + auto wslice2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), mul, wslice2); + auto conv = m2.add_instruction(migraphx::make_op("convolution"), x, concat); + auto slice1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv); + auto slice2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv); + auto add = m2.add_instruction(migraphx::make_op("add"), slice1, slice2); + m2.add_instruction(pass_op{}, add); + } + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_mul_slice_conv_overlapping_slice) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}}); + auto w = m1.add_literal( + migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}})); + auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv); + auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); + auto b = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a); + auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {383}}, {"ends", {767}}}), conv); + auto add = m1.add_instruction(migraphx::make_op("add"), mul, slice2); + m1.add_instruction(pass_op{}, add); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_mul_slice_conv_not_all_slice) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}}); + auto w = m1.add_literal( + migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}})); + auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv); + auto a = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}})); + auto b = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a); + auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b); + auto c = m1.add_literal( + migraphx::generate_literal({migraphx::shape::int32_type, {1, 768, 17, 17}})); + auto add = m1.add_instruction(migraphx::make_op("add"), conv, c); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), mul, add); + m1.add_instruction(pass_op{}, concat); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + TEST_CASE(simplify_mul_add) { - migraphx::program p1; + migraphx::module m1; { - auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}}); - auto one = p1.add_literal(1); - auto two = p1.add_literal(2); - auto sum = p1.add_instruction(migraphx::op::add{}, one, x); - auto mul = p1.add_instruction(migraphx::op::mul{}, sum, two); - p1.add_instruction(pass_op{}, mul); + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto one = m1.add_literal(1); + auto two = m1.add_literal(2); + auto sum = m1.add_instruction(migraphx::make_op("add"), one, x); + auto mul = m1.add_instruction(migraphx::make_op("mul"), sum, two); + m1.add_instruction(pass_op{}, mul); } - run_pass(p1); + run_pass(m1); - migraphx::program p2; + migraphx::module m2; { - auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}}); - auto one = p2.add_literal(1); - auto two = p2.add_literal(2); - auto mul1 = p2.add_instruction(migraphx::op::mul{}, two, x); - auto mul2 = p2.add_instruction(migraphx::op::mul{}, two, one); - auto sum = p2.add_instruction(migraphx::op::add{}, mul1, mul2); - p2.add_instruction(pass_op{}, sum); + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto mul1 = m2.add_instruction(migraphx::make_op("mul"), two, x); + auto mul2 = m2.add_instruction(migraphx::make_op("mul"), two, one); + auto sum = m2.add_instruction(migraphx::make_op("add"), mul1, mul2); + m2.add_instruction(pass_op{}, sum); } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } TEST_CASE(simplify_inner_broadcast) { auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; - migraphx::program p1; + migraphx::module m1; { - auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}}); - auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}}); - auto xb = p1.add_instruction(b, x); - auto yb = p1.add_instruction(b, y); - auto sum = p1.add_instruction(migraphx::op::add{}, xb, yb); - p1.add_instruction(pass_op{}, sum); + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1}}); + auto xb = m1.add_instruction(b, x); + auto yb = m1.add_instruction(b, y); + auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb); + m1.add_instruction(pass_op{}, sum); } - run_pass(p1); + run_pass(m1); - migraphx::program p2; + migraphx::module m2; { - auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}}); - auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}}); - auto sum = p2.add_instruction(migraphx::op::add{}, x, y); - auto sumb = p2.add_instruction(b, sum); - p2.add_instruction(pass_op{}, sumb); + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1}}); + auto sum = m2.add_instruction(migraphx::make_op("add"), x, y); + auto sumb = m2.add_instruction(b, sum); + m2.add_instruction(pass_op{}, sumb); } - EXPECT(p1 == p2); + EXPECT(m1 == m2); } TEST_CASE(simplify_add_conv1) { - migraphx::program p; - auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}}); + migraphx::module m; + auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}}); auto w = - p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}})); - auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}}); + m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}})); + auto y = m.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}}); auto v = - p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}})); - auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w); - auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v); - auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); - p.add_instruction(pass_op{}, sum); - auto s = p.get_shape(); - run_pass(p); - EXPECT(s == p.get_shape()); + m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}})); + auto conv1 = m.add_instruction(migraphx::make_op("convolution"), x, w); + auto conv2 = m.add_instruction(migraphx::make_op("convolution"), y, v); + auto sum = m.add_instruction(migraphx::make_op("add"), conv1, conv2); + m.add_instruction(pass_op{}, sum); + auto s = m.get_output_shapes().back(); + run_pass(m); + EXPECT(s == m.get_output_shapes().back()); EXPECT(std::count_if( - p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1); + m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1); } TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides) { - migraphx::program p; - auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}}); + migraphx::module m; + auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}}); auto w = - p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}})); - auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}}); + m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}})); + auto y = m.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}}); auto v = - p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}})); - auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w); - auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {3, 3}}, y, v); - auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); - p.add_instruction(pass_op{}, sum); - auto s = p.get_shape(); - run_pass(p); - EXPECT(s == p.get_shape()); + m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}})); + auto conv1 = m.add_instruction(migraphx::make_op("convolution"), x, w); + auto conv2 = m.add_instruction( + migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {3, 3}}}), y, v); + auto sum = m.add_instruction(migraphx::make_op("add"), conv1, conv2); + m.add_instruction(pass_op{}, sum); + auto s = m.get_output_shapes().back(); + run_pass(m); + EXPECT(s == m.get_output_shapes().back()); // No fusion EXPECT(std::count_if( - p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2); + m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2); } TEST_CASE(simplify_add_conv_1x1_diff_strides1) { - migraphx::program p; - auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}}); + migraphx::module m; + auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}}); auto w = - p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); - auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}}); + m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); + auto y = m.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}}); auto v = - p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); - auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w); - auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v); - auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); - p.add_instruction(pass_op{}, sum); - auto s = p.get_shape(); - run_pass(p); - EXPECT(s == p.get_shape()); + m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); + auto conv1 = m.add_instruction(migraphx::make_op("convolution"), x, w); + auto conv2 = m.add_instruction( + migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v); + auto sum = m.add_instruction(migraphx::make_op("add"), conv1, conv2); + m.add_instruction(pass_op{}, sum); + auto s = m.get_output_shapes().back(); + run_pass(m); + EXPECT(s == m.get_output_shapes().back()); EXPECT(std::count_if( - p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1); + m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1); } TEST_CASE(simplify_add_conv_1x1_diff_strides2) { - migraphx::program p; - auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}}); + migraphx::module m; + auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}}); + auto w = + m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); + auto y = m.add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}}); + auto v = + m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); + auto conv1 = m.add_instruction( + migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), x, w); + auto conv2 = m.add_instruction(migraphx::make_op("convolution"), y, v); + auto sum = m.add_instruction(migraphx::make_op("add"), conv1, conv2); + m.add_instruction(pass_op{}, sum); + auto s = m.get_output_shapes().back(); + run_pass(m); + EXPECT(s == m.get_output_shapes().back()); + EXPECT(std::count_if( + m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1); +} + +TEST_CASE(simplify_add_conv_1x1_diff_strides_odd) +{ + migraphx::module m; + auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 54, 83, 83}}); auto w = - p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); - auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}}); + m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {54, 54, 1, 1}})); + auto y = m.add_parameter("y", {migraphx::shape::float_type, {1, 54, 165, 165}}); auto v = - p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); - auto conv1 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, x, w); - auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v); - auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); - p.add_instruction(pass_op{}, sum); - auto s = p.get_shape(); - run_pass(p); - EXPECT(s == p.get_shape()); + m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {54, 54, 1, 1}})); + auto conv1 = m.add_instruction(migraphx::make_op("convolution"), x, w); + auto conv2 = m.add_instruction( + migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v); + auto sum = m.add_instruction(migraphx::make_op("add"), conv1, conv2); + m.add_instruction(pass_op{}, sum); + auto s = m.get_output_shapes().back(); + run_pass(m); + EXPECT(s == m.get_output_shapes().back()); EXPECT(std::count_if( - p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1); + m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1); } TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1) { - migraphx::program p; - auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 14}}); + migraphx::module m; + auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 14}}); auto w = - p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); - auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}}); + m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); + auto y = m.add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}}); auto v = - p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); - auto conv1 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 1}}, x, w); - auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v); - auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); - p.add_instruction(pass_op{}, sum); - auto s = p.get_shape(); - run_pass(p); - EXPECT(s == p.get_shape()); + m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); + auto conv1 = m.add_instruction( + migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 1}}}), x, w); + auto conv2 = m.add_instruction(migraphx::make_op("convolution"), y, v); + auto sum = m.add_instruction(migraphx::make_op("add"), conv1, conv2); + m.add_instruction(pass_op{}, sum); + auto s = m.get_output_shapes().back(); + run_pass(m); + EXPECT(s == m.get_output_shapes().back()); // No fusion EXPECT(std::count_if( - p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2); + m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2); } TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2) { - migraphx::program p; - auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}}); + migraphx::module m; + auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}}); auto w = - p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); - auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 14}}); + m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); + auto y = m.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 14}}); auto v = - p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); - auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w); - auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 1}}, y, v); - auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); - p.add_instruction(pass_op{}, sum); - auto s = p.get_shape(); - run_pass(p); - EXPECT(s == p.get_shape()); + m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); + auto conv1 = m.add_instruction(migraphx::make_op("convolution"), x, w); + auto conv2 = m.add_instruction( + migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 1}}}), y, v); + auto sum = m.add_instruction(migraphx::make_op("add"), conv1, conv2); + m.add_instruction(pass_op{}, sum); + auto s = m.get_output_shapes().back(); + run_pass(m); + EXPECT(s == m.get_output_shapes().back()); // No fusion EXPECT(std::count_if( - p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2); + m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2); +} + +TEST_CASE(simplify_concat_add_relu) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {1}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto one = m1.add_literal({s, {1}}); + auto two = m1.add_literal({s, {2}}); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, one); + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, two); + auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), relu1, relu2); + m1.add_instruction(pass_op{}, concat); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto one = m2.add_literal({s, {1}}); + auto two = m2.add_literal({s, {2}}); + auto concat1 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y); + auto concat2 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two); + auto sum = m2.add_instruction(migraphx::make_op("add"), concat1, concat2); + auto relu = m2.add_instruction(migraphx::make_op("relu"), sum); + m2.add_instruction(pass_op{}, relu); + } + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_concat_add_relu_partial) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {1}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto one = m1.add_literal({s, {1}}); + auto two = m1.add_literal({s, {2}}); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, one); + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, two); + auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), x, y); + auto concat = + m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), sum3, relu1, relu2); + m1.add_instruction(pass_op{}, concat); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto one = m2.add_literal({s, {1}}); + auto two = m2.add_literal({s, {2}}); + auto concat1 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y); + auto concat2 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), concat1, concat2); + auto relu = m2.add_instruction(migraphx::make_op("relu"), sum1); + auto sum2 = m2.add_instruction(migraphx::make_op("add"), x, y); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), sum2, relu); + m2.add_instruction(pass_op{}, concat); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_concat_add_relu_partial_broadcast) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}}; + migraphx::module m1; + { + auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto one = m1.add_literal(1); + auto oneb = m1.add_instruction(b, one); + auto two = m1.add_literal(2); + auto twob = m1.add_instruction(b, two); + auto sum = m1.add_instruction(migraphx::make_op("add"), x, y); + auto concat = + m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), sum, oneb, twob); + m1.add_instruction(pass_op{}, concat); + } + run_pass(m1); + + migraphx::module m2; + { + auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}}; + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto concat1 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two); + auto concatb = m2.add_instruction(b, concat1); + auto sum = m2.add_instruction(migraphx::make_op("add"), x, y); + auto concat2 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), sum, concatb); + m2.add_instruction(pass_op{}, concat2); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_concat_add_relu_broadcast_different_axis) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}}; + migraphx::module m1; + { + auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto one = m1.add_literal(1); + auto oneb = m1.add_instruction(b, one); + auto two = m1.add_literal(2); + auto twob = m1.add_instruction(b, two); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, twob); + auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), relu1, relu2); + m1.add_instruction(pass_op{}, concat); + } + run_pass(m1); + + migraphx::module m2; + { + auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}}; + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto concat1 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + auto concat2 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two); + auto concat2b = m2.add_instruction(b, concat2); + auto sum = m2.add_instruction(migraphx::make_op("add"), concat1, concat2b); + auto relu = m2.add_instruction(migraphx::make_op("relu"), sum); + m2.add_instruction(pass_op{}, relu); + } + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_concat_add_relu_broadcast_same_axis) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}}; + migraphx::module m1; + { + auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto one = m1.add_literal(1); + auto oneb = m1.add_instruction(b, one); + auto two = m1.add_literal(2); + auto twob = m1.add_instruction(b, two); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, twob); + auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), relu1, relu2); + m1.add_instruction(pass_op{}, concat); + } + run_pass(m1); + + migraphx::module m2; + { + auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto one = m2.add_literal(1); + auto oneb = m2.add_instruction(b, one); + auto two = m2.add_literal(2); + auto twob = m2.add_instruction(b, two); + auto concat1 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y); + auto concat2 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), oneb, twob); + auto sum = m2.add_instruction(migraphx::make_op("add"), concat1, concat2); + auto relu = m2.add_instruction(migraphx::make_op("relu"), sum); + m2.add_instruction(pass_op{}, relu); + } + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_div_const) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto two = m1.add_literal(2); + m1.add_instruction(migraphx::make_op("div"), x, two); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto two = m2.add_literal(2); + auto recip = m2.insert_instruction(std::next(two), migraphx::make_op("recip"), two); + m2.add_instruction(migraphx::make_op("mul"), x, recip); + } + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_sub_const) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto two = m1.add_literal(2); + m1.add_instruction(migraphx::make_op("sub"), x, two); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto two = m2.add_literal(2); + auto neg = m2.insert_instruction(std::next(two), migraphx::make_op("neg"), two); + m2.add_instruction(migraphx::make_op("add"), x, neg); + } + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_rsqrt) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), x); + m1.add_instruction(migraphx::make_op("recip"), sqrt); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1}}); + m2.add_instruction(migraphx::make_op("rsqrt"), x); + } + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_rsqrt_multi_use) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); + auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), x); + auto add = m1.add_instruction(migraphx::make_op("add"), sqrt, sqrt); + auto rsqrt = m1.add_instruction(migraphx::make_op("recip"), sqrt); + m1.add_instruction(migraphx::make_op("add"), rsqrt, add); + } + migraphx::module m2{m1}; + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_slice_concat) +{ + auto s = migraphx::shape{migraphx::shape::float_type, {256}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto xslice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {128}}}), x); + auto xslice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {128}}, {"ends", {256}}}), x); + auto yslice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {128}}}), y); + auto yslice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {128}}, {"ends", {256}}}), y); + auto concat = m1.add_instruction( + migraphx::make_op("concat", {{"axis", 0}}), xslice1, xslice2, yslice1, yslice2); + m1.add_instruction(pass_op{}, concat); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y); + m2.add_instruction(pass_op{}, concat); + } + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_slice_concat_non_uniform) +{ + auto s = migraphx::shape{migraphx::shape::float_type, {256}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto xslice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), x); + auto xslice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), x); + auto xslice3 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), x); + auto yslice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), y); + auto yslice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), y); + auto yslice3 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), y); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + xslice1, + xslice2, + xslice3, + yslice1, + yslice2, + yslice3); + m1.add_instruction(pass_op{}, concat); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y); + m2.add_instruction(pass_op{}, concat); + } + + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_slice_concat_flipped) +{ + auto s = migraphx::shape{migraphx::shape::float_type, {256}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto xslice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), x); + auto xslice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), x); + auto xslice3 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), x); + auto yslice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {64}}}), y); + auto yslice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {192}}, {"ends", {256}}}), y); + auto yslice3 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {192}}}), y); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), + xslice1, + xslice2, + xslice3, + yslice1, + yslice2, + yslice3); + m1.add_instruction(pass_op{}, concat); + } + migraphx::module m2 = m1; + run_pass(m1); + + EXPECT(m1 == m2); +} + +TEST_CASE(simplify_split_add_relu) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; + migraphx::module m1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto input = m1.add_parameter("input", s); + auto x = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto y = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input); + auto one = m1.add_literal(1); + auto oneb = m1.add_instruction(b, one); + auto two = m1.add_literal(2); + auto twob = m1.add_instruction(b, two); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, twob); + auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2); + auto add = m1.add_instruction(migraphx::make_op("add"), relu1, relu2); + m1.add_instruction(pass_op{}, add); + } + run_pass(m1); + + migraphx::module m2; + { + auto b = migraphx::op::broadcast{1, {3, 2, 4}}; + auto input = m2.add_parameter("input", s); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two); + auto concatb = m2.add_instruction(b, concat); + auto sum = m2.add_instruction(migraphx::make_op("add"), input, concatb); + auto relu = m2.add_instruction(migraphx::make_op("relu"), sum); + auto x = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu); + auto y = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu); + auto add = m2.add_instruction(migraphx::make_op("add"), x, y); + m2.add_instruction(pass_op{}, add); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_split_reduce0) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto x = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto y = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input); + + auto one = m1.add_literal(1); + auto two = m1.add_literal(2); + + auto arx = m1.add_instruction(migraphx::make_op("contiguous"), x); + auto ary = m1.add_instruction(migraphx::make_op("contiguous"), y); + auto rmax0 = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1}}}), x); + auto rmin0 = m1.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), x); + auto rmax1 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), arx, one); + auto rmin1 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), ary, two); + auto rmax2 = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1}}}), y); + auto rmin2 = m1.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), y); + m1.add_return({rmax0, rmin0, rmax1, rmin1, rmax2, rmin2}); + } + + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_split_reduce1) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto x = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto y = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input); + + auto rmax0 = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), x); + auto rmin0 = m1.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 2}}}), x); + auto rmax2 = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), y); + auto rmin2 = m1.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 2}}}), y); + m1.add_return({rmax0, rmin0, rmax2, rmin2}); + } + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + + auto rmn = m2.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 2}}}), input); + auto slc0 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), rmn); + auto rmx = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), input); + auto slc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), rmx); + auto slc2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), rmn); + auto slc3 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), rmx); + m2.add_return({slc3, slc2, slc1, slc0}); + } + + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_split_reduce2) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto x = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto y = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input); + auto rmax0 = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), x); + auto rmin0 = m1.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), x); + auto rmax2 = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), y); + auto rmin2 = m1.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), y); + m1.add_return({rmax0, rmin0, rmax2, rmin2}); + } + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto x = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input); + auto rmn1 = m2.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), x); + auto y = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto rmn2 = m2.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), y); + auto rms = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), input); + auto slc0 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), rms); + auto slc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), rms); + m2.add_return({slc1, rmn2, slc0, rmn1}); + } + + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_split_add_relu_reshape) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; + migraphx::module m1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto r = migraphx::op::reshape{{3, 4}}; + auto input = m1.add_parameter("input", s); + auto x = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto y = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input); + auto one = m1.add_literal(1); + auto oneb = m1.add_instruction(b, one); + auto two = m1.add_literal(2); + auto twob = m1.add_instruction(b, two); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); + auto reshape1 = m1.add_instruction(r, relu1); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, twob); + auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2); + auto reshape2 = m1.add_instruction(r, relu2); + auto add = m1.add_instruction(migraphx::make_op("add"), reshape1, reshape2); + m1.add_instruction(pass_op{}, add); + } + run_pass(m1); + + migraphx::module m2; + { + auto b = migraphx::op::broadcast{1, {3, 2, 4}}; + auto input = m2.add_parameter("input", s); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two); + auto concatb = m2.add_instruction(b, concat); + auto sum = m2.add_instruction(migraphx::make_op("add"), input, concatb); + auto relu = m2.add_instruction(migraphx::make_op("relu"), sum); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 8}}}), relu); + auto slc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), rsp); + auto slc2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {8}}}), rsp); + auto add = m2.add_instruction(migraphx::make_op("add"), slc1, slc2); + m2.add_instruction(pass_op{}, add); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_slice_different_axis) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}}; + migraphx::module m1; + { + auto r = migraphx::op::reshape{{3, 2, 4}}; + auto input = m1.add_parameter("input", s); + auto x = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto y = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {1}}}), input); + auto one = m1.add_literal(1); + auto oneb = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4, 2}}}), one); + auto two = m1.add_literal(2); + auto twob = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 3}, {"out_lens", {3, 2, 4, 1}}}), two); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); + auto reshape1 = m1.add_instruction(r, relu1); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, twob); + auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2); + auto reshape2 = m1.add_instruction(r, relu2); + auto add = m1.add_instruction(migraphx::make_op("add"), reshape1, reshape2); + m1.add_instruction(pass_op{}, add); + } + migraphx::module m2 = m1; + run_pass(m1); + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_slice_missing_begining_slice) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; + migraphx::module m1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto input = m1.add_parameter("input", s); + auto x = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input); + auto y = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input); + auto one = m1.add_literal(1); + auto oneb = m1.add_instruction(b, one); + auto two = m1.add_literal(2); + auto twob = m1.add_instruction(b, two); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, twob); + auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2); + auto add = m1.add_instruction(migraphx::make_op("add"), relu1, relu2); + m1.add_instruction(pass_op{}, add); + } + migraphx::module m2 = m1; + run_pass(m1); + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_slice_missing_middle_slice) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; + migraphx::module m1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto input = m1.add_parameter("input", s); + auto x = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input); + auto y = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto one = m1.add_literal(1); + auto oneb = m1.add_instruction(b, one); + auto two = m1.add_literal(2); + auto twob = m1.add_instruction(b, two); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, twob); + auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2); + auto add = m1.add_instruction(migraphx::make_op("add"), relu1, relu2); + m1.add_instruction(pass_op{}, add); + } + migraphx::module m2 = m1; + run_pass(m1); + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_slice_missing_end_slice) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; + migraphx::module m1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto input = m1.add_parameter("input", s); + auto x = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto y = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input); + auto one = m1.add_literal(1); + auto oneb = m1.add_instruction(b, one); + auto two = m1.add_literal(2); + auto twob = m1.add_instruction(b, two); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, twob); + auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2); + auto add = m1.add_instruction(migraphx::make_op("add"), relu1, relu2); + m1.add_instruction(pass_op{}, add); + } + migraphx::module m2 = m1; + run_pass(m1); + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_split_add_relu_concat_same_axis) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; + migraphx::module m1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto input = m1.add_parameter("input", s); + auto x = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto y = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input); + auto one = m1.add_literal(1); + auto oneb = m1.add_instruction(b, one); + auto two = m1.add_literal(2); + auto twob = m1.add_instruction(b, two); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, twob); + auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), relu1, relu2); + m1.add_instruction(pass_op{}, concat); + } + run_pass(m1); + + migraphx::module m2; + { + auto b = migraphx::op::broadcast{1, {3, 2, 4}}; + auto input = m2.add_parameter("input", s); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two); + auto concatb = m2.add_instruction(b, concat); + auto sum = m2.add_instruction(migraphx::make_op("add"), input, concatb); + auto relu = m2.add_instruction(migraphx::make_op("relu"), sum); + m2.add_instruction(pass_op{}, relu); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_split_add_relu_multi_axes) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}}; + migraphx::module m1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4, 3}}; + auto input = m1.add_parameter("input", s); + auto x = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {0, 0}}, {"ends", {1, 3}}}), + input); + auto y = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {1, 3}}, {"ends", {2, 6}}}), + input); + auto one = m1.add_literal(1); + auto oneb = m1.add_instruction(b, one); + auto two = m1.add_literal(2); + auto twob = m1.add_instruction(b, two); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, twob); + auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2); + auto add = m1.add_instruction(migraphx::make_op("add"), relu1, relu2); + m1.add_instruction(pass_op{}, add); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_split_add_relu_used_multiple_split1) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; + migraphx::module m1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto input = m1.add_parameter("input", s); + auto x = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto y = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input); + auto one = m1.add_literal(1); + auto oneb = m1.add_instruction(b, one); + auto two = m1.add_literal(2); + auto twob = m1.add_instruction(b, two); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, twob); + auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2); + auto add1 = m1.add_instruction(migraphx::make_op("add"), relu1, relu2); + auto add2 = m1.add_instruction(migraphx::make_op("add"), x, add1); + m1.add_instruction(pass_op{}, add2); + } + run_pass(m1); + + migraphx::module m2; + { + auto b = migraphx::op::broadcast{1, {3, 2, 4}}; + auto input = m2.add_parameter("input", s); + auto slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two); + auto concatb = m2.add_instruction(b, concat); + auto sum = m2.add_instruction(migraphx::make_op("add"), input, concatb); + auto relu = m2.add_instruction(migraphx::make_op("relu"), sum); + auto x = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu); + auto y = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu); + auto add1 = m2.add_instruction(migraphx::make_op("add"), x, y); + auto add2 = m2.add_instruction(migraphx::make_op("add"), slice, add1); + m2.add_instruction(pass_op{}, add2); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_split_add_relu_used_multiple_split2) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; + migraphx::module m1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto input = m1.add_parameter("input", s); + auto x = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto y = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input); + auto z = m1.add_instruction(migraphx::make_op("relu"), x); + auto one = m1.add_literal(1); + auto oneb = m1.add_instruction(b, one); + auto two = m1.add_literal(2); + auto twob = m1.add_instruction(b, two); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb); + auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), y, twob); + auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2); + auto add1 = m1.add_instruction(migraphx::make_op("add"), relu1, relu2); + auto add2 = m1.add_instruction(migraphx::make_op("add"), z, add1); + m1.add_instruction(pass_op{}, add2); + } + run_pass(m1); + + migraphx::module m2; + { + auto b = migraphx::op::broadcast{1, {3, 2, 4}}; + auto input = m2.add_parameter("input", s); + auto slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto z = m2.add_instruction(migraphx::make_op("relu"), slice); + auto one = m2.add_literal(1); + auto two = m2.add_literal(2); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two); + auto concatb = m2.add_instruction(b, concat); + auto sum = m2.add_instruction(migraphx::make_op("add"), input, concatb); + auto relu = m2.add_instruction(migraphx::make_op("relu"), sum); + auto x = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu); + auto y = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu); + auto add1 = m2.add_instruction(migraphx::make_op("add"), x, y); + auto add2 = m2.add_instruction(migraphx::make_op("add"), z, add1); + m2.add_instruction(pass_op{}, add2); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_split_between_add) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto x = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto y = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input); + auto sum = m1.add_instruction(migraphx::make_op("add"), x, y); + m1.add_instruction(pass_op{}, sum); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_dot_horiz) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto a = m1.add_literal(migraphx::generate_literal(s, 0)); + auto b = m1.add_literal(migraphx::generate_literal(s, 1)); + auto x = m1.add_instruction(migraphx::make_op("dot"), input, a); + auto y = m1.add_instruction(migraphx::make_op("dot"), input, b); + auto sum = m1.add_instruction(migraphx::make_op("add"), x, y); + m1.add_instruction(pass_op{}, sum); + } + run_pass(m1); + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto a = m2.add_literal(migraphx::generate_literal(s, 0)); + auto b = m2.add_literal(migraphx::generate_literal(s, 1)); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b); + auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat); + auto x = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot); + auto y = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot); + auto sum = m2.add_instruction(migraphx::make_op("add"), x, y); + m2.add_instruction(pass_op{}, sum); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_dot_horiz_same_constant) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto a = m1.add_literal(migraphx::generate_literal(s, 0)); + auto x = m1.add_instruction(migraphx::make_op("dot"), input, a); + auto y = m1.add_instruction(migraphx::make_op("dot"), input, a); + auto sum = m1.add_instruction(migraphx::make_op("add"), x, y); + m1.add_instruction(pass_op{}, sum); + } + run_pass(m1); + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto a = m2.add_literal(migraphx::generate_literal(s, 0)); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, a); + auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat); + auto x = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot); + auto y = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot); + auto sum = m2.add_instruction(migraphx::make_op("add"), x, y); + m2.add_instruction(pass_op{}, sum); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_dot_horiz_flipped) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto a = m1.add_literal(migraphx::generate_literal(s, 0)); + auto b = m1.add_literal(migraphx::generate_literal(s, 1)); + auto x = m1.add_instruction(migraphx::make_op("dot"), input, a); + auto y = m1.add_instruction(migraphx::make_op("dot"), b, input); + auto sum = m1.add_instruction(migraphx::make_op("add"), x, y); + m1.add_instruction(pass_op{}, sum); + } + + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_conv_horiz) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {8, 3, 64, 64}}; + auto ws = migraphx::shape{migraphx::shape::int32_type, {12, 3, 3, 3}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto a = m1.add_literal(migraphx::generate_literal(ws, 0)); + auto b = m1.add_literal(migraphx::generate_literal(ws, 1)); + auto x = m1.add_instruction(migraphx::make_op("convolution"), input, a); + auto y = m1.add_instruction(migraphx::make_op("convolution"), input, b); + auto sum = m1.add_instruction(migraphx::make_op("add"), x, y); + m1.add_instruction(pass_op{}, sum); + } + run_pass(m1); + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto a = m2.add_literal(migraphx::generate_literal(ws, 0)); + auto b = m2.add_literal(migraphx::generate_literal(ws, 1)); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b); + auto conv = m2.add_instruction(migraphx::make_op("convolution"), input, concat); + auto x = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), conv); + auto y = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), conv); + auto sum = m2.add_instruction(migraphx::make_op("add"), x, y); + m2.add_instruction(pass_op{}, sum); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_group_conv_horiz) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {1, 32, 111, 111}}; + auto ws = migraphx::shape{migraphx::shape::int32_type, {32, 1, 7, 7}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto w1 = m1.add_literal(migraphx::generate_literal(ws, 1)); + auto w2 = m1.add_literal(migraphx::generate_literal(ws, 2)); + auto conv1 = m1.add_instruction( + migraphx::make_op( + "convolution", + {{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}, {"group", 32}}), + x, + w1); + auto conv2 = m1.add_instruction( + migraphx::make_op( + "convolution", + {{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}, {"group", 32}}), + x, + w2); + m1.add_instruction(pass_op{}, conv1, conv2); + } + migraphx::module m2 = m1; + run_pass(m1); + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_conv_horiz_grouped) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}}; + auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}}; + auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto a = m1.add_literal(migraphx::generate_literal(ws1, 0)); + auto b = m1.add_literal(migraphx::generate_literal(ws1, 1)); + auto c = m1.add_literal(migraphx::generate_literal(ws2, 2)); + auto d = m1.add_literal(migraphx::generate_literal(ws2, 3)); + auto convx = + m1.add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, a); + auto convy = + m1.add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, b); + auto dotx = m1.add_instruction(migraphx::make_op("dot"), input, c); + auto doty = m1.add_instruction(migraphx::make_op("dot"), input, d); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), convx, convy); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), dotx, doty); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); + + m1.add_instruction(pass_op{}, sum3); + } + run_pass(m1); + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto a = m2.add_literal(migraphx::generate_literal(ws1, 0)); + auto b = m2.add_literal(migraphx::generate_literal(ws1, 1)); + auto c = m2.add_literal(migraphx::generate_literal(ws2, 2)); + auto d = m2.add_literal(migraphx::generate_literal(ws2, 3)); + auto concat1 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b); + auto concat2 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), c, d); + auto conv = m2.add_instruction( + migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, concat1); + auto convx = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv); + auto convy = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), convx, convy); + auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat2); + auto dotx = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot); + auto doty = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot); + auto sum2 = m2.add_instruction(migraphx::make_op("add"), dotx, doty); + auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum1, sum2); + m2.add_instruction(pass_op{}, sum3); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_conv_horiz_grouped_extra1) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}}; + auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}}; + auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto a = m1.add_literal(migraphx::generate_literal(ws1, 0)); + auto b = m1.add_literal(migraphx::generate_literal(ws1, 1)); + auto c = m1.add_literal(migraphx::generate_literal(ws2, 2)); + auto d = m1.add_literal(migraphx::generate_literal(ws2, 3)); + auto e = m1.add_literal(migraphx::generate_literal(s, 4)); + auto convx = + m1.add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, a); + auto convy = + m1.add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, b); + auto dotx = m1.add_instruction(migraphx::make_op("dot"), input, c); + auto doty = m1.add_instruction(migraphx::make_op("dot"), input, d); + auto sqdiffx = m1.add_instruction(migraphx::make_op("sqdiff"), input, e); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), convx, convy); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), dotx, doty); + auto sum3 = sqdiffx; + auto sum4 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); + auto sum5 = m1.add_instruction(migraphx::make_op("add"), sum4, sum3); + m1.add_instruction(pass_op{}, sum5); + } + run_pass(m1); + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto a = m2.add_literal(migraphx::generate_literal(ws1, 0)); + auto b = m2.add_literal(migraphx::generate_literal(ws1, 1)); + auto c = m2.add_literal(migraphx::generate_literal(ws2, 2)); + auto d = m2.add_literal(migraphx::generate_literal(ws2, 3)); + auto e = m2.add_literal(migraphx::generate_literal(s, 4)); + auto concat1 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b); + auto concat2 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), c, d); + auto conv = m2.add_instruction( + migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, concat1); + auto convx = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv); + auto convy = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), convx, convy); + auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat2); + auto dotx = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot); + auto doty = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot); + auto sum2 = m2.add_instruction(migraphx::make_op("add"), dotx, doty); + auto sqdiffx = m2.add_instruction(migraphx::make_op("sqdiff"), input, e); + auto sum3 = sqdiffx; + auto sum4 = m2.add_instruction(migraphx::make_op("add"), sum1, sum2); + auto sum5 = m2.add_instruction(migraphx::make_op("add"), sum4, sum3); + m2.add_instruction(pass_op{}, sum5); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_conv_horiz_grouped_extra2) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}}; + auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}}; + auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto a = m1.add_literal(migraphx::generate_literal(ws1, 0)); + auto b = m1.add_literal(migraphx::generate_literal(ws1, 1)); + auto c = m1.add_literal(migraphx::generate_literal(ws2, 2)); + auto d = m1.add_literal(migraphx::generate_literal(ws2, 3)); + auto e = m1.add_literal(migraphx::generate_literal(s, 4)); + auto f = m1.add_literal(migraphx::generate_literal(s, 5)); + auto convx = + m1.add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, a); + auto convy = + m1.add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, b); + auto dotx = m1.add_instruction(migraphx::make_op("dot"), input, c); + auto doty = m1.add_instruction(migraphx::make_op("dot"), input, d); + auto sqdiffx = m1.add_instruction(migraphx::make_op("sqdiff"), input, e); + auto sqdiffy = m1.add_instruction(migraphx::make_op("sqdiff"), input, f); + auto sum1 = m1.add_instruction(migraphx::make_op("add"), convx, convy); + auto sum2 = m1.add_instruction(migraphx::make_op("add"), dotx, doty); + auto sum3 = m1.add_instruction(migraphx::make_op("add"), sqdiffx, sqdiffy); + auto sum4 = m1.add_instruction(migraphx::make_op("add"), sum1, sum2); + auto sum5 = m1.add_instruction(migraphx::make_op("add"), sum4, sum3); + m1.add_instruction(pass_op{}, sum5); + } + run_pass(m1); + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto a = m2.add_literal(migraphx::generate_literal(ws1, 0)); + auto b = m2.add_literal(migraphx::generate_literal(ws1, 1)); + auto c = m2.add_literal(migraphx::generate_literal(ws2, 2)); + auto d = m2.add_literal(migraphx::generate_literal(ws2, 3)); + auto e = m2.add_literal(migraphx::generate_literal(s, 4)); + auto f = m2.add_literal(migraphx::generate_literal(s, 5)); + auto concat1 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), a, b); + auto concat2 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), c, d); + auto conv = m2.add_instruction( + migraphx::make_op("convolution", {{"padding", {1, 1}}}), input, concat1); + auto convx = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv); + auto convy = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv); + auto sum1 = m2.add_instruction(migraphx::make_op("add"), convx, convy); + auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat2); + auto dotx = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot); + auto doty = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot); + auto sum2 = m2.add_instruction(migraphx::make_op("add"), dotx, doty); + auto sqdiffx = m2.add_instruction(migraphx::make_op("sqdiff"), input, e); + auto sqdiffy = m2.add_instruction(migraphx::make_op("sqdiff"), input, f); + auto sum3 = m2.add_instruction(migraphx::make_op("add"), sqdiffx, sqdiffy); + auto sum4 = m2.add_instruction(migraphx::make_op("add"), sum1, sum2); + auto sum5 = m2.add_instruction(migraphx::make_op("add"), sum4, sum3); + m2.add_instruction(pass_op{}, sum5); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(simplify_mul_slice_conv_horiz_fusion) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}}); + auto w = m1.add_literal( + migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}})); + auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), conv); + auto a1 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1)); + auto b1 = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a1); + auto mul = m1.add_instruction(migraphx::make_op("mul"), slice1, b1); + auto a2 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2)); + auto b2 = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a2); + auto add1 = m1.add_instruction(migraphx::make_op("add"), mul, b2); + auto a3 = + m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3)); + auto b3 = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 384, 17, 17}}}), a3); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), conv); + auto add2 = m1.add_instruction(migraphx::make_op("add"), slice2, b3); + m1.add_instruction(pass_op{}, add1, add2); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}}); + auto w = m2.add_literal( + migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}})); + auto wslice1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {384}}}), w); + auto a1 = + m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1)); + auto b1 = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {384, 1024, 1, 1}}}), a1); + auto mul = m2.add_instruction(migraphx::make_op("mul"), b1, wslice1); + auto wslice2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {384}}, {"ends", {768}}}), w); + auto concat1 = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), mul, wslice2); + auto conv = m2.add_instruction(migraphx::make_op("convolution"), x, concat1); + auto a2 = + m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2)); + auto a3 = + m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3)); + auto concat2 = m2.add_instruction(migraphx::make_op("concat"), a2, a3); + auto b4 = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 768, 17, 17}}}), concat2); + auto add = m2.add_instruction(migraphx::make_op("add"), conv, b4); + auto slice1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {384}}}), add); + auto slice2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {384}}, {"ends", {768}}}), add); + m2.add_instruction(pass_op{}, slice1, slice2); + } + EXPECT(m1.sort() == m2.sort()); +} +TEST_CASE(reorder_reshape_slice) +{ + std::vector perm0 = {0, 2, 1, 3}; + std::vector perm1 = {0, 2, 3, 1}; + auto create_m1 = [&](std::size_t batch_size) { + migraphx::module m1; + auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}}; + auto input = m1.add_parameter("input", s); + auto slc0 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input); + auto slc1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}), + input); + auto slc2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}), + input); + + auto c0 = m1.add_instruction(migraphx::make_op("contiguous"), slc0); + auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1); + auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2); + + std::vector lens = {static_cast(batch_size), 128, 10, 64}; + auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0); + auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); + auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2); + + auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r0); + auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r1); + auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), r2); + + auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); + auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2); + m1.add_return({ret}); + + return m1; + }; + + auto create_m2 = [&](std::size_t batch_size) { + migraphx::module m2; + auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}}; + auto input = m2.add_parameter("input", s); + std::vector lens = {static_cast(batch_size), 128, 30, 64}; + auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input); + + auto slc0 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {10}}}), r); + auto slc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {10}}, {"ends", {20}}}), r); + auto slc2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {20}}, {"ends", {30}}}), r); + + auto t0 = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0); + auto t1 = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1); + auto t2 = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2); + + auto sum = m2.add_instruction(migraphx::make_op("add"), t0, t1); + auto ret = m2.add_instruction(migraphx::make_op("dot"), sum, t2); + m2.add_return({ret}); + + return m2; + }; + + auto test = [&](std::size_t batch_size) { + auto m1 = create_m1(batch_size); + run_pass(m1); + auto m2 = create_m2(batch_size); + EXPECT(m1.sort() == m2.sort()); + }; + + test(1); + test(4); + test(8); +} + +TEST_CASE(reorder_reshape_slice_move_axis1) +{ + auto create_m1 = [](std::size_t batch_size) { + migraphx::module m1; + auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}}; + std::vector perm0 = {0, 2, 1, 3}; + std::vector perm1 = {0, 2, 3, 1}; + auto input = m1.add_parameter("input", s); + auto slc0 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), input); + auto slc1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), input); + auto slc2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {96}}}), input); + + auto c0 = m1.add_instruction(migraphx::make_op("contiguous"), slc0); + auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1); + auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2); + + std::vector lens = {static_cast(batch_size), 64, 4, 32}; + auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0); + auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); + auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2); + + auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r0); + auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), r1); + auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), r2); + + auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); + auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2); + m1.add_return({ret}); + + return m1; + }; + + auto create_m2 = [](std::size_t batch_size) { + migraphx::module m; + auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}}; + std::vector perm0 = {0, 2, 1, 3}; + std::vector perm1 = {0, 2, 3, 1}; + auto input = m.add_parameter("input", s); + std::vector lens = {static_cast(batch_size), 64, 4, 96}; + auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input); + auto slc0 = m.add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp); + auto t0 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0); + auto slc1 = m.add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp); + auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1); + auto slc2 = m.add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp); + auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2); + + auto sum = m.add_instruction(migraphx::make_op("add"), t0, t1); + auto ret = m.add_instruction(migraphx::make_op("dot"), sum, t2); + m.add_return({ret}); + + return m; + }; + + auto test = [&](std::size_t batch_size) { + auto m1 = create_m1(batch_size); + auto m2 = create_m2(batch_size); + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); + }; + + test(4); + test(8); +} + +TEST_CASE(reorder_reshape_slice_move_axis2) +{ + auto create_m1 = [] { + migraphx::module m1; + migraphx::shape s{migraphx::shape::float_type, {128, 96}}; + auto input = m1.add_parameter("input", s); + auto slc0 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {32}}}), input); + auto slc1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {32}}, {"ends", {64}}}), input); + auto slc2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {96}}}), input); + + auto c0 = m1.add_instruction(migraphx::make_op("contiguous"), slc0); + auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1); + auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2); + + std::vector lens = {1, 16, 8, 32}; + auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0); + auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); + auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2); + + auto sum = m1.add_instruction(migraphx::make_op("add"), r0, r1); + auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, r2); + m1.add_return({ret}); + + return m1; + }; + + auto create_m2 = [] { + migraphx::module m; + auto s = migraphx::shape{migraphx::shape::float_type, {128, 96}}; + auto input = m.add_parameter("input", s); + std::vector lens = {1, 16, 8, 96}; + auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input); + auto slc0 = m.add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp); + auto slc1 = m.add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp); + auto slc2 = m.add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp); + + auto sum = m.add_instruction(migraphx::make_op("add"), slc0, slc1); + auto ret = m.add_instruction(migraphx::make_op("mul"), sum, slc2); + m.add_return({ret}); + + return m; + }; + + auto m1 = create_m1(); + auto m2 = create_m2(); + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reorder_reshape_slice_not_apply) +{ + auto create_p = [] { + migraphx::module m; + migraphx::shape s{migraphx::shape::float_type, {128, 96}}; + auto input = m.add_parameter("input", s); + auto slc0 = m.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {32}}}), input); + auto slc1 = m.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {32}}, {"ends", {64}}}), input); + auto slc2 = m.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {96}}}), input); + + auto c0 = m.add_instruction(migraphx::make_op("contiguous"), slc0); + auto c1 = m.add_instruction(migraphx::make_op("contiguous"), slc1); + auto c2 = m.add_instruction(migraphx::make_op("contiguous"), slc2); + + std::vector lens = {1, 16, 16, 16}; + auto r0 = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0); + auto r1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); + auto r2 = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2); + + auto sum = m.add_instruction(migraphx::make_op("add"), r0, r1); + auto ret = m.add_instruction(migraphx::make_op("mul"), sum, r2); + m.add_return({ret}); + + return m; + }; + + auto m1 = create_p(); + auto m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reorder_reshape_slice_diff_dims) +{ + auto create_m1 = [](std::size_t batch_size) { + migraphx::module m1; + auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 96, 96}}; + std::vector perm0 = {0, 2, 1, 3}; + std::vector perm1 = {0, 2, 3, 1}; + auto input = m1.add_parameter("input", s); + auto slc0 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), input); + auto slc1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), input); + auto slc2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {96}}}), input); + + auto c0 = m1.add_instruction(migraphx::make_op("contiguous"), slc0); + auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1); + auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2); + + std::vector lens = {static_cast(batch_size), 32, 3, 32}; + std::vector lens1 = {static_cast(batch_size), 48, 2, 32}; + auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0); + auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); + auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c2); + + m1.add_return({r0, r1, r2}); + + return m1; + }; + + auto test = [&](std::size_t batch_size) { + auto m1 = create_m1(batch_size); + auto m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); + }; + + test(4); + test(8); +} + +TEST_CASE(reorder_slice_trans) +{ + std::vector perm = {0, 2, 1}; + auto create_m1 = [&](std::size_t batch_size) { + migraphx::module m1; + auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}}; + auto input = m1.add_parameter("input", s); + auto slc0 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input); + auto slc1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}), + input); + auto slc2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}), + input); + + auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc0); + auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc1); + auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc2); + + auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); + auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, t2); + m1.add_return({ret}); + + return m1; + }; + + auto create_m2 = [&](std::size_t batch_size) { + migraphx::module m2; + auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}}; + auto input = m2.add_parameter("input", s); + auto r = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input); + + auto slc0 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {640}}}), r); + auto slc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {640}}, {"ends", {1280}}}), r); + auto slc2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1280}}, {"ends", {1920}}}), r); + + auto sum = m2.add_instruction(migraphx::make_op("add"), slc0, slc1); + auto ret = m2.add_instruction(migraphx::make_op("mul"), sum, slc2); + m2.add_return({ret}); + + return m2; + }; + + auto test = [&](std::size_t batch_size) { + auto m1 = create_m1(batch_size); + run_pass(m1); + auto m2 = create_m2(batch_size); + EXPECT(m1.sort() == m2.sort()); + }; + + test(1); + test(8); +} + +TEST_CASE(reorder_slice_trans_diff_perm) +{ + auto create_m1 = [](std::size_t batch_size) { + migraphx::module m1; + auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}}; + std::vector perm0 = {0, 2, 1}; + std::vector perm1 = {0, 1, 2}; + auto input = m1.add_parameter("input", s); + auto slc0 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input); + auto slc1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}), + input); + auto slc2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}), + input); + + auto t0 = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0); + auto t1 = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1); + auto t2 = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2); + + auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); + auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2); + m1.add_return({ret}); + + return m1; + }; + + auto test = [&](std::size_t batch_size) { + auto m1 = create_m1(batch_size); + run_pass(m1); + auto m2 = m1; + EXPECT(m1.sort() == m2.sort()); + }; + + test(1); + test(4); +} + +TEST_CASE(reorder_slice_ins_deps) +{ + auto create_module = [] { + migraphx::module m; + migraphx::shape sx{migraphx::shape::float_type, {4, 2}}; + migraphx::shape sy{migraphx::shape::float_type, {2, 2}}; + std::vector datax = {0, 1, 2, 3, 4, 5, 6, 7}; + std::vector datay = {0, 1, 2, 3}; + auto inx = m.add_literal(migraphx::literal(sx, datax)); + auto iny = m.add_literal(migraphx::literal(sy, datay)); + auto slc0 = m.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), inx); + auto slc1 = m.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), inx); + auto n0 = m.add_instruction(migraphx::make_op("neg"), slc0); + auto a0 = m.add_instruction(migraphx::make_op("add"), n0, slc1); + auto m0 = m.add_instruction(migraphx::make_op("mul"), a0, iny); + auto r = m.add_instruction(migraphx::make_op("add"), m0, slc0); + m.add_return({r}); + + return m; + }; + + auto m = create_module(); + run_pass(m); + EXPECT(m == create_module()); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fca2678674f89543317d1bd1f437d6bb60de592c --- /dev/null +++ b/test/simplify_qdq_test.cpp @@ -0,0 +1,729 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; } +bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot"; } + +void run_pass(migraphx::module& m) +{ + migraphx::simplify_qdq sqdq; + sqdq.apply(m); +} + +migraphx::instruction_ref add_quantize_op(migraphx::module& m, + const std::string& name, + migraphx::instruction_ref x, + migraphx::instruction_ref scale, + migraphx::instruction_ref shift) +{ + auto lens = x->get_shape().lens(); + migraphx::instruction_ref scale_mb; + if(scale->get_shape().lens().front() == 1) + scale_mb = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale); + else + scale_mb = m.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale); + auto shift_mb = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), shift); + return m.add_instruction(migraphx::make_op(name), x, scale_mb, shift_mb); +} + +migraphx::instruction_ref add_quantize_op(migraphx::module& m, + const std::string& name, + migraphx::instruction_ref x, + migraphx::instruction_ref scale) +{ + auto lens = x->get_shape().lens(); + migraphx::instruction_ref scale_mb; + if(scale->get_shape().lens().front() == 1) + scale_mb = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale); + else + scale_mb = m.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale); + return m.add_instruction(migraphx::make_op(name), x, scale_mb); +} + +TEST_CASE(remove_qdq) +{ + migraphx::shape sh1{migraphx::shape::float_type, {100, 100}}; + migraphx::shape sh2{migraphx::shape::float_type, {100, 100}}; + + migraphx::module m1; + { + auto t1 = m1.add_parameter("t1", sh1); + auto t2 = m1.add_parameter("t2", sh2); + auto scale = m1.add_literal(0.5f); + auto zero = m1.add_literal(std::int8_t{0}); + + auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero); + auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); + auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); + auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); + auto add = m1.add_instruction(migraphx::make_op("add"), d1, d2); + m1.add_return({add}); + } + + migraphx::module m2; + { + auto t1 = m2.add_parameter("t1", sh1); + auto t2 = m2.add_parameter("t2", sh2); + + auto add = m2.add_instruction(migraphx::make_op("add"), t1, t2); + m2.add_return({add}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(qdq_different_scales) +{ + migraphx::shape sh1{migraphx::shape::float_type, {100, 100}}; + migraphx::shape sh2{migraphx::shape::float_type, {100, 100}}; + + migraphx::module m1; + { + auto t1 = m1.add_parameter("t1", sh1); + auto t2 = m1.add_parameter("t2", sh2); + auto scale1 = m1.add_literal(0.5f); + auto scale2 = m1.add_literal(0.4f); + auto zero = m1.add_literal(std::int8_t{0}); + + auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero); + auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale2, zero); + auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale1, zero); + auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero); + auto add = m1.add_instruction(migraphx::make_op("add"), d1, d2); + m1.add_return({add}); + } + + migraphx::module m2 = m1; + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(dot) +{ + migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}}; + migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}}; + + migraphx::module m1; + { + auto t1 = m1.add_parameter("t1", sh1); + auto t2 = m1.add_parameter("t2", sh2); + auto scale = m1.add_literal(0.5f); + auto zero = m1.add_literal(std::int8_t{0}); + + auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero); + auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); + auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); + auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); + auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2); + m1.add_return({dot}); + } + + migraphx::module m2; + { + auto t1 = m2.add_parameter("t1", sh1); + auto t2 = m2.add_parameter("t2", sh2); + auto scale = m2.add_literal(0.5f); + auto zero = m2.add_literal(std::int8_t{0}); + auto scale1 = m2.add_literal(0.25f); + + auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero); + auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); + auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1); + m2.add_return({d3}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(dot_non_zero_point) +{ + migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}}; + migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}}; + + migraphx::module m1; + { + auto t1 = m1.add_parameter("t1", sh1); + auto t2 = m1.add_parameter("t2", sh2); + auto scale = m1.add_literal(0.5f); + auto zero = m1.add_literal(std::int8_t{1}); + + auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero); + auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); + auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); + auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); + auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2); + m1.add_return({dot}); + } + + migraphx::module m2; + { + auto t1 = m2.add_parameter("t1", sh1); + auto t2 = m2.add_parameter("t2", sh2); + auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2); + m2.add_return({dot}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(dot_uint8) +{ + migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}}; + migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}}; + + migraphx::module m1; + { + auto t1 = m1.add_parameter("t1", sh1); + auto t2 = m1.add_parameter("t2", sh2); + auto scale = m1.add_literal(0.5f); + auto zero = m1.add_literal(std::uint8_t{0}); + + auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero); + auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); + auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); + auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); + auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2); + m1.add_return({dot}); + } + + migraphx::module m2; + { + auto t1 = m2.add_parameter("t1", sh1); + auto t2 = m2.add_parameter("t2", sh2); + auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2); + m2.add_return({dot}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(dot_add) +{ + migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}}; + migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}}; + migraphx::shape sh3{migraphx::shape::float_type, {1280, 1024}}; + + migraphx::module m1; + { + auto t1 = m1.add_parameter("t1", sh1); + auto t2 = m1.add_parameter("t2", sh2); + auto ab = m1.add_parameter("ab", sh3); + auto scale = m1.add_literal(0.5f); + auto zero = m1.add_literal(std::int8_t{0}); + + auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero); + auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); + auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); + auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); + auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2); + auto q3 = add_quantize_op(m1, "quantizelinear", dot, scale, zero); + auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero); + auto add = m1.add_instruction(migraphx::make_op("add"), d3, ab); + m1.add_return({add}); + } + + migraphx::module m2; + { + auto t1 = m2.add_parameter("t1", sh1); + auto t2 = m2.add_parameter("t2", sh2); + auto ab = m2.add_parameter("ab", sh3); + auto scale = m2.add_literal(0.5f); + auto zero = m2.add_literal(std::int8_t{0}); + auto scale1 = m2.add_literal(0.25f); + + auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero); + auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); + auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1); + auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab); + m2.add_return({add}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(conv) +{ + migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}}; + migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}}; + + migraphx::module m1; + { + auto input = m1.add_parameter("input", s7); + auto weights = m1.add_parameter("weights", s4); + auto scale = m1.add_literal(0.5f); + auto zero = m1.add_literal(std::int8_t{0}); + + auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero); + auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero); + auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); + auto c1 = m1.add_instruction(migraphx::make_op("convolution", + {{"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"dilation", {1, 1}}, + {"group", 1}, + {"padding_mode", 0}}), + d5, + d1); + m1.add_return({c1}); + } + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s7); + auto weights = m2.add_parameter("weights", s4); + auto scale = m2.add_literal(0.5f); + auto zero = m2.add_literal(std::int8_t{0}); + auto scale1 = m2.add_literal(0.25f); + + auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); + auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", + {{"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"dilation", {1, 1}}, + {"group", 1}, + {"padding_mode", 0}}), + q1, + weights); + auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1); + m2.add_return({d6}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(conv_multi_scale) +{ + migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}}; + migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}}; + migraphx::shape s8{migraphx::shape::float_type, {320}}; + + migraphx::module m1; + { + auto input = m1.add_parameter("input", s7); + auto weights = m1.add_parameter("weights", s4); + auto scale = m1.add_literal(migraphx::generate_literal(s8, 0)); + auto zero = m1.add_literal(std::int8_t{0}); + + auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero); + auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero); + auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); + auto c1 = m1.add_instruction(migraphx::make_op("convolution", + {{"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"dilation", {1, 1}}, + {"group", 1}, + {"padding_mode", 0}}), + d5, + d1); + m1.add_return({c1}); + } + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s7); + auto weights = m2.add_parameter("weights", s4); + auto scale = m2.add_literal(migraphx::generate_literal(s8, 0)); + auto zero = m2.add_literal(std::int8_t{0}); + + auto d1 = add_quantize_op(m2, "dequantizelinear", weights, scale, zero); + auto c1 = m2.add_instruction(migraphx::make_op("convolution", + {{"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"dilation", {1, 1}}, + {"group", 1}, + {"padding_mode", 0}}), + input, + d1); + m2.add_return({c1}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(conv_bias_add) +{ + migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}}; + migraphx::shape s6{migraphx::shape::int32_type, {1280}}; + migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}}; + + migraphx::module m1; + { + auto input = m1.add_parameter("input", s7); + auto weights = m1.add_parameter("weights", s4); + auto bias = m1.add_parameter("bias", s6); + auto scale = m1.add_literal(0.5f); + auto zero = m1.add_literal(std::int8_t{0}); + + auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero); + auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero); + auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero); + auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); + auto c1 = m1.add_instruction(migraphx::make_op("convolution", + {{"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"dilation", {1, 1}}, + {"group", 1}, + {"padding_mode", 0}}), + d5, + d1); + auto b1 = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); + auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, b1); + m1.add_return({a1}); + } + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s7); + auto weights = m2.add_parameter("weights", s4); + auto bias = m2.add_parameter("bias", s6); + auto scale = m2.add_literal(0.5f); + auto zero = m2.add_literal(std::int8_t{0}); + auto scale1 = m2.add_literal(0.25f); + + auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero); + auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); + auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", + {{"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"dilation", {1, 1}}, + {"group", 1}, + {"padding_mode", 0}}), + q1, + weights); + auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1); + auto b1 = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); + auto a1 = m2.add_instruction(migraphx::make_op("add"), d6, b1); + m2.add_return({a1}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(conv_pooling_dot) +{ + migraphx::shape s2{migraphx::shape::int8_type, {1280, 1000}}; + migraphx::shape s3{migraphx::shape::int8_type, {1000}}; + migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}}; + migraphx::shape s6{migraphx::shape::int32_type, {1280}}; + migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}}; + + migraphx::module m1; + { + auto db = m1.add_parameter("db", s2); // dot input b + auto ab = m1.add_parameter("ab", s3); // add input b + auto weights = m1.add_parameter("weights", s4); + auto bias = m1.add_parameter("bias", s6); + auto input = m1.add_parameter("input", s7); + auto scale = m1.add_literal(0.5f); + auto zero = m1.add_literal(std::int8_t{0}); + + auto d1 = add_quantize_op(m1, "dequantizelinear", weights, scale, zero); + auto d2 = add_quantize_op(m1, "dequantizelinear", bias, scale, zero); + auto d3 = add_quantize_op(m1, "dequantizelinear", ab, scale, zero); + auto d4 = add_quantize_op(m1, "dequantizelinear", db, scale, zero); + auto q1 = add_quantize_op(m1, "quantizelinear", input, scale, zero); + auto d5 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); + auto c1 = m1.add_instruction(migraphx::make_op("convolution", + {{"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"dilation", {1, 1}}, + {"group", 1}, + {"padding_mode", 0}}), + d5, + d1); + auto bc1 = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); + auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1); + auto ap = + m1.add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"lengths", {7, 7}}, + {"ceil_mode", 0}}), + a1); + auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap); + auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero); + auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero); + auto dot = m1.add_instruction(migraphx::make_op("dot"), d8, d4); + auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero); + auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero); + auto mb1 = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3); + auto a2 = m1.add_instruction(migraphx::make_op("add"), d9, mb1); + m1.add_return({a2}); + } + + migraphx::module m2; + { + auto db = m2.add_parameter("db", s2); // dot input b + auto ab = m2.add_parameter("ab", s3); // add input b + auto weights = m2.add_parameter("weights", s4); + auto bias = m2.add_parameter("bias", s6); + auto input = m2.add_parameter("input", s7); + auto scale = m2.add_literal(0.5f); + auto zero = m2.add_literal(std::int8_t{0}); + auto scale1 = m2.add_literal(0.25f); + auto scale2 = m2.add_literal(0.25f); + + auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero); + auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero); + auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); + auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", + {{"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"dilation", {1, 1}}, + {"group", 1}, + {"padding_mode", 0}}), + q1, + weights); + auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1); + auto bc1 = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); + auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1); + auto ap = + m2.add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"lengths", {7, 7}}, + {"ceil_mode", 0}}), + a1); + auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap); + auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero); + auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db); + auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2); + auto mb1 = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3); + auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1); + m2.add_return({a2}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(mobilenet_snippet) +{ + migraphx::shape s2{migraphx::shape::int8_type, {1280, 1000}}; + migraphx::shape s3{migraphx::shape::int8_type, {1000}}; + migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}}; + migraphx::shape s6{migraphx::shape::int32_type, {1280}}; + migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}}; + + auto create_module = [&]() { + migraphx::module mm; + auto db = mm.add_parameter("db", s2); // dot input b + auto ab = mm.add_parameter("ab", s3); // add input b + auto weights = mm.add_parameter("weights", s4); + auto bias = mm.add_parameter("bias", s6); + auto input = mm.add_parameter("input", s7); + auto scale = mm.add_literal(0.5f); + auto zero = mm.add_literal(std::int8_t{0}); + + auto d1 = add_quantize_op(mm, "dequantizelinear", weights, scale, zero); + auto d2 = add_quantize_op(mm, "dequantizelinear", bias, scale, zero); + auto d3 = add_quantize_op(mm, "dequantizelinear", ab, scale, zero); + auto d4 = add_quantize_op(mm, "dequantizelinear", db, scale, zero); + auto q1 = add_quantize_op(mm, "quantizelinear", input, scale, zero); + auto d5 = add_quantize_op(mm, "dequantizelinear", q1, scale, zero); + auto c1 = mm.add_instruction(migraphx::make_op("convolution", + {{"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"dilation", {1, 1}}, + {"group", 1}, + {"padding_mode", 0}}), + d5, + d1); + auto bc1 = mm.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); + auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1); + auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero); + auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero); + auto ap = + mm.add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"lengths", {7, 7}}, + {"ceil_mode", 0}}), + d6); + auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero); + auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero); + auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7); + auto q4 = add_quantize_op(mm, "quantizelinear", rs, scale, zero); + auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero); + auto dot = mm.add_instruction(migraphx::make_op("dot"), d8, d4); + auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero); + auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero); + auto mb1 = + mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3); + auto a2 = mm.add_instruction(migraphx::make_op("add"), d9, mb1); + mm.add_return({a2}); + + return mm; + }; + + auto mod1 = create_module(); + auto mod2 = create_module(); + + run_pass(mod2); + + auto match_qdq = migraphx::match::name("dequantizelinear")( + migraphx::match::arg(0)(migraphx::match::name("quantizelinear"))); + auto ins1 = migraphx::match::find_match(mod1, match_qdq); + auto ins2 = migraphx::match::find_match(mod2, match_qdq); + + EXPECT((ins1.result != mod1.end()) and (ins2.result == mod2.end())); + EXPECT(any_of(mod1, &is_convolution)); + EXPECT(none_of(mod2, &is_convolution)); + EXPECT(any_of(mod1, &is_dot)); + EXPECT(none_of(mod2, &is_dot)); +} + +TEST_CASE(conv_correctness) +{ + migraphx::shape si{migraphx::shape::float_type, {2, 3, 4, 4}}; + migraphx::shape sw{migraphx::shape::int8_type, {2, 3, 3, 3}}; + + migraphx::program p1; + { + auto* m1 = p1.get_main_module(); + auto input = m1->add_parameter("input", si); + auto weights = m1->add_parameter("weights", sw); + auto scale_i = m1->add_literal(0.5f); + auto scale_w = m1->add_literal(0.1f); + auto zero = m1->add_literal(std::int8_t{0}); + + auto d1 = add_quantize_op(*m1, "dequantizelinear", weights, scale_w, zero); + auto q1 = add_quantize_op(*m1, "quantizelinear", input, scale_i, zero); + auto d5 = add_quantize_op(*m1, "dequantizelinear", q1, scale_i, zero); + auto c1 = m1->add_instruction(migraphx::make_op("convolution", + {{"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"dilation", {1, 1}}, + {"group", 1}, + {"padding_mode", 0}}), + d5, + d1); + m1->add_return({c1}); + run_pass(*m1); + } + + migraphx::program p2; + { + auto* m2 = p2.get_main_module(); + auto input = m2->add_parameter("input", si); + auto weights = m2->add_parameter("weights", sw); + auto scale = m2->add_literal(0.1f); + auto zero = m2->add_literal(std::int8_t{0}); + + auto d1 = add_quantize_op(*m2, "dequantizelinear", weights, scale, zero); + auto c1 = m2->add_instruction(migraphx::make_op("convolution", + {{"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"dilation", {1, 1}}, + {"group", 1}, + {"padding_mode", 0}}), + input, + d1); + m2->add_return({c1}); + } + + std::vector iv(si.elements(), 4); + auto input = migraphx::argument(si, iv.data()); + std::vector wv(sw.elements(), 10); + auto weights = migraphx::argument(sw, wv.data()); + p1.compile(migraphx::target(migraphx::ref::target{})); + p2.compile(migraphx::target(migraphx::ref::target{})); + + auto result1 = p1.eval({{"input", input}, {"weights", weights}}).back(); + std::vector rv1(16); + result1.visit([&](auto output) { rv1.assign(output.begin(), output.end()); }); + auto result2 = p2.eval({{"input", input}, {"weights", weights}}).back(); + std::vector rv2(16); + result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(rv1, rv2)); +} + +TEST_CASE(dot_correctness) +{ + migraphx::shape sh1{migraphx::shape::float_type, {10, 4}}; + migraphx::shape sh2{migraphx::shape::float_type, {4, 12}}; + migraphx::shape sh3{migraphx::shape::float_type, {10, 12}}; + + migraphx::program p1; + { + auto* m1 = p1.get_main_module(); + auto a = m1->add_parameter("a", sh1); + auto b = m1->add_parameter("b", sh2); + auto scale_a = m1->add_literal(0.4f); + auto scale_b = m1->add_literal(0.5f); + auto zero = m1->add_literal(std::int8_t{0}); + + auto q1 = add_quantize_op(*m1, "quantizelinear", a, scale_a, zero); + auto d1 = add_quantize_op(*m1, "dequantizelinear", q1, scale_a, zero); + auto q2 = add_quantize_op(*m1, "quantizelinear", b, scale_b, zero); + auto d2 = add_quantize_op(*m1, "dequantizelinear", q2, scale_b, zero); + auto dot = m1->add_instruction(migraphx::make_op("dot"), d1, d2); + m1->add_return({dot}); + + run_pass(*m1); + } + + migraphx::program p2; + { + auto* m2 = p2.get_main_module(); + auto a = m2->add_parameter("a", sh1); + auto b = m2->add_parameter("b", sh2); + auto dot = m2->add_instruction(migraphx::make_op("dot"), a, b); + m2->add_return({dot}); + } + + std::vector av(sh1.elements(), 10); + auto a = migraphx::argument(sh1, av.data()); + std::vector bv(sh2.elements(), 10); + auto b = migraphx::argument(sh2, bv.data()); + p1.compile(migraphx::target(migraphx::ref::target{})); + p2.compile(migraphx::target(migraphx::ref::target{})); + + auto result1 = p1.eval({{"a", a}, {"b", b}}).back(); + std::vector rv1(sh3.elements()); + result1.visit([&](auto output) { rv1.assign(output.begin(), output.end()); }); + auto result2 = p2.eval({{"a", a}, {"b", b}}).back(); + std::vector rv2(sh3.elements()); + result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify_range(rv1, rv2)); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 9a934a5fa94a70f4168c94fafc8e47e3dda1a81b..56fe8216b49b98e44ff6f36dca7a1ade8c816a26 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -5,362 +5,1117 @@ #include #include #include +#include + +#include + #include -void run_pass(migraphx::program& p) +void run_pass(migraphx::module& m) { - migraphx::run_passes(p, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}}); } TEST_CASE(double_contig) { migraphx::program p; - auto l = p.add_literal(get_2x2()); - auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); - auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1); - auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1); - p.add_instruction(pass_op{}, c2); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().transposed()); - run_pass(p); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().transposed()); - EXPECT(std::distance(p.begin(), p.end()) == 4); - auto result = p.eval({}); + auto* mm = p.get_main_module(); + + auto l = mm->add_literal(get_2x2()); + auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t1); + auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), c1); + mm->add_return({c2}); + EXPECT(mm->get_output_shapes().back().standard()); + EXPECT(not mm->get_output_shapes().back().transposed()); + run_pass(*mm); + EXPECT(mm->get_output_shapes().back().standard()); + EXPECT(not mm->get_output_shapes().back().transposed()); + EXPECT(std::distance(mm->begin(), mm->end()) == 4); + auto result = p.eval({}).back(); EXPECT(result != get_2x2()); } TEST_CASE(double_transpose) { migraphx::program p; - auto l = p.add_literal(get_2x2()); - auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); - auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1); - p.add_instruction(pass_op{}, t2); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().transposed()); - run_pass(p); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().transposed()); - EXPECT(std::distance(p.begin(), p.end()) == 2); - auto result = p.eval({}); + auto* mm = p.get_main_module(); + + auto l = mm->add_literal(get_2x2()); + auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t1); + mm->add_return({t2}); + EXPECT(mm->get_output_shapes().back().standard()); + EXPECT(not mm->get_output_shapes().back().transposed()); + run_pass(*mm); + EXPECT(mm->get_output_shapes().back().standard()); + EXPECT(not mm->get_output_shapes().back().transposed()); + EXPECT(std::distance(mm->begin(), mm->end()) == 2); + auto result = p.eval({}).back(); EXPECT(result == get_2x2()); } TEST_CASE(double_transpose_contig) { migraphx::program p; - auto l = p.add_literal(get_2x2()); - auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); - auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1); - auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1); - auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2); - p.add_instruction(pass_op{}, c2); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().transposed()); - run_pass(p); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().transposed()); - EXPECT(std::distance(p.begin(), p.end()) == 2); - auto result = p.eval({}); + auto* mm = p.get_main_module(); + + auto l = mm->add_literal(get_2x2()); + auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t1); + auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), c1); + auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), t2); + mm->add_return({c2}); + EXPECT(mm->get_output_shapes().back().standard()); + EXPECT(not mm->get_output_shapes().back().transposed()); + run_pass(*mm); + EXPECT(mm->get_output_shapes().back().standard()); + EXPECT(not mm->get_output_shapes().back().transposed()); + EXPECT(std::distance(mm->begin(), mm->end()) == 2); + auto result = p.eval({}).back(); EXPECT(result == get_2x2()); } TEST_CASE(single_transpose) { migraphx::program p; - auto l = p.add_literal(get_2x2()); - auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); - p.add_instruction(pass_op{}, t1); - EXPECT(not p.get_shape().standard()); - EXPECT(p.get_shape().transposed()); - run_pass(p); - EXPECT(not p.get_shape().standard()); - EXPECT(p.get_shape().transposed()); - EXPECT(std::distance(p.begin(), p.end()) == 3); - auto result = p.eval({}); + auto* mm = p.get_main_module(); + + auto l = mm->add_literal(get_2x2()); + auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + mm->add_return({t1}); + EXPECT(not mm->get_output_shapes().back().standard()); + EXPECT(mm->get_output_shapes().back().transposed()); + run_pass(*mm); + EXPECT(not mm->get_output_shapes().back().standard()); + EXPECT(mm->get_output_shapes().back().transposed()); + EXPECT(std::distance(mm->begin(), mm->end()) == 3); + auto result = p.eval({}).back(); EXPECT(result != get_2x2()); } TEST_CASE(double_transpose_sin_pass) { migraphx::program p; - auto l = p.add_literal(get_2x2()); - auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); - p.add_instruction(migraphx::op::transpose{{1, 0}}, t1); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().transposed()); - run_pass(p); - EXPECT(p.get_shape().standard()); - EXPECT(not p.get_shape().transposed()); + auto* mm = p.get_main_module(); + + auto l = mm->add_literal(get_2x2()); + auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t1); + EXPECT(mm->get_output_shapes().back().standard()); + EXPECT(not mm->get_output_shapes().back().transposed()); + run_pass(*mm); + EXPECT(mm->get_output_shapes().back().standard()); + EXPECT(not mm->get_output_shapes().back().transposed()); // TODO: Fix this - // EXPECT(std::distance(p.begin(), p.end()) == 1); - auto result = p.eval({}); + // EXPECT(std::distance(mm->begin(), mm->end()) == 1); + auto result = p.eval({}).back(); EXPECT(result == get_2x2()); } TEST_CASE(single_transpose_sin_pass) { migraphx::program p; - auto l = p.add_literal(get_2x2()); - p.add_instruction(migraphx::op::transpose{{1, 0}}, l); - EXPECT(not p.get_shape().standard()); - EXPECT(p.get_shape().transposed()); - run_pass(p); - EXPECT(not p.get_shape().standard()); - EXPECT(p.get_shape().transposed()); - EXPECT(std::distance(p.begin(), p.end()) == 2); - auto result = p.eval({}); + auto* mm = p.get_main_module(); + + auto l = mm->add_literal(get_2x2()); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l); + EXPECT(not mm->get_output_shapes().back().standard()); + EXPECT(mm->get_output_shapes().back().transposed()); + run_pass(*mm); + EXPECT(not mm->get_output_shapes().back().standard()); + EXPECT(mm->get_output_shapes().back().transposed()); + EXPECT(std::distance(mm->begin(), mm->end()) == 2); + auto result = p.eval({}).back(); EXPECT(result != get_2x2()); } TEST_CASE(reshape_transpose) { - migraphx::program p; + migraphx::module m; + auto s = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}}; - auto x = p.add_parameter("x", s); - auto r1 = p.add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x); - auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1); - auto ct = p.add_instruction(migraphx::op::contiguous{}, t); - auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct); - p.add_instruction(pass_op{}, r2); - EXPECT(p.get_shape() == s); - auto n = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(p.get_shape() == s); - EXPECT(std::distance(p.begin(), p.end()) == n); + auto x = m.add_parameter("x", s); + auto r1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 4, 28, 56, 56}}}), x); + auto t = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3, 4}}}), r1); + auto ct = m.add_instruction(migraphx::make_op("contiguous"), t); + auto r2 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 112, 56, 56}}}), ct); + m.add_return({r2}); + EXPECT(m.get_output_shapes().back() == s); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(m.get_output_shapes().back() == s); + EXPECT(std::distance(m.begin(), m.end()) == n); } TEST_CASE(transpose_contiguous) { - migraphx::program p; + migraphx::module m; + auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}}; - auto x = p.add_parameter("x", s); - auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x); - auto c1 = p.add_instruction(migraphx::op::contiguous{}, t); - p.add_instruction(pass_op{}, c1); - auto out_shape = p.get_shape(); - auto n = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(p.get_shape() == out_shape); - EXPECT(std::distance(p.begin(), p.end()) == n); + auto x = m.add_parameter("x", s); + auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x); + auto c1 = m.add_instruction(migraphx::make_op("contiguous"), t); + m.add_return({c1}); + auto out_shape = m.get_output_shapes().back(); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(m.get_output_shapes().back() == out_shape); + EXPECT(std::distance(m.begin(), m.end()) == n); } TEST_CASE(transpose_double_contiguous) { - migraphx::program p; + migraphx::module m; + auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}}; - auto x = p.add_parameter("x", s); - auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x); - auto c1 = p.add_instruction(migraphx::op::contiguous{}, t); - auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1); - p.add_instruction(pass_op{}, c2); - auto out_shape = p.get_shape(); - auto n = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(p.get_shape() == out_shape); - EXPECT(std::distance(p.begin(), p.end()) == n - 1); - EXPECT(p.has_instruction(t)); + auto x = m.add_parameter("x", s); + auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x); + auto c1 = m.add_instruction(migraphx::make_op("contiguous"), t); + auto c2 = m.add_instruction(migraphx::make_op("contiguous"), c1); + m.add_return({c2}); + auto out_shape = m.get_output_shapes().back(); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(m.get_output_shapes().back() == out_shape); + EXPECT(std::distance(m.begin(), m.end()) == n - 1); + EXPECT(m.has_instruction(t)); } TEST_CASE(transpose_partial1) { - migraphx::program p; + migraphx::module m; + auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; - auto x = p.add_parameter("x", s); - auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x); - auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1); - p.add_instruction(pass_op{}, t2); - auto out_shape = p.get_shape(); - auto n = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(p.get_shape() == out_shape); - EXPECT(std::distance(p.begin(), p.end()) == n - 1); + auto x = m.add_parameter("x", s); + auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), x); + auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), t1); + m.add_return({t2}); + auto out_shape = m.get_output_shapes().back(); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(m.get_output_shapes().back() == out_shape); + EXPECT(std::distance(m.begin(), m.end()) == n - 1); } TEST_CASE(transpose_partial2) { - migraphx::program p; + migraphx::module m; + auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; - auto x = p.add_parameter("x", s); - auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x); - auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1); - auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2); - p.add_instruction(pass_op{}, t3); - auto out_shape = p.get_shape(); - auto n = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(p.get_shape() == out_shape); - EXPECT(std::distance(p.begin(), p.end()) == n - 2); + auto x = m.add_parameter("x", s); + auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), x); + auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), t1); + auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), t2); + m.add_return({t3}); + auto out_shape = m.get_output_shapes().back(); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(m.get_output_shapes().back() == out_shape); + EXPECT(std::distance(m.begin(), m.end()) == n - 2); } TEST_CASE(transpose_partial3) { - migraphx::program p; + migraphx::module m; + auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; - auto x = p.add_parameter("x", s); - auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x); - auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1); - auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2); - auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3); - p.add_instruction(pass_op{}, t4); - auto out_shape = p.get_shape(); - auto n = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(p.get_shape() == out_shape); - EXPECT(std::distance(p.begin(), p.end()) == n - 3); + auto x = m.add_parameter("x", s); + auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), x); + auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 2, 0}}}), t1); + auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), t2); + auto t4 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), t3); + m.add_return({t4}); + auto out_shape = m.get_output_shapes().back(); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(m.get_output_shapes().back() == out_shape); + EXPECT(std::distance(m.begin(), m.end()) == n - 3); } TEST_CASE(nop_transpose1) { - migraphx::program p; + migraphx::module m; + auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; - auto x = p.add_parameter("x", s); - auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x); - p.add_instruction(pass_op{}, t); - auto out_shape = p.get_shape(); - auto n = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(p.get_shape() == out_shape); - EXPECT(std::distance(p.begin(), p.end()) == n - 1); + auto x = m.add_parameter("x", s); + auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), x); + m.add_return({t}); + auto out_shape = m.get_output_shapes().back(); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(m.get_output_shapes().back() == out_shape); + EXPECT(std::distance(m.begin(), m.end()) == n - 1); } TEST_CASE(nop_transpose2) { - migraphx::program p; + migraphx::module m; + auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; - auto x = p.add_parameter("x", s); - auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x); - auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t1); - auto t3 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2); - auto t4 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3); - p.add_instruction(pass_op{}, t4); - auto out_shape = p.get_shape(); - auto n = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(p.get_shape() == out_shape); - EXPECT(std::distance(p.begin(), p.end()) == n - 4); + auto x = m.add_parameter("x", s); + auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), x); + auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), t1); + auto t3 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), t2); + auto t4 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), t3); + m.add_instruction(pass_op{}, t4); + auto out_shape = m.get_output_shapes().back(); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(m.get_output_shapes().back() == out_shape); + EXPECT(std::distance(m.begin(), m.end()) == n - 4); } TEST_CASE(nop_transpose3) { - migraphx::program p; + migraphx::module m; + auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - auto concat = p.add_instruction(migraphx::op::concat{3}, x, y); - auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat); - auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1); - p.add_instruction(pass_op{}, t2); - auto out_shape = p.get_shape(); - auto n = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(p.get_shape() == out_shape); - EXPECT(std::distance(p.begin(), p.end()) == n - 1); + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), x, y); + auto t1 = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 2, 3}}}), concat); + auto t2 = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t1); + m.add_return({t2}); + auto out_shape = m.get_output_shapes().back(); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(m.get_output_shapes().back() == out_shape); + EXPECT(std::distance(m.begin(), m.end()) == n - 1); +} + +TEST_CASE(nop_convert) +{ + migraphx::module m; + + auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; + auto x = m.add_parameter("x", s); + auto t = m.add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), + x); + m.add_return({t}); + auto out_shape = m.get_output_shapes().back(); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(m.get_output_shapes().back() == out_shape); + EXPECT(std::distance(m.begin(), m.end()) == n - 1); } TEST_CASE(concat_transpose1) { - migraphx::program p; - auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - auto xt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x); - auto yt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y); - auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt); - auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat); - p.add_instruction(pass_op{}, t); - auto out_shape = p.get_shape(); - auto n = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(p.get_shape().lens() == out_shape.lens()); - EXPECT(std::distance(p.begin(), p.end()) == n - 3); + migraphx::module m; + + auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto xt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), x); + auto yt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), y); + auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), xt, yt); + auto t = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), concat); + m.add_return({t}); + auto out_shape = m.get_output_shapes().back(); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(m.get_output_shapes().back().lens() == out_shape.lens()); + EXPECT(std::distance(m.begin(), m.end()) == n - 3); auto new_concat = - std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }); - EXPECT(bool{new_concat != p.end()}); + std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }); + EXPECT(bool{new_concat != m.end()}); EXPECT(migraphx::any_cast(new_concat->get_operator()).axis == 3); } TEST_CASE(concat_transpose2) { - migraphx::program p; - auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x); - auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y); - auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt); - auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat); - p.add_instruction(pass_op{}, t); - auto out_shape = p.get_shape(); - auto n = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(p.get_shape().lens() == out_shape.lens()); - EXPECT(std::distance(p.begin(), p.end()) == n - 2); + migraphx::module m; + + auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto xt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); + auto yt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), y); + auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", -1}}), xt, yt); + auto t = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), concat); + m.add_return({t}); + auto out_shape = m.get_output_shapes().back(); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(m.get_output_shapes().back().lens() == out_shape.lens()); + EXPECT(std::distance(m.begin(), m.end()) == n - 2); auto new_concat = - std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }); - EXPECT(bool{new_concat != p.end()}); + std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }); + EXPECT(bool{new_concat != m.end()}); EXPECT(migraphx::any_cast(new_concat->get_operator()).axis == 1); } TEST_CASE(concat_transpose3) { - migraphx::program p; - auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; - auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}); - auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}}); - auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x); - auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y); - auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt); - auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat); - p.add_instruction(pass_op{}, t); - auto out_shape = p.get_shape(); - auto n = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(p.get_shape().lens() == out_shape.lens()); - EXPECT(std::distance(p.begin(), p.end()) == n - 2); + migraphx::module m; + + auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; + auto x = m.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}); + auto y = m.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}}); + auto xt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); + auto yt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), y); + auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt); + auto t = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), concat); + m.add_return({t}); + auto out_shape = m.get_output_shapes().back(); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(m.get_output_shapes().back().lens() == out_shape.lens()); + EXPECT(std::distance(m.begin(), m.end()) == n - 2); auto new_concat = - std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }); - EXPECT(bool{new_concat != p.end()}); + std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }); + EXPECT(bool{new_concat != m.end()}); EXPECT(migraphx::any_cast(new_concat->get_operator()).axis == 1); } +TEST_CASE(concat_transpose4) +{ + migraphx::module m; + auto sx = migraphx::shape{migraphx::shape::float_type, {1, 1, 12, 64}}; + auto sy = migraphx::shape{migraphx::shape::float_type, {1, 12, 1, 64}}; + auto x = m.add_parameter("x", sx); + auto y = m.add_parameter("y", sy); + auto xt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); + auto yt = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), y); + auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), xt, yt); + auto t = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), concat); + m.add_return({t}); + + migraphx::module m1 = m; + run_pass(m); + + EXPECT(m1 == m); +} + TEST_CASE(nested_concat) { - migraphx::program p; + migraphx::module m; + auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y); - auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x); - auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2); - p.add_instruction(pass_op{}, concat3); - auto out_shape = p.get_shape(); - auto n = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(p.get_shape().lens() == out_shape.lens()); - EXPECT(std::distance(p.begin(), p.end()) == n - 2); - EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1); + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto concat1 = m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + auto concat2 = m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), y, x); + auto concat3 = m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), concat1, concat2); + m.add_return({concat3}); + auto out_shape = m.get_output_shapes().back(); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(m.get_output_shapes().back().lens() == out_shape.lens()); + EXPECT(std::distance(m.begin(), m.end()) == n - 2); + EXPECT(std::count_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }) == 1); } TEST_CASE(nested_concat_partial) { - migraphx::program p; + migraphx::module m; + auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}}; - auto x = p.add_parameter("x", s); - auto y = p.add_parameter("y", s); - auto l = p.add_literal( + auto x = m.add_parameter("x", s); + auto y = m.add_parameter("y", s); + auto l = m.add_literal( migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 4, 3, 4}})); - auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y); - auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x); - auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2, l); - p.add_instruction(pass_op{}, concat3); - auto out_shape = p.get_shape(); - auto n = std::distance(p.begin(), p.end()); - run_pass(p); - EXPECT(p.get_shape().lens() == out_shape.lens()); - EXPECT(std::distance(p.begin(), p.end()) == n - 2); - EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1); + auto concat1 = m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + auto concat2 = m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), y, x); + auto concat3 = + m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), concat1, concat2, l); + m.add_return({concat3}); + auto out_shape = m.get_output_shapes().back(); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(m.get_output_shapes().back().lens() == out_shape.lens()); + EXPECT(std::distance(m.begin(), m.end()) == n - 2); + EXPECT(std::count_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }) == 1); +} + +TEST_CASE(multibroadcast_simplify) +{ + migraphx::module m; + + std::vector s_lens{1, 2, 3, 4}; + auto s = migraphx::shape{migraphx::shape::float_type, s_lens}; + auto x = m.add_parameter("x", s); + auto y = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_lens}}), x); + m.add_instruction(migraphx::make_op("mul"), y, y); + auto n = std::distance(m.begin(), m.end()); + run_pass(m); + EXPECT(std::distance(m.begin(), m.end()) == n - 1); +} + +TEST_CASE(double_slice1) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {256}}); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {256}}}), x); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {64}}}), slice1); + m1.add_return({slice2}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {256}}); + auto slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {64}}, {"ends", {96}}}), x); + m2.add_return({slice}); + } + EXPECT(m1 == m2); +} + +TEST_CASE(double_slice2) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {256}}); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {128}}}), x); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {32}}}), slice1); + m1.add_return({slice2}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {256}}); + auto slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {64}}}), x); + m2.add_return({slice}); + } + EXPECT(m1 == m2); +} + +TEST_CASE(double_slice_multi_axes) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {256, 128}}); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {32}}, {"ends", {128}}}), x); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {32}}}), slice1); + m1.add_return({slice2}); + } + run_pass(m1); + + migraphx::module m2; + + { + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {256, 128}}); + auto slice = m2.add_instruction( + migraphx::make_op("slice", + {{"axes", {0, 1}}, {"starts", {32, 0}}, {"ends", {128, 32}}}), + x); + m2.add_return({slice}); + } + EXPECT(m1 == m2); +} + +TEST_CASE(optimize_resize) +{ + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; + auto create_resize_module = [&] { + migraphx::module m; + auto inx = m.add_parameter("X", sx); + + migraphx::shape si{migraphx::shape::int32_type, {1, 2, 4, 6}}; + std::vector ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, + 3, 3, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}; + auto li = m.add_literal(migraphx::literal(si, ind)); + + auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx); + auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + auto r = m.add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), gr); + m.add_return({r}); + + return m; + }; + + auto m1 = create_resize_module(); + run_pass(m1); + + auto create_optimized_module = [&] { + migraphx::module m; + auto inx = m.add_parameter("X", sx); + std::vector dims = {1, 1, 2, 1, 2, 1}; + auto rspx = m.add_instruction(migraphx::make_op("reshape", {{"dims", dims}}), inx); + std::vector mb_dims = {1, 2, 2, 2, 2, 3}; + auto mbx = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_dims}}), rspx); + auto std_mb = m.add_instruction(migraphx::make_op("contiguous"), mbx); + std::vector orig_dims = {1, 2, 4, 6}; + auto rmb = m.add_instruction(migraphx::make_op("reshape", {{"dims", orig_dims}}), std_mb); + auto r = m.add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), rmb); + m.add_return({r}); + + return m; + }; + + EXPECT(m1 == create_optimized_module()); +} + +TEST_CASE(optimize_resize_ind_not_apply) +{ + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; + auto create_resize_module = [&] { + migraphx::module m; + auto inx = m.add_parameter("X", sx); + + migraphx::shape si{migraphx::shape::int32_type, {1, 2, 4, 6}}; + std::vector ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 2, 2, 2, 3, + 3, 3, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}; + auto li = m.add_literal(migraphx::literal(si, ind)); + + auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx); + auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + auto r = m.add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), gr); + m.add_return({r}); + + return m; + }; + + auto m1 = create_resize_module(); + run_pass(m1); + EXPECT(m1 == create_resize_module()); +} + +TEST_CASE(optimize_resize_rsp_dim_1) +{ + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; + auto create_resize_module = [&] { + migraphx::module m; + auto inx = m.add_parameter("X", sx); + + migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 3, 2}}; + std::vector ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, + 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}; + auto li = m.add_literal(migraphx::literal(si, ind)); + + auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2}}}), inx); + auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + m.add_return({r}); + + return m; + }; + + auto m = create_resize_module(); + run_pass(m); + EXPECT(m == create_resize_module()); +} + +TEST_CASE(optimize_resize_ndims_unequal) +{ + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; + migraphx::shape sy{migraphx::shape::float_type, {1, 1, 4, 3, 2}}; + auto create_resize_module = [&] { + migraphx::module m; + auto inx = m.add_parameter("X", sx); + auto iny = m.add_parameter("Y", sy); + + migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 3, 2}}; + std::vector ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, + 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}; + auto li = m.add_literal(migraphx::literal(si, ind)); + + auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx); + auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + auto r = m.add_instruction(migraphx::make_op("sub"), iny, gr); + m.add_return({r}); + + return m; + }; + + auto m = create_resize_module(); + run_pass(m); + EXPECT(m == create_resize_module()); +} + +TEST_CASE(optimize_resize_ind_non_brcst) +{ + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; + migraphx::shape sy{migraphx::shape::float_type, {1, 1, 4, 6}}; + auto create_resize_module = [&] { + migraphx::module m; + auto inx = m.add_parameter("X", sx); + auto iny = m.add_parameter("Y", sy); + + migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}}; + std::vector ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, + 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}; + auto li = m.add_literal(migraphx::literal(si, ind)); + + auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {6}}}), inx); + auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + auto r = m.add_instruction(migraphx::make_op("sub"), iny, gr); + m.add_return({r}); + + return m; + }; + + auto m = create_resize_module(); + run_pass(m); + EXPECT(m == create_resize_module()); +} + +TEST_CASE(optimize_resize_ind_non_const) +{ + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; + migraphx::shape sy{migraphx::shape::float_type, {1, 1, 4, 6}}; + auto create_resize_module = [&] { + migraphx::module m; + auto inx = m.add_parameter("X", sx); + auto iny = m.add_parameter("Y", sy); + + migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}}; + auto li = m.add_parameter("ind", si); + auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {6}}}), inx); + auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + auto r = m.add_instruction(migraphx::make_op("sub"), iny, gr); + m.add_return({r}); + + return m; + }; + + auto m = create_resize_module(); + run_pass(m); + EXPECT(m == create_resize_module()); +} + +TEST_CASE(optimize_where_true) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; + auto create_where_module = [&](bool cond) { + migraphx::module m; + auto inx = m.add_parameter("X", s); + auto iny = m.add_parameter("Y", s); + + migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; + std::vector idata(si.elements(), static_cast(cond)); + auto li = m.add_literal(migraphx::literal(si, idata)); + auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m.add_return({r}); + return m; + }; + + auto return_xy = [&](bool cond) { + migraphx::module m; + auto x = m.add_parameter("X", s); + auto y = m.add_parameter("Y", s); + cond ? m.add_return({x}) : m.add_return({y}); + return m; + }; + + auto m = create_where_module(true); + run_pass(m); + EXPECT(m == return_xy(true)); + + auto m1 = create_where_module(false); + run_pass(m1); + EXPECT(m1 == return_xy(false)); +} + +TEST_CASE(where_different_cond_values) +{ + auto create_where_module = [] { + migraphx::module m; + migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; + auto inx = m.add_parameter("X", s); + auto iny = m.add_parameter("Y", s); + + migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; + std::vector idata = {1, 1, 0, 1, 0, 1}; + auto li = m.add_literal(migraphx::literal(si, idata)); + auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m.add_return({r}); + return m; + }; + + auto m = create_where_module(); + run_pass(m); + EXPECT(m == create_where_module()); +} + +TEST_CASE(where_axis_nonzero) +{ + auto create_where_module = [] { + migraphx::module m; + migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; + auto inx = m.add_parameter("X", s); + auto iny = m.add_parameter("Y", s); + + migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; + std::vector idata(6, 1); + auto li = m.add_literal(migraphx::literal(si, idata)); + auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), inx, iny); + auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m.add_return({r}); + return m; + }; + + auto m = create_where_module(); + run_pass(m); + EXPECT(m == create_where_module()); +} + +TEST_CASE(where_three_concat_inputs) +{ + auto create_where_module = [] { + migraphx::module m; + migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; + auto inx = m.add_parameter("X", s); + auto iny = m.add_parameter("Y", s); + + migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; + std::vector idata(6, 1); + auto li = m.add_literal(migraphx::literal(si, idata)); + auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny, inx); + auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data); + auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m.add_return({r}); + return m; + }; + + auto m = create_where_module(); + run_pass(m); + EXPECT(m == create_where_module()); +} + +TEST_CASE(where_three_inputs_diff_shapes) +{ + auto create_where_module = [] { + migraphx::module m; + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; + migraphx::shape sy{migraphx::shape::float_type, {2, 1, 3, 2}}; + auto inx = m.add_parameter("X", sx); + auto iny = m.add_parameter("Y", sy); + + migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; + std::vector idata(6, 1); + auto li = m.add_literal(migraphx::literal(si, idata)); + auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data); + auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m.add_return({r}); + return m; + }; + + auto m = create_where_module(); + run_pass(m); + EXPECT(m == create_where_module()); +} + +TEST_CASE(where_three_lens_diff) +{ + auto create_where_module = [] { + migraphx::module m; + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; + migraphx::shape sy{migraphx::shape::float_type, {1, 1, 3, 2}}; + auto inx = m.add_parameter("X", sx); + auto iny = m.add_parameter("Y", sy); + + migraphx::shape si{migraphx::shape::bool_type, {1, 1, 6}}; + std::vector idata(6, 1); + auto li = m.add_literal(migraphx::literal(si, idata)); + auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m.add_return({r}); + return m; + }; + + auto m = create_where_module(); + run_pass(m); + EXPECT(m == create_where_module()); +} + +TEST_CASE(reshape_cont) +{ + auto create_module = [] { + migraphx::module m; + migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}}; + migraphx::shape sy{migraphx::shape::float_type, {2, 2, 2, 6}}; + + auto inx = m.add_parameter("x", sx); + auto iny = m.add_parameter("y", sy); + auto mb_inx = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), inx); + auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx); + auto rsp = + m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx); + auto r = m.add_instruction(migraphx::make_op("add"), rsp, iny); + m.add_return({r}); + + return m; + }; + + auto m1 = create_module(); + run_pass(m1); + + auto create_opt_module = [] { + migraphx::module m; + migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}}; + migraphx::shape sy{migraphx::shape::float_type, {2, 2, 2, 6}}; + + auto inx = m.add_parameter("x", sx); + auto iny = m.add_parameter("y", sy); + auto mb_inx = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), inx); + auto rsp_iny = m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4, 6}}}), iny); + auto sum = m.add_instruction(migraphx::make_op("add"), mb_inx, rsp_iny); + auto r = m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), sum); + m.add_return({r}); + + return m; + }; + + EXPECT(m1 == create_opt_module()); +} + +TEST_CASE(reshape_input_non_std) +{ + auto create_module = [] { + migraphx::module m; + migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}}; + migraphx::shape sy{migraphx::shape::float_type, {2, 6, 2, 2}}; + + auto inx = m.add_parameter("x", sx); + auto iny = m.add_parameter("y", sy); + auto mb_inx = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), inx); + auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx); + auto rsp = + m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx); + auto ty = + m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), iny); + auto r = m.add_instruction(migraphx::make_op("add"), rsp, ty); + m.add_return({r}); + + return m; + }; + + auto m1 = create_module(); + run_pass(m1); + + EXPECT(m1 == create_module()); +} + +TEST_CASE(reshape_cont_nonpw) +{ + auto create_module = [] { + migraphx::module m; + migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}}; + migraphx::shape sy{migraphx::shape::float_type, {2, 2, 2, 6}}; + + auto inx = m.add_parameter("x", sx); + auto iny = m.add_parameter("y", sy); + auto mb_inx = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), inx); + auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx); + auto rsp = + m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx); + auto r = m.add_instruction(migraphx::make_op("convolution"), rsp, iny); + m.add_return({r}); + + return m; + }; + + auto m1 = create_module(); + run_pass(m1); + + EXPECT(m1 == create_module()); +} + +TEST_CASE(transpose_contiguous_reshape_unary) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}}); + auto reshape_ins1 = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x); + auto transpose_ins = m1.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1); + auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins); + auto reshape_ins2 = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins); + auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins2); + m1.add_instruction(pass_op{}, relu); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}}); + auto reshape_ins1 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x); + auto transpose_ins = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1); + auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose_ins); + auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), relu); + auto reshape_ins2 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins); + m2.add_instruction(pass_op{}, reshape_ins2); + } + EXPECT(m1 == m2); +} + +TEST_CASE(transpose_contiguous_squeeze_unary) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}}); + auto transpose_ins = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); + auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins); + auto sq_ins = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), cont_ins); + auto rsqrt = m1.add_instruction(migraphx::make_op("rsqrt"), sq_ins); + m1.add_instruction(pass_op{}, rsqrt); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}}); + auto transpose_ins = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); + auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins); + auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), rsqrt); + auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), cont_ins); + m2.add_instruction(pass_op{}, sq_ins); + } + EXPECT(m1 == m2); +} + +TEST_CASE(transpose_contiguous_unsqueeze_unary) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}}); + auto transpose_ins = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); + auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins); + auto unsq_ins = + m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins); + auto round = m1.add_instruction(migraphx::make_op("round"), unsq_ins); + m1.add_instruction(pass_op{}, round); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}}); + auto transpose_ins = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); + auto round = m2.add_instruction(migraphx::make_op("round"), transpose_ins); + auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), round); + auto unsq_ins = + m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins); + m2.add_instruction(pass_op{}, unsq_ins); + } + EXPECT(m1 == m2); +} + +TEST_CASE(transpose_contiguous_reshape_binary_packed) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 128, 28, 28}}); + auto w1 = m1.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); + auto conv1 = m1.add_instruction( + migraphx::make_op("convolution", + {{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}), + x, + w1); // (2, 256, 28, 28) + auto w2 = m1.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {512, 256, 1, 1}})); + auto conv2 = m1.add_instruction( + migraphx::make_op("convolution", + {{"padding", {0, 0}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + conv1, + w2); // (2, 512, 14, 14) + + auto conv2_rsp1 = m1.add_instruction( + migraphx::make_op("reshape", {{"dims", {2, 2, 2, 128, 14, 14}}}), conv2); + auto conv2_trans = m1.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), conv2_rsp1); + auto conv2_cont = m1.add_instruction(migraphx::make_op("contiguous"), conv2_trans); + auto conv2_rsp2 = m1.add_instruction( + migraphx::make_op("reshape", {{"dims", {2, 128, 28, 28}}}), conv2_cont); + auto add_ins = m1.add_instruction(migraphx::make_op("add"), conv2_rsp2, x); + m1.add_instruction(pass_op{}, add_ins); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 128, 28, 28}}); + auto w1 = m2.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}})); + auto conv1 = m2.add_instruction( + migraphx::make_op("convolution", + {{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}), + x, + w1); // (2, 256, 28, 28) + auto w2 = m2.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {512, 256, 1, 1}})); + auto conv2 = m2.add_instruction( + migraphx::make_op("convolution", + {{"padding", {0, 0}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + conv1, + w2); // (2, 512, 14, 14) + + auto conv2_rsp = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", {2, 2, 2, 128, 14, 14}}}), conv2); + auto conv2_trans = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), conv2_rsp); + auto x_rsp = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 128, 14, 2, 14, 2}}}), x); + auto add_ins = m2.add_instruction(migraphx::make_op("add"), conv2_trans, x_rsp); + auto add_rsp = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 128, 28, 28}}}), add_ins); + m2.add_instruction(pass_op{}, add_rsp); + } + EXPECT(m1 == m2); +} + +TEST_CASE(transpose_contiguous_reshape_binary_broadcast) +{ + migraphx::module m1; + { + migraphx::shape sx{migraphx::shape::float_type, {4}}; + migraphx::shape sy{migraphx::shape::float_type, {2, 6, 2, 2}}; + + auto x = m1.add_parameter("x", sx); + auto y = m1.add_parameter("y", sy); + auto x_brcst = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 4, 6}}}), x); + auto y_trans = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), y); + auto y_cont = m1.add_instruction(migraphx::make_op("contiguous"), y_trans); + auto y_rsp = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4, 6}}}), y_cont); + auto r = m1.add_instruction(migraphx::make_op("add"), y_rsp, x_brcst); + m1.add_return({r}); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/stringutils.cpp b/test/stringutils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b2de0ea32702f85c5b61e39b92030ba13196d266 --- /dev/null +++ b/test/stringutils.cpp @@ -0,0 +1,79 @@ +#include +#include + +TEST_CASE(interpolate_string_simple1) +{ + std::string input = "Hello ${w}!"; + auto s = migraphx::interpolate_string(input, {{"w", "world"}}); + EXPECT(s == "Hello world!"); +} + +TEST_CASE(interpolate_string_simple2) +{ + std::string input = "${hello}"; + auto s = migraphx::interpolate_string(input, {{"hello", "bye"}}); + EXPECT(s == "bye"); +} + +TEST_CASE(interpolate_string_unbalanced) +{ + std::string input = "${hello"; + EXPECT(test::throws([&] { migraphx::interpolate_string(input, {{"hello", "bye"}}); })); +} + +TEST_CASE(interpolate_string_extra_space) +{ + std::string input = "${ hello }"; + auto s = migraphx::interpolate_string(input, {{"hello", "bye"}}); + EXPECT(s == "bye"); +} + +TEST_CASE(interpolate_string_multiple) +{ + std::string input = "${h} ${w}!"; + auto s = migraphx::interpolate_string(input, {{"w", "world"}, {"h", "Hello"}}); + EXPECT(s == "Hello world!"); +} + +TEST_CASE(interpolate_string_next) +{ + std::string input = "${hh}${ww}!"; + auto s = migraphx::interpolate_string(input, {{"ww", "world"}, {"hh", "Hello"}}); + EXPECT(s == "Helloworld!"); +} + +TEST_CASE(interpolate_string_dollar_sign) +{ + std::string input = "$hello"; + auto s = migraphx::interpolate_string(input, {{"hello", "bye"}}); + EXPECT(s == "$hello"); +} + +TEST_CASE(interpolate_string_missing) +{ + std::string input = "${hello}"; + EXPECT(test::throws([&] { migraphx::interpolate_string(input, {{"h", "bye"}}); })); +} + +TEST_CASE(interpolate_string_custom1) +{ + std::string input = "****{{a}}****"; + auto s = migraphx::interpolate_string(input, {{"a", "b"}}, "{{", "}}"); + EXPECT(s == "****b****"); +} + +TEST_CASE(interpolate_string_custom2) +{ + std::string input = "****{{{a}}}****"; + auto s = migraphx::interpolate_string(input, {{"a", "b"}}, "{{{", "}}}"); + EXPECT(s == "****b****"); +} + +TEST_CASE(interpolate_string_custom3) +{ + std::string input = "****{{{{a}}}}****"; + auto s = migraphx::interpolate_string(input, {{"a", "b"}}, "{{{{", "}}}}"); + EXPECT(s == "****b****"); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/targets.cpp b/test/targets.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2e286022ed73054d080c82f4ff6ec975a6b454f9 --- /dev/null +++ b/test/targets.cpp @@ -0,0 +1,26 @@ +#include +#include +#include +#include "test.hpp" + +TEST_CASE(make_target) +{ + for(const auto& name : migraphx::get_targets()) + { + auto t = migraphx::make_target(name); + CHECK(t.name() == name); + } +} + +TEST_CASE(make_invalid_target) +{ + EXPECT(test::throws([&] { migraphx::make_target("mi100"); })); +} + +TEST_CASE(targets) +{ + auto ts = migraphx::get_targets(); + EXPECT(ts.size() > 0); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/tf/add_bcast_test.pb b/test/tf/add_bcast_test.pb index da70f7e1914d321b2d8036c59a79b00110f17bd0..d20cc975dc2699aa3d3728b05b3bf925ba601877 100644 --- a/test/tf/add_bcast_test.pb +++ b/test/tf/add_bcast_test.pb @@ -1,9 +1,9 @@ 2 -0 Placeholder* +0 Placeholder* +dtype0* shape -:* -dtype0 +: 2 1 Placeholder* dtype0* @@ -12,4 +12,4 @@ add_bcast1Add01* -T0" \ No newline at end of file +T0"¸ \ No newline at end of file diff --git a/test/tf/add_test.pb b/test/tf/add_test.pb index 58e24c5aebecdf86d5b740dc470d70af92143cdf..f176c1b2b9340c497f5b1db72a64dcc87aa711ff 100644 --- a/test/tf/add_test.pb +++ b/test/tf/add_test.pb @@ -1,12 +1,12 @@ : -0 Placeholder* -shape:* -dtype0 +0 Placeholder* +dtype0* +shape: : 1 Placeholder* dtype0* shape:  add1Add01* -T0" \ No newline at end of file +T0"¸ \ No newline at end of file diff --git a/test/tf/addv2_test.pb b/test/tf/addv2_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..8fcfc0e3cd90995a61497b82c66b2e3603a54a6f --- /dev/null +++ b/test/tf/addv2_test.pb @@ -0,0 +1,12 @@ + +: +0 Placeholder* +dtype0* +shape: +: +1 Placeholder* +dtype0* +shape: + +add1AddV201* +T0"¸ \ No newline at end of file diff --git a/test/tf/argmax_test.pb b/test/tf/argmax_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..b95aee9f82b96c99d32cd202530c38cc130437a3 Binary files /dev/null and b/test/tf/argmax_test.pb differ diff --git a/test/tf/argmin_test.pb b/test/tf/argmin_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..910deda40532745759f30837abbe097cbed2a098 Binary files /dev/null and b/test/tf/argmin_test.pb differ diff --git a/test/tf/assert_less_equal_test.pb b/test/tf/assert_less_equal_test.pb index c6a75ea4c7cc62c65a1bde0972c62d2bfd2e66e1..f230cb17d145db8cb863c2c2adee08ddfd027fc3 100644 Binary files a/test/tf/assert_less_equal_test.pb and b/test/tf/assert_less_equal_test.pb differ diff --git a/test/tf/batchmatmul_test.pb b/test/tf/batchmatmul_test.pb index 08394a93d2724ec69cc546e6ca7c4e1f3d13c14d..ac0f8c6e35f220bbb6f6ec3142520216a8b419e2 100644 --- a/test/tf/batchmatmul_test.pb +++ b/test/tf/batchmatmul_test.pb @@ -1,14 +1,14 @@ : -0 Placeholder* -shape:* -dtype0 +0 Placeholder* +dtype0* +shape: : 1 Placeholder* dtype0* shape: -D - batchmatmul1 BatchMatMul01* +F + batchmatmul1 BatchMatMulV201* +T0* adj_x(* -adj_y(* -T0" \ No newline at end of file +adj_y("¸ \ No newline at end of file diff --git a/test/tf/batchnorm_test.pb b/test/tf/batchnorm_test.pb index 310e474d42ba040a2af9356955352c2a65938e0d..ffb2db868c4aa4be25fd377f55c246985430052a 100644 Binary files a/test/tf/batchnorm_test.pb and b/test/tf/batchnorm_test.pb differ diff --git a/test/tf/batchnormv3_test.pb b/test/tf/batchnormv3_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..bd62844c0c535cc26097da59c986f58dcee9d1af Binary files /dev/null and b/test/tf/batchnormv3_test.pb differ diff --git a/test/tf/biasadd_scalar_test.pb b/test/tf/biasadd_scalar_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..40d1141289c38cec299a4b1b742aa24176a5c695 Binary files /dev/null and b/test/tf/biasadd_scalar_test.pb differ diff --git a/test/tf/biasadd_test.pb b/test/tf/biasadd_test.pb index b708cbc3679782ef0a87c3d7f8043536aefe54c5..c426cb8be184f50bf49d2df4d12a8b14dc1a889f 100644 --- a/test/tf/biasadd_test.pb +++ b/test/tf/biasadd_test.pb @@ -1,8 +1,8 @@ ; -0 Placeholder* -shape:ô* -dtype0 +0 Placeholder* +dtype0* +shape:ô / 1 Placeholder* dtype0* @@ -10,4 +10,4 @@ : bias_add1BiasAdd01* T0* - data_formatNHWC" \ No newline at end of file + data_formatNHWC"¸ \ No newline at end of file diff --git a/test/tf/cast_test.pb b/test/tf/cast_test.pb index dd9c2488e229c5f36f3c6480d57f04dc64212e9c..d0a201ca5ecdd05dacf28ff03298ebc264a519f1 100644 Binary files a/test/tf/cast_test.pb and b/test/tf/cast_test.pb differ diff --git a/test/tf/concat_test.pb b/test/tf/concat_test.pb index f2b5317eb6f7b74962722adab9c703142fdf1429..faae6041abb300a2eba2a6a993f52df57ff7299d 100644 Binary files a/test/tf/concat_test.pb and b/test/tf/concat_test.pb differ diff --git a/test/tf/conv_add_test.pb b/test/tf/conv_add_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..b27a3c2642de6516d5d1459b0b5e0122cb23ff57 Binary files /dev/null and b/test/tf/conv_add_test.pb differ diff --git a/test/tf/conv_batch_test.pb b/test/tf/conv_batch_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..f1b555029e74eaa433ac6f29de28e09a5c880323 Binary files /dev/null and b/test/tf/conv_batch_test.pb differ diff --git a/test/tf/conv_nchw_test.pb b/test/tf/conv_nchw_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..a548dd3ae9fc6d516173f0ea28e221ae7e8f2893 Binary files /dev/null and b/test/tf/conv_nchw_test.pb differ diff --git a/test/tf/conv_relu6_test.pb b/test/tf/conv_relu6_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..6d63b53f9d48e262a0878b710a062a9e9cd9e644 Binary files /dev/null and b/test/tf/conv_relu6_test.pb differ diff --git a/test/tf/conv_relu_test.pb b/test/tf/conv_relu_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..3fb1c3b0223a30ebc99cb918174212e4f203399b Binary files /dev/null and b/test/tf/conv_relu_test.pb differ diff --git a/test/tf/conv_test.pb b/test/tf/conv_test.pb index 958b0ba73c2a1a67851d6d85f945ce6da24f88b5..bfb490d2c825fabeb56d37f209d6ebcd2251cc7a 100644 Binary files a/test/tf/conv_test.pb and b/test/tf/conv_test.pb differ diff --git a/test/tf/expanddims_test.pb b/test/tf/expanddims_test.pb index 7006d42d60020f97c602468cb3d878aac1c17342..8f9d946b62096e654d2574041413515d7ce86b1d 100644 Binary files a/test/tf/expanddims_test.pb and b/test/tf/expanddims_test.pb differ diff --git a/test/tf/gather_test.pb b/test/tf/gather_test.pb index 7ae2081fbdada86c25dce581de7692b9ff080cc0..0ddc69356c437c31644fe7a3436760fbed1b492d 100644 Binary files a/test/tf/gather_test.pb and b/test/tf/gather_test.pb differ diff --git a/test/tf/gen_tf_pb.py b/test/tf/gen_tf_pb.py index e399051d88caa62ccf4ffe09b84eddfe533ef0f1..b9fae1ed19c770b22a518d7f586a9d09f460d8e8 100644 --- a/test/tf/gen_tf_pb.py +++ b/test/tf/gen_tf_pb.py @@ -1,4 +1,6 @@ -import numpy as np +# This script generates tf pb files for MIGraphX tf operator tests. +# To generate an individual pb file, you can use the following +# command: python -c "import gen_tf_pb; gen_tf_pb.{test_name}_test()" import tensorflow as tf @@ -17,34 +19,72 @@ def tf_test(op_test): @tf_test def add_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 2, 2, 3), name='0') - g2_input = tf.placeholder(tf.float32, shape=(1, 2, 2, 3), name='1') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 2, 2, 3), + name='0') + g2_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 2, 2, 3), + name='1') tf.add(g1_input, g2_input, name='add1') +@tf_test +def addv2_test(g1): + with g1.as_default(): + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 2, 2, 3), + name='0') + g2_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 2, 2, 3), + name='1') + tf.raw_ops.AddV2(x=g1_input, y=g2_input, name='add1') + + @tf_test def add_bcast_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(2, 3), name='0') - g2_input = tf.placeholder(tf.float32, shape=(2, 1), name='1') + g1_input = tf.compat.v1.placeholder(tf.float32, shape=(2, 3), name='0') + g2_input = tf.compat.v1.placeholder(tf.float32, shape=(2, 1), name='1') tf.math.add(g1_input, g2_input, name='add_bcast1') +@tf_test +def argmax_test(g1): + with g1.as_default(): + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(3, 4, 5, 6), + name='0') + tf.argmax(g1_input, axis=2, name='argmax1') + + +@tf_test +def argmin_test(g1): + with g1.as_default(): + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(3, 4, 5, 6), + name='0') + tf.argmin(g1_input, axis=2, name='argmin1') + + @tf_test def assert_less_equal_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(2, 3), name='0') - g2_input = tf.placeholder(tf.float32, shape=(2, 3), name='1') + g1_input = tf.compat.v1.placeholder(tf.float32, shape=(2, 3), name='0') + g2_input = tf.compat.v1.placeholder(tf.float32, shape=(2, 3), name='1') with tf.control_dependencies( - [tf.assert_less_equal(g1_input, g2_input)]): + [tf.compat.v1.assert_less_equal(g1_input, g2_input)]): tf.add(g1_input, g2_input, name='add1') @tf_test def batchmatmul_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 2, 8, 4), name='0') - g2_input = tf.placeholder(tf.float32, shape=(1, 2, 4, 8), name='1') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 2, 8, 4), + name='0') + g2_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 2, 4, 8), + name='1') tf.matmul(g1_input, g2_input, transpose_a=True, @@ -55,41 +95,83 @@ def batchmatmul_test(g1): @tf_test def batchnorm_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 16, 16, 32), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 16, 16, 32), + name='0') g1_scale = tf.constant(1.0, dtype=tf.float32, shape=[32], name='1') - g1_offset = tf.placeholder(tf.float32, shape=(32), name='2') - g1_mean = tf.placeholder(tf.float32, shape=(32), name='3') - g1_variance = tf.placeholder(tf.float32, shape=(32), name='4') - tf.nn.fused_batch_norm(g1_input, - g1_scale, - g1_offset, - g1_mean, - g1_variance, - epsilon=0.00001, - is_training=False, - name='batchnorm1') + g1_offset = tf.compat.v1.placeholder(tf.float32, shape=(32), name='2') + g1_mean = tf.compat.v1.placeholder(tf.float32, shape=(32), name='3') + g1_variance = tf.compat.v1.placeholder(tf.float32, + shape=(32), + name='4') + tf.compat.v1.nn.fused_batch_norm(x=g1_input, + scale=g1_scale, + offset=g1_offset, + mean=g1_mean, + variance=g1_variance, + epsilon=0.00001, + is_training=False, + name='batchnorm1') + + +@tf_test +def batchnormv3_test(g1): + with g1.as_default(): + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 16, 16, 32), + name='0') + g1_scale = tf.constant(1.0, dtype=tf.float32, shape=[32], name='1') + g1_offset = tf.compat.v1.placeholder(tf.float32, shape=(32), name='2') + g1_mean = tf.compat.v1.placeholder(tf.float32, shape=(32), name='3') + g1_variance = tf.compat.v1.placeholder(tf.float32, + shape=(32), + name='4') + tf.raw_ops.FusedBatchNormV3(x=g1_input, + scale=g1_scale, + offset=g1_offset, + mean=g1_mean, + variance=g1_variance, + epsilon=0.00001, + is_training=False, + name='batchnorm1') @tf_test def biasadd_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 1, 1, 500), name='0') - g2_input = tf.placeholder(tf.float32, shape=(500), name='1') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 1, 1, 500), + name='0') + g2_input = tf.compat.v1.placeholder(tf.float32, shape=(500), name='1') tf.nn.bias_add(g1_input, g2_input, name='bias_add1') +@tf_test +def biasadd_scalar_test(g1): + with g1.as_default(): + g1_input = tf.compat.v1.placeholder(tf.float32, shape=(1, 1), name='0') + g2_const = tf.constant(1.0, tf.float32, shape=(1, ), name='1') + tf.nn.bias_add(g1_input, g2_const, name='bias_add1') + + @tf_test def cast_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 3, 16, 16), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 3, 16, 16), + name='0') tf.cast(g1_input, dtype=tf.int32, name='cast1') @tf_test def concat_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(4, 7, 3), name='0') - g2_input = tf.placeholder(tf.float32, shape=(4, 2, 3), name='1') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(4, 7, 3), + name='0') + g2_input = tf.compat.v1.placeholder(tf.float32, + shape=(4, 2, 3), + name='1') tf.concat([g1_input, g2_input], axis=1, name='concat1') @@ -102,7 +184,9 @@ def const_test(g1): @tf_test def conv_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 16, 16, 3), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 16, 16, 3), + name='0') g1_weights = tf.constant(value=1.0, dtype=tf.float32, shape=(3, 3, 3, 32), @@ -110,46 +194,133 @@ def conv_test(g1): tf.nn.conv2d(g1_input, g1_weights, [1, 1, 1, 1], "SAME", name='conv1') +@tf_test +def conv_add_test(g1): + with g1.as_default(): + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 16, 16, 3), + name='0') + g1_weights = tf.constant(value=1.0, + dtype=tf.float32, + shape=(3, 3, 3, 32), + name='1') + conv = tf.nn.conv2d(g1_input, + g1_weights, [1, 1, 1, 1], + "SAME", + name='conv1') + tf.add(conv, conv, name='add1') + + +@tf_test +def conv_batch_test(g1): + with g1.as_default(): + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(None, 16, 16, 3), + name='0') + g1_weights = tf.constant(value=1.0, + dtype=tf.float32, + shape=(3, 3, 3, 32), + name='1') + tf.nn.conv2d(g1_input, g1_weights, [1, 1, 1, 1], "SAME", name='conv1') + + +@tf_test +def conv_nchw_test(g1): + with g1.as_default(): + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 3, 16, 16), + name='0') + g1_weights = tf.constant(value=1.0, + dtype=tf.float32, + shape=(3, 3, 3, 32), + name='1') + tf.nn.conv2d(g1_input, + g1_weights, [1, 1, 1, 1], + "SAME", + data_format='NCHW', + name='conv1') + + +@tf_test +def conv_relu_test(g1): + with g1.as_default(): + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 16, 16, 3), + name='0') + g1_weights = tf.constant(value=1.0, + dtype=tf.float32, + shape=(3, 3, 3, 32), + name='1') + conv = tf.nn.conv2d(g1_input, + g1_weights, [1, 1, 1, 1], + "SAME", + name='conv1') + tf.nn.relu(conv, name='relu1') + + +@tf_test +def conv_relu6_test(g1): + with g1.as_default(): + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 16, 16, 3), + name='0') + g1_weights = tf.constant(value=1.0, + dtype=tf.float32, + shape=(3, 3, 3, 32), + name='1') + conv = tf.nn.conv2d(g1_input, + g1_weights, [1, 1, 1, 1], + "SAME", + name='conv1') + tf.nn.relu6(conv, name='relu1') + + @tf_test def depthwiseconv_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 16, 16, 3), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 16, 16, 3), + name='0') g1_weights = tf.constant(value=1.0, dtype=tf.float32, shape=(3, 3, 3, 1), name='1') - tf.nn.depthwise_conv2d_native(g1_input, - g1_weights, [1, 1, 1, 1], - "SAME", - name='depthwiseconv1') + tf.compat.v1.nn.depthwise_conv2d_native(g1_input, + g1_weights, [1, 1, 1, 1], + "SAME", + name='depthwiseconv1') @tf_test def expanddims_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(2, 3, 4), name='0') - tf.expand_dims(g1_input, axis=-1, name='expanddims_neg') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(2, 3, 4), + name='0') + tf.expand_dims(g1_input, axis=0, name='expanddims_neg') @tf_test def gather_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(2, 4), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, shape=(2, 4), name='0') tf.gather(g1_input, [1, 1], axis=1, name='gather1') @tf_test def identity_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 3, 16, 16), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 3, 16, 16), + name='0') tf.identity(g1_input, 'identity') @tf_test def matmul_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(8, 4), name='0') - g2_input = tf.placeholder(tf.float32, shape=(4, 8), name='1') + g1_input = tf.compat.v1.placeholder(tf.float32, shape=(8, 4), name='0') + g2_input = tf.compat.v1.placeholder(tf.float32, shape=(4, 8), name='1') tf.matmul(g1_input, g2_input, transpose_a=True, @@ -160,7 +331,9 @@ def matmul_test(g1): @tf_test def mean_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 3, 16, 16), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 3, 16, 16), + name='0') tf.math.reduce_mean(g1_input, axis=(2, 3), keepdims=True, name='mean1') tf.math.reduce_mean(g1_input, axis=(2, 3), @@ -171,8 +344,9 @@ def mean_test(g1): @tf_test def mean_test_nhwc(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 16, 16, 3), name='0') - tf.math.reduce_mean(g1_input, axis=(1, 2), keepdims=True, name='mean1') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 16, 16, 3), + name='0') tf.math.reduce_mean(g1_input, axis=(1, 2), keepdims=False, @@ -182,101 +356,177 @@ def mean_test_nhwc(g1): @tf_test def mul_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 1, 1, 16), name='0') - g2_input = tf.placeholder(tf.float32, shape=(1, 1, 1, 16), name='1') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 1, 1, 16), + name='0') + g2_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 1, 1, 16), + name='1') tf.multiply(g1_input, g2_input, name='mul1') +@tf_test +def multi_output_test(g1): + with g1.as_default(): + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 3, 16, 16), + name='0') + tf.nn.relu(g1_input, 'relu') + tf.tanh(g1_input, 'tanh') + + +@tf_test +def noop_test(g1): + with g1.as_default(): + tf.raw_ops.NoOp(name='noop1') + + +@tf_test +def onehot_test(g1): + with g1.as_default(): + g1_input = tf.constant((1, 1, 1, 1, 1), dtype=tf.int32) + tf.one_hot(g1_input, 2, name='onehot1') + + @tf_test def pack_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(2), name='0') - g2_input = tf.placeholder(tf.float32, shape=(2), name='1') - g3_input = tf.placeholder(tf.float32, shape=(2), name='2') + g1_input = tf.compat.v1.placeholder(tf.float32, shape=(2), name='0') + g2_input = tf.compat.v1.placeholder(tf.float32, shape=(2), name='1') + g3_input = tf.compat.v1.placeholder(tf.float32, shape=(2), name='2') tf.stack([g1_input, g2_input, g3_input], axis=1, name='pack1') @tf_test def pack_test_nhwc(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 1, 1, 2), name='0') - g2_input = tf.placeholder(tf.float32, shape=(1, 1, 1, 2), name='1') - g3_input = tf.placeholder(tf.float32, shape=(1, 1, 1, 2), name='2') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 1, 1, 2), + name='0') + g2_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 1, 1, 2), + name='1') + g3_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 1, 1, 2), + name='2') tf.stack([g1_input, g2_input, g3_input], axis=3, name='pack1') +@tf_test +def pad_test(g1): + with g1.as_default(): + g1_input = tf.compat.v1.placeholder(tf.float32, shape=(2, 4), name='0') + paddings = tf.constant([[1, 1], [2, 2]]) + + tf.pad(g1_input, paddings, name='pad1') + + @tf_test def pooling_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 16, 16, 3), name='0') - tf.nn.avg_pool(value=g1_input, - ksize=(1, 2, 2, 1), - strides=(1, 2, 2, 1), - padding='VALID', - data_format='NHWC', - name='avg_pooling') - tf.nn.max_pool(value=g1_input, - ksize=(1, 2, 2, 1), - strides=(1, 2, 2, 1), - padding='VALID', - data_format='NHWC', - name='max_pooling') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 16, 16, 3), + name='0') + tf.compat.v1.nn.avg_pool(value=g1_input, + ksize=(1, 2, 2, 1), + strides=(1, 2, 2, 1), + padding='VALID', + data_format='NHWC', + name='avg_pooling') + tf.compat.v1.nn.max_pool(value=g1_input, + ksize=(1, 2, 2, 1), + strides=(1, 2, 2, 1), + padding='VALID', + data_format='NHWC', + name='max_pooling') @tf_test def pow_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 2, 2, 3), name='0') - g2_input = tf.placeholder(tf.float32, shape=(1, 2, 2, 3), name='1') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 2, 2, 3), + name='0') + g2_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 2, 2, 3), + name='1') tf.pow(g1_input, g2_input, name='pow1') @tf_test def relu_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 3, 16, 16), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 3, 16, 16), + name='0') tf.nn.relu(g1_input, 'relu') @tf_test def relu6_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 3, 16, 16), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 3, 16, 16), + name='0') + tf.nn.relu6(g1_input, 'relu6') + + +@tf_test +def relu6_mismatch_test(g1): + with g1.as_default(): + g1_input = tf.compat.v1.placeholder(tf.float16, + shape=(1, 3, 13, 37), + name='0') tf.nn.relu6(g1_input, 'relu6') @tf_test def reshape_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(16), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, shape=(16), name='0') tf.reshape(g1_input, (1, 1, 1, 16), 'reshape') @tf_test def rsqrt_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 3, 16, 16), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 3, 16, 16), + name='0') tf.math.rsqrt(g1_input, 'rsqrt') +@tf_test +def shape_test(g1): + with g1.as_default(): + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 3, 16, 16), + name='0') + g1.create_op(op_type='Shape', inputs=[g1_input]) + + @tf_test def slice_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(5, 10), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(5, 10), + name='0') tf.slice(g1_input, [1, 0], [2, -1], name='slice1') @tf_test def softmax_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 3), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, shape=(1, 3), name='0') tf.nn.softmax(g1_input, name='softmax') @tf_test def split_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(5, 30), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(5, 30), + name='0') split0, split1, split2 = tf.split(g1_input, 3, 1, name='split') tf.concat([split0, split1], axis=1, name='concat1') tf.concat([split1, split2], axis=1, name='concat2') @@ -285,14 +535,18 @@ def split_test(g1): @tf_test def split_test_one_output(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(5, 30), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(5, 30), + name='0') tf.split(g1_input, 1, 1, name='split') @tf_test def split_test_vector_as_input(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(5, 30), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(5, 30), + name='0') split0, split1, split2 = tf.split(g1_input, [4, 15, 11], 1, name='split') @@ -303,29 +557,39 @@ def split_test_vector_as_input(g1): @tf_test def sqdiff_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 2, 2, 3), name='0') - g2_input = tf.placeholder(tf.float32, shape=(1, 2, 2, 3), name='1') - tf.squared_difference(g1_input, g2_input, name='sqdiff') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 2, 2, 3), + name='0') + g2_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 2, 2, 3), + name='1') + tf.compat.v1.squared_difference(g1_input, g2_input, name='sqdiff') @tf_test def squeeze_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 2, 3, 1), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 2, 3, 1), + name='0') tf.squeeze(g1_input, name='squeeze') @tf_test def stopgradient_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 3, 16, 16), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 3, 16, 16), + name='0') tf.stop_gradient(g1_input, 'stopgradient') @tf_test def stridedslice_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 1, 1, 10), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 1, 1, 10), + name='0') tf.strided_slice(g1_input, [0, 0, 0, 0], [1, 1, 1, 5], [1, 1, 1, 1], shrink_axis_mask=2, name='stridedslice1') @@ -334,7 +598,9 @@ def stridedslice_test(g1): @tf_test def stridedslice_masks_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 3, 3, 10), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 3, 3, 10), + name='0') tf.strided_slice(g1_input, [0, 1, 1, 0], [0, 0, 0, 0], [1, 1, 1, 1], begin_mask=9, end_mask=15, @@ -344,20 +610,96 @@ def stridedslice_masks_test(g1): @tf_test def sub_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 2, 2, 3), name='0') - g2_input = tf.placeholder(tf.float32, shape=(1, 2, 2, 3), name='1') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 2, 2, 3), + name='0') + g2_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 2, 2, 3), + name='1') tf.subtract(g1_input, g2_input, name='sub1') @tf_test def tanh_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 3, 16, 16), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 3, 16, 16), + name='0') tf.tanh(g1_input, 'tanh') @tf_test def transpose_test(g1): with g1.as_default(): - g1_input = tf.placeholder(tf.float32, shape=(1, 3, 16, 16), name='0') + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(1, 3, 16, 16), + name='0') tf.transpose(g1_input, perm=[0, 2, 3, 1], name='transpose') + + +@tf_test +def variable_batch_test(g1): + with g1.as_default(): + g1_input = tf.compat.v1.placeholder(tf.float32, + shape=(0, 3, 16, 16), + name='0') + tf.identity(g1_input, name='identity') + + +if __name__ == '__main__': + add_test() + addv2_test() + add_bcast_test() + argmax_test() + argmin_test() + assert_less_equal_test() + batchmatmul_test() + batchnorm_test() + batchnormv3_test() + biasadd_test() + biasadd_scalar_test() + cast_test() + concat_test() + const_test() + conv_test() + conv_add_test() + conv_batch_test() + conv_nchw_test() + conv_relu_test() + conv_relu6_test() + depthwiseconv_test() + expanddims_test() + gather_test() + identity_test() + matmul_test() + mean_test() + mean_test_nhwc() + mul_test() + multi_output_test() + noop_test() + onehot_test() + pack_test() + pack_test_nhwc() + pad_test() + pooling_test() + pow_test() + relu_test() + relu6_test() + relu6_mismatch_test() + reshape_test() + rsqrt_test() + shape_test() + slice_test() + softmax_test() + split_test() + split_test_one_output() + split_test_vector_as_input() + sqdiff_test() + squeeze_test() + stopgradient_test() + stridedslice_test() + stridedslice_masks_test() + sub_test() + tanh_test() + transpose_test() + variable_batch_test() diff --git a/test/tf/identity_test.pb b/test/tf/identity_test.pb index 5e7081b5b85ca2fa49cbef8839117cb890cafff9..2878cddf21fbb282929861b719834a322e861a4e 100644 --- a/test/tf/identity_test.pb +++ b/test/tf/identity_test.pb @@ -1,8 +1,8 @@ : -0 Placeholder* -shape:* -dtype0 +0 Placeholder* +dtype0* +shape: identityIdentity0* -T0" \ No newline at end of file +T0"¸ \ No newline at end of file diff --git a/test/tf/matmul_test.pb b/test/tf/matmul_test.pb index 10e4a499829457bf80c9b8ee23d60f686894ebab..81cbadc0615247afa2e2dc5bb2ff24ee34e8fd00 100644 --- a/test/tf/matmul_test.pb +++ b/test/tf/matmul_test.pb @@ -13,4 +13,4 @@ F matmul1MatMul01* T0* transpose_a(* - transpose_b(" \ No newline at end of file + transpose_b("¸ \ No newline at end of file diff --git a/test/tf/mean_test.pb b/test/tf/mean_test.pb index 32ede18ba260bd4da7b226a7e1f5b01803c184b7..470ad52fe1b4f51e07bff457058d13d8b7ca95d9 100644 Binary files a/test/tf/mean_test.pb and b/test/tf/mean_test.pb differ diff --git a/test/tf/mean_test_nhwc.pb b/test/tf/mean_test_nhwc.pb index 35e427f7a71e1666fcec8ed5b7ae06487fb6b8cb..9a11f8f14c43b93545414bcbf4ab3add89aa580a 100644 Binary files a/test/tf/mean_test_nhwc.pb and b/test/tf/mean_test_nhwc.pb differ diff --git a/test/tf/mul_test.pb b/test/tf/mul_test.pb index 405acb06c99a885ffe5e6abed2c51f29e7acf9d1..6fd4d845abe9100a545f1dbe6a30bd09fdd7e351 100644 --- a/test/tf/mul_test.pb +++ b/test/tf/mul_test.pb @@ -1,12 +1,12 @@ : -0 Placeholder* -shape:* -dtype0 +0 Placeholder* +dtype0* +shape: : 1 Placeholder* dtype0* shape:  mul1Mul01* -T0" \ No newline at end of file +T0"¸ \ No newline at end of file diff --git a/test/tf/multi_output_test.pb b/test/tf/multi_output_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..717a59d3c336876aff0886b95eaf085d4f48682e --- /dev/null +++ b/test/tf/multi_output_test.pb @@ -0,0 +1,11 @@ + +: +0 Placeholder* +dtype0* +shape: + +reluRelu0* +T0 + +tanhTanh0* +T0"& \ No newline at end of file diff --git a/test/tf/noop_test.pb b/test/tf/noop_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..cf4ae5659a43762b5852e194cc137b0845a4ca8b --- /dev/null +++ b/test/tf/noop_test.pb @@ -0,0 +1,3 @@ + + +noop1NoOp"¸ \ No newline at end of file diff --git a/test/tf/onehot_test.pb b/test/tf/onehot_test.pb index 7d9a63ddecf2096a87f0c55a4171dd165b3aa12c..4cc54765e1cb359631f57d9d4d49d7eeea6721be 100644 Binary files a/test/tf/onehot_test.pb and b/test/tf/onehot_test.pb differ diff --git a/test/tf/pack_test.pb b/test/tf/pack_test.pb index 70ffa9cf88df26c6995bd75d6dac38625baf5795..fcbdd40178354637cf4e1bf8de22051387123ce5 100644 --- a/test/tf/pack_test.pb +++ b/test/tf/pack_test.pb @@ -1,8 +1,8 @@ . -0 Placeholder* -shape:* -dtype0 +0 Placeholder* +dtype0* +shape: . 1 Placeholder* dtype0* @@ -13,7 +13,7 @@ shape: 4 pack1Pack012* +N* T0* -axis* -N" \ No newline at end of file +axis"¸ \ No newline at end of file diff --git a/test/tf/pack_test_nhwc.pb b/test/tf/pack_test_nhwc.pb index 4314cc93ae588bbb41b38e0719ed254423e47c6a..43d18488d3483123d570331fcfcbfb2eaea4e0a0 100644 --- a/test/tf/pack_test_nhwc.pb +++ b/test/tf/pack_test_nhwc.pb @@ -13,7 +13,7 @@ shape: 4 pack1Pack012* +N* T0* -axis* -N" \ No newline at end of file +axis"¸ \ No newline at end of file diff --git a/test/tf/pad_test.pb b/test/tf/pad_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..c2411dba66e4f3be114b43f42ded26d270d0c6dc Binary files /dev/null and b/test/tf/pad_test.pb differ diff --git a/test/tf/pooling_test.pb b/test/tf/pooling_test.pb index 912523bcf558a9f535efe47de161cda1159b516a..5d29d7bda4ef7f1e2c46af180cd510f19b66b30c 100644 --- a/test/tf/pooling_test.pb +++ b/test/tf/pooling_test.pb @@ -4,20 +4,20 @@ dtype0* shape: u - avg_poolingAvgPool0* + avg_poolingAvgPool0* +T0* + data_formatNHWC* ksize * -paddingVALID* -T0* - data_formatNHWC* +paddingVALID* strides  u - max_poolingMaxPool0* - data_formatNHWC* -strides -* + max_poolingMaxPool0* +T0* + data_formatNHWC* ksize * -paddingVALID* -T0" \ No newline at end of file +paddingVALID* +strides +"¸ \ No newline at end of file diff --git a/test/tf/pow_test.pb b/test/tf/pow_test.pb index d39ca5df89e99c6c893b599139f8bfdb648ced63..65fd955182c584b93d995e47cd0ae57fcff8cfb8 100644 --- a/test/tf/pow_test.pb +++ b/test/tf/pow_test.pb @@ -9,4 +9,4 @@ shape:  pow1Pow01* -T0" \ No newline at end of file +T0"¸ \ No newline at end of file diff --git a/test/tf/relu6_mismatch_test.pb b/test/tf/relu6_mismatch_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..64d952519fc1aa6ee2a958aef1b7d5a9e85e2937 --- /dev/null +++ b/test/tf/relu6_mismatch_test.pb @@ -0,0 +1,8 @@ + +: +0 Placeholder* +dtype0* +shape: % + +relu6Relu60* +T0"‚ \ No newline at end of file diff --git a/test/tf/relu6_test.pb b/test/tf/relu6_test.pb index 6644205409d6502e54d2f562806ff6fa637fba91..f5651ad51cabf5b5c4d82654ccd6570e123e69a2 100644 --- a/test/tf/relu6_test.pb +++ b/test/tf/relu6_test.pb @@ -5,4 +5,4 @@ shape:  relu6Relu60* -T0" \ No newline at end of file +T0"¸ \ No newline at end of file diff --git a/test/tf/relu_test.pb b/test/tf/relu_test.pb index 7437e0ac1608a0f5473049f47af5b4641a60e361..c2dd6c580895a04cf73d0573dd53276b8749929a 100644 --- a/test/tf/relu_test.pb +++ b/test/tf/relu_test.pb @@ -5,4 +5,4 @@ shape:  reluRelu0* -T0" \ No newline at end of file +T0"¸ \ No newline at end of file diff --git a/test/tf/reshape_test.pb b/test/tf/reshape_test.pb index 905ace81b669135e916b713a33e997cb6e33eb03..7a32b964fd6e0a1a1ab723a7541c7d160ada7e66 100644 Binary files a/test/tf/reshape_test.pb and b/test/tf/reshape_test.pb differ diff --git a/test/tf/rsqrt_test.pb b/test/tf/rsqrt_test.pb index 808d1d74a5676d94dfa3b9c91961ab54f2e123b4..9b876c6b6b5e217fe655d2c0292bfbd85d8f3d01 100644 --- a/test/tf/rsqrt_test.pb +++ b/test/tf/rsqrt_test.pb @@ -1,8 +1,8 @@ : -0 Placeholder* -shape:* -dtype0 +0 Placeholder* +dtype0* +shape:  rsqrtRsqrt0* -T0" \ No newline at end of file +T0"¸ \ No newline at end of file diff --git a/test/tf/shape_test.pb b/test/tf/shape_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..ef5e1bf076b66fe4834250ac15b97fa3aa07b5df --- /dev/null +++ b/test/tf/shape_test.pb @@ -0,0 +1,9 @@ + +: +0 Placeholder* +dtype0* +shape: +* +ShapeShape0* +T0* +out_type0"¸ \ No newline at end of file diff --git a/test/tf/slice_test.pb b/test/tf/slice_test.pb index 9549d2db16f829c37ede23fbd448be8df6ce4507..a880db5d12d37e726f8b950dfdf5c4d11861c608 100644 Binary files a/test/tf/slice_test.pb and b/test/tf/slice_test.pb differ diff --git a/test/tf/softmax_test.pb b/test/tf/softmax_test.pb index 138a0c74b1d9595930fb27c45fe597b8a603f9d0..5c733891f77a35d74b41f3abf8caf3efcaf545e5 100644 --- a/test/tf/softmax_test.pb +++ b/test/tf/softmax_test.pb @@ -6,4 +6,4 @@ :  softmaxSoftmax0* -T0" \ No newline at end of file +T0"¸ \ No newline at end of file diff --git a/test/tf/split_test.pb b/test/tf/split_test.pb index 365897ef511abe23776471b1b69c5cd47758d472..673202bb4900d612bfd91e6b115d738b4330f84b 100644 Binary files a/test/tf/split_test.pb and b/test/tf/split_test.pb differ diff --git a/test/tf/split_test_one_output.pb b/test/tf/split_test_one_output.pb index c7c8e568a813dced522e1e2b0d6274ff366d4e31..f6c759ec8216c3f38f9f3b1b62543435b517e5f0 100644 Binary files a/test/tf/split_test_one_output.pb and b/test/tf/split_test_one_output.pb differ diff --git a/test/tf/split_test_vector_as_input.pb b/test/tf/split_test_vector_as_input.pb index afe0693f989381f7b072812ef2fcf1290d63e723..87f1ea1bd551396f233f21059790de9c9dca46aa 100644 Binary files a/test/tf/split_test_vector_as_input.pb and b/test/tf/split_test_vector_as_input.pb differ diff --git a/test/tf/sqdiff_test.pb b/test/tf/sqdiff_test.pb index 8fa95843bc9e2d372d8eea59595d6e6441f7fdbc..c598143390524acf9a908b0a93d00230dbcc6164 100644 --- a/test/tf/sqdiff_test.pb +++ b/test/tf/sqdiff_test.pb @@ -9,4 +9,4 @@ shape: * sqdiffSquaredDifference01* -T0" \ No newline at end of file +T0"¸ \ No newline at end of file diff --git a/test/tf/squeeze_test.pb b/test/tf/squeeze_test.pb index 1f8f9aa851fa58240b942a7e206e43184498b7d9..97ee60bacae257d50212536e27ae98081ef730f6 100644 Binary files a/test/tf/squeeze_test.pb and b/test/tf/squeeze_test.pb differ diff --git a/test/tf/stopgradient_test.pb b/test/tf/stopgradient_test.pb index 06f40fdbb5125984bc7041697d185cfdc1156625..fe99dcdb09e2b1ccd527ae5c5ed42734fa2e11de 100644 --- a/test/tf/stopgradient_test.pb +++ b/test/tf/stopgradient_test.pb @@ -5,4 +5,4 @@ shape: ( stopgradient StopGradient0* -T0" \ No newline at end of file +T0"¸ \ No newline at end of file diff --git a/test/tf/stridedslice_masks_test.pb b/test/tf/stridedslice_masks_test.pb index 0769e16c5aec299047ddfb81e97e4d46e93cc92e..8141b02c00174c5d0283212839ac7f3b7c606349 100644 Binary files a/test/tf/stridedslice_masks_test.pb and b/test/tf/stridedslice_masks_test.pb differ diff --git a/test/tf/stridedslice_test.pb b/test/tf/stridedslice_test.pb index f7c8df43573ebda2bc87abfd596ba32ee00edf6e..1ecd7fece03d056ac70048f13ca097cc321c69c3 100644 Binary files a/test/tf/stridedslice_test.pb and b/test/tf/stridedslice_test.pb differ diff --git a/test/tf/sub_test.pb b/test/tf/sub_test.pb index fe4d4255fc7a8e32c0af65b49acb7dfe950acb3a..0a463e46f938f5ec3e5809631e2e14498cf3653f 100644 --- a/test/tf/sub_test.pb +++ b/test/tf/sub_test.pb @@ -1,12 +1,12 @@ : -0 Placeholder* -shape:* -dtype0 +0 Placeholder* +dtype0* +shape: : -1 Placeholder* -shape:* -dtype0 +1 Placeholder* +dtype0* +shape:  sub1Sub01* -T0" \ No newline at end of file +T0"¸ \ No newline at end of file diff --git a/test/tf/tanh_test.pb b/test/tf/tanh_test.pb index ccf6d647646b0c13029692dd2a0c1370992c8a93..831e6a3126ab431dd25ac53ad06ab781d2b3b553 100644 --- a/test/tf/tanh_test.pb +++ b/test/tf/tanh_test.pb @@ -5,4 +5,4 @@ shape:  tanhTanh0* -T0" \ No newline at end of file +T0"¸ \ No newline at end of file diff --git a/test/tf/tf_test.cpp b/test/tf/tf_test.cpp index f0a18be122d130c4f96387a48b142e1ef6d2c027..1fd9198ecfca5837cfd6960d3a35bb9aae18fd3c 100644 --- a/test/tf/tf_test.cpp +++ b/test/tf/tf_test.cpp @@ -1,64 +1,142 @@ #include #include +#include #include #include #include #include #include -#include #include #include #include +#include +#include +#include +#include +#include +#include + +#include + #include "test.hpp" +migraphx::program +parse_tf(const std::string& name, + bool is_nhwc, + const std::unordered_map>& dim_params = {}, + const std::vector& output_node_names = {}) +{ + return migraphx::parse_tf(name, + migraphx::tf_options{is_nhwc, 1, dim_params, output_node_names}); +} + migraphx::program optimize_tf(const std::string& name, bool is_nhwc) { - auto prog = migraphx::parse_tf(name, is_nhwc); + auto prog = migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1}); + auto* mm = prog.get_main_module(); if(is_nhwc) - migraphx::run_passes(prog, + migraphx::run_passes(*mm, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}, migraphx::eliminate_identity{}}); + + // remove the last return instruction + auto last_ins = std::prev(mm->end()); + if(last_ins != mm->end()) + { + if(last_ins->name() == "@return") + { + mm->remove_instruction(last_ins); + } + } return prog; } TEST_CASE(add_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - p.add_instruction(migraphx::op::add{}, l0, l1); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + mm->add_instruction(migraphx::make_op("add"), l0, l1); auto prog = optimize_tf("add_test.pb", false); EXPECT(p == prog); } +TEST_CASE(addv2_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + mm->add_instruction(migraphx::make_op("add"), l0, l1); + auto prog = optimize_tf("addv2_test.pb", false); + + EXPECT(p == prog); +} + TEST_CASE(add_bcast_test) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape s0{migraphx::shape::float_type, {2, 3}}; - auto l0 = p.add_parameter("0", s0); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}}); - auto l2 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l0); - auto l3 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1); - p.add_instruction(migraphx::op::add{}, l2, l3); + auto l0 = mm->add_parameter("0", s0); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}}); + auto l2 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1); + mm->add_instruction(migraphx::make_op("add"), l0, l2); auto prog = optimize_tf("add_bcast_test.pb", false); EXPECT(p == prog); } +TEST_CASE(argmax_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 5, 6, 7}}); + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}}); + auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0); + auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins); + mm->add_return({l1}); + auto prog = parse_tf("argmax_test.pb", false, {{"0", {4, 5, 6, 7}}}); + + EXPECT(p == prog); +} + +TEST_CASE(argmin_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}}); + auto ins = mm->add_instruction(migraphx::make_op("argmin", {{"axis", 2}}), l0); + auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins); + mm->add_return({l1}); + auto prog = parse_tf("argmin_test.pb", false); + + EXPECT(p == prog); +} + TEST_CASE(assert_less_equal_test) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape s0{migraphx::shape::float_type, {2, 3}}; - auto l0 = p.add_parameter("0", s0); - auto l1 = p.add_parameter("1", s0); + auto l0 = mm->add_parameter("0", s0); + auto l1 = mm->add_parameter("1", s0); migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {0, 1}}; - auto l2 = p.add_literal(l); - p.add_instruction(migraphx::op::add{}, l0, l1); - auto l3 = p.add_instruction(migraphx::op::identity{}, l0, l1); - p.add_instruction(migraphx::op::identity{}, l3, l2); + auto l2 = mm->add_literal(l); + mm->add_instruction(migraphx::make_op("add"), l0, l1); + auto l3 = mm->add_instruction(migraphx::make_op("identity"), l0, l1); + mm->add_instruction(migraphx::make_op("identity"), l3, l2); auto prog = optimize_tf("assert_less_equal_test.pb", false); EXPECT(p == prog); @@ -67,13 +145,17 @@ TEST_CASE(assert_less_equal_test) TEST_CASE(batchmatmul_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 4}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}}); - auto trans_l0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0); - auto trans_l1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 4}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}}); + + auto trans_l0 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l0); + auto trans_l1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1); - p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1); + mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1); auto prog = optimize_tf("batchmatmul_test.pb", false); EXPECT(p == prog); @@ -85,42 +167,94 @@ TEST_CASE(batchnorm_test) float momentum = 0.9f; migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::op::batch_norm_inference op{ epsilon, momentum, migraphx::op::batch_norm_inference::spatial}; migraphx::shape s0{migraphx::shape::float_type, {32}}; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}}); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}}); std::vector const_vals(32); std::fill(const_vals.begin(), const_vals.end(), 1.0f); - auto l2 = p.add_parameter("2", s0); - auto l3 = p.add_parameter("3", s0); - auto l4 = p.add_parameter("4", s0); - auto l1 = p.add_literal(migraphx::literal{s0, const_vals}); - p.add_instruction(op, l0, l1, l2, l3, l4); + auto l2 = mm->add_parameter("2", s0); + auto l3 = mm->add_parameter("3", s0); + auto l4 = mm->add_parameter("4", s0); + auto l1 = mm->add_literal(migraphx::literal{s0, const_vals}); + mm->add_instruction(op, l0, l1, l2, l3, l4); auto prog = optimize_tf("batchnorm_test.pb", true); EXPECT(p == prog); } +TEST_CASE(batchnormv3_test) +{ + float epsilon = 1.0e-5f; + float momentum = 0.9f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::op::batch_norm_inference op{ + epsilon, momentum, migraphx::op::batch_norm_inference::spatial}; + migraphx::shape s0{migraphx::shape::float_type, {32}}; + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}}); + std::vector const_vals(32); + std::fill(const_vals.begin(), const_vals.end(), 1.0f); + + auto l2 = mm->add_parameter("2", s0); + auto l3 = mm->add_parameter("3", s0); + auto l4 = mm->add_parameter("4", s0); + auto l1 = mm->add_literal(migraphx::literal{s0, const_vals}); + mm->add_instruction(op, l0, l1, l2, l3, l4); + auto prog = optimize_tf("batchnormv3_test.pb", true); + + EXPECT(p == prog); +} + TEST_CASE(biasadd_test) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::shape s0{migraphx::shape::float_type, {1, 500, 1, 1}}; uint64_t axis = 1; - auto l0 = p.add_parameter("0", s0); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}}); - auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1); - p.add_instruction(migraphx::op::add{}, l0, l2); + auto l0 = mm->add_parameter("0", s0); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}}); + auto l2 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1); + mm->add_instruction(migraphx::make_op("add"), l0, l2); auto prog = optimize_tf("biasadd_test.pb", true); EXPECT(p == prog); } +TEST_CASE(biasadd_scalar_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::float_type, {1, 1}}; + uint64_t axis = 1; + auto l0 = mm->add_parameter("0", s0); + auto l1 = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}}); + auto l2 = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l0->get_shape().lens()}}), l1); + mm->add_instruction(migraphx::make_op("add"), l0, l2); + auto prog = optimize_tf("biasadd_scalar_test.pb", true); + + EXPECT(p == prog); +} + TEST_CASE(cast_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - p.add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, l0); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), + l0); auto prog = optimize_tf("cast_test.pb", false); EXPECT(p == prog); @@ -130,15 +264,17 @@ TEST_CASE(concat_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 7, 3}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}}); + auto* mm = p.get_main_module(); + + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 7, 3}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}}); int axis = 1; // tf uses axis as the third input, and it is in int32 format // add the literal using a vector in order to set stride to 1 (like in tf parser) - p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector{axis}); + mm->add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector{axis}); - p.add_instruction(migraphx::op::concat{axis}, l0, l1); + mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1); auto prog = optimize_tf("concat_test.pb", false); EXPECT(p == prog); @@ -147,30 +283,89 @@ TEST_CASE(concat_test) TEST_CASE(const_test) { migraphx::program p; - p.add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector{1.0f}); + + auto* mm = p.get_main_module(); + mm->add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector{1.0f}); auto prog = optimize_tf("constant_test.pb", false); EXPECT(p == prog); } -TEST_CASE(conv_test) +migraphx::program create_conv() { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + auto* mm = p.get_main_module(); + + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); std::vector weight_data(3 * 3 * 3 * 32); std::fill(weight_data.begin(), weight_data.end(), 1.0f); auto l1 = - p.add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data); + mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data); migraphx::op::convolution op; op.padding_mode = migraphx::op::padding_mode_t::same; - op.padding = {1, 1}; + op.padding = {1, 1, 1, 1}; op.stride = {1, 1}; op.dilation = {1, 1}; - auto l2 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1); - p.add_instruction(op, l0, l2); - auto prog = optimize_tf("conv_test.pb", true); + auto l2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1); + mm->add_instruction(op, l0, l2); + return p; +} + +TEST_CASE(conv_test) +{ + migraphx::program p = create_conv(); + auto prog = optimize_tf("conv_test.pb", true); + + EXPECT(p == prog); +} + +TEST_CASE(conv_add_test) +{ + migraphx::program p = create_conv(); + auto* mm = p.get_main_module(); + auto l0 = std::prev(mm->end()); + mm->add_instruction(migraphx::make_op("add"), l0, l0); + auto prog = optimize_tf("conv_add_test.pb", true); + + EXPECT(p == prog); +} + +TEST_CASE(conv_nchw_test) +{ + migraphx::program p = create_conv(); + auto prog = optimize_tf("conv_nchw_test.pb", false); + + EXPECT(p == prog); +} + +TEST_CASE(conv_relu_test) +{ + migraphx::program p = create_conv(); + auto* mm = p.get_main_module(); + auto l0 = std::prev(mm->end()); + mm->add_instruction(migraphx::make_op("relu"), l0); + auto prog = optimize_tf("conv_relu_test.pb", true); + + EXPECT(p == prog); +} + +TEST_CASE(conv_relu6_test) +{ + migraphx::program p = create_conv(); + auto* mm = p.get_main_module(); + std::vector input_lens{1, 32, 16, 16}; + auto l0 = std::prev(mm->end()); + auto min_val = mm->add_literal(0.0f); + auto max_val = mm->add_literal(6.0f); + min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + min_val); + max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + max_val); + mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); + auto prog = optimize_tf("conv_relu6_test.pb", true); EXPECT(p == prog); } @@ -179,11 +374,13 @@ TEST_CASE(depthwiseconv_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + auto* mm = p.get_main_module(); + + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); std::vector weight_data(3 * 3 * 3 * 1); std::fill(weight_data.begin(), weight_data.end(), 1.0f); auto l1 = - p.add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data); + mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data); migraphx::op::convolution op; op.padding_mode = migraphx::op::padding_mode_t::same; @@ -191,10 +388,11 @@ TEST_CASE(depthwiseconv_test) op.stride = {1, 1}; op.dilation = {1, 1}; op.group = 3; - auto l3 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1); - auto l4 = p.add_instruction(migraphx::op::contiguous{}, l3); - auto l5 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l4); - p.add_instruction(op, l0, l5); + auto l3 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1); + auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3); + auto l5 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 3, 3}}}), l4); + mm->add_instruction(op, l0, l5); auto prog = optimize_tf("depthwise_conv_test.pb", true); EXPECT(p == prog); @@ -204,9 +402,11 @@ TEST_CASE(expanddims_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}); - p.add_literal(0); - p.add_instruction(migraphx::op::reshape{{1, 2, 3, 4}}, l0); + auto* mm = p.get_main_module(); + + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}); + mm->add_literal(0); + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 3, 4}}}), l0); auto prog = optimize_tf("expanddims_test.pb", false); EXPECT(p == prog); @@ -217,9 +417,11 @@ TEST_CASE(expanddims_test_neg_dims) // this check makes sure the pb parses negative dim value correctly migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}); - p.add_literal(-1); - p.add_instruction(migraphx::op::reshape{{2, 3, 4, 1}}, l0); + auto* mm = p.get_main_module(); + + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}); + mm->add_literal(-1); + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 4, 1}}}), l0); auto prog = optimize_tf("expanddims_neg_test.pb", false); EXPECT(p == prog); @@ -229,13 +431,15 @@ TEST_CASE(gather_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}}); - auto l1 = - p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 1}}); - p.add_literal(1); + auto* mm = p.get_main_module(); + + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}}); + auto l1 = mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 1}}); + mm->add_literal(1); int axis = 1; - p.add_instruction(migraphx::op::gather{axis}, l0, l1); + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l0, l1); auto prog = optimize_tf("gather_test.pb", false); EXPECT(p == prog); @@ -244,8 +448,10 @@ TEST_CASE(gather_test) TEST_CASE(identity_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - p.add_instruction(migraphx::op::identity{}, l0); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_instruction(migraphx::make_op("identity"), l0); auto prog = optimize_tf("identity_test.pb", false); EXPECT(p == prog); @@ -254,13 +460,17 @@ TEST_CASE(identity_test) TEST_CASE(matmul_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}}); - auto trans_l0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0); - auto trans_l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}}); - p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1); + auto trans_l0 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l0); + auto trans_l1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); + + mm->add_instruction(migraphx::make_op("dot"), trans_l0, trans_l1); auto prog = optimize_tf("matmul_test.pb", false); EXPECT(p == prog); @@ -269,14 +479,16 @@ TEST_CASE(matmul_test) TEST_CASE(mean_test) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 3}}; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - p.add_literal(l); - p.add_literal(l); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_literal(l); + mm->add_literal(l); migraphx::op::reduce_mean op{{2, 3}}; - p.add_instruction(op, l0); - auto l3 = p.add_instruction(op, l0); - p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); + mm->add_instruction(op, l0); + auto l3 = mm->add_instruction(op, l0); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), l3); auto prog = optimize_tf("mean_test.pb", false); EXPECT(p == prog); @@ -285,12 +497,15 @@ TEST_CASE(mean_test) TEST_CASE(mean_test_nhwc) { migraphx::program p; + + auto* mm = p.get_main_module(); migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}}; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - auto l1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + auto l1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); migraphx::op::reduce_mean op{{1, 2}}; - auto l2 = p.add_instruction(op, l1); - p.add_instruction(migraphx::op::squeeze{{1, 2}}, l2); + auto l2 = mm->add_instruction(op, l1); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2}}}), l2); auto prog = optimize_tf("mean_test_nhwc.pb", true); EXPECT(p == prog); @@ -299,49 +514,81 @@ TEST_CASE(mean_test_nhwc) TEST_CASE(mul_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}}); - p.add_instruction(migraphx::op::mul{}, l0, l1); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}}); + + mm->add_instruction(migraphx::make_op("mul"), l0, l1); auto prog = optimize_tf("mul_test.pb", false); EXPECT(p == prog); } +TEST_CASE(multi_output_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + auto l1 = mm->add_instruction(migraphx::make_op("relu"), l0); + auto l2 = mm->add_instruction(migraphx::make_op("tanh"), l0); + mm->add_return({l1, l2}); + + EXPECT(test::throws([&] { parse_tf("multi_output_test.pb", false, {}, {"relu", "relu6"}); })); + auto prog = parse_tf("multi_output_test.pb", false, {}, {"relu", "tanh"}); + + EXPECT(p == prog); +} + TEST_CASE(onehot_test) { migraphx::program p; - auto l0 = p.add_literal( + + auto* mm = p.get_main_module(); + auto l0 = mm->add_literal( migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}}); - p.add_literal(2); - p.add_literal(1.0f); - p.add_literal(0.0f); - auto l1 = p.add_literal( + mm->add_literal(2); + mm->add_literal(1.0f); + mm->add_literal(0.0f); + auto l1 = mm->add_literal( migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}}); int axis = 0; - p.add_instruction(migraphx::op::gather{axis}, l1, l0); + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l1, l0); auto prog = optimize_tf("onehot_test.pb", false); EXPECT(p == prog); } +TEST_CASE(noop_test) +{ + migraphx::program p; + auto prog = optimize_tf("noop_test.pb", false); + + EXPECT(p == prog); +} + TEST_CASE(pack_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2}}); - auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {2}}); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2}}); + auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {2}}); std::vector args{l0, l1, l2}; std::vector unsqueezed_args; int64_t axis = 1; - std::transform(args.begin(), - args.end(), - std::back_inserter(unsqueezed_args), - [&](migraphx::instruction_ref arg) { - return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg); - }); - p.add_instruction(migraphx::op::concat{static_cast(axis)}, unsqueezed_args); + std::transform( + args.begin(), + args.end(), + std::back_inserter(unsqueezed_args), + [&](migraphx::instruction_ref arg) { + return mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg); + }); + mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast(axis)}}), + unsqueezed_args); auto prog = optimize_tf("pack_test.pb", false); EXPECT(p == prog); @@ -350,12 +597,17 @@ TEST_CASE(pack_test) TEST_CASE(pack_test_nhwc) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); - auto lt0 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); - auto lt1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l1); - auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); - auto lt2 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l2); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); + auto lt0 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); + auto lt1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l1); + auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); + auto lt2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l2); std::vector args{lt0, lt1, lt2}; std::vector unsqueezed_args; int64_t nchw_axis = 3; @@ -364,27 +616,47 @@ TEST_CASE(pack_test_nhwc) args.end(), std::back_inserter(unsqueezed_args), [&](migraphx::instruction_ref arg) { - return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg); + return mm->add_instruction( + migraphx::make_op("unsqueeze", {{"axes", {nchw_axis}}}), arg); }); - p.add_instruction(migraphx::op::concat{static_cast(nchw_axis)}, unsqueezed_args); + mm->add_instruction(migraphx::make_op("concat", {{"axis", static_cast(nchw_axis)}}), + unsqueezed_args); auto prog = optimize_tf("pack_test_nhwc.pb", true); EXPECT(p == prog); } +TEST_CASE(pad_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}}); + std::vector pad_literals{1, 1, 2, 2}; + std::vector pads{1, 2, 1, 2}; + mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {2, 2}}, pad_literals); + + mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), l0); + auto prog = optimize_tf("pad_test.pb", false); + + EXPECT(p == prog); +} + TEST_CASE(pooling_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - migraphx::op::pooling avg_pool_op{"average"}; - migraphx::op::pooling max_pool_op{"max"}; - avg_pool_op.padding_mode = migraphx::op::padding_mode_t::valid; - max_pool_op.padding_mode = migraphx::op::padding_mode_t::valid; - avg_pool_op.stride = {2, 2}; - max_pool_op.stride = {2, 2}; - avg_pool_op.lengths = {2, 2}; - max_pool_op.lengths = {2, 2}; - p.add_instruction(max_pool_op, l0); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + migraphx::op::pooling avg_pool_op{migraphx::op::pooling_mode::average}; + migraphx::op::pooling max_pool_op{migraphx::op::pooling_mode::max}; + avg_pool_op.stride = {2, 2}; + max_pool_op.stride = {2, 2}; + avg_pool_op.lengths = {2, 2}; + max_pool_op.lengths = {2, 2}; + mm->add_instruction(avg_pool_op, l0); + mm->add_instruction(max_pool_op, l0); auto prog = optimize_tf("pooling_test.pb", true); EXPECT(p == prog); @@ -393,9 +665,11 @@ TEST_CASE(pooling_test) TEST_CASE(pow_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - p.add_instruction(migraphx::op::pow{}, l0, l1); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + mm->add_instruction(migraphx::make_op("pow"), l0, l1); auto prog = optimize_tf("pow_test.pb", false); EXPECT(p == prog); @@ -404,8 +678,10 @@ TEST_CASE(pow_test) TEST_CASE(relu_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - p.add_instruction(migraphx::op::relu{}, l0); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_instruction(migraphx::make_op("relu"), l0); auto prog = optimize_tf("relu_test.pb", false); EXPECT(p == prog); @@ -414,21 +690,57 @@ TEST_CASE(relu_test) TEST_CASE(relu6_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0); + + auto* mm = p.get_main_module(); + std::vector input_lens{1, 3, 16, 16}; + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens}); + auto min_val = mm->add_literal(0.0f); + auto max_val = mm->add_literal(6.0f); + min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + min_val); + max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + max_val); + mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); auto prog = optimize_tf("relu6_test.pb", false); EXPECT(p == prog); } +TEST_CASE(relu6_mismatch_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + std::vector input_lens{1, 3, 13, 37}; + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::half_type, input_lens}); + auto min_val = mm->add_literal(0.0f); + auto max_val = mm->add_literal(6.0f); + + auto l0_convert = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l0); + + min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + min_val); + max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), + max_val); + + mm->add_instruction(migraphx::make_op("clip"), l0_convert, min_val, max_val); + + auto prog = optimize_tf("relu6_mismatch_test.pb", false); + + EXPECT(p == prog); +} + TEST_CASE(reshape_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}}); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}}); migraphx::shape s0{migraphx::shape::int32_type, {4}}; // in tf, the second arg is a literal that contains new dimensions - p.add_literal(migraphx::literal{s0, {1, 1, 1, 16}}); - p.add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0); + mm->add_literal(migraphx::literal{s0, {1, 1, 1, 16}}); + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 1, 16}}}), l0); auto prog = optimize_tf("reshape_test.pb", false); EXPECT(p == prog); @@ -437,28 +749,45 @@ TEST_CASE(reshape_test) TEST_CASE(rsqrt_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - p.add_instruction(migraphx::op::rsqrt{}, l0); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_instruction(migraphx::make_op("rsqrt"), l0); auto prog = optimize_tf("rsqrt_test.pb", false); EXPECT(p == prog); } +TEST_CASE(shape_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 16, 16}}); + auto prog = optimize_tf("shape_test.pb", false); + + EXPECT(p == prog); +} + TEST_CASE(slice_test) { migraphx::program p; + + auto* mm = p.get_main_module(); std::size_t num_axes = 2; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 10}}); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 10}}); migraphx::shape s0{migraphx::shape::int32_type, {num_axes}}; - p.add_literal(migraphx::literal{s0, {1, 0}}); - p.add_literal(migraphx::literal{s0, {2, -1}}); + mm->add_literal(migraphx::literal{s0, {1, 0}}); + mm->add_literal(migraphx::literal{s0, {2, -1}}); migraphx::op::slice op; op.starts = {1, 0}; op.ends = {3, 10}; op.axes = std::vector(num_axes); std::iota(op.axes.begin(), op.axes.end(), 0); - p.add_instruction(op, l0); + mm->add_instruction(op, l0); auto prog = optimize_tf("slice_test.pb", false); EXPECT(p == prog); @@ -467,8 +796,10 @@ TEST_CASE(slice_test) TEST_CASE(softmax_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}}); - p.add_instruction(migraphx::op::softmax{1}, l0); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}}); + mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), l0); auto prog = optimize_tf("softmax_test.pb", false); EXPECT(p == prog); @@ -477,19 +808,24 @@ TEST_CASE(softmax_test) TEST_CASE(split_test) { migraphx::program p; - std::vector axes{0, 1}; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}}); - p.add_literal(3); // num_splits - p.add_literal(1); // split axis - p.add_literal(1); // concat axis - p.add_literal(1); // concat axis - auto l1 = p.add_instruction(migraphx::op::slice{axes, {0, 0}, {5, 10}}, l0); - auto l2 = p.add_instruction(migraphx::op::slice{axes, {0, 10}, {5, 20}}, l0); - auto l3 = p.add_instruction(migraphx::op::slice{axes, {0, 20}, {5, 30}}, l0); - p.add_instruction(migraphx::op::concat{1}, l1, l2); - p.add_instruction(migraphx::op::concat{1}, l2, l3); - auto prog = migraphx::parse_tf("split_test.pb", false); + auto* mm = p.get_main_module(); + std::vector axes{0, 1}; + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}}); + mm->add_literal(3); // num_splits + mm->add_literal(1); // split axis + mm->add_literal(1); // concat axis + mm->add_literal(1); // concat axis + auto l1 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 10}}}), l0); + auto l2 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 10}}, {"ends", {5, 20}}}), l0); + auto l3 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 20}}, {"ends", {5, 30}}}), l0); + auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2); + auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3); + mm->add_return({l4, l5}); + auto prog = parse_tf("split_test.pb", false); EXPECT(p == prog); } @@ -497,12 +833,14 @@ TEST_CASE(split_test) TEST_CASE(split_test_one_output) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}}); - p.add_literal(1); // num_splits - p.add_literal(1); // split axis - p.add_instruction(migraphx::op::identity{}, l0); - auto prog = migraphx::parse_tf("split_test_one_output.pb", false); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}}); + mm->add_literal(1); // num_splits + mm->add_literal(1); // split axis + auto l1 = mm->add_instruction(migraphx::make_op("identity"), l0); + mm->add_return({l1}); + auto prog = parse_tf("split_test_one_output.pb", false); EXPECT(p == prog); } @@ -510,21 +848,26 @@ TEST_CASE(split_test_one_output) TEST_CASE(split_test_vector_as_input) { migraphx::program p; + + auto* mm = p.get_main_module(); std::vector axes{0, 1}; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}}); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}}); // split sizes - p.add_literal( + mm->add_literal( migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {3}}, {4, 15, 11}}); - p.add_literal(1); // split axis - p.add_literal(1); // concat axis - p.add_literal(1); // concat axis - auto l1 = p.add_instruction(migraphx::op::slice{axes, {0, 0}, {5, 4}}, l0); - auto l2 = p.add_instruction(migraphx::op::slice{axes, {0, 4}, {5, 19}}, l0); - auto l3 = p.add_instruction(migraphx::op::slice{axes, {0, 19}, {5, 30}}, l0); - p.add_instruction(migraphx::op::concat{1}, l1, l2); - p.add_instruction(migraphx::op::concat{1}, l2, l3); - - auto prog = migraphx::parse_tf("split_test_vector_as_input.pb", false); + mm->add_literal(1); // split axis + mm->add_literal(1); // concat axis + mm->add_literal(1); // concat axis + auto l1 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 0}}, {"ends", {5, 4}}}), l0); + auto l2 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 4}}, {"ends", {5, 19}}}), l0); + auto l3 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", axes}, {"starts", {0, 19}}, {"ends", {5, 30}}}), l0); + auto l4 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l2); + auto l5 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l3); + mm->add_return({l4, l5}); + auto prog = parse_tf("split_test_vector_as_input.pb", false); EXPECT(p == prog); } @@ -532,9 +875,11 @@ TEST_CASE(split_test_vector_as_input) TEST_CASE(sqdiff_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - p.add_instruction(migraphx::op::sqdiff{}, l0, l1); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + mm->add_instruction(migraphx::make_op("sqdiff"), l0, l1); auto prog = optimize_tf("sqdiff_test.pb", false); EXPECT(p == prog); @@ -543,8 +888,10 @@ TEST_CASE(sqdiff_test) TEST_CASE(squeeze_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}}); - p.add_instruction(migraphx::op::squeeze{{0, 3}}, l0); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}}); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 3}}}), l0); auto prog = optimize_tf("squeeze_test.pb", false); EXPECT(p == prog); @@ -553,8 +900,10 @@ TEST_CASE(squeeze_test) TEST_CASE(stopgradient_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); - p.add_instruction(migraphx::op::identity{}, l0); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_instruction(migraphx::make_op("identity"), l0); auto prog = optimize_tf("stopgradient_test.pb", false); EXPECT(p == prog); @@ -563,17 +912,20 @@ TEST_CASE(stopgradient_test) TEST_CASE(stridedslice_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}}); - auto l1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}}); + auto l1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); std::size_t num_axes = 4; migraphx::op::slice op; op.starts = {0, 0, 0, 0}; op.ends = {1, 1, 1, 5}; op.axes = std::vector(num_axes); std::iota(op.axes.begin(), op.axes.end(), 0); - auto l2 = p.add_instruction(op, l1); + auto l2 = mm->add_instruction(op, l1); auto shrink_axis = 1; - p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l2); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2); auto prog = optimize_tf("stridedslice_test.pb", true); EXPECT(p == prog); @@ -582,7 +934,9 @@ TEST_CASE(stridedslice_test) TEST_CASE(stridedslice_masks_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}}); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}}); std::size_t num_axes = 4; migraphx::op::slice op; op.starts = {0, 1, 1, 0}; @@ -590,14 +944,20 @@ TEST_CASE(stridedslice_masks_test) op.axes = std::vector(num_axes); std::iota(op.axes.begin(), op.axes.end(), 0); // add literals for starts, ends, and strides in tf (NHWC format) - p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector{0, 1, 1, 0}); - p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector{0, 0, 0, 0}); - p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector{1, 1, 1, 1}); + mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, + std::vector{0, 1, 1, 0}); + mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, + std::vector{0, 0, 0, 0}); + mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, + std::vector{1, 1, 1, 1}); - auto l1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0); - auto l2 = p.add_instruction(op, l1); - p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l2); - auto prog = migraphx::parse_tf("stridedslice_masks_test.pb", true); + auto l1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); + auto l2 = mm->add_instruction(op, l1); + auto l3 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2); + mm->add_return({l3}); + auto prog = parse_tf("stridedslice_masks_test.pb", true); EXPECT(p == prog); } @@ -605,10 +965,13 @@ TEST_CASE(stridedslice_masks_test) TEST_CASE(sub_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - p.add_instruction(migraphx::op::sub{}, l0, l1); - auto prog = migraphx::parse_tf("sub_test.pb", false); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); + auto l2 = mm->add_instruction(migraphx::make_op("sub"), l0, l1); + mm->add_return({l2}); + auto prog = parse_tf("sub_test.pb", false); EXPECT(p == prog); } @@ -616,10 +979,12 @@ TEST_CASE(sub_test) TEST_CASE(tanh_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); - p.add_instruction(migraphx::op::sub{}, l0, l1); - auto prog = migraphx::parse_tf("sub_test.pb", false); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + auto l1 = mm->add_instruction(migraphx::make_op("tanh"), l0); + mm->add_return({l1}); + auto prog = parse_tf("tanh_test.pb", false); EXPECT(p == prog); } @@ -627,13 +992,27 @@ TEST_CASE(tanh_test) TEST_CASE(transpose_test) { migraphx::program p; - auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); migraphx::shape s0{migraphx::shape::int32_type, {4}}; - p.add_literal(migraphx::literal{s0, {0, 2, 3, 1}}); - p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0); + mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}}); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); auto prog = optimize_tf("transpose_test.pb", false); EXPECT(p == prog); } +TEST_CASE(variable_batch_test) +{ + migraphx::program p; + + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + mm->add_instruction(migraphx::make_op("identity"), l0); + auto prog = optimize_tf("variable_batch_test.pb", false); + + EXPECT(p == prog); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/tf/transpose_test.pb b/test/tf/transpose_test.pb index 7c32c4eec7f2fb2dc49617bb7ead43d1eb193e10..fba0589ddb75e571c2500434b5dd9239344104ac 100644 Binary files a/test/tf/transpose_test.pb and b/test/tf/transpose_test.pb differ diff --git a/test/tf/variable_batch_test.pb b/test/tf/variable_batch_test.pb new file mode 100644 index 0000000000000000000000000000000000000000..592bc29a662460d83038475867f73a8fa57bf953 Binary files /dev/null and b/test/tf/variable_batch_test.pb differ diff --git a/test/validate.cpp b/test/validate.cpp index b7ca3c38b57e66709b44c8f573eefedd16751600..6d588587261b07a81d1a483f212ea2ad1b19974a 100644 --- a/test/validate.cpp +++ b/test/validate.cpp @@ -7,34 +7,34 @@ TEST_CASE(simple_test) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - p.add_instruction(sum_op{}, one, two); - EXPECT(bool{p.validate() == p.end()}); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(sum_op{}, one, two); + EXPECT(bool{mm->validate() == mm->end()}); auto result = p.eval({}); - EXPECT(result == migraphx::literal{3}); - EXPECT(result != migraphx::literal{4}); + EXPECT(result.back() == migraphx::literal{3}); + EXPECT(result.back() != migraphx::literal{4}); } TEST_CASE(out_of_order) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto ins = p.add_instruction(sum_op{}, one, two); - p.move_instruction(two, p.end()); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto ins = mm->add_instruction(sum_op{}, one, two); + mm->move_instruction(two, mm->end()); EXPECT(bool{p.validate() == ins}); } TEST_CASE(incomplete_args) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto ins = p.add_instruction(sum_op{}, one, two); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto ins = mm->add_instruction(sum_op{}, one, two); ins->clear_arguments(); EXPECT(bool{p.validate() == ins}); } @@ -47,12 +47,12 @@ MIGRAPHX_ROB(access_ins_arguments, TEST_CASE(invalid_args) { migraphx::program p; - - auto one = p.add_literal(1); - auto two = p.add_literal(2); - auto ins = p.add_instruction(sum_op{}, one, two); + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + auto ins = mm->add_instruction(sum_op{}, one, two); access_ins_arguments(*ins).clear(); - EXPECT(bool{p.validate() == p.begin()}); + EXPECT(bool{mm->validate() == mm->begin()}); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/value_test.cpp b/test/value_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..468450cd225f3d269cd7f189d859bccc0cb7701a --- /dev/null +++ b/test/value_test.cpp @@ -0,0 +1,938 @@ +#include +#include +#include +#include + +enum class enum_type +{ + a, + b, + c +}; + +TEST_CASE(value_default_construct) +{ + migraphx::value v; + EXPECT(v.is_null()); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_construct_null) +{ + migraphx::value v = nullptr; + EXPECT(v.is_null()); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_assign_null) +{ + migraphx::value v; + v = nullptr; + EXPECT(v.is_null()); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_construct_int1) +{ + EXPECT(migraphx::value(1).is_int64()); + migraphx::value v(1); + EXPECT(v.is_int64()); + EXPECT(v.get_int64() == 1); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_construct_int2) +{ + migraphx::value v = 1; + EXPECT(v.is_int64()); + EXPECT(v.get_int64() == 1); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_construct_string) +{ + migraphx::value v = "one"; + EXPECT(v.is_string()); + EXPECT(v.get_string() == "one"); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_construct_key_string_literal_pair) +{ + // Use parens instead {} to construct to test the key-pair constructor + migraphx::value v("key", "one"); + EXPECT(v.is_string()); + EXPECT(v.get_string() == "one"); + EXPECT(v.get_key() == "key"); +} + +TEST_CASE(value_construct_float) +{ + migraphx::value v = 1.0; + EXPECT(v.is_float()); + EXPECT(migraphx::float_equal(v.get_float(), 1.0)); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_construct_bool) +{ + migraphx::value v = true; + EXPECT(v.is_bool()); + EXPECT(v.get_bool() == true); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_construct_enum1) +{ + migraphx::value v = enum_type::a; + EXPECT(v.is_int64()); + EXPECT(v.get_int64() == static_cast(enum_type::a)); + EXPECT(bool{v.to() == enum_type::a}); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_construct_enum2) +{ + migraphx::value v = enum_type::b; + EXPECT(v.is_int64()); + EXPECT(v.get_int64() == static_cast(enum_type::b)); + EXPECT(bool{v.to() == enum_type::b}); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_construct_enum3) +{ + migraphx::value v = enum_type::c; + EXPECT(v.is_int64()); + EXPECT(v.get_int64() == static_cast(enum_type::c)); + EXPECT(bool{v.to() == enum_type::c}); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_construct_empty_object) +{ + migraphx::value v = migraphx::value::object{}; + EXPECT(v.is_object()); + EXPECT(v.get_object().empty()); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_construct_empty_array) +{ + migraphx::value v = migraphx::value::array{}; + EXPECT(v.is_array()); + EXPECT(v.get_array().empty()); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_assign_int) +{ + migraphx::value v; + v = 0; + EXPECT(v.is_int64()); + EXPECT(v.get_int64() == 0); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_copy_construct) +{ + migraphx::value v1(1); + migraphx::value v2 = v1; // NOLINT + EXPECT(v1 == v2); +} + +TEST_CASE(value_copy_assign) +{ + migraphx::value v1(1); + migraphx::value v2; + v2 = v1; + EXPECT(v1 == v2); +} + +TEST_CASE(value_reassign) +{ + migraphx::value v1(1); + migraphx::value v2 = v1; + v1 = 2; + EXPECT(v1 != v2); +} + +TEST_CASE(value_copy_assign_key) +{ + migraphx::value v1("key", 1); + migraphx::value v2; + v2 = v1; + EXPECT(v2.get_key() == "key"); + EXPECT(v1 == v2); +} + +TEST_CASE(value_copy_assign_keyless) +{ + migraphx::value v1(1); + migraphx::value v2("key", nullptr); + v2 = v1; + EXPECT(v2.get_key() == "key"); + EXPECT(v1 != v2); + EXPECT(v1.without_key() == v2.without_key()); +} + +TEST_CASE(value_assign_key_string_literal_pair) +{ + migraphx::value v = migraphx::value::object{}; + v["key"] = "one"; + EXPECT(v["key"].is_string()); + EXPECT(v["key"].get_string() == "one"); + EXPECT(v["key"].get_key() == "key"); +} + +TEST_CASE(value_construct_array) +{ + migraphx::value v = {1, 2, 3}; + EXPECT(v.is_array()); + EXPECT(v.get_array().size() == 3); + EXPECT(v.size() == 3); + EXPECT(not v.empty()); + EXPECT(v.data() != nullptr); + EXPECT(v.front().is_int64()); + EXPECT(v.front() == migraphx::value(1)); + EXPECT(v[1] == migraphx::value(2)); + EXPECT(v.at(1) == migraphx::value(2)); + EXPECT(v.back() == migraphx::value(3)); + EXPECT(test::throws([&] { v.at("???"); })); + [=] { + EXPECT(v.data() != nullptr); + EXPECT(v.front().is_int64()); + EXPECT(v.front() == migraphx::value(1)); + EXPECT(v[1] == migraphx::value(2)); + EXPECT(v.at(1) == migraphx::value(2)); + EXPECT(v.back() == migraphx::value(3)); + }(); +} + +TEST_CASE(value_insert_array) +{ + migraphx::value v; + v.insert(v.end(), 1); + v.insert(v.end(), 2); + v.insert(v.end(), 3); + EXPECT(v.is_array()); + EXPECT(v.get_array().size() == 3); + EXPECT(v.size() == 3); + EXPECT(not v.empty()); + EXPECT(v.data() != nullptr); + EXPECT(v.front().is_int64()); + EXPECT(v.front() == migraphx::value(1)); + EXPECT(v[1] == migraphx::value(2)); + EXPECT(v.at(1) == migraphx::value(2)); + EXPECT(v.back() == migraphx::value(3)); +} + +TEST_CASE(value_key_array) +{ + std::vector values = {1, 2, 3}; + migraphx::value v("key", values); + EXPECT(v.is_array()); + EXPECT(v.get_key() == "key"); + EXPECT(v.get_array().size() == 3); + EXPECT(v.size() == 3); + EXPECT(not v.empty()); + EXPECT(v.data() != nullptr); + EXPECT(v.front().is_int64()); + EXPECT(v.front() == migraphx::value(1)); + EXPECT(v[1] == migraphx::value(2)); + EXPECT(v.at(1) == migraphx::value(2)); + EXPECT(v.back() == migraphx::value(3)); +} + +TEST_CASE(value_key_array_empty) +{ + std::vector values{}; + migraphx::value v("key", values); + EXPECT(v.is_array()); + EXPECT(v.get_key() == "key"); + EXPECT(v.get_array().size() == 0); + EXPECT(v.size() == 0); + EXPECT(v.empty()); +} + +TEST_CASE(value_construct_key_int1) +{ + migraphx::value v("one", 1); + EXPECT(v.is_int64()); + EXPECT(v.get_int64() == 1); + EXPECT(v.get_key() == "one"); +} + +TEST_CASE(value_construct_key_int2) +{ + migraphx::value v = {"one", 1}; + EXPECT(v.is_int64()); + EXPECT(v.get_int64() == 1); + EXPECT(v.get_key() == "one"); +} + +TEST_CASE(value_construct_key_pair) +{ + migraphx::value v = std::make_pair("one", 1); + EXPECT(v.is_int64()); + EXPECT(v.get_int64() == 1); + EXPECT(v.get_key() == "one"); +} + +TEST_CASE(value_construct_object) +{ + migraphx::value v = {{"one", 1}, {"two", migraphx::value(2)}, {"three", 3}}; + EXPECT(v.is_object()); + EXPECT(v.get_object().size() == 3); + EXPECT(v.size() == 3); + EXPECT(not v.empty()); + EXPECT(v.data() != nullptr); + EXPECT(v.front().is_int64()); + EXPECT(v.front().get_int64() == 1); + EXPECT(v.front().get_key() == "one"); + EXPECT(v[1].is_int64()); + EXPECT(v[1].get_int64() == 2); + EXPECT(v[1].get_key() == "two"); + EXPECT(v.back().is_int64()); + EXPECT(v.back().get_int64() == 3); + EXPECT(v.back().get_key() == "three"); + + EXPECT(v.contains("one")); + EXPECT(v.contains("two")); + EXPECT(v.contains("three")); + EXPECT(not v.contains("four")); + + EXPECT(v.at("one").is_int64()); + EXPECT(v.at("one").get_int64() == 1); + EXPECT(v.at("one").get_key() == "one"); + EXPECT(v.at("two").is_int64()); + EXPECT(v.at("two").get_int64() == 2); + EXPECT(v.at("two").get_key() == "two"); + EXPECT(v.at("three").is_int64()); + EXPECT(v.at("three").get_int64() == 3); + EXPECT(v.at("three").get_key() == "three"); + + EXPECT(v["one"].is_int64()); + EXPECT(v["one"].get_int64() == 1); + EXPECT(v["one"].get_key() == "one"); + EXPECT(v["two"].is_int64()); + EXPECT(v["two"].get_int64() == 2); + EXPECT(v["two"].get_key() == "two"); + EXPECT(v["three"].is_int64()); + EXPECT(v["three"].get_int64() == 3); + EXPECT(v["three"].get_key() == "three"); +} + +TEST_CASE(value_key_object) +{ + std::unordered_map values = { + {"one", 1}, {"two", migraphx::value(2)}, {"three", 3}}; + migraphx::value v("key", values); + EXPECT(v.get_key() == "key"); + EXPECT(v.is_object()); + EXPECT(v.get_object().size() == 3); + EXPECT(v.size() == 3); + EXPECT(not v.empty()); + EXPECT(v.data() != nullptr); + + EXPECT(v.contains("one")); + EXPECT(v.contains("two")); + EXPECT(v.contains("three")); + EXPECT(not v.contains("four")); + + EXPECT(v.at("one").is_int64()); + EXPECT(v.at("one").get_int64() == 1); + EXPECT(v.at("one").get_key() == "one"); + EXPECT(v.at("two").is_int64()); + EXPECT(v.at("two").get_int64() == 2); + EXPECT(v.at("two").get_key() == "two"); + EXPECT(v.at("three").is_int64()); + EXPECT(v.at("three").get_int64() == 3); + EXPECT(v.at("three").get_key() == "three"); + + EXPECT(v["one"].is_int64()); + EXPECT(v["one"].get_int64() == 1); + EXPECT(v["one"].get_key() == "one"); + EXPECT(v["two"].is_int64()); + EXPECT(v["two"].get_int64() == 2); + EXPECT(v["two"].get_key() == "two"); + EXPECT(v["three"].is_int64()); + EXPECT(v["three"].get_int64() == 3); + EXPECT(v["three"].get_key() == "three"); +} + +TEST_CASE(value_key_object_empty) +{ + std::unordered_map values{}; + migraphx::value v("key", values); + EXPECT(v.get_key() == "key"); + EXPECT(v.is_object()); + EXPECT(v.get_object().size() == 0); + EXPECT(v.size() == 0); + EXPECT(v.empty()); + EXPECT(not v.contains("one")); +} + +TEST_CASE(value_bracket_object) +{ + migraphx::value v; + v["one"] = 1; + v["two"] = migraphx::value(2); + v["three"] = 3; + + EXPECT(v.is_object()); + EXPECT(v.get_object().size() == 3); + EXPECT(v.size() == 3); + EXPECT(not v.empty()); + EXPECT(v.data() != nullptr); + EXPECT(v.front().is_int64()); + EXPECT(v.front().get_int64() == 1); + EXPECT(v.front().get_key() == "one"); + EXPECT(v[1].is_int64()); + EXPECT(v[1].get_int64() == 2); + EXPECT(v[1].get_key() == "two"); + EXPECT(v.back().is_int64()); + EXPECT(v.back().get_int64() == 3); + EXPECT(v.back().get_key() == "three"); + + EXPECT(v.contains("one")); + EXPECT(v.contains("two")); + EXPECT(v.contains("three")); + EXPECT(not v.contains("four")); + + EXPECT(v.at("one").is_int64()); + EXPECT(v.at("one").get_int64() == 1); + EXPECT(v.at("one").get_key() == "one"); + EXPECT(v.at("two").is_int64()); + EXPECT(v.at("two").get_int64() == 2); + EXPECT(v.at("two").get_key() == "two"); + EXPECT(v.at("three").is_int64()); + EXPECT(v.at("three").get_int64() == 3); + EXPECT(v.at("three").get_key() == "three"); +} + +TEST_CASE(value_insert_object) +{ + migraphx::value v; + v.insert({"one", 1}); + v.insert({"two", migraphx::value(2)}); + v.insert({"three", 3}); + EXPECT(v.is_object()); + EXPECT(v.get_object().size() == 3); + EXPECT(v.size() == 3); + EXPECT(not v.empty()); + EXPECT(v.data() != nullptr); + EXPECT(v.front().is_int64()); + EXPECT(v.front().get_int64() == 1); + EXPECT(v.front().get_key() == "one"); + EXPECT(v[1].is_int64()); + EXPECT(v[1].get_int64() == 2); + EXPECT(v[1].get_key() == "two"); + EXPECT(v.back().is_int64()); + EXPECT(v.back().get_int64() == 3); + EXPECT(v.back().get_key() == "three"); + + EXPECT(v.contains("one")); + EXPECT(v.contains("two")); + EXPECT(v.contains("three")); + EXPECT(not v.contains("four")); + + EXPECT(v.at("one").is_int64()); + EXPECT(v.at("one").get_int64() == 1); + EXPECT(v.at("one").get_key() == "one"); + EXPECT(v.at("two").is_int64()); + EXPECT(v.at("two").get_int64() == 2); + EXPECT(v.at("two").get_key() == "two"); + EXPECT(v.at("three").is_int64()); + EXPECT(v.at("three").get_int64() == 3); + EXPECT(v.at("three").get_key() == "three"); + + EXPECT(v["one"].is_int64()); + EXPECT(v["one"].get_int64() == 1); + EXPECT(v["one"].get_key() == "one"); + EXPECT(v["two"].is_int64()); + EXPECT(v["two"].get_int64() == 2); + EXPECT(v["two"].get_key() == "two"); + EXPECT(v["three"].is_int64()); + EXPECT(v["three"].get_int64() == 3); + EXPECT(v["three"].get_key() == "three"); +} + +TEST_CASE(value_emplace_object) +{ + migraphx::value v; + v.emplace("one", 1); + v.emplace("two", migraphx::value(2)); + v.emplace("three", 3); + EXPECT(v.is_object()); + EXPECT(v.size() == 3); + EXPECT(not v.empty()); + EXPECT(v.data() != nullptr); + EXPECT(v.front().is_int64()); + EXPECT(v.front().get_int64() == 1); + EXPECT(v.front().get_key() == "one"); + EXPECT(v[1].is_int64()); + EXPECT(v[1].get_int64() == 2); + EXPECT(v[1].get_key() == "two"); + EXPECT(v.back().is_int64()); + EXPECT(v.back().get_int64() == 3); + EXPECT(v.back().get_key() == "three"); + + EXPECT(v.contains("one")); + EXPECT(v.contains("two")); + EXPECT(v.contains("three")); + EXPECT(not v.contains("four")); + + EXPECT(v.at("one").is_int64()); + EXPECT(v.at("one").get_int64() == 1); + EXPECT(v.at("one").get_key() == "one"); + EXPECT(v.at("two").is_int64()); + EXPECT(v.at("two").get_int64() == 2); + EXPECT(v.at("two").get_key() == "two"); + EXPECT(v.at("three").is_int64()); + EXPECT(v.at("three").get_int64() == 3); + EXPECT(v.at("three").get_key() == "three"); + + EXPECT(v["one"].is_int64()); + EXPECT(v["one"].get_int64() == 1); + EXPECT(v["one"].get_key() == "one"); + EXPECT(v["two"].is_int64()); + EXPECT(v["two"].get_int64() == 2); + EXPECT(v["two"].get_key() == "two"); + EXPECT(v["three"].is_int64()); + EXPECT(v["three"].get_int64() == 3); + EXPECT(v["three"].get_key() == "three"); +} + +TEST_CASE(value_bracket_convert_throws) +{ + migraphx::value v1; + EXPECT(test::throws([&] { v1["key"].to(); })); +} + +TEST_CASE(value_construct_object_string_value) +{ + migraphx::value v = {{"one", "onev"}, {"two", "twov"}}; + EXPECT(v.is_object()); + EXPECT(v.size() == 2); + EXPECT(not v.empty()); + EXPECT(v.data() != nullptr); + EXPECT(v.at("one").is_string()); + EXPECT(v.at("one").get_key() == "one"); + EXPECT(v.at("one").get_string() == "onev"); + EXPECT(v.at("two").is_string()); + EXPECT(v.at("two").get_key() == "two"); + EXPECT(v.at("two").get_string() == "twov"); +} + +TEST_CASE(value_construct_object_string_mixed_value) +{ + migraphx::value v = {{"one", "onev"}, {"two", 2}}; + EXPECT(v.is_object()); + EXPECT(v.size() == 2); + EXPECT(not v.empty()); + EXPECT(v.data() != nullptr); + EXPECT(v.at("one").is_string()); + EXPECT(v.at("one").get_key() == "one"); + EXPECT(v.at("one").get_string() == "onev"); + EXPECT(v.at("two").is_int64()); + EXPECT(v.at("two").get_key() == "two"); + EXPECT(v.at("two").get_int64() == 2); +} + +template +auto compare_predicate(const Expression& e) +{ + bool result = e.value(); + return test::make_predicate(test::as_string(e) + " => " + test::as_string(result), + [=] { return result; }); +} + +TEST_CASE(value_compare) +{ + EXPECT(migraphx::value(1) == migraphx::value(1)); + EXPECT(migraphx::value("key", 1) == migraphx::value("key", 1)); + EXPECT(migraphx::value(1) != migraphx::value(2)); + EXPECT(migraphx::value("key", 1) != migraphx::value("key", 2)); + EXPECT(migraphx::value("key1", 1) != migraphx::value("key2", 1)); + EXPECT(migraphx::value(1) < migraphx::value(2)); + EXPECT(migraphx::value(1) <= migraphx::value(2)); + EXPECT(migraphx::value(1) <= migraphx::value(1)); + EXPECT(migraphx::value(2) > migraphx::value(1)); + EXPECT(migraphx::value(2) >= migraphx::value(1)); + EXPECT(migraphx::value(1) >= migraphx::value(1)); + EXPECT(migraphx::value(1) != migraphx::value("1")); + EXPECT(migraphx::value(1) != migraphx::value()); +} + +// NOLINTNEXTLINE +#define MIGRAPHX_VALUE_TEST_COMPARE(...) compare_predicate(TEST_CAPTURE(__VA_ARGS__)) + +// NOLINTNEXTLINE +#define EXPECT_TOTALLY_ORDERED_IMPL(_, x, y) \ + EXPECT(_(x <= y) or _(x >= y)); \ + EXPECT(_(x < y) or _(x > y) or _(x == y)); \ + EXPECT((_(x < y) or _(x > y)) == _(x != y)); \ + EXPECT(_(x < y) == _(y > x)); \ + EXPECT(_(x <= y) == _(y >= x)); \ + EXPECT(_(x < y) != _(x >= y)); \ + EXPECT(_(x > y) != _(x <= y)); \ + EXPECT(_(x == y) != _(x != y)) + +// NOLINTNEXTLINE +#define EXPECT_TOTALLY_ORDERED(x, y) \ + EXPECT_TOTALLY_ORDERED_IMPL(MIGRAPHX_VALUE_TEST_COMPARE, x, y); \ + EXPECT_TOTALLY_ORDERED_IMPL(MIGRAPHX_VALUE_TEST_COMPARE, y, x) + +// NOLINTNEXTLINE(readability-function-size) +TEST_CASE(value_compare_ordered) +{ + EXPECT_TOTALLY_ORDERED(migraphx::value(), migraphx::value()); + EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value(1)); + EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value(2)); + EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", 1)); + EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", 2)); + EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", 2)); + EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", 2)); + EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", "2")); + EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", "2")); + EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{1}), migraphx::value(std::uint64_t{1})); + EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{1}), migraphx::value(std::uint64_t{2})); + EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{2}), migraphx::value(std::uint64_t{1})); + EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value("1")); + EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value()); +} + +TEST_CASE(value_to_from_string) +{ + migraphx::value v = "1"; + EXPECT(v.to() == "1"); + EXPECT(v.to() == 1); + EXPECT(migraphx::float_equal(v.to(), 1.0)); +} + +TEST_CASE(value_to_from_int) +{ + migraphx::value v = 1; + EXPECT(v.to() == "1"); + EXPECT(v.to() == 1); + EXPECT(migraphx::float_equal(v.to(), 1.0)); +} + +TEST_CASE(value_to_from_float) +{ + migraphx::value v = 1.5; + EXPECT(v.to() == "1.5"); + EXPECT(v.to() == 1); + EXPECT(migraphx::float_equal(v.to(), 1.5)); +} + +TEST_CASE(value_to_from_pair) +{ + migraphx::value v = {"one", 1}; + EXPECT(bool{v.to>() == + std::pair("one", "1")}); + EXPECT(bool{v.to>() == std::pair("one", 1)}); + EXPECT( + bool{v.to>() == std::pair("one", 1.0)}); +} + +TEST_CASE(value_to_struct) +{ + migraphx::value v = 1; + struct local + { + int i = 0; + local() = default; + local(int ii) : i(ii) {} + }; + EXPECT(v.to().i == 1); +} + +TEST_CASE(value_to_error1) +{ + migraphx::value v = {1, 2, 3}; + EXPECT(test::throws([&] { v.to(); })); +} + +TEST_CASE(value_to_error2) +{ + migraphx::value v = 1; + struct local + { + }; + EXPECT(test::throws([&] { v.to(); })); +} + +TEST_CASE(value_to_error_parse) +{ + migraphx::value v = "abc"; + EXPECT(test::throws([&] { v.to(); })); +} + +TEST_CASE(value_to_vector) +{ + migraphx::value v = {1, 2, 3}; + std::vector a = {1, 2, 3}; + EXPECT(v.to_vector() == a); +} + +TEST_CASE(not_array) +{ + migraphx::value v = 1; + EXPECT(v.size() == 0); + EXPECT(not v.contains("???")); + EXPECT(test::throws([&] { v.at(0); })); + EXPECT(test::throws([&] { v.at("???"); })); + EXPECT(v.data() == nullptr); + [=] { + EXPECT(test::throws([&] { v.at(0); })); + EXPECT(test::throws([&] { v.at("???"); })); + EXPECT(v.data() == nullptr); + }(); +} + +TEST_CASE(print) +{ + std::stringstream ss; + migraphx::value v = {1, {{"one", 1}, {"two", 2}}, {1, 2}, {}}; + ss << v; + EXPECT(ss.str() == "{1, {one: 1, two: 2}, {1, 2}, null}"); +} + +TEST_CASE(value_clear) +{ + migraphx::value values = {1, 2, 3}; + EXPECT(values.is_array()); + EXPECT(values.size() == 3); + values.clear(); + EXPECT(values.empty()); + + values.push_back(3); + EXPECT(values.size() == 1); + EXPECT(values.at(0).to() == 3); +} + +TEST_CASE(value_clear_non_array) +{ + migraphx::value values = 1.0; + EXPECT(test::throws([&] { values.clear(); })); +} + +TEST_CASE(value_clear_object) +{ + migraphx::value values = {{"a", 1}, {"b", 2}}; + EXPECT(values.is_object()); + EXPECT(values.size() == 2); + values.clear(); + EXPECT(values.empty()); + + values["c"] = 3; + EXPECT(values.size() == 1); + EXPECT(values.at("c").to() == 3); +} + +TEST_CASE(value_clear_empty_array) +{ + migraphx::value values = migraphx::value::array{}; + EXPECT(values.empty()); + values.clear(); + EXPECT(values.empty()); +} + +TEST_CASE(value_clear_empty_object) +{ + migraphx::value values = migraphx::value::object{}; + EXPECT(values.empty()); + values.clear(); + EXPECT(values.empty()); +} + +TEST_CASE(value_resize) +{ + migraphx::value values = {1, 2, 3}; + EXPECT(values.is_array()); + EXPECT(values.size() == 3); + values.resize(5); + EXPECT(values.size() == 5); + + EXPECT(values.at(3).is_null()); + EXPECT(values.at(4).is_null()); +} + +TEST_CASE(value_resize_with_value) +{ + migraphx::value values = {1, 2, 3}; + EXPECT(values.is_array()); + EXPECT(values.size() == 3); + values.resize(5, 7); + EXPECT(values.size() == 5); + + EXPECT(values.at(3).to() == 7); + EXPECT(values.at(4).to() == 7); +} + +TEST_CASE(value_resize_empty_array) +{ + migraphx::value values = migraphx::value::array{}; + EXPECT(values.is_array()); + EXPECT(values.empty()); + values.resize(3); + EXPECT(values.size() == 3); + + EXPECT(values.at(0).is_null()); + EXPECT(values.at(1).is_null()); + EXPECT(values.at(2).is_null()); +} + +TEST_CASE(value_resize_object) +{ + migraphx::value values = migraphx::value::object{}; + EXPECT(values.is_object()); + EXPECT(test::throws([&] { values.resize(4); })); +} + +TEST_CASE(value_resize_n_object) +{ + migraphx::value values = migraphx::value::object{}; + EXPECT(values.is_object()); + EXPECT(test::throws([&] { values.resize(4, ""); })); +} + +TEST_CASE(value_assign_construct_from_vector) +{ + std::vector v = {1, 2, 3}; + migraphx::value values = v; + EXPECT(values.to_vector() == v); +} + +TEST_CASE(value_construct_from_vector) +{ + std::vector v = {1, 2, 3}; + migraphx::value values(v); + EXPECT(values.to_vector() == v); +} + +TEST_CASE(value_assign_from_vector) +{ + std::vector v = {1, 2, 3}; + migraphx::value values{}; + values = v; + EXPECT(values.to_vector() == v); +} + +TEST_CASE(value_init_from_vector) +{ + std::vector v = {1, 2, 3}; + migraphx::value values = {{"a", v}}; + EXPECT(values.at("a").to_vector() == v); +} + +TEST_CASE(value_binary_default) +{ + migraphx::value v; + v = migraphx::value::binary{}; + EXPECT(v.is_binary()); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_binary) +{ + migraphx::value v; + std::vector data(20); + std::iota(data.begin(), data.end(), 0); + v = migraphx::value::binary{data}; + EXPECT(v.is_binary()); + EXPECT(v.get_binary().size() == data.size()); + EXPECT(v.get_binary() == data); + EXPECT(v.get_key().empty()); +} + +TEST_CASE(value_binary_object) +{ + std::vector data(20); + std::iota(data.begin(), data.end(), 0); + migraphx::value v = {{"data", migraphx::value::binary{data}}}; + + EXPECT(v["data"].is_binary()); + EXPECT(v["data"].get_binary().size() == data.size()); + EXPECT(v["data"].get_binary() == data); +} + +TEST_CASE(value_binary_object_conv) +{ + std::vector data(20); + std::iota(data.begin(), data.end(), 0); + migraphx::value v = {{"data", migraphx::value::binary{data}}}; + + EXPECT(v["data"].is_binary()); + EXPECT(v["data"].get_binary().size() == data.size()); + EXPECT(migraphx::equal(v["data"].get_binary(), data)); +} + +template +bool is_null_type(T) +{ + return false; +} + +bool is_null_type(std::nullptr_t) { return true; } + +TEST_CASE(visit_null) +{ + migraphx::value v; + EXPECT(v.is_null()); + bool visited = false; + v.visit([&](auto&& x) { visited = is_null_type(x); }); + EXPECT(visited); +} + +TEST_CASE(value_or_convert) +{ + migraphx::value v = 1; + EXPECT(v.is_int64()); + EXPECT(v.value_or(3) == 1); +} + +TEST_CASE(value_or_null) +{ + migraphx::value v; + EXPECT(v.is_null()); + EXPECT(v.value_or(3) == 3); +} + +TEST_CASE(value_get_default) +{ + migraphx::value v = {{"key", 1}}; + EXPECT(v.get("key", 3) == 1); + EXPECT(v.get("missing", 3) == 3); +} + +TEST_CASE(value_get_default_vector) +{ + std::vector ints = {1, 2, 3}; + std::vector fallback = {-1}; + migraphx::value v = {{"key", ints}}; + EXPECT(v.get("key", fallback) == ints); + EXPECT(v.get("missing", fallback) == fallback); + EXPECT(v.get("missing", {-1}) == fallback); +} + +TEST_CASE(value_get_default_string_literal) +{ + migraphx::value v = {{"key", "hello"}}; + EXPECT(v.get("key", "none") == "hello"); + EXPECT(v.get("missing", "none") == "none"); +} + +TEST_CASE(value_get_default_string_literal_vector) +{ + std::vector strings = {"1", "2", "3"}; + std::vector fallback = {"none"}; + migraphx::value v = {{"key", strings}}; + EXPECT(v.get("key", fallback) == strings); + EXPECT(v.get("missing", fallback) == fallback); + EXPECT(v.get("missing", {"none"}) == fallback); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/verify/CMakeLists.txt b/test/verify/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..6b7341fa66df1f50d5b8a2739745145b0701a19e --- /dev/null +++ b/test/verify/CMakeLists.txt @@ -0,0 +1,21 @@ + +file(GLOB VERIFY_TESTS ${CONFIGURE_DEPENDS} *.cpp) + +add_executable(test_verify ${VERIFY_TESTS}) +add_dependencies(tests test_verify) +add_dependencies(check test_verify) +target_link_libraries(test_verify migraphx migraphx_all_targets) +target_include_directories(test_verify PUBLIC ../include) +rocm_clang_tidy_check(test_verify) + +foreach(SECTION general rnn) + add_test_command(test_verify_${SECTION} test_verify ${SECTION}) + set_tests_properties(test_verify_${SECTION} PROPERTIES + COST 100 + ) + if(MIGRAPHX_ENABLE_GPU) + set_tests_properties(test_verify_${SECTION} PROPERTIES + RESOURCE_LOCK gpu + ) + endif() +endforeach() diff --git a/test/verify/auto_print.cpp b/test/verify/auto_print.cpp new file mode 100644 index 0000000000000000000000000000000000000000..12c37479057feb2ecf6ae1ef1d48caa2c1cae957 --- /dev/null +++ b/test/verify/auto_print.cpp @@ -0,0 +1,67 @@ +#include "auto_print.hpp" +#include +#include +#include + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wglobal-constructors" +#endif + +using handler_map = std::map>; + +static handler_map create_handlers() +{ + handler_map m; + for(const auto& name : migraphx::get_targets()) + m[name] = [] {}; + return m; +} + +std::function& auto_print::get_handler(const std::string& name) +{ + // NOLINTNEXTLINE + static handler_map handlers = create_handlers(); + return handlers.at(name); +} + +void auto_print::set_terminate_handler(const std::string& name) +{ + // NOLINTNEXTLINE + static std::string pname; + pname = name; + std::set_terminate(+[] { + std::cout << "FAILED: " << pname << std::endl; + try + { + std::rethrow_exception(std::current_exception()); + } + catch(const std::exception& e) + { + std::cout << " what(): " << e.what() << std::endl; + } + std::cout << std::endl; + for(const auto& tname : migraphx::get_targets()) + get_handler(tname)(); + }); +} + +static bool in_exception() +{ +#if __cplusplus >= 201703L + return std::uncaught_exceptions() > 0; +#else + return std::uncaught_exception(); +#endif +} + +auto_print::~auto_print() +{ + if(in_exception()) + { + std::cout << std::endl; + for(const auto& tname : migraphx::get_targets()) + get_handler(tname)(); + } + get_handler(name) = [] {}; +} diff --git a/test/verify/auto_print.hpp b/test/verify/auto_print.hpp new file mode 100755 index 0000000000000000000000000000000000000000..ed7ff0e3d6107242c5322032fd2a54760d5e166c --- /dev/null +++ b/test/verify/auto_print.hpp @@ -0,0 +1,21 @@ +#ifndef MIGRAPHX_GUARD_TEST_AUTO_PRINT_HPP +#define MIGRAPHX_GUARD_TEST_AUTO_PRINT_HPP + +#include +#include + +struct auto_print +{ + static std::function& get_handler(const std::string& name); + static void set_terminate_handler(const std::string& name); + std::string name; + template + auto_print(T& x, std::string s) : name(std::move(s)) + { + get_handler(name) = [&x] { std::cout << x << std::endl; }; + } + + ~auto_print(); +}; + +#endif diff --git a/test/verify/batch_quant_dot_1.cpp b/test/verify/batch_quant_dot_1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8d4ecab005aa5212d7e11976e77a144dc7ae1021 --- /dev/null +++ b/test/verify/batch_quant_dot_1.cpp @@ -0,0 +1,28 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct batch_quant_dot_1 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}}; + migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; + + auto l1 = mm->add_parameter("a", m1_shape); + auto tl1 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1); + auto l2 = mm->add_parameter("b", m2_shape); + auto tl2 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); + auto l3 = mm->add_parameter("c", m3_shape); + migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2); + return p; + } +}; diff --git a/test/verify/batch_quant_dot_2.cpp b/test/verify/batch_quant_dot_2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cf37c39ed7add279b1098c719cb3304f66501980 --- /dev/null +++ b/test/verify/batch_quant_dot_2.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct batch_quant_dot_2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}}; + migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; + + auto l1 = mm->add_parameter("a", m1_shape); + auto l2 = mm->add_parameter("b", m2_shape); + auto l3 = mm->add_parameter("c", m3_shape); + migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3); + return p; + } +}; diff --git a/test/verify/batch_quant_dot_3.cpp b/test/verify/batch_quant_dot_3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..334a379ded57f289049fc966c1a7973f4937bc76 --- /dev/null +++ b/test/verify/batch_quant_dot_3.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct batch_quant_dot_3 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 6}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 6, 7}}; + + auto l1 = mm->add_parameter("a", m1_shape); + auto l2 = mm->add_parameter("b", m2_shape); + mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2); + return p; + } +}; diff --git a/test/verify/batch_quant_dot_4.cpp b/test/verify/batch_quant_dot_4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c6946d2f5f99eeb5266af59bb551c4df84bc6139 --- /dev/null +++ b/test/verify/batch_quant_dot_4.cpp @@ -0,0 +1,25 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct batch_quant_dot_4 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}}; + + auto l1 = mm->add_parameter("a", m1_shape); + auto l2 = mm->add_parameter("b", m2_shape); + auto tl1 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {3, 0, 1, 2}}}), l1); + auto tl2 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {3, 1, 2, 0}}}), l2); + mm->add_instruction(migraphx::make_op("quant_dot"), tl1, tl2); + return p; + } +}; diff --git a/test/verify/batch_quant_dot_5.cpp b/test/verify/batch_quant_dot_5.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dbb04cb579e83bb498649d8ef5d2f8f0127aee1d --- /dev/null +++ b/test/verify/batch_quant_dot_5.cpp @@ -0,0 +1,27 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct batch_quant_dot_5 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}}; + + auto l1 = mm->add_parameter("a", m1_shape); + auto l2 = mm->add_parameter("b", m2_shape); + auto tl1 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1); + auto sl1 = mm->add_instruction(migraphx::make_op("add"), tl1, tl1); + auto tl2 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); + auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2); + mm->add_instruction(migraphx::make_op("quant_dot"), sl1, sl2); + return p; + } +}; diff --git a/test/verify/gemm_2args_bmv.cpp b/test/verify/gemm_2args_bmv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5ebbd1b713216dc3a993ab4d6b5bbe97c7f15d2e --- /dev/null +++ b/test/verify/gemm_2args_bmv.cpp @@ -0,0 +1,25 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_2args_bmv : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 3, 5}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {5}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2); + auto bul2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5, 1}}}), ul2); + + mm->add_instruction(migraphx::make_op("dot"), l1, bul2); + + return p; + } +}; diff --git a/test/verify/gemm_2args_mm_1.cpp b/test/verify/gemm_2args_mm_1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d97ea8d506a26d00e2acc6b47545125096926d43 --- /dev/null +++ b/test/verify/gemm_2args_mm_1.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_2args_mm_1 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + auto bl2 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4}}}), l2); + + mm->add_instruction(migraphx::make_op("dot"), l1, bl2); + + return p; + } +}; diff --git a/test/verify/gemm_2args_mm_2.cpp b/test/verify/gemm_2args_mm_2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ca7964ded8a140ad1c7c0753b4cefaa3eb529079 --- /dev/null +++ b/test/verify/gemm_2args_mm_2.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_2args_mm_2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + auto bl2 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4}}}), l2); + + mm->add_instruction(migraphx::make_op("dot"), l1, bl2); + + return p; + } +}; diff --git a/test/verify/gemm_2args_mm_3.cpp b/test/verify/gemm_2args_mm_3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c00d4bf762f3cb08a3f5ef04d53c6ed752c1acf --- /dev/null +++ b/test/verify/gemm_2args_mm_3.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_2args_mm_3 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto bl1 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 3}}}), l1); + auto l2 = mm->add_parameter("2", m2_shape); + + mm->add_instruction(migraphx::make_op("dot"), bl1, l2); + + return p; + } +}; diff --git a/test/verify/gemm_2args_mm_4.cpp b/test/verify/gemm_2args_mm_4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b4d664ac885460d64871cf6098b3a4b320390cbd --- /dev/null +++ b/test/verify/gemm_2args_mm_4.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_2args_mm_4 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto bl1 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 3}}}), l1); + auto l2 = mm->add_parameter("2", m2_shape); + + mm->add_instruction(migraphx::make_op("dot"), bl1, l2); + + return p; + } +}; diff --git a/test/verify/gemm_2args_mm_5.cpp b/test/verify/gemm_2args_mm_5.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e63b9c94a61b4dd6f344ada2e03f319988b88eb6 --- /dev/null +++ b/test/verify/gemm_2args_mm_5.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_2args_mm_5 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto bl1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2, 3}}}), l1); + auto l2 = mm->add_parameter("2", m2_shape); + + mm->add_instruction(migraphx::make_op("dot"), bl1, l2); + + return p; + } +}; diff --git a/test/verify/gemm_2args_mm_6.cpp b/test/verify/gemm_2args_mm_6.cpp new file mode 100644 index 0000000000000000000000000000000000000000..897279e9b213cf8249d0d67ab39505b80c844b1f --- /dev/null +++ b/test/verify/gemm_2args_mm_6.cpp @@ -0,0 +1,26 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_2args_mm_6 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto bl1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2, 3}}}), l1); + auto l2 = mm->add_parameter("2", m2_shape); + auto bl2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 3, 4}}}), l2); + + mm->add_instruction(migraphx::make_op("dot"), bl1, bl2); + + return p; + } +}; diff --git a/test/verify/gemm_2args_mm_7.cpp b/test/verify/gemm_2args_mm_7.cpp new file mode 100644 index 0000000000000000000000000000000000000000..543c98e526e2f0dc18886f404e3b4160b0b250da --- /dev/null +++ b/test/verify/gemm_2args_mm_7.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_2args_mm_7 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto bl1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2, 3}}}), l1); + auto l2 = mm->add_parameter("2", m2_shape); + + mm->add_instruction(migraphx::make_op("dot"), bl1, l2); + + return p; + } +}; diff --git a/test/verify/gemm_2args_mv.cpp b/test/verify/gemm_2args_mv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa9496b53b6d0ae14369b9069e2f0db07671ff25 --- /dev/null +++ b/test/verify/gemm_2args_mv.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_2args_mv : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {3, 5}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {5}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2); + + mm->add_instruction(migraphx::make_op("dot"), l1, ul2); + + return p; + } +}; diff --git a/test/verify/gemm_2args_vbm.cpp b/test/verify/gemm_2args_vbm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ca2cdd34ab8d20e99ed7161a39afb5870bf516a0 --- /dev/null +++ b/test/verify/gemm_2args_vbm.cpp @@ -0,0 +1,27 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_2args_vbm : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {5}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {2, 2, 5, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto ul1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1); + auto bul1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 1, 5}}}), ul1); + + auto l2 = mm->add_parameter("2", m2_shape); + + auto res = mm->add_instruction(migraphx::make_op("dot"), bul1, l2); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), res); + + return p; + } +}; diff --git a/test/verify/gemm_2args_vm.cpp b/test/verify/gemm_2args_vm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aeca564eb0c64d2cc6079f14b99b06a7b2b8cf27 --- /dev/null +++ b/test/verify/gemm_2args_vm.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_2args_vm : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {5}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {5, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto ul1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1); + auto l2 = mm->add_parameter("2", m2_shape); + + auto res = mm->add_instruction(migraphx::make_op("dot"), ul1, l2); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res); + + return p; + } +}; diff --git a/test/verify/gemm_2args_vv.cpp b/test/verify/gemm_2args_vv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5ce5fac970d7f77ccc76eea6ba3c00f666a1e27c --- /dev/null +++ b/test/verify/gemm_2args_vv.cpp @@ -0,0 +1,27 @@ + +#include +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_2args_vv : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {8}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {8}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto ul1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1); + auto l2 = mm->add_parameter("2", m2_shape); + auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2); + float alpha = 0.23f; + auto res = migraphx::add_apply_alpha_beta(*mm, {ul1, ul2}, migraphx::make_op("dot"), alpha); + auto sres = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sres); + + return p; + } +}; diff --git a/test/verify/gemm_add.cpp b/test/verify/gemm_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..875f7eb7e705a9aa8b7e172a896600de81fe9f6a --- /dev/null +++ b/test/verify/gemm_add.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include +struct gemm_add : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; + migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + auto l3 = mm->add_parameter("3", m3_shape); + + auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2); + mm->add_instruction(migraphx::make_op("add"), dot, l3); + return p; + } +}; diff --git a/test/verify/gemm_literal.cpp b/test/verify/gemm_literal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..158ba51013e974697d51a091ef79cde71ae1870d --- /dev/null +++ b/test/verify/gemm_literal.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_literal : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape a_shape{migraphx::shape::float_type, {2, 4}}; + migraphx::shape b_shape{migraphx::shape::float_type, {4, 4}}; + + auto a = mm->add_literal(migraphx::generate_literal(a_shape)); + auto b = mm->add_parameter("b", b_shape); + mm->add_instruction(migraphx::op::dot{}, a, b); + + return p; + } +}; diff --git a/test/verify/gemm_multi_3args.cpp b/test/verify/gemm_multi_3args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..47e9ddb049d1d98dcfa989aedc4d82a21aac7b09 --- /dev/null +++ b/test/verify/gemm_multi_3args.cpp @@ -0,0 +1,26 @@ + +#include +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_multi_3args : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}}; + migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}}; + + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + auto l3 = mm->add_parameter("3", m3_shape); + float alpha = 0.35; + float beta = 0.41; + migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta); + return p; + } +}; diff --git a/test/verify/gemm_multi_3args_alpha0.cpp b/test/verify/gemm_multi_3args_alpha0.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e71f69cb20304af2a94d6929ff7d536add3ba2cc --- /dev/null +++ b/test/verify/gemm_multi_3args_alpha0.cpp @@ -0,0 +1,25 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include +struct gemm_multi_3args_alpha0 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; + migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + auto l3 = mm->add_parameter("3", m3_shape); + + float alpha = 0.0f; + float beta = 1.0f; + migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta); + return p; + } +}; diff --git a/test/verify/gemm_multi_3args_beta0.cpp b/test/verify/gemm_multi_3args_beta0.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f3fb4bb73aa00fe4da29cf8efe98824caddebfbe --- /dev/null +++ b/test/verify/gemm_multi_3args_beta0.cpp @@ -0,0 +1,26 @@ + +#include +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_multi_3args_beta0 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; + migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + auto l3 = mm->add_parameter("3", m3_shape); + + float alpha = 1.0f; + float beta = 0.0f; + migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta); + return p; + } +}; diff --git a/test/verify/gemm_multi_3args_c25.cpp b/test/verify/gemm_multi_3args_c25.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a985c7dbef6c28f7eb1f2dedf993aa504dee17ef --- /dev/null +++ b/test/verify/gemm_multi_3args_c25.cpp @@ -0,0 +1,26 @@ + +#include +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_multi_3args_c25 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {3, 5}}; + migraphx::shape m3_shape{migraphx::shape::float_type, {2, 5}}; + + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + auto l3 = mm->add_parameter("3", m3_shape); + float alpha = 0.35; + float beta = 0.41; + migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta); + return p; + } +}; diff --git a/test/verify/gemm_multi_dim_2.cpp b/test/verify/gemm_multi_dim_2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..35958c2fc92d50928af0915e7375db0bb1a87a62 --- /dev/null +++ b/test/verify/gemm_multi_dim_2.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_multi_dim_2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + + mm->add_instruction(migraphx::make_op("dot"), l1, l2); + + return p; + } +}; diff --git a/test/verify/gemm_multi_dim_2_3.cpp b/test/verify/gemm_multi_dim_2_3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..40e2d173e1f70c4f14cf5b32ea71f29130e9c43c --- /dev/null +++ b/test/verify/gemm_multi_dim_2_3.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_multi_dim_2_3 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + + mm->add_instruction(migraphx::make_op("dot"), l1, l2); + + return p; + } +}; diff --git a/test/verify/gemm_multi_transpose.cpp b/test/verify/gemm_multi_transpose.cpp new file mode 100644 index 0000000000000000000000000000000000000000..28e3f91bb4f91e3a564d9be060d578125c56f0e7 --- /dev/null +++ b/test/verify/gemm_multi_transpose.cpp @@ -0,0 +1,26 @@ + +#include +#include "verify_program.hpp" +#include +#include +#include + +struct gemm_multi_transpose : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; + migraphx::shape m2_shape{migraphx::shape::float_type, {3, 2, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + auto tl2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), l2); + + float alpha = 1.0f; + float beta = 1.0f; + migraphx::add_apply_alpha_beta(*mm, {l1, tl2}, migraphx::make_op("dot"), alpha, beta); + return p; + } +}; diff --git a/test/verify/main.cpp b/test/verify/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72fac4def1466e3b47dbc9424edce467cef44786 --- /dev/null +++ b/test/verify/main.cpp @@ -0,0 +1,50 @@ +#include "run_verify.hpp" +#include +#include + +#ifdef HAVE_GPU +#include +#include +#endif +#ifdef HAVE_CPU +#include +#endif + +inline void check_gpu_streams(const migraphx::program& p) +{ +#ifdef HAVE_GPU + const auto* mm = p.get_main_module(); + auto races = migraphx::gpu::analyze_streams(*mm); + for(auto&& race : races) + { + std::cout << "FAILED: " << std::endl; + std::cout << "Race condition detected for: "; + mm->debug_print(race.ins); + std::cout << "Should happen after: "; + mm->debug_print(race.before); + } +#else + (void)p; +#endif +} + +void validate_gpu(const migraphx::program& p, const migraphx::parameter_map& m) +{ + check_gpu_streams(p); + + // Ensure the program doesn't modify the context in a dry run + auto ctx = p.get_context(); + assert(&ctx != &p.get_context()); + EXPECT(is_shared(ctx, p.get_context())); + p.dry_run(m); + EXPECT(is_shared(ctx, p.get_context())); +} + +int main(int argc, const char* argv[]) +{ + run_verify rv; + rv.add_validation_for("gpu", &validate_gpu); + rv.disable_test_for("cpu", {"test_if_lp", "test_if_param", "test_if_literal"}); + rv.disable_test_for("gpu", {"test_conv_bn_add"}); + rv.run(argc, argv); +} diff --git a/test/verify/quant_conv.cpp b/test/verify/quant_conv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a7ca40fe97a8e16141d5a8a2c112aeaa151505c8 --- /dev/null +++ b/test/verify/quant_conv.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct quant_conv : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + auto pa = mm->add_parameter("a", a_shape); + migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + auto pc = mm->add_parameter("c", c_shape); + mm->add_instruction(migraphx::make_op("quant_convolution"), pa, pc); + return p; + } +}; diff --git a/test/verify/quant_conv_default_mode.cpp b/test/verify/quant_conv_default_mode.cpp new file mode 100755 index 0000000000000000000000000000000000000000..652e903c2d26d621be247dd4ea33fa8008f17534 --- /dev/null +++ b/test/verify/quant_conv_default_mode.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct quant_conv_default_mode : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + auto pa = mm->add_parameter("a", a_shape); + migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + auto pc = mm->add_parameter("c", c_shape); + mm->add_instruction( + migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same}, + pa, + pc); + return p; + } +}; diff --git a/test/verify/quant_conv_int8x4_default.cpp b/test/verify/quant_conv_int8x4_default.cpp new file mode 100644 index 0000000000000000000000000000000000000000..58dec575d539f183088a758c6a9b698878fafad9 --- /dev/null +++ b/test/verify/quant_conv_int8x4_default.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct quant_conv_int8x4_default : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape a_shape{migraphx::shape::int8_type, {16, 16, 4, 4}}; + auto pa = mm->add_parameter("a", a_shape); + migraphx::shape c_shape{migraphx::shape::int8_type, {16, 16, 3, 3}}; + auto pc = mm->add_parameter("c", c_shape); + mm->add_instruction( + migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same}, + pa, + pc); + return p; + } +}; diff --git a/test/verify/quant_conv_padding.cpp b/test/verify/quant_conv_padding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..069717f3a471c8d02b24a3f6bfc56a578a866f2a --- /dev/null +++ b/test/verify/quant_conv_padding.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct quant_conv_padding : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + auto pa = mm->add_parameter("a", a_shape); + migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + auto pc = mm->add_parameter("c", c_shape); + mm->add_instruction( + migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}), + pa, + pc); + return p; + } +}; diff --git a/test/verify/quant_conv_padding_stride.cpp b/test/verify/quant_conv_padding_stride.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9760708ae8dd1aa0659108a3f9a7d9be1327d76c --- /dev/null +++ b/test/verify/quant_conv_padding_stride.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct quant_conv_padding_stride : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + auto pa = mm->add_parameter("a", a_shape); + migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + auto pc = mm->add_parameter("c", c_shape); + mm->add_instruction( + migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}), + pa, + pc); + + return p; + } +}; diff --git a/test/verify/quant_conv_valid_mode.cpp b/test/verify/quant_conv_valid_mode.cpp new file mode 100755 index 0000000000000000000000000000000000000000..04e2c123961129f7a085c82ac81872eb8cde1055 --- /dev/null +++ b/test/verify/quant_conv_valid_mode.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct quant_conv_valid_mode : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + auto pa = mm->add_parameter("a", a_shape); + migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + auto pc = mm->add_parameter("c", c_shape); + mm->add_instruction( + migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid}, + pa, + pc); + return p; + } +}; diff --git a/test/verify/quant_dot_3args_1.cpp b/test/verify/quant_dot_3args_1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5acc99c23898508eb1b710c60d4e57025dbc4d47 --- /dev/null +++ b/test/verify/quant_dot_3args_1.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct quant_dot_3args_1 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; + migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; + + auto l1 = mm->add_parameter("a", m1_shape); + auto l2 = mm->add_parameter("b", m2_shape); + auto l3 = mm->add_parameter("c", m3_shape); + migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1); + return p; + } +}; diff --git a/test/verify/quant_dot_3args_2.cpp b/test/verify/quant_dot_3args_2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..40fab33ae7c2862a3ca4d7a9009ad34d16429a06 --- /dev/null +++ b/test/verify/quant_dot_3args_2.cpp @@ -0,0 +1,26 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct quant_dot_3args_2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; + migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; + + auto l1 = mm->add_parameter("a", m1_shape); + auto tl1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); + auto l2 = mm->add_parameter("b", m2_shape); + auto l3 = mm->add_parameter("c", m3_shape); + migraphx::add_apply_alpha_beta(*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3); + return p; + } +}; diff --git a/test/verify/quant_dot_3args_3.cpp b/test/verify/quant_dot_3args_3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e66a74a133ee555819e48e42566827831e1aaeea --- /dev/null +++ b/test/verify/quant_dot_3args_3.cpp @@ -0,0 +1,26 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct quant_dot_3args_3 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; + migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; + + auto l1 = mm->add_parameter("a", m1_shape); + auto l2 = mm->add_parameter("b", m2_shape); + auto tl2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); + auto l3 = mm->add_parameter("c", m3_shape); + migraphx::add_apply_alpha_beta(*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3); + return p; + } +}; diff --git a/test/verify/quant_dot_3args_4.cpp b/test/verify/quant_dot_3args_4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f60080a18637bc234cb3c37115cd748bae17eddd --- /dev/null +++ b/test/verify/quant_dot_3args_4.cpp @@ -0,0 +1,28 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct quant_dot_3args_4 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; + migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; + + auto l1 = mm->add_parameter("a", m1_shape); + auto tl1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); + auto l2 = mm->add_parameter("b", m2_shape); + auto tl2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); + auto l3 = mm->add_parameter("c", m3_shape); + migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2); + return p; + } +}; diff --git a/test/verify/quant_dot_3args_5.cpp b/test/verify/quant_dot_3args_5.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c86d2970ffbfcb51414902be04dc215f039bd83 --- /dev/null +++ b/test/verify/quant_dot_3args_5.cpp @@ -0,0 +1,26 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct quant_dot_3args_5 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::int8_type, {6, 2}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 6}}; + + auto l1 = mm->add_parameter("a", m1_shape); + auto tl1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); + auto l2 = mm->add_parameter("b", m2_shape); + auto tl2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); + migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), 3); + return p; + } +}; diff --git a/test/verify/run_verify.cpp b/test/verify/run_verify.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7c14a303ad4bced7c9fff9d67c06d5b75b8cb70b --- /dev/null +++ b/test/verify/run_verify.cpp @@ -0,0 +1,222 @@ +#include "run_verify.hpp" +#include "auto_print.hpp" +#include "verify_program.hpp" +#include "test.hpp" +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_TEST_COMPILE) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_TEST) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DUMP_TEST) + +// An improved async, that doesn't block +template +std::future::type> detach_async(Function&& f, + bool parallel = true) +{ + if(parallel) + { + using result_type = typename std::result_of::type; + std::packaged_task task(std::forward(f)); + auto fut = task.get_future(); + std::thread(std::move(task)).detach(); + return fut; + } + return std::async(std::launch::deferred, std::forward(f)); +} + +inline void compile_check(migraphx::program& p, const migraphx::target& t, bool show_trace = false) +{ + auto name = t.name(); + auto shapes = p.get_output_shapes(); + std::stringstream ss; + migraphx::compile_options options; + if(show_trace) + options.trace = migraphx::tracer{std::cout}; + p.compile(t, options); + if(shapes.size() != p.get_output_shapes().size()) + { + std::cout << ss.str() << std::endl; + throw std::runtime_error("Compiling program with " + name + + " alters its number of outputs"); + } + + auto num = shapes.size(); + for(std::size_t i = 0; i < num; ++i) + { + if(p.get_output_shapes()[i].lens() != shapes[i].lens()) + { + std::cout << ss.str() << std::endl; + throw std::runtime_error("Compiling program with " + name + " alters its shape"); + } + } +} + +target_info run_verify::get_target_info(const std::string& name) const +{ + auto it = info.find(name); + if(it != info.end()) + return it->second; + else + return {}; +} + +void run_verify::validate(const migraphx::target& t, + const migraphx::program& p, + const migraphx::parameter_map& m) const +{ + auto ti = get_target_info(t.name()); + if(ti.validate) + ti.validate(p, m); +} + +std::vector run_verify::run_ref(migraphx::program p, + migraphx::parameter_map inputs) const +{ + migraphx::ref::target t{}; + auto_print pp{p, t.name()}; + compile_check(p, t); + return p.eval(std::move(inputs)); +} +std::pair> run_verify::run_target( + const migraphx::target& t, migraphx::program p, const migraphx::parameter_map& inputs) const +{ + auto_print pp{p, t.name()}; + auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{}); + compile_check(p, t, (trace_target == t.name())); + migraphx::parameter_map m; + for(auto&& input : inputs) + { + m[input.first] = t.copy_to(input.second); + } + for(auto&& x : p.get_parameter_shapes()) + { + if(m.count(x.first) == 0) + { + m[x.first] = t.allocate(x.second); + } + } + validate(t, p, m); + p.eval(m); + + auto tres = p.eval(m); + std::vector res(tres.size()); + std::transform( + tres.begin(), tres.end(), res.begin(), [&](auto& argu) { return t.copy_from(argu); }); + + return std::make_pair(std::move(p), res); +} + +template +auto get_hash(const T& x) +{ + return std::hash{}(x); +} + +void run_verify::verify(const std::string& name, const migraphx::program& p) const +{ + using result_future = + std::future>>; + auto_print::set_terminate_handler(name); + if(migraphx::enabled(MIGRAPHX_DUMP_TEST{})) + migraphx::save(p, name + ".mxr"); + std::vector target_names; + for(const auto& tname : migraphx::get_targets()) + { + if(tname == "ref") + continue; + + // if tests disabled, skip running it + target_info ti = get_target_info(tname); + if(migraphx::contains(ti.disabled_tests, name)) + continue; + + target_names.push_back(tname); + } + if(not target_names.empty()) + { + std::vector> results; + migraphx::parameter_map m; + for(auto&& x : p.get_parameter_shapes()) + { + m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first)); + } + + auto gold_f = detach_async([=] { return run_ref(p, m); }); + for(const auto& tname : target_names) + { + target_info ti = get_target_info(tname); + auto t = migraphx::make_target(tname); + results.emplace_back(tname, + detach_async([=] { return run_target(t, p, m); }, ti.parallel)); + } + + assert(gold_f.valid()); + auto gold = gold_f.get(); + + for(auto&& pp : results) + { + assert(pp.second.valid()); + auto tname = pp.first; + auto x = pp.second.get(); + auto cp = x.first; + auto result = x.second; + + bool passed = true; + passed &= (gold.size() == result.size()); + std::size_t num = gold.size(); + for(std::size_t i = 0; ((i < num) and passed); ++i) + { + passed &= migraphx::verify_args(tname, gold[i], result[i]); + } + + if(not passed or migraphx::enabled(MIGRAPHX_TRACE_TEST{})) + { + std::cout << p << std::endl; + std::cout << "ref:\n" << p << std::endl; + std::cout << tname << ":\n" << cp << std::endl; + std::cout << std::endl; + } + EXPECT(passed); + } + } + std::set_terminate(nullptr); +} + +void run_verify::run(int argc, const char* argv[]) const +{ + std::unordered_map> labels; + for(auto&& p : get_programs()) + { + labels[p.section].push_back(p.name); + test::add_test_case(p.name, [=] { verify(p.name, p.get_program()); }); + } + test::driver d{}; + d.get_case_names = [&](const std::string& name) -> std::vector { + if(labels.count(name) > 0) + return labels.at(name); + return {name}; + }; + d.run(argc, argv); +} + +void run_verify::disable_parallel_for(const std::string& name) { info[name].parallel = false; } +void run_verify::add_validation_for(const std::string& name, target_info::validation_function v) +{ + info[name].validate = std::move(v); +} + +void run_verify::disable_test_for(const std::string& name, const std::vector& tests) +{ + auto& disabled_tests = info[name].disabled_tests; + disabled_tests.insert(disabled_tests.end(), tests.begin(), tests.end()); +} diff --git a/test/verify/run_verify.hpp b/test/verify/run_verify.hpp new file mode 100755 index 0000000000000000000000000000000000000000..37ad1560f71a940c975d14e64c5d91ba18c168ac --- /dev/null +++ b/test/verify/run_verify.hpp @@ -0,0 +1,40 @@ +#ifndef MIGRAPHX_GUARD_TEST_RUN_VERIFY_HPP +#define MIGRAPHX_GUARD_TEST_RUN_VERIFY_HPP + +#include +#include +#include + +struct target_info +{ + using validation_function = + std::function; + bool parallel = true; + validation_function validate; + std::vector disabled_tests; +}; + +struct run_verify +{ + std::vector run_ref(migraphx::program p, + migraphx::parameter_map inputs) const; + std::pair> + run_target(const migraphx::target& t, + migraphx::program p, + const migraphx::parameter_map& inputs) const; + void validate(const migraphx::target& t, + const migraphx::program& p, + const migraphx::parameter_map& m) const; + void verify(const std::string& name, const migraphx::program& p) const; + void run(int argc, const char* argv[]) const; + + target_info get_target_info(const std::string& name) const; + void disable_parallel_for(const std::string& name); + void add_validation_for(const std::string& name, target_info::validation_function v); + void disable_test_for(const std::string& name, const std::vector& tests); + + private: + std::map info{}; +}; + +#endif diff --git a/test/verify/test_abs.cpp b/test/verify/test_abs.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eff867c9e42b9d2a84c25372466636f80aa2b270 --- /dev/null +++ b/test/verify/test_abs.cpp @@ -0,0 +1,17 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_abs : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + mm->add_instruction(migraphx::make_op("abs"), x); + return p; + } +}; diff --git a/test/verify/test_acos.cpp b/test/verify/test_acos.cpp new file mode 100755 index 0000000000000000000000000000000000000000..c6b31d3024c3cdd6de8bba461d2710e2bfa5456e --- /dev/null +++ b/test/verify/test_acos.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_acos : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {16}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("acos"), x); + return p; + } +}; diff --git a/test/verify/test_acosh.cpp b/test/verify/test_acosh.cpp new file mode 100644 index 0000000000000000000000000000000000000000..53b6219e5f0e2211f6cf91ad72f3bf273520992c --- /dev/null +++ b/test/verify/test_acosh.cpp @@ -0,0 +1,25 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_acosh : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {16}}; + auto x = mm->add_parameter("x", s); + auto min_val = mm->add_literal(1.1f); + auto max_val = mm->add_literal(100.0f); + min_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val); + max_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), max_val); + auto cx = mm->add_instruction(migraphx::make_op("clip"), x, min_val, max_val); + mm->add_instruction(migraphx::make_op("acosh"), cx); + return p; + } +}; diff --git a/test/verify/test_add.cpp b/test/verify/test_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..889de84b7a6c2501276e58ce0dfbb140820bd35d --- /dev/null +++ b/test/verify/test_add.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_add : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + mm->add_instruction(migraphx::make_op("add"), x, y); + return p; + } +}; diff --git a/test/verify/test_add_broadcast.cpp b/test/verify/test_add_broadcast.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2b0d935f5913e7ccb9120bf03040bede2a1c65c5 --- /dev/null +++ b/test/verify/test_add_broadcast.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_add_broadcast : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}}); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, {2, 2}}); + auto by = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", x->get_shape().lens()}}), y); + mm->add_instruction(migraphx::make_op("add"), x, by); + return p; + } +}; diff --git a/test/verify/test_add_broadcast2.cpp b/test/verify/test_add_broadcast2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..60a91e7d08f35754834627a41c0181eed098fc0f --- /dev/null +++ b/test/verify/test_add_broadcast2.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_add_broadcast2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 3, 4}}); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, {3}}); + auto by = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", x->get_shape().lens()}}), y); + mm->add_instruction(migraphx::make_op("add"), x, by); + return p; + } +}; diff --git a/test/verify/test_add_broadcast3.cpp b/test/verify/test_add_broadcast3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7eb01e1c8fced694e7e708c09ae0b4b94200e51b --- /dev/null +++ b/test/verify/test_add_broadcast3.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_add_broadcast3 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 4, 5}}); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, {4}}); + auto by = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", x->get_shape().lens()}}), y); + mm->add_instruction(migraphx::make_op("add"), x, by); + return p; + } +}; diff --git a/test/verify/test_add_broadcast4.cpp b/test/verify/test_add_broadcast4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6ac9d807b265cf9e97bf24e38f4a95f6f07983be --- /dev/null +++ b/test/verify/test_add_broadcast4.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_add_broadcast4 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 3, 5}}); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, {3}}); + auto by = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", x->get_shape().lens()}}), y); + mm->add_instruction(migraphx::make_op("add"), x, by); + return p; + } +}; diff --git a/test/verify/test_add_broadcast5.cpp b/test/verify/test_add_broadcast5.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0ed3481f87658b6c16510db64a6684aa0b507486 --- /dev/null +++ b/test/verify/test_add_broadcast5.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_add_broadcast5 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 4, 8}}); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, {4}}); + auto by = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", x->get_shape().lens()}}), y); + mm->add_instruction(migraphx::make_op("add"), x, by); + return p; + } +}; diff --git a/test/verify/test_add_broadcast6.cpp b/test/verify/test_add_broadcast6.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6f9a8c5a965543ec2e35c402927158869905c2de --- /dev/null +++ b/test/verify/test_add_broadcast6.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_add_broadcast6 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 64, 568, 1328}}); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, {64}}); + auto by = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 64, 568, 1328}}}), y); + mm->add_instruction(migraphx::make_op("add"), x, by); + return p; + } +}; diff --git a/test/verify/test_add_gelu.cpp b/test/verify/test_add_gelu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ccd6ab3c3d18cb37a0e890bb587356fd3f9df45a --- /dev/null +++ b/test/verify/test_add_gelu.cpp @@ -0,0 +1,33 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_add_gelu : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector input_lens{1, 1, 5}; + auto x = mm->add_parameter("x", {migraphx::shape::float_type, input_lens}); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, input_lens}); + auto half = mm->add_literal(0.5f); + auto one = mm->add_literal(1.0f); + auto sqrt2 = mm->add_literal(static_cast(M_SQRT2)); + auto add = mm->add_instruction(migraphx::make_op("add"), x, y); + auto half_mbcast = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), half); + auto mul_half = mm->add_instruction(migraphx::make_op("mul"), add, half_mbcast); + auto sqrt2_mbcast = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), sqrt2); + auto div = mm->add_instruction(migraphx::make_op("div"), add, sqrt2_mbcast); + auto erf = mm->add_instruction(migraphx::make_op("erf"), div); + auto one_mbcast = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), one); + auto add_one = mm->add_instruction(migraphx::make_op("add"), erf, one_mbcast); + mm->add_instruction(migraphx::make_op("mul"), mul_half, add_one); + return p; + } +}; diff --git a/test/verify/test_add_half.cpp b/test/verify/test_add_half.cpp new file mode 100644 index 0000000000000000000000000000000000000000..324c3a41889ac152db36f60f55816836aedeaa45 --- /dev/null +++ b/test/verify/test_add_half.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_add_half : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::half_type, {3}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + mm->add_instruction(migraphx::make_op("add"), x, y); + return p; + } +}; diff --git a/test/verify/test_add_relu.cpp b/test/verify/test_add_relu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bbc97539ba12d8c42b4998463e708ee1e0cf6d75 --- /dev/null +++ b/test/verify/test_add_relu.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_add_relu : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto add = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_instruction(migraphx::make_op("relu"), add); + return p; + } +}; diff --git a/test/verify/test_add_relu_add.cpp b/test/verify/test_add_relu_add.cpp new file mode 100755 index 0000000000000000000000000000000000000000..90dd38ebd0bab32968429702eba1c8f8f853ad3b --- /dev/null +++ b/test/verify/test_add_relu_add.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_add_relu_add : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 2}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 2}}); + auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 2}}); + auto a = mm->add_instruction(migraphx::make_op("add"), x, y); + auto b = mm->add_instruction(migraphx::make_op("relu"), a); + auto c = mm->add_instruction(migraphx::make_op("add"), b, z); + mm->add_instruction(migraphx::make_op("relu"), c); + return p; + } +}; diff --git a/test/verify/test_add_sigmoid.cpp b/test/verify/test_add_sigmoid.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5e9b82ecaeb5f9c32483a764cc33cbe2f3ef0c5 --- /dev/null +++ b/test/verify/test_add_sigmoid.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_add_sigmoid : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto add = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_instruction(migraphx::make_op("sigmoid"), add); + return p; + } +}; diff --git a/test/verify/test_add_tanh.cpp b/test/verify/test_add_tanh.cpp new file mode 100644 index 0000000000000000000000000000000000000000..73df5b3ee52dc702c0dae1cf15d7a8b7da9dbd13 --- /dev/null +++ b/test/verify/test_add_tanh.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_add_tanh : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto add = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_instruction(migraphx::make_op("tanh"), add); + return p; + } +}; diff --git a/test/verify/test_and.cpp b/test/verify/test_and.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ac49df4935270e2d4b218a2d1582ead203d22faf --- /dev/null +++ b/test/verify/test_and.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_and : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::bool_type, {3}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + mm->add_instruction(migraphx::make_op("logical_and"), x, y); + return p; + } +}; diff --git a/test/verify/test_arg_ops.cpp b/test/verify/test_arg_ops.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9f312b25fc6f241da3a2c9a32d1f6698366079c3 --- /dev/null +++ b/test/verify/test_arg_ops.cpp @@ -0,0 +1,93 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include +#include + +template +struct test_arg_ops : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 1, 4, 1025}}; + auto param = mm->add_parameter("data", s); + switch(NonStdShape) + { + case 0: + param = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), param); + break; + case 1: + param = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 1025}}}), param); + break; + case 2: + param = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}), param); + break; + default: break; + } + mm->add_instruction(T{Axis}, param); + return p; + } +}; +// transpose argmax tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +// transpose argmin tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +// broadcast argmax tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +// broadcast argmin tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +// slice argmax tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +// slice argmin tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +// default case, standard shape argmax tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +// default case, standard shape argmin tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; diff --git a/test/verify/test_asin.cpp b/test/verify/test_asin.cpp new file mode 100755 index 0000000000000000000000000000000000000000..fac47e46054bb943a181340a275ec2e9a3e59ef2 --- /dev/null +++ b/test/verify/test_asin.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_asin : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {16}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("asin"), x); + return p; + } +}; diff --git a/test/verify/test_asinh.cpp b/test/verify/test_asinh.cpp new file mode 100755 index 0000000000000000000000000000000000000000..bdf2ba26ff5d89667941bb2c93bea3d23fdcbc1f --- /dev/null +++ b/test/verify/test_asinh.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_asinh : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {16}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("asinh"), x); + return p; + } +}; diff --git a/test/verify/test_atan.cpp b/test/verify/test_atan.cpp new file mode 100755 index 0000000000000000000000000000000000000000..d42428d4cd0d774c22857dac69679f1ff66d7c5c --- /dev/null +++ b/test/verify/test_atan.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_atan : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {16}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("atan"), x); + return p; + } +}; diff --git a/test/verify/test_atanh.cpp b/test/verify/test_atanh.cpp new file mode 100644 index 0000000000000000000000000000000000000000..92b4a23104d03112f57b188b3c64f13fd6453d97 --- /dev/null +++ b/test/verify/test_atanh.cpp @@ -0,0 +1,25 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_atanh : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {16}}; + auto x = mm->add_parameter("x", s); + auto min_val = mm->add_literal(-0.95f); + auto max_val = mm->add_literal(0.95f); + min_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val); + max_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), max_val); + auto cx = mm->add_instruction(migraphx::make_op("clip"), x, min_val, max_val); + mm->add_instruction(migraphx::make_op("atanh"), cx); + return p; + } +}; diff --git a/test/verify/test_avg_pooling_1d.cpp b/test/verify/test_avg_pooling_1d.cpp new file mode 100755 index 0000000000000000000000000000000000000000..56d1d5bc0ae21f3ff8dbb5eaca9eb58a3fc3e125 --- /dev/null +++ b/test/verify/test_avg_pooling_1d.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_avg_pooling_1d : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}}); + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average, {0}, {1}, {3}}; + mm->add_instruction(op, input); + return p; + } +}; diff --git a/test/verify/test_avg_pooling_3d.cpp b/test/verify/test_avg_pooling_3d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c3f0f37c03c2f3464148f156c9d656d780682ade --- /dev/null +++ b/test/verify/test_avg_pooling_3d.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_avg_pooling_3d : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}}); + auto op = migraphx::op::pooling{ + migraphx::op::pooling_mode::average, {1, 1, 1}, {3, 3, 3}, {3, 3, 3}}; + mm->add_instruction(op, input); + return p; + } +}; diff --git a/test/verify/test_avg_pooling_3d_opt.cpp b/test/verify/test_avg_pooling_3d_opt.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bf0482b1c63bb5ab47e5f10c56d311c497db23b2 --- /dev/null +++ b/test/verify/test_avg_pooling_3d_opt.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_avg_pooling_3d_opt : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 2, 3, 3, 3}}); + auto op = migraphx::op::pooling{ + migraphx::op::pooling_mode::average, {0, 0, 0}, {1, 1, 1}, {3, 3, 3}}; + mm->add_instruction(op, input); + return p; + } +}; diff --git a/test/verify/test_avg_pooling_ceil_3d.cpp b/test/verify/test_avg_pooling_ceil_3d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..46f8e453f1e87f954f67663ea1c468c23b2639fd --- /dev/null +++ b/test/verify/test_avg_pooling_ceil_3d.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_avg_pooling_ceil_3d : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}}); + auto op = migraphx::op::pooling{ + migraphx::op::pooling_mode::average, {1, 1, 1}, {3, 3, 3}, {3, 3, 3}, true}; + mm->add_instruction(op, input); + return p; + } +}; diff --git a/test/verify/test_batchnorm_1d.cpp b/test/verify/test_batchnorm_1d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ad59e0fb22ca7c237bfa7480603e9a1b9d90771d --- /dev/null +++ b/test/verify/test_batchnorm_1d.cpp @@ -0,0 +1,29 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_batchnorm_1d : verify_program +{ + const size_t size = 3; + const size_t channels = 3; + const size_t batches = 4; + + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::float_type, {batches, channels, size}}; + migraphx::shape vars{migraphx::shape::float_type, {channels}}; + auto x = mm->add_parameter("x", s); + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); + auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); + mm->add_instruction( + migraphx::make_op("batch_norm_inference"), x, scale, bias, mean, variance); + return p; + } +}; diff --git a/test/verify/test_batchnorm_1d_per_actv.cpp b/test/verify/test_batchnorm_1d_per_actv.cpp new file mode 100755 index 0000000000000000000000000000000000000000..8d572e2a4106230c89e53c5340ae82d0e62913a6 --- /dev/null +++ b/test/verify/test_batchnorm_1d_per_actv.cpp @@ -0,0 +1,43 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_batchnorm_1d_per_actv : verify_program +{ + const size_t d1 = 5; + const size_t channels = 2; + const size_t batches = 3; + + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1}}; + migraphx::shape vars{migraphx::shape::float_type, {channels, d1}}; + auto x = mm->add_parameter("x", s); + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); + auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); + mm->add_instruction( + migraphx::make_op( + "batch_norm_inference", + {{"epsilon", 1.0e-5}, + {"momentum", 0.96f}, + {"bn_mode", + migraphx::to_value(migraphx::op::batch_norm_inference::per_activation)}}), + x, + scale, + bias, + mean, + variance); + return p; + } +}; diff --git a/test/verify/test_batchnorm_2d_per_actv.cpp b/test/verify/test_batchnorm_2d_per_actv.cpp new file mode 100755 index 0000000000000000000000000000000000000000..4820854851571acb075c9a413b2160745bdd33f3 --- /dev/null +++ b/test/verify/test_batchnorm_2d_per_actv.cpp @@ -0,0 +1,44 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_batchnorm_2d_per_actv : verify_program +{ + const size_t d1 = 2; + const size_t d2 = 4; + const size_t channels = 2; + const size_t batches = 3; + + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1, d2}}; + migraphx::shape vars{migraphx::shape::float_type, {channels, d1, d2}}; + auto x = mm->add_parameter("x", s); + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); + auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); + mm->add_instruction( + migraphx::make_op( + "batch_norm_inference", + {{"epsilon", 1.0e-6}, + {"momentum", 0.9f}, + {"bn_mode", + migraphx::to_value(migraphx::op::batch_norm_inference::per_activation)}}), + x, + scale, + bias, + mean, + variance); + return p; + } +}; diff --git a/test/verify/test_batchnorm_3d.cpp b/test/verify/test_batchnorm_3d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6677c4959c86cabd3e67d3b58b1b96347b195f92 --- /dev/null +++ b/test/verify/test_batchnorm_3d.cpp @@ -0,0 +1,31 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_batchnorm_3d : verify_program +{ + const size_t d1 = 2; + const size_t d2 = 2; + const size_t d3 = 2; + const size_t channels = 2; + const size_t batches = 2; + + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1, d2, d3}}; + migraphx::shape vars{migraphx::shape::float_type, {channels}}; + auto x = mm->add_parameter("x", s); + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); + auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); + mm->add_instruction( + migraphx::make_op("batch_norm_inference"), x, scale, bias, mean, variance); + return p; + } +}; diff --git a/test/verify/test_batchnorm_3d_per_actv.cpp b/test/verify/test_batchnorm_3d_per_actv.cpp new file mode 100755 index 0000000000000000000000000000000000000000..e675feff0b87ff70ca1c7926b595acb709cf3ad0 --- /dev/null +++ b/test/verify/test_batchnorm_3d_per_actv.cpp @@ -0,0 +1,45 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_batchnorm_3d_per_actv : verify_program +{ + const size_t d1 = 2; + const size_t d2 = 4; + const size_t d3 = 5; + const size_t channels = 2; + const size_t batches = 3; + + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1, d2, d3}}; + migraphx::shape vars{migraphx::shape::float_type, {channels, d1, d2, d3}}; + auto x = mm->add_parameter("x", s); + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); + auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); + mm->add_instruction( + migraphx::make_op( + "batch_norm_inference", + {{"epsilon", 1.0e-6}, + {"momentum", 0.8f}, + {"bn_mode", + migraphx::to_value(migraphx::op::batch_norm_inference::per_activation)}}), + x, + scale, + bias, + mean, + variance); + return p; + } +}; diff --git a/test/verify/test_batchnorm_inference.cpp b/test/verify/test_batchnorm_inference.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cce9d9679851676ee5db5350b2c106e21e9f494a --- /dev/null +++ b/test/verify/test_batchnorm_inference.cpp @@ -0,0 +1,30 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_batchnorm_inference : verify_program +{ + const size_t width = 3; + const size_t height = 3; + const size_t channels = 3; + const size_t batches = 4; + + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::float_type, {batches, channels, height, width}}; + migraphx::shape vars{migraphx::shape::float_type, {channels}}; + auto x = mm->add_parameter("x", s); + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); + auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); + mm->add_instruction( + migraphx::make_op("batch_norm_inference"), x, scale, bias, mean, variance); + return p; + } +}; diff --git a/test/verify/test_batchnorm_inference_2.cpp b/test/verify/test_batchnorm_inference_2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7124420da704e5c44ce14ad54b498e9780bd2365 --- /dev/null +++ b/test/verify/test_batchnorm_inference_2.cpp @@ -0,0 +1,30 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_batchnorm_inference_2 : verify_program +{ + const size_t width = 14; + const size_t height = 14; + const size_t channels = 256; + const size_t batches = 1; + + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::float_type, {batches, channels, height, width}}; + migraphx::shape vars{migraphx::shape::float_type, {channels}}; + auto x = mm->add_parameter("x", s); + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); + auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); + mm->add_instruction( + migraphx::make_op("batch_norm_inference"), x, scale, bias, mean, variance); + return p; + } +}; diff --git a/test/verify/test_ceil.cpp b/test/verify/test_ceil.cpp new file mode 100644 index 0000000000000000000000000000000000000000..892d69524bbd66292bb5e01b7893edb5c942e8be --- /dev/null +++ b/test/verify/test_ceil.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_ceil : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; + auto param = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("ceil"), param); + return p; + }; +}; diff --git a/test/verify/test_clip.cpp b/test/verify/test_clip.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8854a0a68e45517775ec1b5583cbcd9b4fed60c5 --- /dev/null +++ b/test/verify/test_clip.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_clip : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3}}); + auto min_val = mm->add_literal(0.0f); + auto max_val = mm->add_literal(6.0f); + min_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), min_val); + max_val = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), max_val); + mm->add_instruction(migraphx::make_op("clip"), x, min_val, max_val); + return p; + } +}; diff --git a/test/verify/test_concat_axis_0.cpp b/test/verify/test_concat_axis_0.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6f70e6f38513d4dfdb3144362865a85bc9c02c1e --- /dev/null +++ b/test/verify/test_concat_axis_0.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_concat_axis_0 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + int axis = 0; + migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; + migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; + migraphx::shape s2{migraphx::shape::int32_type, {1, 2}}; + auto l0 = mm->add_parameter("x", s0); + auto l1 = mm->add_parameter("y", s1); + auto l2 = mm->add_parameter("z", s2); + mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2); + return p; + } +}; diff --git a/test/verify/test_concat_axis_1.cpp b/test/verify/test_concat_axis_1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c66220920172cd7f2ca3318792ef478dee83206f --- /dev/null +++ b/test/verify/test_concat_axis_1.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_concat_axis_1 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + int axis = 1; + migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; + migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::int32_type, {2, 1}}; + auto l0 = mm->add_parameter("x", s0); + auto l1 = mm->add_parameter("y", s1); + auto l2 = mm->add_parameter("z", s2); + mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2); + return p; + } +}; diff --git a/test/verify/test_concat_axis_neg_1.cpp b/test/verify/test_concat_axis_neg_1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1f515687164d9caa2a78444dde482bf74de11c16 --- /dev/null +++ b/test/verify/test_concat_axis_neg_1.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_concat_axis_neg_1 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + int axis = -1; + migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; + migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::int32_type, {2, 1}}; + auto l0 = mm->add_parameter("x", s0); + auto l1 = mm->add_parameter("y", s1); + auto l2 = mm->add_parameter("z", s2); + mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2); + return p; + } +}; diff --git a/test/verify/test_concat_pooling.cpp b/test/verify/test_concat_pooling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f455836d718993165c3f4c5aefadebecaacfabcf --- /dev/null +++ b/test/verify/test_concat_pooling.cpp @@ -0,0 +1,32 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct test_concat_pooling : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 256, 8, 8}}); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), input); + auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 3}}), transpose); + auto concat_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), concat); + + auto pooling = + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {0, 0}}, + {"stride", {1, 1}}, + {"lengths", {8, 8}}}), + concat_t); + mm->add_instruction(migraphx::make_op("relu"), pooling); + return p; + } +}; diff --git a/test/verify/test_concat_relu.cpp b/test/verify/test_concat_relu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..797ed55f179ca7b9062523cc90bcb5b27947b4fa --- /dev/null +++ b/test/verify/test_concat_relu.cpp @@ -0,0 +1,27 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_concat_relu : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + int axis = 0; + migraphx::shape s0{migraphx::shape::float_type, {2, 2}}; + migraphx::shape s1{migraphx::shape::float_type, {3, 2}}; + migraphx::shape s2{migraphx::shape::float_type, {1, 2}}; + auto l0 = mm->add_parameter("x", s0); + auto l1 = mm->add_parameter("y", s1); + auto l2 = mm->add_parameter("z", s2); + auto r0 = mm->add_instruction(migraphx::make_op("relu"), l0); + auto r1 = mm->add_instruction(migraphx::make_op("relu"), l1); + auto r2 = mm->add_instruction(migraphx::make_op("relu"), l2); + auto c0 = mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), r0, r1, r2); + mm->add_instruction(migraphx::make_op("relu"), c0); + return p; + } +}; diff --git a/test/verify/test_concat_transpose.cpp b/test/verify/test_concat_transpose.cpp new file mode 100644 index 0000000000000000000000000000000000000000..06bec2d1f5cd6852d0dbc9d105afbe0d0d07adc7 --- /dev/null +++ b/test/verify/test_concat_transpose.cpp @@ -0,0 +1,25 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_concat_transpose : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + int axis = 1; + migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; + migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; + migraphx::shape s2{migraphx::shape::int32_type, {2, 4}}; + auto l0 = mm->add_parameter("x", s0); + auto lp1 = mm->add_parameter("y", s1); + auto l1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), lp1); + auto l2 = mm->add_parameter("z", s2); + mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2); + return p; + } +}; diff --git a/test/verify/test_concat_transpose2.cpp b/test/verify/test_concat_transpose2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fd590271978010df487ce4f1e88c682b731a5703 --- /dev/null +++ b/test/verify/test_concat_transpose2.cpp @@ -0,0 +1,25 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_concat_transpose2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + int axis = 1; + migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; + migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::int32_type, {5, 2}}; + auto l0 = mm->add_parameter("x", s0); + auto l1 = mm->add_parameter("y", s1); + auto lp2 = mm->add_parameter("z", s2); + auto l2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), lp2); + mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2); + return p; + } +}; diff --git a/test/verify/test_concat_transpose3.cpp b/test/verify/test_concat_transpose3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fa0523f6ab2e4c0f08f7979d579f1ce2bc3c372d --- /dev/null +++ b/test/verify/test_concat_transpose3.cpp @@ -0,0 +1,27 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_concat_transpose3 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + int axis = 1; + migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; + migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; + migraphx::shape s2{migraphx::shape::int32_type, {5, 2}}; + auto l0 = mm->add_parameter("x", s0); + auto lp1 = mm->add_parameter("y", s1); + auto l1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), lp1); + auto lp2 = mm->add_parameter("z", s2); + auto l2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), lp2); + mm->add_instruction(migraphx::make_op("concat", {{"axis", axis}}), l0, l1, l2); + return p; + } +}; diff --git a/test/verify/test_contiguous.cpp b/test/verify/test_contiguous.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6b1dc740c59c04c15066a5ff72a8bcadb1acccde --- /dev/null +++ b/test/verify/test_contiguous.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_contiguous : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("contiguous"), x); + assert(p.get_output_shapes().back().standard()); + return p; + } +}; diff --git a/test/verify/test_contiguous_broadcast.cpp b/test/verify/test_contiguous_broadcast.cpp new file mode 100644 index 0000000000000000000000000000000000000000..da4bddcaff9f45ec4ac89a84069d6c657c7d68e8 --- /dev/null +++ b/test/verify/test_contiguous_broadcast.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_contiguous_broadcast : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {1, 2}, {0, 1}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("contiguous"), x); + assert(p.get_output_shapes().back().standard()); + return p; + } +}; diff --git a/test/verify/test_contiguous_broadcast_transpose.cpp b/test/verify/test_contiguous_broadcast_transpose.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f8a076b441aabecddf2e54c31b6d348d76ffdde3 --- /dev/null +++ b/test/verify/test_contiguous_broadcast_transpose.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_contiguous_broadcast_transpose : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {1, 3072, 768}, {0, 1, 3072}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("contiguous"), x); + assert(p.get_output_shapes().back().standard()); + return p; + } +}; diff --git a/test/verify/test_conv.cpp b/test/verify/test_conv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a6e597b50c7210f00281d87d6b7a8675d97cd7ef --- /dev/null +++ b/test/verify/test_conv.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_conv : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto weights = + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + mm->add_instruction(migraphx::make_op("convolution"), input, weights); + return p; + } +}; diff --git a/test/verify/test_conv2.cpp b/test/verify/test_conv2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..41ca9da24f28052d59c261d7b9ff817cd3d3b0cf --- /dev/null +++ b/test/verify/test_conv2.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_conv2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 512, 28, 28}}); + auto weights = + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {256, 512, 1, 1}}); + mm->add_instruction( + migraphx::make_op("convolution", + {{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}), + input, + weights); + return p; + } +}; diff --git a/test/verify/test_conv3d.cpp b/test/verify/test_conv3d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..768e42ec4be717720e1e68bc9429b47c301414e7 --- /dev/null +++ b/test/verify/test_conv3d.cpp @@ -0,0 +1,25 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_conv3d : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3, 3}}); + auto weights = + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3, 3}}); + mm->add_instruction( + migraphx::make_op( + "convolution", + {{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}), + input, + weights); + return p; + } +}; diff --git a/test/verify/test_conv_add.cpp b/test/verify/test_conv_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..755e840eda5778e9c0abd8b9f190265786c0b51e --- /dev/null +++ b/test/verify/test_conv_add.cpp @@ -0,0 +1,25 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_conv_add : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}}); + auto w = mm->add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 1)); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); + auto v = mm->add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 2)); + auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w); + auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v); + auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2); + mm->add_instruction(migraphx::make_op("exp"), sum); + return p; + } +}; diff --git a/test/verify/test_conv_add_1x1_diff_strides.cpp b/test/verify/test_conv_add_1x1_diff_strides.cpp new file mode 100644 index 0000000000000000000000000000000000000000..75f34ffbc3453406d43833a91812e08096864643 --- /dev/null +++ b/test/verify/test_conv_add_1x1_diff_strides.cpp @@ -0,0 +1,26 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_conv_add_1x1_diff_strides : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 2, 2}}); + auto w = mm->add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 1)); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); + auto v = mm->add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 2)); + auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w); + auto conv2 = mm->add_instruction( + migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v); + auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2); + mm->add_instruction(migraphx::make_op("exp"), sum); + return p; + } +}; diff --git a/test/verify/test_conv_bias_clipped_relu.cpp b/test/verify/test_conv_bias_clipped_relu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..56f6d3e491dff76c8ed9fe6cb061db9e43f2fe57 --- /dev/null +++ b/test/verify/test_conv_bias_clipped_relu.cpp @@ -0,0 +1,36 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_conv_bias_clipped_relu : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto weights = + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto l0 = migraphx::literal{migraphx::shape{migraphx::shape::float_type, {4}}, + {2.0f, 2.0f, 2.0f, 2.0f}}; + auto bias = mm->add_literal(l0); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); + auto bcast_add = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), + bias); + auto bias_add = mm->add_instruction(migraphx::make_op("add"), conv, bcast_add); + auto min_val = mm->add_literal(0.0f); + auto max_val = mm->add_literal(6.0f); + min_val = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", conv->get_shape().lens()}}), min_val); + max_val = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", conv->get_shape().lens()}}), max_val); + mm->add_instruction(migraphx::make_op("clip"), bias_add, min_val, max_val); + return p; + } +}; diff --git a/test/verify/test_conv_bn.cpp b/test/verify/test_conv_bn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c702269dd032f91a7a21952d58ed3b95ce13a27 --- /dev/null +++ b/test/verify/test_conv_bn.cpp @@ -0,0 +1,32 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_conv_bn : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}}; + migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}}; + migraphx::shape vars{migraphx::shape::float_type, {64}}; + auto x = mm->add_parameter("x", xs); + auto w = mm->add_parameter("w", ws); + auto conv = mm->add_instruction( + migraphx::make_op("convolution", + {{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + w); + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); + auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); + mm->add_instruction( + migraphx::make_op("batch_norm_inference"), conv, scale, bias, mean, variance); + return p; + } +}; diff --git a/test/verify/test_conv_bn_add.cpp b/test/verify/test_conv_bn_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fd9239b8da0bb6ac0fdf79a86a7e99407b74fc31 --- /dev/null +++ b/test/verify/test_conv_bn_add.cpp @@ -0,0 +1,45 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_conv_bn_add : verify_program +{ + static migraphx::instruction_ref add_bn(migraphx::module& m, + migraphx::instruction_ref x, + std::size_t channels, + std::size_t seed = 1) + { + migraphx::shape vars{migraphx::shape::float_type, {channels}}; + auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + seed))); + auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + seed))); + auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + seed))); + auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + seed))); + return m.add_instruction( + migraphx::make_op("batch_norm_inference"), x, scale, bias, mean, variance); + } + + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::size_t ichannels = 64; + std::size_t ochannels = 256; + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, 56}}); + auto w = mm->add_literal(migraphx::generate_literal( + {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 1)); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, ichannels, 56, 56}}); + auto v = mm->add_literal(migraphx::generate_literal( + {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 2)); + auto relu1 = mm->add_instruction(migraphx::make_op("relu"), x); + auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), relu1, w); + auto bn1 = add_bn(*mm, conv1, ochannels, 1); + auto relu2 = mm->add_instruction(migraphx::make_op("relu"), y); + auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), relu2, v); + auto bn2 = add_bn(*mm, conv2, ochannels, 1); + auto sum = mm->add_instruction(migraphx::make_op("add"), bn1, bn2); + mm->add_instruction(migraphx::make_op("relu"), sum); + return p; + } +}; diff --git a/test/verify/test_conv_bn_relu_pooling.cpp b/test/verify/test_conv_bn_relu_pooling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..504ef63d69a2931042a48c034a993e1f117b4054 --- /dev/null +++ b/test/verify/test_conv_bn_relu_pooling.cpp @@ -0,0 +1,40 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct test_conv_bn_relu_pooling : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}}; + migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}}; + migraphx::shape vars{migraphx::shape::float_type, {64}}; + auto x = mm->add_parameter("x", xs); + auto w = mm->add_parameter("w", ws); + auto conv = mm->add_instruction( + migraphx::make_op("convolution", + {{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + w); + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); + auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); + auto bn = mm->add_instruction( + migraphx::make_op("batch_norm_inference"), conv, scale, bias, mean, variance); + auto relu = mm->add_instruction(migraphx::make_op("relu"), bn); + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {1, 1}}, + {"stride", {2, 2}}, + {"lengths", {3, 3}}}), + relu); + return p; + } +}; diff --git a/test/verify/test_conv_bn_relu_pooling2.cpp b/test/verify/test_conv_bn_relu_pooling2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..233d740e79c7042f8c8725370770ccd1b36ca6b5 --- /dev/null +++ b/test/verify/test_conv_bn_relu_pooling2.cpp @@ -0,0 +1,58 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct test_conv_bn_relu_pooling2 : verify_program +{ + static migraphx::instruction_ref + add_bn(migraphx::program& p, migraphx::instruction_ref x, std::size_t channels) + { + auto* mm = p.get_main_module(); + migraphx::shape vars{migraphx::shape::float_type, {channels}}; + auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + channels))); + auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + channels))); + auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + channels))); + auto variance = + mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + channels))); + return mm->add_instruction( + migraphx::make_op("batch_norm_inference"), x, scale, bias, mean, variance); + } + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape xs1{migraphx::shape::float_type, {1, 512, 7, 7}}; + migraphx::shape xs2{migraphx::shape::float_type, {1, 1024, 14, 14}}; + migraphx::shape ws1{migraphx::shape::float_type, {2048, 512, 1, 1}}; + migraphx::shape ws2{migraphx::shape::float_type, {2048, 1024, 1, 1}}; + auto x1 = mm->add_parameter("x1", xs1); + auto w1 = mm->add_parameter("w1", ws1); + auto conv1 = mm->add_instruction( + migraphx::make_op("convolution", + {{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}), + x1, + w1); + auto bn1 = add_bn(p, conv1, 2048); + auto x2 = mm->add_parameter("x2", xs2); + auto w2 = mm->add_parameter("w2", ws2); + auto conv2 = mm->add_instruction( + migraphx::make_op("convolution", + {{"padding", {0, 0}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x2, + w2); + auto bn2 = add_bn(p, conv2, 2048); + auto add = mm->add_instruction(migraphx::make_op("add"), bn1, bn2); + auto relu = mm->add_instruction(migraphx::make_op("relu"), add); + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {1, 1}}, + {"stride", {2, 2}}, + {"lengths", {3, 3}}}), + relu); + return p; + } +}; diff --git a/test/verify/test_conv_pooling.cpp b/test/verify/test_conv_pooling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3a231d7f435234871d0a8b8c277a07c2f02a4779 --- /dev/null +++ b/test/verify/test_conv_pooling.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct test_conv_pooling : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 32, 32}}); + auto weights = + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); + auto pooling = mm->add_instruction( + migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv); + mm->add_instruction(migraphx::make_op("relu"), pooling); + return p; + } +}; diff --git a/test/verify/test_conv_relu.cpp b/test/verify/test_conv_relu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d66511bfac17ac5da053f488b12326fa5e7b03c6 --- /dev/null +++ b/test/verify/test_conv_relu.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_conv_relu : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto weights = + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); + mm->add_instruction(migraphx::make_op("relu"), conv); + return p; + } +}; diff --git a/test/verify/test_conv_relu_half.cpp b/test/verify/test_conv_relu_half.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f14d2faf9714dcf445a04fa4886b5c7509f5d0e9 --- /dev/null +++ b/test/verify/test_conv_relu_half.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_conv_relu_half : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); + auto weights = + mm->add_parameter("w", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); + mm->add_instruction(migraphx::make_op("relu"), conv); + return p; + } +}; diff --git a/test/verify/test_convert.cpp b/test/verify/test_convert.cpp new file mode 100755 index 0000000000000000000000000000000000000000..c211faa9b145b4078f8bab0f2f2e2d1e5388ce2b --- /dev/null +++ b/test/verify/test_convert.cpp @@ -0,0 +1,31 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_convert : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sa{migraphx::shape::int8_type, {8, 24}}; + migraphx::shape sb{migraphx::shape::int8_type, {24, 6}}; + auto pa = mm->add_parameter("a", sa); + auto pb = mm->add_parameter("b", sb); + auto ia = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), + pa); + auto ib = mm->add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), + pb); + mm->add_instruction(migraphx::make_op("dot"), ia, ib); + + return p; + }; +}; diff --git a/test/verify/test_cos.cpp b/test/verify/test_cos.cpp new file mode 100755 index 0000000000000000000000000000000000000000..01731f58f0773ef688fb30fa9c8b32095b4a0554 --- /dev/null +++ b/test/verify/test_cos.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_cos : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {8}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("cos"), x); + return p; + } +}; diff --git a/test/verify/test_cosh.cpp b/test/verify/test_cosh.cpp new file mode 100755 index 0000000000000000000000000000000000000000..dac24999439522dc86956f58f692fae80c727838 --- /dev/null +++ b/test/verify/test_cosh.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_cosh : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {16}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("cosh"), x); + return p; + } +}; diff --git a/test/verify/test_deconv.cpp b/test/verify/test_deconv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3f00ed1bebabfb9c7e6ca667e9ba937d9f6d82e2 --- /dev/null +++ b/test/verify/test_deconv.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_deconv : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}}); + auto weights = + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}}); + mm->add_instruction(migraphx::make_op("deconvolution"), input, weights); + return p; + } +}; diff --git a/test/verify/test_deconv_1d.cpp b/test/verify/test_deconv_1d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b009b551e553e6d809f2052660ec3e4348b036ad --- /dev/null +++ b/test/verify/test_deconv_1d.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_deconv_1d : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3}}); + auto weights = + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3}}); + mm->add_instruction( + migraphx::make_op("deconvolution", + {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}), + input, + weights); + return p; + } +}; diff --git a/test/verify/test_deconv_2x3.cpp b/test/verify/test_deconv_2x3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c0929ddb806248e5bdac1b7c409d21b134f3aafd --- /dev/null +++ b/test/verify/test_deconv_2x3.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_deconv_2x3 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 6, 7}}); + auto weights = + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {3, 4, 3, 3}}); + mm->add_instruction( + migraphx::make_op("deconvolution", + {{"padding", {1, 1}}, {"stride", {2, 3}}, {"dilation", {1, 1}}}), + input, + weights); + return p; + } +}; diff --git a/test/verify/test_deconv_3d.cpp b/test/verify/test_deconv_3d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8651dac93d2f8ff112f82488209c5a07685e4838 --- /dev/null +++ b/test/verify/test_deconv_3d.cpp @@ -0,0 +1,25 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_deconv_3d : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3, 3}}); + auto weights = + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3, 3}}); + mm->add_instruction( + migraphx::make_op( + "deconvolution", + {{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}), + input, + weights); + return p; + } +}; diff --git a/test/verify/test_dequantizelinear.cpp b/test/verify/test_dequantizelinear.cpp new file mode 100644 index 0000000000000000000000000000000000000000..88619d60ac1984c85a24a8a9065b654de9c02810 --- /dev/null +++ b/test/verify/test_dequantizelinear.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_dequantizelinear : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape sx{migraphx::shape::int8_type, {2, 2, 2}}; + migraphx::shape ss{migraphx::shape::float_type, {2, 2, 2}}; + migraphx::shape sz{migraphx::shape::int8_type, {2, 2, 2}}; + auto input1 = mm->add_parameter("x", sx); + auto input2 = mm->add_parameter("x_scale", ss); + auto input3 = mm->add_parameter("x_zero_point", sz); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), input1, input2, input3); + mm->add_return({r}); + return p; + }; +}; diff --git a/test/verify/test_div.cpp b/test/verify/test_div.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d8c03869801b33d5a9e7b53d5b8b907a2b679520 --- /dev/null +++ b/test/verify/test_div.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_div : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto diff = mm->add_instruction(migraphx::make_op("div"), x, y); + mm->add_instruction(migraphx::make_op("div"), diff, z); + return p; + } +}; diff --git a/test/verify/test_div2.cpp b/test/verify/test_div2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8b9d42de7a51705438ca576d89f4c08db02251fe --- /dev/null +++ b/test/verify/test_div2.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_div2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::shape b{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", b); + auto zb = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), z); + auto diff = mm->add_instruction(migraphx::make_op("div"), x, y); + mm->add_instruction(migraphx::make_op("div"), diff, zb); + return p; + } +}; diff --git a/test/verify/test_elu.cpp b/test/verify/test_elu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6a316da8e3a4d61978ed095df7be25649c1232e2 --- /dev/null +++ b/test/verify/test_elu.cpp @@ -0,0 +1,17 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_elu : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + mm->add_instruction(migraphx::make_op("leaky_relu", {{"alpha", 1.0}}), x); + return p; + } +}; diff --git a/test/verify/test_equal.cpp b/test/verify/test_equal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1a4582de9bbb8c274e62da96474c0ba4685cfd8c --- /dev/null +++ b/test/verify/test_equal.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_equal : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; + auto input1 = mm->add_parameter("x", s); + auto input2 = mm->add_parameter("y", s); + auto r = mm->add_instruction(migraphx::make_op("equal"), input1, input2); + mm->add_return({r}); + return p; + }; +}; diff --git a/test/verify/test_equal_brcst.cpp b/test/verify/test_equal_brcst.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f726973ba38d7fe14ac2a2908ccabde49c011be9 --- /dev/null +++ b/test/verify/test_equal_brcst.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_equal_brcst : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::float_type, {3, 3}}; + auto l0 = mm->add_parameter("x", s0); + migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; + auto l1 = mm->add_parameter("y", s1); + auto bl1 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1); + auto r = mm->add_instruction(migraphx::make_op("equal"), l0, bl1); + mm->add_return({r}); + + return p; + }; +}; diff --git a/test/verify/test_erf.cpp b/test/verify/test_erf.cpp new file mode 100644 index 0000000000000000000000000000000000000000..63602dadf1ab24b6b5447a8579bce632d3caab40 --- /dev/null +++ b/test/verify/test_erf.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_erf : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; + auto param = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("erf"), param); + return p; + } +}; diff --git a/test/verify/test_exp.cpp b/test/verify/test_exp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..75e875956a05d968e8e185a82e02af6595906b94 --- /dev/null +++ b/test/verify/test_exp.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_exp : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {6}}; + auto x = mm->add_instruction(migraphx::make_op("abs"), mm->add_parameter("x", s)); + mm->add_instruction(migraphx::make_op("exp"), x); + return p; + } +}; diff --git a/test/verify/test_floor.cpp b/test/verify/test_floor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d6cd9df379779f9e9d4bf84c2200c8e132805c64 --- /dev/null +++ b/test/verify/test_floor.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_floor : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; + auto param = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("floor"), param); + return p; + }; +}; diff --git a/test/verify/test_fp32_fp16_add.cpp b/test/verify/test_fp32_fp16_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..26340c04e1ec0afd19a44f721f71e67ac7783838 --- /dev/null +++ b/test/verify/test_fp32_fp16_add.cpp @@ -0,0 +1,25 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_fp32_fp16_add : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + auto p1 = mm->add_parameter("x", s); + auto p2 = mm->add_parameter("y", s); + auto sum = mm->add_instruction(migraphx::make_op("add"), p1, p2); + auto diff = mm->add_instruction(migraphx::make_op("sub"), sum, p2); + mm->add_instruction(migraphx::make_op("add"), diff, p1); + migraphx::quantize_fp16(p, {"add"}); + + return p; + }; +}; diff --git a/test/verify/test_fp32_fp16_ladd.cpp b/test/verify/test_fp32_fp16_ladd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a0ebc9583480af6887bb28f205595f4120ba1930 --- /dev/null +++ b/test/verify/test_fp32_fp16_ladd.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_fp32_fp16_ladd : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + std::vector data(2 * 3); + std::iota(data.begin(), data.end(), 1.0f); + auto l1 = mm->add_literal(migraphx::literal(s, data)); + auto l2 = mm->add_parameter("p2", s); + mm->add_instruction(migraphx::make_op("add"), l1, l2); + migraphx::quantize_fp16(p, {"add"}); + return p; + }; +}; diff --git a/test/verify/test_fp32_fp16_lall.cpp b/test/verify/test_fp32_fp16_lall.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d44aa1952ee8688064b3d2ea391ddf1f61768c80 --- /dev/null +++ b/test/verify/test_fp32_fp16_lall.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_fp32_fp16_lall : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + std::vector data(2 * 3); + std::iota(data.begin(), data.end(), 1.0f); + auto l1 = mm->add_literal(migraphx::literal(s, data)); + auto l2 = mm->add_parameter("p2", s); + mm->add_instruction(migraphx::make_op("add"), l1, l2); + migraphx::quantize_fp16(p, {"all"}); + return p; + }; +}; diff --git a/test/verify/test_fp32_fp16_sub.cpp b/test/verify/test_fp32_fp16_sub.cpp new file mode 100644 index 0000000000000000000000000000000000000000..35fd2b63ab85b6a345f3a4dff4b34f71be4bc135 --- /dev/null +++ b/test/verify/test_fp32_fp16_sub.cpp @@ -0,0 +1,25 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_fp32_fp16_sub : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + auto p1 = mm->add_parameter("x", s); + auto p2 = mm->add_parameter("y", s); + auto sum = mm->add_instruction(migraphx::make_op("add"), p1, p2); + auto diff = mm->add_instruction(migraphx::make_op("sub"), sum, p2); + mm->add_instruction(migraphx::make_op("add"), diff, p1); + migraphx::quantize_fp16(p, {"sub"}); + + return p; + }; +}; diff --git a/test/verify/test_gather.cpp b/test/verify/test_gather.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6b53194d1a3e81c53aeadfc1c394d084277cf1a9 --- /dev/null +++ b/test/verify/test_gather.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}}; + std::vector indices{1, 2, 2, 1}; + auto a0 = mm->add_parameter("data", s); + auto a1 = mm->add_literal(migraphx::literal{s_indices, indices}); + int axis = 0; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1); + return p; + } +}; diff --git a/test/verify/test_gather_1d_index.cpp b/test/verify/test_gather_1d_index.cpp new file mode 100644 index 0000000000000000000000000000000000000000..73f69c9a1b6908ae856915b552480bd82d56ba59 --- /dev/null +++ b/test/verify/test_gather_1d_index.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_1d_index : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + migraphx::shape s_indices{migraphx::shape::int32_type, {1}}; + std::vector indices{1}; + auto a0 = mm->add_parameter("data", s); + auto a1 = mm->add_literal(migraphx::literal{s_indices, indices}); + int axis = -1; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1); + return p; + } +}; diff --git a/test/verify/test_gather_neg_axis.cpp b/test/verify/test_gather_neg_axis.cpp new file mode 100644 index 0000000000000000000000000000000000000000..691a60bc79caeca80cfee660b42c406360114b66 --- /dev/null +++ b/test/verify/test_gather_neg_axis.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_neg_axis : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}}; + std::vector indices{1, 2, 2, 1}; + auto a0 = mm->add_parameter("data", s); + auto a1 = mm->add_literal(migraphx::literal{s_indices, indices}); + int axis = -1; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1); + return p; + } +}; diff --git a/test/verify/test_gather_neg_indices.cpp b/test/verify/test_gather_neg_indices.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8edf381695d3b998998ce71ad521c9e877b3ebda --- /dev/null +++ b/test/verify/test_gather_neg_indices.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_neg_indices : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}}; + std::vector indices{-2, -1, -1, -2}; + auto a0 = mm->add_parameter("data", s); + auto a1 = mm->add_literal(migraphx::literal{s_indices, indices}); + int axis = -1; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1); + return p; + } +}; diff --git a/test/verify/test_gather_scalar_index.cpp b/test/verify/test_gather_scalar_index.cpp new file mode 100644 index 0000000000000000000000000000000000000000..61aa03862649ff5c5192d3c7e15229cd7c856680 --- /dev/null +++ b/test/verify/test_gather_scalar_index.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_scalar_index : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + migraphx::shape s_indices{migraphx::shape::int32_type}; + std::vector indices{1}; + auto a0 = mm->add_parameter("data", s); + auto a1 = mm->add_literal(migraphx::literal{s_indices, indices}); + int axis = -1; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1); + return p; + } +}; diff --git a/test/verify/test_gather_scalar_output.cpp b/test/verify/test_gather_scalar_output.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c88794ea3c7363656b6ba1fc091b89bbe11ed2d --- /dev/null +++ b/test/verify/test_gather_scalar_output.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_scalar_output : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + migraphx::shape s_indices{migraphx::shape::int32_type}; + std::vector indices{1}; + auto a0 = mm->add_parameter("data", s); + auto a1 = mm->add_literal(migraphx::literal{s_indices, indices}); + int axis = 0; + mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1); + return p; + } +}; diff --git a/test/verify/test_gathernd_batch_dims_1.cpp b/test/verify/test_gathernd_batch_dims_1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c902f80c82fb89fcfff44e5b7052744ff08e5c56 --- /dev/null +++ b/test/verify/test_gathernd_batch_dims_1.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gathernd_batch_dims_1 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape ds{migraphx::shape::float_type, {2, 3, 2, 3}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 3, 2}}; + std::vector indices{1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0}; + auto a0 = mm->add_parameter("data", ds); + auto a1 = mm->add_literal(migraphx::literal{is, indices}); + int batch_dims = 1; + mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), a0, a1); + return p; + } +}; diff --git a/test/verify/test_gathernd_batch_dims_2.cpp b/test/verify/test_gathernd_batch_dims_2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..94b914293d79d916b5d32d999fe34724567fd491 --- /dev/null +++ b/test/verify/test_gathernd_batch_dims_2.cpp @@ -0,0 +1,21 @@ +#include "verify_program.hpp" +#include +#include +#include + +struct test_gathernd_batch_dims_2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1, 3}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 3, 2}}; + std::vector indices{0, 0, 0, 1, 0, 2, 0, 2, 0, 1, 0, 0}; + auto a0 = mm->add_parameter("data", ds); + auto a1 = mm->add_literal(migraphx::literal{is, indices}); + int batch_dims = 2; + mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), a0, a1); + return p; + } +}; diff --git a/test/verify/test_gathernd_default.cpp b/test/verify/test_gathernd_default.cpp new file mode 100644 index 0000000000000000000000000000000000000000..020210e7c869151b0870a171d2d64a2ff08a377d --- /dev/null +++ b/test/verify/test_gathernd_default.cpp @@ -0,0 +1,20 @@ +#include "verify_program.hpp" +#include +#include +#include + +struct test_gathernd_default : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape ds{migraphx::shape::float_type, {2, 2}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 2}}; + std::vector indices{0, 0, 1, 1}; + auto a0 = mm->add_parameter("data", ds); + auto a1 = mm->add_literal(migraphx::literal{is, indices}); + mm->add_instruction(migraphx::make_op("gathernd"), a0, a1); + return p; + } +}; diff --git a/test/verify/test_gathernd_negative_indices.cpp b/test/verify/test_gathernd_negative_indices.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e1e1f2b2796a69610afd116cbc88df63483cbc19 --- /dev/null +++ b/test/verify/test_gathernd_negative_indices.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gathernd_negative_indices : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape ds{migraphx::shape::float_type, {2, 2}}; + migraphx::shape is{migraphx::shape::int64_type, {2, 1, 1}}; + std::vector indices{-1, 0}; + auto a0 = mm->add_parameter("data", ds); + auto a1 = mm->add_literal(migraphx::literal{is, indices}); + int batch_dims = 1; + mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), a0, a1); + return p; + } +}; diff --git a/test/verify/test_gelu.cpp b/test/verify/test_gelu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a54ac25f5c7a36b7ad290c99b151e198e8c58a65 --- /dev/null +++ b/test/verify/test_gelu.cpp @@ -0,0 +1,31 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gelu : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector input_lens{1, 1, 5}; + auto x = mm->add_parameter("x", {migraphx::shape::float_type, input_lens}); + auto half = mm->add_literal(0.5f); + auto one = mm->add_literal(1.0f); + auto sqrt2 = mm->add_literal(static_cast(M_SQRT2)); + auto half_mbcast = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), half); + auto mul_half = mm->add_instruction(migraphx::make_op("mul"), x, half_mbcast); + auto sqrt2_mbcast = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), sqrt2); + auto div = mm->add_instruction(migraphx::make_op("div"), x, sqrt2_mbcast); + auto erf = mm->add_instruction(migraphx::make_op("erf"), div); + auto one_mbcast = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), one); + auto add_one = mm->add_instruction(migraphx::make_op("add"), erf, one_mbcast); + mm->add_instruction(migraphx::make_op("mul"), mul_half, add_one); + return p; + } +}; diff --git a/test/verify/test_gemm.cpp b/test/verify/test_gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..45b86fd0387ccda3caa583e626b839027b7111d9 --- /dev/null +++ b/test/verify/test_gemm.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gemm : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}}); + mm->add_instruction(migraphx::make_op("dot"), a, b); + return p; + } +}; diff --git a/test/verify/test_gemm_copy.cpp b/test/verify/test_gemm_copy.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed56ccae857cad58b5f927d153d049b2f9093e02 --- /dev/null +++ b/test/verify/test_gemm_copy.cpp @@ -0,0 +1,25 @@ + +#include +#include "verify_program.hpp" +#include +#include +#include + +struct test_gemm_copy : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; + migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; + migraphx::shape sc{migraphx::shape::float_type, {1, 8}}; + auto pa = mm->add_parameter("a", sa); + auto pb = mm->add_parameter("b", sb); + auto pc = mm->add_parameter("c", sc); + auto dr = + migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f); + mm->add_instruction(migraphx::make_op("add"), dr, dr); + return p; + } +}; diff --git a/test/verify/test_gemm_ex.cpp b/test/verify/test_gemm_ex.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c24099f6127c0271643844b7579a5a2a3e47fbf5 --- /dev/null +++ b/test/verify/test_gemm_ex.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gemm_ex : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 4, 5}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}}); + mm->add_instruction(migraphx::make_op("dot"), a, b); + return p; + } +}; diff --git a/test/verify/test_gemm_half.cpp b/test/verify/test_gemm_half.cpp new file mode 100644 index 0000000000000000000000000000000000000000..521c52cd075fb5e44f0524c6e4c2c3d0af0de1e8 --- /dev/null +++ b/test/verify/test_gemm_half.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gemm_half : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::half_type, {4, 5}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::half_type, {5, 3}}); + mm->add_instruction(migraphx::make_op("dot"), a, b); + return p; + } +}; diff --git a/test/verify/test_gemm_ld.cpp b/test/verify/test_gemm_ld.cpp new file mode 100644 index 0000000000000000000000000000000000000000..00f7d1d5e286deea4f5187321184ecf4a67b6f42 --- /dev/null +++ b/test/verify/test_gemm_ld.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gemm_ld //: verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = + mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}, {10, 1}}); + auto b = + mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}, {20, 1}}); + mm->add_instruction(migraphx::make_op("dot"), a, b); + return p; + } +}; diff --git a/test/verify/test_gemm_transposea.cpp b/test/verify/test_gemm_transposea.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dfe795ec7c48f5e3dadca2b3f06000d441ef5c22 --- /dev/null +++ b/test/verify/test_gemm_transposea.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gemm_transposea : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}}); + auto at = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a); + mm->add_instruction(migraphx::make_op("dot"), at, b); + return p; + } +}; diff --git a/test/verify/test_gemm_transposea_ex.cpp b/test/verify/test_gemm_transposea_ex.cpp new file mode 100644 index 0000000000000000000000000000000000000000..86c8bf68323e0770c439e667428ce57ff4a32240 --- /dev/null +++ b/test/verify/test_gemm_transposea_ex.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gemm_transposea_ex : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}}); + auto at = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), a); + mm->add_instruction(migraphx::make_op("dot"), at, b); + return p; + } +}; diff --git a/test/verify/test_gemm_transposeab.cpp b/test/verify/test_gemm_transposeab.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a3335e48d720a325484170ce81e7043b52340180 --- /dev/null +++ b/test/verify/test_gemm_transposeab.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gemm_transposeab : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}}); + auto at = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a); + auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); + mm->add_instruction(migraphx::make_op("dot"), at, bt); + return p; + } +}; diff --git a/test/verify/test_gemm_transposeb.cpp b/test/verify/test_gemm_transposeb.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6b4acd6ce1bce9f236a8266153d4b861c0f9a3f6 --- /dev/null +++ b/test/verify/test_gemm_transposeb.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gemm_transposeb : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}}); + auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); + mm->add_instruction(migraphx::make_op("dot"), a, bt); + return p; + } +}; diff --git a/test/verify/test_gemm_transposeb_ex.cpp b/test/verify/test_gemm_transposeb_ex.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d3d7bb6653af31ae40d8bb0a09cebeb1ea4d0a9c --- /dev/null +++ b/test/verify/test_gemm_transposeb_ex.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gemm_transposeb_ex : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 4, 5}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}}); + auto bt = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + mm->add_instruction(migraphx::make_op("dot"), a, bt); + return p; + } +}; diff --git a/test/verify/test_global_avg_pooling.cpp b/test/verify/test_global_avg_pooling.cpp new file mode 100755 index 0000000000000000000000000000000000000000..3addcfdff807a439b526c5df874a96b8c665a6cd --- /dev/null +++ b/test/verify/test_global_avg_pooling.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct test_global_avg_pooling : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average}; + auto lens = input->get_shape().lens(); + op.lengths = {lens[2], lens[3]}; + mm->add_instruction(op, input); + return p; + } +}; diff --git a/test/verify/test_global_max_pooling.cpp b/test/verify/test_global_max_pooling.cpp new file mode 100755 index 0000000000000000000000000000000000000000..157aca8286c2e6a53572dfd4aebde4c0edc04663 --- /dev/null +++ b/test/verify/test_global_max_pooling.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct test_global_max_pooling : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max}; + auto lens = input->get_shape().lens(); + op.lengths = {lens[2], lens[3]}; + mm->add_instruction(op, input); + return p; + } +}; diff --git a/test/verify/test_greater.cpp b/test/verify/test_greater.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02487b5f1a0af4c867fe684441355487b4a6c0c1 --- /dev/null +++ b/test/verify/test_greater.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_greater : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; + auto input1 = mm->add_parameter("x", s); + auto input2 = mm->add_parameter("y", s); + auto r = mm->add_instruction(migraphx::make_op("greater"), input1, input2); + mm->add_return({r}); + return p; + }; +}; diff --git a/test/verify/test_greater_brcst.cpp b/test/verify/test_greater_brcst.cpp new file mode 100644 index 0000000000000000000000000000000000000000..11fb9809d61e54ccf036b2e3625135a3f415fb5e --- /dev/null +++ b/test/verify/test_greater_brcst.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_greater_brcst : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::float_type, {3, 3}}; + auto l0 = mm->add_parameter("x", s0); + migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; + auto l1 = mm->add_parameter("y", s1); + auto bl1 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1); + auto r = mm->add_instruction(migraphx::make_op("greater"), l0, bl1); + mm->add_return({r}); + + return p; + }; +}; diff --git a/test/verify/test_group_conv.cpp b/test/verify/test_group_conv.cpp new file mode 100755 index 0000000000000000000000000000000000000000..58a79ae32c3d30d3f0dfb688f1eb2b861af8d72c --- /dev/null +++ b/test/verify/test_group_conv.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_group_conv : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 4, 16, 16}}); + auto weights = + mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 1, 3, 3}}); + migraphx::op::convolution op; + op.group = 4; + mm->add_instruction(op, input, weights); + return p; + } +}; diff --git a/test/verify/test_gru_bidirct.cpp b/test/verify/test_gru_bidirct.cpp new file mode 100644 index 0000000000000000000000000000000000000000..63f0460c11e5c7acb95a1c2829da9a170ff14f96 --- /dev/null +++ b/test/verify/test_gru_bidirct.cpp @@ -0,0 +1,60 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_gru_bidirct : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_return({hs, lho}); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_gru_bidirct_3args.cpp b/test/verify/test_gru_bidirct_3args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fd93459a2ffcb8df701cc64c258e680d2bdce33e --- /dev/null +++ b/test/verify/test_gru_bidirct_3args.cpp @@ -0,0 +1,48 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_gru_bidirct_3args : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_gru_bidirct_3args_und.cpp b/test/verify/test_gru_bidirct_3args_und.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d87fd8255aeac3c9b6d259666080aff2a4ab6a06 --- /dev/null +++ b/test/verify/test_gru_bidirct_3args_und.cpp @@ -0,0 +1,52 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_gru_bidirct_3args_und : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + und, + und, + und); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_gru_bidirct_default_actv.cpp b/test/verify/test_gru_bidirct_default_actv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8d8604dc96fbc3d080fe8344c7f0e6ad0abc0872 --- /dev/null +++ b/test/verify/test_gru_bidirct_default_actv.cpp @@ -0,0 +1,46 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_gru_bidirct_default_actv : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 1; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_gru_bidirct_default_actv1.cpp b/test/verify/test_gru_bidirct_default_actv1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d88280820b8390db5e1f625607346d2af871eda1 --- /dev/null +++ b/test/verify/test_gru_bidirct_default_actv1.cpp @@ -0,0 +1,58 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_gru_bidirct_default_actv1 : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_gru_bidirct_seq1.cpp b/test/verify/test_gru_bidirct_seq1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a60e8094c1d0bbfe16a14f4346fb677830c1456 --- /dev/null +++ b/test/verify/test_gru_bidirct_seq1.cpp @@ -0,0 +1,48 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_gru_bidirct_seq1 : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 1; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_gru_forward.cpp b/test/verify/test_gru_forward.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f250ef613a579acf4287ffd07f16d0eebecf42e0 --- /dev/null +++ b/test/verify/test_gru_forward.cpp @@ -0,0 +1,60 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_gru_forward : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_return({lho, hs}); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_gru_forward_3args.cpp b/test/verify/test_gru_forward_3args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed98e02beedf352c545dd5bd68741537fd691757 --- /dev/null +++ b/test/verify/test_gru_forward_3args.cpp @@ -0,0 +1,48 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_gru_forward_3args : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_gru_forward_3args_und.cpp b/test/verify/test_gru_forward_3args_und.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c4def6f3e194c0fca285ae6833c58d512034bd62 --- /dev/null +++ b/test/verify/test_gru_forward_3args_und.cpp @@ -0,0 +1,52 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_gru_forward_3args_und : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + und, + und, + und); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_gru_forward_default_actv.cpp b/test/verify/test_gru_forward_default_actv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0e92424e9000d4f467523d86fc07c3182f07d7da --- /dev/null +++ b/test/verify/test_gru_forward_default_actv.cpp @@ -0,0 +1,46 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_gru_forward_default_actv : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 1; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_gru_forward_default_actv1.cpp b/test/verify/test_gru_forward_default_actv1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..11e72e1e8f914c1a56acf544def41fce8c790226 --- /dev/null +++ b/test/verify/test_gru_forward_default_actv1.cpp @@ -0,0 +1,58 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_gru_forward_default_actv1 : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_gru_forward_seq1.cpp b/test/verify/test_gru_forward_seq1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b761dd788ab1b71396390ee7e96482ad47fac0c9 --- /dev/null +++ b/test/verify/test_gru_forward_seq1.cpp @@ -0,0 +1,48 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_gru_forward_seq1 : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 1; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_gru_reverse_3args.cpp b/test/verify/test_gru_reverse_3args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9c4dde284b33f7c7b4d9b1a958f5862bd6e6d995 --- /dev/null +++ b/test/verify/test_gru_reverse_3args.cpp @@ -0,0 +1,48 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_gru_reverse_3args : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_gru_reverse_last.cpp b/test/verify/test_gru_reverse_last.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5796e1da7e8b45ed3a2e70d83dd390a9c9981255 --- /dev/null +++ b/test/verify/test_gru_reverse_last.cpp @@ -0,0 +1,59 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_gru_reverse_last : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto output = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_gru_two_outputs.cpp b/test/verify/test_gru_two_outputs.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c5325615fb9b6871099af2721d5fd4d7bb4a3b4a --- /dev/null +++ b/test/verify/test_gru_two_outputs.cpp @@ -0,0 +1,48 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_gru_two_outputs : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 1; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r); + auto last_hs = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_return({hs, last_hs}); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_hsqrt.cpp b/test/verify/test_hsqrt.cpp new file mode 100644 index 0000000000000000000000000000000000000000..12469581469878f02476f9b6122eaa5e9731df6f --- /dev/null +++ b/test/verify/test_hsqrt.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_hsqrt : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::half_type, {2, 3, 4, 6}}; + auto param = mm->add_parameter("x", s); + auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param); + mm->add_instruction(migraphx::make_op("sqrt"), param_abs); + return p; + } +}; diff --git a/test/verify/test_if_literal.cpp b/test/verify/test_if_literal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..750b1441ddbb835b7d9d2a57b80a5a49d8fd1bb5 --- /dev/null +++ b/test/verify/test_if_literal.cpp @@ -0,0 +1,34 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_if_literal : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + auto cond = mm->add_parameter("cond", cond_s); + + migraphx::shape s{migraphx::shape::float_type, {5}}; + + auto* then_mod = p.create_module("If_0_if"); + std::vector data1 = {1, 2, 3, 4, 5}; + auto l1 = then_mod->add_literal(migraphx::literal(s, data1)); + then_mod->add_return({l1}); + + auto* else_mod = p.create_module("If_0_else"); + std::vector data2 = {5, 4, 3, 2, 1}; + auto l2 = else_mod->add_literal(migraphx::literal(s, data2)); + else_mod->add_return({l2}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + mm->add_return({r}); + + return p; + } +}; diff --git a/test/verify/test_if_lp.cpp b/test/verify/test_if_lp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a4c10960b8ac571b4fd7c8200be6d0edd9cceca --- /dev/null +++ b/test/verify/test_if_lp.cpp @@ -0,0 +1,36 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_if_lp : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + migraphx::shape s{migraphx::shape::float_type, {5}}; + auto cond = mm->add_parameter("cond", cond_s); + auto x = mm->add_parameter("x", s); + + auto* then_mod = p.create_module("If_0_if"); + std::vector data1 = {1, 2, 3, 4, 5}; + auto l1 = then_mod->add_literal(migraphx::literal(s, data1)); + then_mod->add_return({l1, x}); + + auto* else_mod = p.create_module("If_0_else"); + std::vector data2 = {5, 4, 3, 2, 1}; + auto l2 = else_mod->add_literal(migraphx::literal(s, data2)); + auto s2 = else_mod->add_instruction(migraphx::make_op("add"), x, l2); + else_mod->add_return({s2, l2}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret); + mm->add_return({r0, r1}); + + return p; + } +}; diff --git a/test/verify/test_if_param.cpp b/test/verify/test_if_param.cpp new file mode 100644 index 0000000000000000000000000000000000000000..47918b3a7db66d79796277da329f8a6ea7249375 --- /dev/null +++ b/test/verify/test_if_param.cpp @@ -0,0 +1,37 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_if_param : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape cond_s{migraphx::shape::bool_type}; + auto cond = mm->add_parameter("cond", cond_s); + migraphx::shape ds{migraphx::shape::float_type, {2, 3}}; + auto x = mm->add_parameter("x", ds); + auto y = mm->add_parameter("y", ds); + + auto* then_mod = p.create_module("If_0_if"); + std::vector data1 = {0.384804, -1.77948, -0.453775, 0.477438, -1.06333, -1.12893}; + auto l1 = then_mod->add_literal(migraphx::literal(ds, data1)); + auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x, l1); + then_mod->add_return({a1}); + + auto* else_mod = p.create_module("If_0_else"); + std::vector data2 = {-0.258047, 0.360394, 0.536804, -0.577762, 1.0217, 1.02442}; + auto l2 = else_mod->add_literal(migraphx::literal(ds, data2)); + auto a2 = else_mod->add_instruction(migraphx::make_op("mul"), y, l2); + else_mod->add_return({a2}); + + auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); + auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); + mm->add_return({r}); + + return p; + } +}; diff --git a/test/verify/test_isnan_broadcast.cpp b/test/verify/test_isnan_broadcast.cpp new file mode 100644 index 0000000000000000000000000000000000000000..059278cd688564c7144a57d5a5d2daa8875c9cb9 --- /dev/null +++ b/test/verify/test_isnan_broadcast.cpp @@ -0,0 +1,24 @@ +#include +#include "verify_program.hpp" +#include +#include +#include + +struct test_isnan_broadcast : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2}}); + auto s0 = migraphx::shape{migraphx::shape::float_type, {2, 2}}; + x = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", s0.lens()}}), x); + std::vector data0{2, std::numeric_limits::quiet_NaN()}; + migraphx::shape s1{migraphx::shape::float_type, {1, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s1, data0}); + x = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, l0); + mm->add_instruction(migraphx::make_op("isnan"), x); + return p; + } +}; diff --git a/test/verify/test_isnan_float.cpp b/test/verify/test_isnan_float.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5cc46b1a4b513235d0a6faa0b1d5b3c4bef2af7a --- /dev/null +++ b/test/verify/test_isnan_float.cpp @@ -0,0 +1,19 @@ +#include +#include "verify_program.hpp" +#include +#include +#include + +struct test_isnan_float : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2}}); + auto l0 = mm->add_literal(std::numeric_limits::quiet_NaN()); + x = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, l0); + mm->add_instruction(migraphx::make_op("isnan"), x); + return p; + } +}; diff --git a/test/verify/test_isnan_half.cpp b/test/verify/test_isnan_half.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1168d0306d84eae9e00c5d2deffbbd81bbe47e11 --- /dev/null +++ b/test/verify/test_isnan_half.cpp @@ -0,0 +1,20 @@ +#include +#include "verify_program.hpp" +#include +#include +#include +#include + +struct test_isnan_half : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2}}); + auto l0 = mm->add_literal(std::numeric_limits::quiet_NaN()); + x = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, l0); + mm->add_instruction(migraphx::make_op("isnan"), x); + return p; + } +}; diff --git a/test/verify/test_layernorm.cpp b/test/verify/test_layernorm.cpp new file mode 100755 index 0000000000000000000000000000000000000000..0780b5aa85a5cd670aaaf8c31edd3351f98cb861 --- /dev/null +++ b/test/verify/test_layernorm.cpp @@ -0,0 +1,100 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +migraphx::instruction_ref +add_layernorm(migraphx::module& m, migraphx::instruction_ref x, std::vector dims) +{ + auto scale = + m.add_parameter("scale", migraphx::shape{migraphx::shape::float_type, {dims.back()}}); + auto bias = + m.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {dims.back()}}); + auto epsilon = m.add_literal(1e-12f); + auto exponent = m.add_literal(2.0f); + + auto mean = m.add_instruction(migraphx::op::reduce_mean({2}), x); + auto mean_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean); + auto sub = m.add_instruction(migraphx::make_op("sub"), x, mean_mbcast); + auto exponent_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent); + auto pow = m.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast); + auto var = m.add_instruction(migraphx::op::reduce_mean({2}), pow); + auto epsilon_mbcast = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon); + auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast); + auto sqrt = m.add_instruction(migraphx::make_op("sqrt"), add_epsilon); + auto sqrt_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), sqrt); + auto div = m.add_instruction(migraphx::make_op("div"), sub, sqrt_mbcast); + auto scale_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), scale); + auto mul = m.add_instruction(migraphx::make_op("mul"), scale_mbcast, div); + auto bias_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), bias); + return m.add_instruction(migraphx::make_op("add"), mul, bias_mbcast); +} + +struct test_layernorm : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector dims = {1, 1, 5}; + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims}); + add_layernorm(*mm, x, dims); + return p; + } +}; + +struct test_layernorm2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector dims = {1, 4, 24}; + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims}); + add_layernorm(*mm, x, dims); + return p; + } +}; + +struct test_layernorm_triadd : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector dims = {1, 4, 24}; + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, dims}); + auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, dims}); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); + add_layernorm(*mm, add2, dims); + return p; + } +}; + +struct test_layernorm_triadd_large : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector dims = {1, 384, 1024}; + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, dims}); + auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, dims}); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z); + add_layernorm(*mm, add2, dims); + return p; + } +}; diff --git a/test/verify/test_leaky_relu.cpp b/test/verify/test_leaky_relu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6503e6fe1643d87d43167e9beb035e41bba18a08 --- /dev/null +++ b/test/verify/test_leaky_relu.cpp @@ -0,0 +1,17 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_leaky_relu : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + mm->add_instruction(migraphx::make_op("leaky_relu", {{"alpha", 0.01}}), x); + return p; + } +}; diff --git a/test/verify/test_less.cpp b/test/verify/test_less.cpp new file mode 100644 index 0000000000000000000000000000000000000000..75e420bc2d9901f2d3ae49dbfa0a1226035c6958 --- /dev/null +++ b/test/verify/test_less.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_less : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; + auto input1 = mm->add_parameter("x", s); + auto input2 = mm->add_parameter("y", s); + auto r = mm->add_instruction(migraphx::make_op("less"), input1, input2); + mm->add_return({r}); + return p; + }; +}; diff --git a/test/verify/test_less_brcst.cpp b/test/verify/test_less_brcst.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fe47d53e87feee245f43b508becc5f60660daba6 --- /dev/null +++ b/test/verify/test_less_brcst.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_less_brcst : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::float_type, {3, 3}}; + auto l0 = mm->add_parameter("x", s0); + migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; + auto l1 = mm->add_parameter("y", s1); + auto bl1 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s0.lens()}}), l1); + auto r = mm->add_instruction(migraphx::make_op("less"), l0, bl1); + mm->add_return({r}); + + return p; + }; +}; diff --git a/test/verify/test_literals.cpp b/test/verify/test_literals.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1d2f226422c9d1fac7463f6bc4299fb5862db68c --- /dev/null +++ b/test/verify/test_literals.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_literals : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_literal( + generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}})); + auto weights = mm->add_literal( + generate_literal(migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}})); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); + mm->add_instruction(migraphx::make_op("relu"), conv); + return p; + } +}; diff --git a/test/verify/test_log.cpp b/test/verify/test_log.cpp new file mode 100755 index 0000000000000000000000000000000000000000..9c910475dae19ba4e0349ab0b927e6ed2501874c --- /dev/null +++ b/test/verify/test_log.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_log : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {6}}; + auto x = mm->add_instruction(migraphx::make_op("abs"), mm->add_parameter("x", s)); + mm->add_instruction(migraphx::make_op("log"), x); + return p; + } +}; diff --git a/test/verify/test_logsoftmax.cpp b/test/verify/test_logsoftmax.cpp new file mode 100755 index 0000000000000000000000000000000000000000..a642ce183c1f13c6b516da9b076db109919f7e05 --- /dev/null +++ b/test/verify/test_logsoftmax.cpp @@ -0,0 +1,29 @@ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_logsoftmax : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{T, {10, 4, 2080, 6}}; + auto param = mm->add_parameter("0", s); + mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", Axis}}), param); + + return p; + } +}; + +template struct test_logsoftmax<0, migraphx::shape::float_type>; +template struct test_logsoftmax<1, migraphx::shape::float_type>; +template struct test_logsoftmax<2, migraphx::shape::float_type>; +template struct test_logsoftmax<3, migraphx::shape::float_type>; +template struct test_logsoftmax<1, migraphx::shape::half_type>; +template struct test_logsoftmax<0, migraphx::shape::half_type>; +template struct test_logsoftmax<2, migraphx::shape::half_type>; +template struct test_logsoftmax<3, migraphx::shape::half_type>; diff --git a/test/verify/test_logsoftmax1.cpp b/test/verify/test_logsoftmax1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..231a32761108ed9480835b973a5276b83dd8be67 --- /dev/null +++ b/test/verify/test_logsoftmax1.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_logsoftmax1 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 3, 4}}); + auto tx = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 3, 0, 1}}}), x); + auto r = mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", 0}}), tx); + mm->add_return({r}); + return p; + } +}; diff --git a/test/verify/test_loop.cpp b/test/verify/test_loop.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c52234005c6ec01576ad7aff37c9ede28c88b98a --- /dev/null +++ b/test/verify/test_loop.cpp @@ -0,0 +1,45 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct test_loop : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape si{migraphx::shape::int64_type}; + migraphx::shape s{migraphx::shape::int64_type, {1}}; + migraphx::shape sc{migraphx::shape::bool_type}; + int64_t iter_num = 10; + auto in_iter = mm->add_literal(migraphx::literal(si, {iter_num})); + auto in_cond = mm->add_parameter("ccond", sc); + int64_t value = 5; + auto in_val = mm->add_literal(migraphx::literal(s, {value})); + + auto* body = p.create_module("loop_module"); + auto iter = body->add_parameter("iter_num", si); + body->add_parameter("cond", sc); + auto in_v = body->add_parameter("input", s); + std::vector vd = {3}; + auto l = body->add_literal(migraphx::literal(si, vd)); + auto ad = body->add_instruction(migraphx::make_op("add"), iter, l); + auto val = body->add_instruction(migraphx::make_op("add"), in_v, ad); + auto eq = body->add_instruction(migraphx::make_op("equal"), iter, l); + auto beq = body->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), eq); + auto neq = body->add_instruction(migraphx::make_op("not"), beq); + body->add_return({neq, val, val}); + + auto rl = mm->add_instruction( + migraphx::make_op("loop", {{"max_iterations", 8}}), {in_iter, in_cond, in_val}, {body}); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), rl); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), rl); + mm->add_return({r0, r1}); + + return p; + } +}; diff --git a/test/verify/test_lstm_bidirct_3args.cpp b/test/verify/test_lstm_bidirct_3args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b7d833d41ed9edaa705aaa25098bf2559c8ff8e7 --- /dev/null +++ b/test/verify/test_lstm_bidirct_3args.cpp @@ -0,0 +1,48 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_bidirct_3args : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_bidirct_3args_und.cpp b/test/verify/test_lstm_bidirct_3args_und.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4798accdee0c0b4cd9d83c1e0b32ea5400386e47 --- /dev/null +++ b/test/verify/test_lstm_bidirct_3args_und.cpp @@ -0,0 +1,55 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_bidirct_3args_und : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + und, + und, + und, + und, + und); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_bidirct_default_actv.cpp b/test/verify/test_lstm_bidirct_default_actv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b4dc856048ccfa510b123bb13ddedf4bec00ffa5 --- /dev/null +++ b/test/verify/test_lstm_bidirct_default_actv.cpp @@ -0,0 +1,46 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_bidirct_default_actv : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 1; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_bidirct_default_actv1.cpp b/test/verify/test_lstm_bidirct_default_actv1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..77594619afc815122dc7e94abcaf8cfb2d3bcf6a --- /dev/null +++ b/test/verify/test_lstm_bidirct_default_actv1.cpp @@ -0,0 +1,60 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_bidirct_default_actv1 : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + std::vector sl_data(batch_size, 2); + auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data}); + + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + sql, + ih); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_bidirct_default_actv2.cpp b/test/verify/test_lstm_bidirct_default_actv2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0ce34ad4a668aeba0e673b9de55d2c611b5976fa --- /dev/null +++ b/test/verify/test_lstm_bidirct_default_actv2.cpp @@ -0,0 +1,58 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_bidirct_default_actv2 : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{ + migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_bidirct_hs.cpp b/test/verify/test_lstm_bidirct_hs.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cf4d719f004227354560cc54d724da7e2554ca84 --- /dev/null +++ b/test/verify/test_lstm_bidirct_hs.cpp @@ -0,0 +1,60 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_bidirct_hs : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + std::vector sl_data{3, 2}; + auto sql = mm->add_literal(migraphx::literal{migraphx::literal{sl_shape, sl_data}}); + + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + sql, + ih); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_bidirct_last.cpp b/test/verify/test_lstm_bidirct_last.cpp new file mode 100644 index 0000000000000000000000000000000000000000..41a2782bd5a9547e61539e9e367843465650a731 --- /dev/null +++ b/test/verify/test_lstm_bidirct_last.cpp @@ -0,0 +1,66 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_bidirct_last : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto ic = mm->add_parameter("ic", ic_shape); + auto pph = mm->add_parameter("pph", pph_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto output = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_bidirct_seq1.cpp b/test/verify/test_lstm_bidirct_seq1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1a64a3bd75519a0355476e67d4f913c1b73c34af --- /dev/null +++ b/test/verify/test_lstm_bidirct_seq1.cpp @@ -0,0 +1,48 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_bidirct_seq1 : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 1; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_forward_3args.cpp b/test/verify/test_lstm_forward_3args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..706fc2c7754ebfe77f7a7db302982975702d1fa7 --- /dev/null +++ b/test/verify/test_lstm_forward_3args.cpp @@ -0,0 +1,49 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_forward_3args : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_forward_3args_und.cpp b/test/verify/test_lstm_forward_3args_und.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f6d66f66665adb356a6e3258d734271a9010ffa3 --- /dev/null +++ b/test/verify/test_lstm_forward_3args_und.cpp @@ -0,0 +1,55 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_forward_3args_und : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + und, + und, + und, + und, + und); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_forward_default_actv.cpp b/test/verify/test_lstm_forward_default_actv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5e7cd5a2d4d4bd8ca3905967ccdb7f0241f0b4f1 --- /dev/null +++ b/test/verify/test_lstm_forward_default_actv.cpp @@ -0,0 +1,46 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_forward_default_actv : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 1; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_forward_default_actv1.cpp b/test/verify/test_lstm_forward_default_actv1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a0339e6f8205aeaebccd32a57c5a0ea3511bde72 --- /dev/null +++ b/test/verify/test_lstm_forward_default_actv1.cpp @@ -0,0 +1,58 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_forward_default_actv1 : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_forward_hs.cpp b/test/verify/test_lstm_forward_hs.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9bf0edfbb8b2b3c7919052ea1e299833f75cc071 --- /dev/null +++ b/test/verify/test_lstm_forward_hs.cpp @@ -0,0 +1,65 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_forward_hs : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto ic = mm->add_parameter("ic", ic_shape); + auto pph = mm->add_parameter("pph", pph_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_forward_last.cpp b/test/verify/test_lstm_forward_last.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3ee831c94603b543491e09d83f13ebb05486a737 --- /dev/null +++ b/test/verify/test_lstm_forward_last.cpp @@ -0,0 +1,67 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_forward_last : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape l_shape{migraphx::shape::int32_type, {batch_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto len = mm->add_literal(migraphx::literal(l_shape, {1, 2})); + auto ic = mm->add_parameter("ic", ic_shape); + auto pph = mm->add_parameter("pph", pph_shape); + + auto output = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + len, + ih, + ic, + pph); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output, len); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_forward_seq1.cpp b/test/verify/test_lstm_forward_seq1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..255f906d817b68fdec6716d5b889c53072aeb414 --- /dev/null +++ b/test/verify/test_lstm_forward_seq1.cpp @@ -0,0 +1,49 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_forward_seq1 : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 1; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_reverse_3args.cpp b/test/verify/test_lstm_reverse_3args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..95cc70e703cb29207e61eb8f0154a3a9853feb8b --- /dev/null +++ b/test/verify/test_lstm_reverse_3args.cpp @@ -0,0 +1,49 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_reverse_3args : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_reverse_3args_cell_output.cpp b/test/verify/test_lstm_reverse_3args_cell_output.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c16a39fab79eecd232eaa9101225d460682d7070 --- /dev/null +++ b/test/verify/test_lstm_reverse_3args_cell_output.cpp @@ -0,0 +1,50 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_reverse_3args_cell_output : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r); + mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_reverse_last.cpp b/test/verify/test_lstm_reverse_last.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4a1fb43e2280427f4533e29046a26b2a753994ba --- /dev/null +++ b/test/verify/test_lstm_reverse_last.cpp @@ -0,0 +1,66 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_reverse_last : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto ic = mm->add_parameter("ic", ic_shape); + auto pph = mm->add_parameter("pph", pph_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto output = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_three_outputs.cpp b/test/verify/test_lstm_three_outputs.cpp new file mode 100644 index 0000000000000000000000000000000000000000..65386e6459fcc44eda59d6b9b7dcaba95bca5bc2 --- /dev/null +++ b/test/verify/test_lstm_three_outputs.cpp @@ -0,0 +1,52 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_three_outputs : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r); + auto last_hs = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + mm->add_return({hs, last_hs, last_cell}); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_two_outputs.cpp b/test/verify/test_lstm_two_outputs.cpp new file mode 100644 index 0000000000000000000000000000000000000000..455c51bf7cfe07b03fb739725384823353382a73 --- /dev/null +++ b/test/verify/test_lstm_two_outputs.cpp @@ -0,0 +1,51 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_two_outputs : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r); + auto last_hs = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_return({hs, last_hs}); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_max_pooling_ceil_3d.cpp b/test/verify/test_max_pooling_ceil_3d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..44532907af23892e650a934b389560e72ce50801 --- /dev/null +++ b/test/verify/test_max_pooling_ceil_3d.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_max_pooling_ceil_3d : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}}); + auto op = migraphx::op::pooling{ + migraphx::op::pooling_mode::max, {1, 1, 1}, {3, 3, 3}, {3, 3, 3}, true}; + mm->add_instruction(op, input); + return p; + } +}; diff --git a/test/verify/test_mul.cpp b/test/verify/test_mul.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eafc8a7bd4933a027655c2649ebc552bfa83154d --- /dev/null +++ b/test/verify/test_mul.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_mul : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + mm->add_instruction(migraphx::make_op("mul"), x, y); + return p; + } +}; diff --git a/test/verify/test_mul_add.cpp b/test/verify/test_mul_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f0d47b7643a7e368ef24e38f8cf0ff1ec5aed1af --- /dev/null +++ b/test/verify/test_mul_add.cpp @@ -0,0 +1,26 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_mul_add : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::shape bs{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", s); + auto a = mm->add_parameter("a", bs); + auto b = mm->add_parameter("b", bs); + auto ab = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), a); + auto bb = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), b); + auto mul = mm->add_instruction(migraphx::make_op("mul"), x, ab); + mm->add_instruction(migraphx::make_op("add"), mul, bb); + return p; + } +}; diff --git a/test/verify/test_multinomial.cpp b/test/verify/test_multinomial.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c89fbe7e316e3ac164639f4de1f6bda5f6517207 --- /dev/null +++ b/test/verify/test_multinomial.cpp @@ -0,0 +1,37 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_multinomial : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + size_t sample_size = 10; + size_t batch_size = 2; + float seed = 0.0f; + std::mt19937 gen(seed); + std::uniform_real_distribution<> dis(0.0, 1.0); + std::vector rand_samples(batch_size * sample_size); + std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); + migraphx::shape rs{migraphx::shape::float_type, {batch_size, sample_size}}; + auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples}); + + migraphx::shape s{migraphx::shape::float_type, {batch_size, 5}}; + auto input = mm->add_parameter("input", s); + + auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input); + auto mb_maxes = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {batch_size, 5}}}), maxes); + auto cdf = mm->add_instruction(migraphx::make_op("sub"), input, mb_maxes); + cdf = mm->add_instruction(migraphx::make_op("exp"), cdf); + cdf = mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf); + + mm->add_instruction(migraphx::make_op("multinomial"), cdf, rs_lit); + return p; + } +}; diff --git a/test/verify/test_neg.cpp b/test/verify/test_neg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..454c83eb2876513dafe7602ed17c725d77391c01 --- /dev/null +++ b/test/verify/test_neg.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_neg : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; + auto input = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("neg"), input); + return p; + }; +}; diff --git a/test/verify/test_nms.cpp b/test/verify/test_nms.cpp new file mode 100644 index 0000000000000000000000000000000000000000..099e2e9e6eac72ca0dbc42ee1a27a4ba7048e79f --- /dev/null +++ b/test/verify/test_nms.cpp @@ -0,0 +1,36 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_nms : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape boxes_s{migraphx::shape::float_type, {1, 6, 4}}; + + migraphx::shape scores_s{migraphx::shape::float_type, {1, 1, 6}}; + std::vector scores_vec = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3}; + + auto boxes_l = mm->add_parameter("boxes", boxes_s); + auto scores_l = mm->add_literal(migraphx::literal(scores_s, scores_vec)); + auto max_out_l = mm->add_literal(int64_t{4}); + auto iou_threshold = mm->add_literal(0.5f); + auto score_threshold = mm->add_literal(0.0f); + + auto r = + mm->add_instruction(migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}), + boxes_l, + scores_l, + max_out_l, + iou_threshold, + score_threshold); + mm->add_return({r}); + + return p; + } +}; diff --git a/test/verify/test_nonstd_gather.cpp b/test/verify/test_nonstd_gather.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9c1f41f71c14c42f1b4d1b37f870e7abfd68a0ea --- /dev/null +++ b/test/verify/test_nonstd_gather.cpp @@ -0,0 +1,26 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_nonstd_gather : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}}; + std::vector indices{1, 1, 0, 2}; + auto d = mm->add_parameter("data", s); + auto td = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d); + auto ind = mm->add_literal(migraphx::literal{s_indices, indices}); + auto tind = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), ind); + auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), td, tind); + mm->add_return({r}); + + return p; + } +}; diff --git a/test/verify/test_nonzero.cpp b/test/verify/test_nonzero.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e43dc937f2dfa5875e6ba82527c1ad1b216b8a70 --- /dev/null +++ b/test/verify/test_nonzero.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_nonzero : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}}; + auto x = mm->add_parameter("data", s); + auto r = mm->add_instruction(migraphx::make_op("nonzero"), x); + mm->add_return({r}); + + return p; + } +}; diff --git a/test/verify/test_nonzero_half.cpp b/test/verify/test_nonzero_half.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15f7ea7672e4d701005b5ff63f46a014f2d6e3fd --- /dev/null +++ b/test/verify/test_nonzero_half.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_nonzero_half : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::half_type, {3, 4, 3, 5}}; + auto x = mm->add_parameter("data", s); + auto r = mm->add_instruction(migraphx::make_op("nonzero"), x); + mm->add_return({r}); + + return p; + } +}; diff --git a/test/verify/test_not.cpp b/test/verify/test_not.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9c58d4f4edf12c4dfc15700705f4925f08ea9907 --- /dev/null +++ b/test/verify/test_not.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_not : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::bool_type, {4}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("not"), x); + return p; + } +}; diff --git a/test/verify/test_or.cpp b/test/verify/test_or.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3a5257d56cefd0b1e77c6b8e30cfdd91a5c871e7 --- /dev/null +++ b/test/verify/test_or.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_or : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::bool_type, {3}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + mm->add_instruction(migraphx::make_op("logical_or"), x, y); + return p; + } +}; diff --git a/test/verify/test_pad.cpp b/test/verify/test_pad.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bd3bf9b7ba8c21a94712b82cb22769a0093c3cd7 --- /dev/null +++ b/test/verify/test_pad.cpp @@ -0,0 +1,25 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_pad : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::int32_type, {1, 96, 165, 165}}; + std::vector pads0 = {0, 0, 0, 0, 0, 0, 1, 1}; + std::vector pads1 = {0, 0, 0, 0, 1, 1, 1, 1}; + std::vector pads2 = {1, 1, 1, 1, 0, 0, 0, 0}; + std::vector pads3 = {1, 0, 1, 0, 1, 0, 2, 0}; + auto l0 = mm->add_parameter("x", s0); + mm->add_instruction(migraphx::make_op("pad", {{"pads", pads0}}), l0); + mm->add_instruction(migraphx::make_op("pad", {{"pads", pads1}}), l0); + mm->add_instruction(migraphx::make_op("pad", {{"pads", pads2}}), l0); + mm->add_instruction(migraphx::make_op("pad", {{"pads", pads3}}), l0); + return p; + } +}; diff --git a/test/verify/test_pad_highest.cpp b/test/verify/test_pad_highest.cpp new file mode 100755 index 0000000000000000000000000000000000000000..55c6b06d4baa6392bcf7afa6eb4d25ba55d8fa6e --- /dev/null +++ b/test/verify/test_pad_highest.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_pad_highest : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data0(4); + std::iota(data0.begin(), data0.end(), 0); + migraphx::shape s0{migraphx::shape::half_type, {2, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s0, data0}); + migraphx::op::pad op{}; + op.value = std::numeric_limits::max(); + op.pads = {0, 0, 1, 1}; + mm->add_instruction(op, l0); + return p; + } +}; diff --git a/test/verify/test_pad_int8.cpp b/test/verify/test_pad_int8.cpp new file mode 100755 index 0000000000000000000000000000000000000000..e5bf3f5c3f2257cfb65c7c918ebe53cba7babf6f --- /dev/null +++ b/test/verify/test_pad_int8.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_pad_int8 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data0 = {0, 1, 2, 3}; + migraphx::shape s0{migraphx::shape::float_type, {2, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s0, data0}); + migraphx::op::pad op{}; + op.value = std::numeric_limits::lowest(); + op.pads = {0, 0, 1, 1}; + mm->add_instruction(op, l0); + return p; + } +}; diff --git a/test/verify/test_pad_lowest.cpp b/test/verify/test_pad_lowest.cpp new file mode 100755 index 0000000000000000000000000000000000000000..d14d72f463edda43910f14d180ecf32191df3d98 --- /dev/null +++ b/test/verify/test_pad_lowest.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_pad_lowest : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data0(4); + std::iota(data0.begin(), data0.end(), 0); + migraphx::shape s0{migraphx::shape::half_type, {2, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s0, data0}); + migraphx::op::pad op{}; + op.value = std::numeric_limits::lowest(); + op.pads = {0, 0, 1, 1}; + mm->add_instruction(op, l0); + return p; + } +}; diff --git a/test/verify/test_pad_transposed.cpp b/test/verify/test_pad_transposed.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fc8c3affd66496f08ae8b8343127fc34aff4e35c --- /dev/null +++ b/test/verify/test_pad_transposed.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_pad_transposed : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int32_type, {1, 224, 224, 3}}; + auto x = mm->add_parameter("x", s); + auto t = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), x); + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 2, 2, 0, 0, 3, 3}}}), t); + return p; + } +}; diff --git a/test/verify/test_pooling_autopad.cpp b/test/verify/test_pooling_autopad.cpp new file mode 100755 index 0000000000000000000000000000000000000000..976d717b96d8230d8f7c47bf4c8eb56e4f9290c7 --- /dev/null +++ b/test/verify/test_pooling_autopad.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_pooling_autopad : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::float_type, {1, 3, 63, 63}}; + auto l0 = mm->add_parameter("x", s0); + migraphx::op::pooling op{migraphx::op::pooling_mode::max}; + op.lengths = {2, 2}; + op.stride = {2, 2}; + mm->add_instruction(op, l0); + return p; + } +}; diff --git a/test/verify/test_pow.cpp b/test/verify/test_pow.cpp new file mode 100644 index 0000000000000000000000000000000000000000..46cccc7df5c7caf143dce5e8aeb949f05b0e0987 --- /dev/null +++ b/test/verify/test_pow.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_pow : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {6}}; + std::vector vec_e(s.elements(), 2.0f); + auto b = mm->add_parameter("x", s); + auto e = mm->add_literal(migraphx::literal(s, vec_e)); + mm->add_instruction(migraphx::make_op("pow"), b, e); + return p; + } +}; diff --git a/test/verify/test_prefix_scan_sum_2d.cpp b/test/verify/test_prefix_scan_sum_2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..707930bd6d35c476bf0b918a91f09ec9f67ec3d0 --- /dev/null +++ b/test/verify/test_prefix_scan_sum_2d.cpp @@ -0,0 +1,34 @@ +#include "verify_program.hpp" +#include +#include +#include + +struct test_prefix_scan_sum_2d_small : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {1}}; + auto x = mm->add_parameter("x", s); + auto xb = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), x); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), xb); + return p; + } +}; + +struct test_prefix_scan_sum_2d_large : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 1000}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), x); + return p; + } +}; diff --git a/test/verify/test_prefix_scan_sum_exclusive.cpp b/test/verify/test_prefix_scan_sum_exclusive.cpp new file mode 100644 index 0000000000000000000000000000000000000000..786da3e17141facdd300af7cdad362524d77b897 --- /dev/null +++ b/test/verify/test_prefix_scan_sum_exclusive.cpp @@ -0,0 +1,20 @@ +#include "verify_program.hpp" +#include +#include +#include + +struct test_prefix_scan_sum_exclusive : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 3, 3}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", + {{"axis", 2}, {"exclusive", true}, {"reverse", false}}), + x); + return p; + } +}; diff --git a/test/verify/test_prefix_scan_sum_exclusive_reverse.cpp b/test/verify/test_prefix_scan_sum_exclusive_reverse.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f4f44fb062f6b36b975dec26fa1f593e326805fe --- /dev/null +++ b/test/verify/test_prefix_scan_sum_exclusive_reverse.cpp @@ -0,0 +1,21 @@ +#include "verify_program.hpp" +#include +#include +#include + +struct test_prefix_scan_sum_exclusive_reverse + : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 3, 3}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", + {{"axis", 0}, {"exclusive", true}, {"reverse", true}}), + x); + return p; + } +}; diff --git a/test/verify/test_prefix_scan_sum_reverse.cpp b/test/verify/test_prefix_scan_sum_reverse.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ea9c55d1475f0ca9a782359b53be5884210b0f41 --- /dev/null +++ b/test/verify/test_prefix_scan_sum_reverse.cpp @@ -0,0 +1,20 @@ +#include "verify_program.hpp" +#include +#include +#include + +struct test_prefix_scan_sum_reverse : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 3, 3}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction( + migraphx::make_op("prefix_scan_sum", + {{"axis", 1}, {"exclusive", false}, {"reverse", true}}), + x); + return p; + } +}; diff --git a/test/verify/test_prelu_brcst.cpp b/test/verify/test_prelu_brcst.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d65916b7cae31ca684aac0b226a770cba4b72b27 --- /dev/null +++ b/test/verify/test_prelu_brcst.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_prelu_brcst : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {6}}; + auto x = mm->add_parameter("x", s); + auto slp = mm->add_parameter("slp", s); + auto r = mm->add_instruction(migraphx::make_op("prelu"), x, slp); + mm->add_return({r}); + + return p; + } +}; diff --git a/test/verify/test_quantizelinear.cpp b/test/verify/test_quantizelinear.cpp new file mode 100644 index 0000000000000000000000000000000000000000..84ccced52f02a63fb83b917310dd1f4d5bd935a3 --- /dev/null +++ b/test/verify/test_quantizelinear.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_quantizelinear : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape sx{migraphx::shape::float_type, {2, 2, 2}}; + migraphx::shape ss{migraphx::shape::float_type, {2, 2, 2}}; + migraphx::shape sz{migraphx::shape::int8_type, {2, 2, 2}}; + auto input1 = mm->add_parameter("x", sx); + auto input2 = mm->add_parameter("y_scale", ss); + auto input3 = mm->add_parameter("y_zero_point", sz); + auto r = mm->add_instruction(migraphx::make_op("quantizelinear"), input1, input2, input3); + mm->add_return({r}); + return p; + }; +}; diff --git a/test/verify/test_quantizelinear_int32.cpp b/test/verify/test_quantizelinear_int32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ea3bd5fd1787d9eb8f665798d55373518b2dafe6 --- /dev/null +++ b/test/verify/test_quantizelinear_int32.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_quantizelinear_int32 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape sx{migraphx::shape::int32_type, {2, 2, 2}}; + migraphx::shape ss{migraphx::shape::float_type, {2, 2, 2}}; + migraphx::shape sz{migraphx::shape::int8_type, {2, 2, 2}}; + auto input1 = mm->add_parameter("x", sx); + auto input2 = mm->add_parameter("y_scale", ss); + auto input3 = mm->add_parameter("y_zero_point", sz); + auto r = mm->add_instruction(migraphx::make_op("quantizelinear"), input1, input2, input3); + mm->add_return({r}); + return p; + }; +}; diff --git a/test/verify/test_recip.cpp b/test/verify/test_recip.cpp new file mode 100755 index 0000000000000000000000000000000000000000..125a9972cdd5403a36a3d3d54be5fefb04724778 --- /dev/null +++ b/test/verify/test_recip.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_recip : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("recip"), x); + return p; + } +}; diff --git a/test/verify/test_reduce_op_large.cpp b/test/verify/test_reduce_op_large.cpp new file mode 100755 index 0000000000000000000000000000000000000000..6ea43c9759def5172dd1760c50e57411d7413e9d --- /dev/null +++ b/test/verify/test_reduce_op_large.cpp @@ -0,0 +1,42 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include +#include +#include +#include + +template +struct test_reduce_op_large : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{T, {3, 1026, 4, 3}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(Op{{Axis}}, x); + return p; + }; +}; + +template struct test_reduce_op_large; +template struct test_reduce_op_large; +template struct test_reduce_op_large; +template struct test_reduce_op_large; +template struct test_reduce_op_large; + +struct test_reduce_mean : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {1, 384, 1024}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(migraphx::op::reduce_mean{{1}}, x); + return p; + }; +}; diff --git a/test/verify/test_reduce_op_small.cpp b/test/verify/test_reduce_op_small.cpp new file mode 100755 index 0000000000000000000000000000000000000000..a78c829a9c4ecc86bd82d33a3c651e775827cee0 --- /dev/null +++ b/test/verify/test_reduce_op_small.cpp @@ -0,0 +1,35 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include +#include +#include +#include + +template +struct test_reduce_op_small : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{T, {3, 4, 2, 2}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(Op{{Axis}}, x); + return p; + }; +}; + +template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; + +template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; diff --git a/test/verify/test_relu_lrn.cpp b/test/verify/test_relu_lrn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4dacac1f8e3b63d81db72ca19186476e5e4d286c --- /dev/null +++ b/test/verify/test_relu_lrn.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_relu_lrn : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 2}}); + auto y = mm->add_instruction(migraphx::make_op("relu"), x); + mm->add_instruction( + migraphx::make_op("lrn", + {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), + y); + return p; + } +}; diff --git a/test/verify/test_reverse.cpp b/test/verify/test_reverse.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02b07fe2f31ae35fadccde6a445d1d0fd22634cb --- /dev/null +++ b/test/verify/test_reverse.cpp @@ -0,0 +1,18 @@ +#include "verify_program.hpp" +#include +#include +#include + +struct test_reverse : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {4, 16}}; + auto a0 = mm->add_parameter("data", s); + std::vector axis = {0}; + mm->add_instruction(migraphx::make_op("reverse", {{"axes", axis}}), a0); + return p; + } +}; diff --git a/test/verify/test_reverse_multiaxis.cpp b/test/verify/test_reverse_multiaxis.cpp new file mode 100644 index 0000000000000000000000000000000000000000..27c286cb07db8f18c1a58dd911482f88f991f2e6 --- /dev/null +++ b/test/verify/test_reverse_multiaxis.cpp @@ -0,0 +1,18 @@ +#include "verify_program.hpp" +#include +#include +#include + +struct test_reverse_multiaxis : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {4, 16}}; + auto a0 = mm->add_parameter("data", s); + std::vector axes = {0, 1}; + mm->add_instruction(migraphx::make_op("reverse", {{"axes", axes}}), a0); + return p; + } +}; diff --git a/test/verify/test_reverse_negaxis.cpp b/test/verify/test_reverse_negaxis.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a13653c6062ed90150e0f033ab8d9ffa72342940 --- /dev/null +++ b/test/verify/test_reverse_negaxis.cpp @@ -0,0 +1,18 @@ +#include "verify_program.hpp" +#include +#include +#include + +struct test_reverse_negaxis : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {4, 16}}; + auto a0 = mm->add_parameter("data", s); + std::vector axis = {-1}; + mm->add_instruction(migraphx::make_op("reverse", {{"axes", axis}}), a0); + return p; + } +}; diff --git a/test/verify/test_rnn_3args.cpp b/test/verify/test_rnn_3args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..12713281254c423200743d1e89d7c2789950147e --- /dev/null +++ b/test/verify/test_rnn_3args.cpp @@ -0,0 +1,48 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_rnn_3args : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 1; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + + mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_rnn_4args.cpp b/test/verify/test_rnn_4args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..10eb419891e49f5fdd7d027d51e584a1191583d6 --- /dev/null +++ b/test/verify/test_rnn_4args.cpp @@ -0,0 +1,51 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_rnn_4args : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 5; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + + mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r, + bias); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_rnn_5args.cpp b/test/verify/test_rnn_5args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9b7f041f33befa6859720c396e9591b325c461ca --- /dev/null +++ b/test/verify/test_rnn_5args.cpp @@ -0,0 +1,54 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_rnn_5args : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 10; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto output = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_rnn_bi_3args.cpp b/test/verify/test_rnn_bi_3args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..293988a3c41619c271e4d2e73cf1fce1ee50708c --- /dev/null +++ b/test/verify/test_rnn_bi_3args.cpp @@ -0,0 +1,50 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_rnn_bi_3args : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 10; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto output = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_rnn_bidirectional.cpp b/test/verify/test_rnn_bidirectional.cpp new file mode 100644 index 0000000000000000000000000000000000000000..857c06bb7c9e8d61ce49865bc39c27b173296667 --- /dev/null +++ b/test/verify/test_rnn_bidirectional.cpp @@ -0,0 +1,57 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_rnn_bidirectional : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 1; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto output = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_rnn_bidirectional10.cpp b/test/verify/test_rnn_bidirectional10.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b2bd6e8f3ae72da912875159d27c4444fafdc04 --- /dev/null +++ b/test/verify/test_rnn_bidirectional10.cpp @@ -0,0 +1,56 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_rnn_bidirectional10 : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 10; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto output = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_rnn_forward.cpp b/test/verify/test_rnn_forward.cpp new file mode 100644 index 0000000000000000000000000000000000000000..56aa6b41ef7d346bf4782f637d2eb8a6f03eb56d --- /dev/null +++ b/test/verify/test_rnn_forward.cpp @@ -0,0 +1,58 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_rnn_forward : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 1; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_return({hs, lho}); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_rnn_forward10.cpp b/test/verify/test_rnn_forward10.cpp new file mode 100644 index 0000000000000000000000000000000000000000..57edaa1365703026c5c355d72647f2e1aa3d8f0f --- /dev/null +++ b/test/verify/test_rnn_forward10.cpp @@ -0,0 +1,58 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_rnn_forward10 : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 10; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + auto hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_return({hs, lho}); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_rnn_reverse.cpp b/test/verify/test_rnn_reverse.cpp new file mode 100644 index 0000000000000000000000000000000000000000..98744d5257c1d093626e07ccd08c7ecf7dc59ec2 --- /dev/null +++ b/test/verify/test_rnn_reverse.cpp @@ -0,0 +1,56 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_rnn_reverse : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 1; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_rnn_reverse2.cpp b/test/verify/test_rnn_reverse2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9cadcec5f1312d7b5b008e5421abe0f4323507fc --- /dev/null +++ b/test/verify/test_rnn_reverse2.cpp @@ -0,0 +1,56 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_rnn_reverse2 : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 2; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_rnn_sql_1.cpp b/test/verify/test_rnn_sql_1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..57bc4735e6606f3dfb3c9c97e89e0907b4c1ea43 --- /dev/null +++ b/test/verify/test_rnn_sql_1.cpp @@ -0,0 +1,60 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_rnn_sql_1 : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 10; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + std::vector sl_data{5, 7}; + auto sql = mm->add_literal(migraphx::literal{s_shape, sl_data}); + auto ih = mm->add_parameter("ih", ih_shape); + + auto hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + sql, + ih); + auto last_hs = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_return({hs, last_hs}); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_rnn_sql_2.cpp b/test/verify/test_rnn_sql_2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1a91a1dfcf66c201289112de3ea6c55cd10903df --- /dev/null +++ b/test/verify/test_rnn_sql_2.cpp @@ -0,0 +1,65 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_rnn_sql_2 : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 10; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq_orig = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + migraphx::shape pad_s{migraphx::shape::float_type, {2, batch_size, input_size}}; + std::vector pad_data(pad_s.elements(), 0.0f); + auto seq_pad = mm->add_literal(migraphx::literal{pad_s, pad_data}); + auto seq = + mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_pad); + std::vector sl_data(batch_size, static_cast(seq_len)); + auto sql = mm->add_literal(migraphx::literal{s_shape, sl_data}); + auto ih = mm->add_parameter("ih", ih_shape); + + auto hs = mm->add_instruction( + migraphx::make_op( + "rnn", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + sql, + ih); + auto last_hs = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_return({hs, last_hs}); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_roialign.cpp b/test/verify/test_roialign.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d3a47a597f711385ef9c945ab06a520c585b252b --- /dev/null +++ b/test/verify/test_roialign.cpp @@ -0,0 +1,35 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_roialign : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape x_s{migraphx::shape::float_type, {5, 4, 10, 10}}; + + migraphx::shape roi_s{migraphx::shape::float_type, {5, 4}}; + + migraphx::shape ind_s{migraphx::shape::int64_type, {5}}; + std::vector ind_vec = {0, 2, 3, 4, 1}; + + auto x = mm->add_parameter("x", x_s); + auto roi = mm->add_parameter("roi", roi_s); + auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec)); + auto r = mm->add_instruction(migraphx::make_op("roialign", + {{"spatial_scale", 1.0}, + {"output_height", 5}, + {"output_width", 5}, + {"sampling_ratio", 2}}), + x, + roi, + ind); + mm->add_return({r}); + + return p; + } +}; diff --git a/test/verify/test_roialign_nondefault.cpp b/test/verify/test_roialign_nondefault.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9283ea34b7e27c45996ea64571ad22f1db3200fe --- /dev/null +++ b/test/verify/test_roialign_nondefault.cpp @@ -0,0 +1,39 @@ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct test_roialign_nondefault : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape x_s{migraphx::shape::float_type, {5, 4, 10, 10}}; + + migraphx::shape roi_s{migraphx::shape::float_type, {5, 4}}; + + migraphx::shape ind_s{migraphx::shape::int64_type, {5}}; + std::vector ind_vec = {0, 2, 3, 4, 1}; + + auto x = mm->add_parameter("x", x_s); + auto roi = mm->add_parameter("roi", roi_s); + auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec)); + auto r = mm->add_instruction( + migraphx::make_op("roialign", + {{"coordinate_transformation_mode", "output_half_pixel"}, + {"mode", migraphx::op::pooling_mode::max}, + {"spatial_scale", 1.0}, + {"output_height", 5}, + {"output_width", 5}, + {"sampling_ratio", 2}}), + x, + roi, + ind); + mm->add_return({r}); + + return p; + } +}; diff --git a/test/verify/test_roialign_nonstandard.cpp b/test/verify/test_roialign_nonstandard.cpp new file mode 100644 index 0000000000000000000000000000000000000000..45191487928cb4694fb01823879a74acb33330eb --- /dev/null +++ b/test/verify/test_roialign_nonstandard.cpp @@ -0,0 +1,36 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_roialign_nonstandard : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x_s = migraphx::shape::from_permutation( + migraphx::shape::float_type, {5, 4, 10, 10}, {0, 2, 3, 1}); + + migraphx::shape roi_s{migraphx::shape::float_type, {5, 4}}; + + migraphx::shape ind_s{migraphx::shape::int64_type, {5}}; + std::vector ind_vec = {0, 2, 3, 4, 1}; + + auto x = mm->add_parameter("x", x_s); + auto roi = mm->add_parameter("roi", roi_s); + auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec)); + auto r = mm->add_instruction(migraphx::make_op("roialign", + {{"spatial_scale", 1.0}, + {"output_height", 5}, + {"output_width", 5}, + {"sampling_ratio", 2}}), + x, + roi, + ind); + mm->add_return({r}); + + return p; + } +}; diff --git a/test/verify/test_round.cpp b/test/verify/test_round.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1cdaffe2670e0bbedb96bcab53a5d232a92cadc8 --- /dev/null +++ b/test/verify/test_round.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_round : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; + auto param = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("round"), param); + return p; + }; +}; diff --git a/test/verify/test_rsqrt.cpp b/test/verify/test_rsqrt.cpp new file mode 100644 index 0000000000000000000000000000000000000000..36dae6865b8591a016ce6533a35a8ad974cc0561 --- /dev/null +++ b/test/verify/test_rsqrt.cpp @@ -0,0 +1,26 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_rsqrt : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector input_lens{1, 3, 16, 16}; + migraphx::shape s{migraphx::shape::float_type, input_lens}; + auto x = mm->add_parameter("x", s); + auto min_val = mm->add_literal(1.0f); + auto max_val = mm->add_literal(std::numeric_limits::max()); + min_val = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val); + max_val = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val); + auto l0 = mm->add_instruction(migraphx::make_op("clip"), x, min_val, max_val); + mm->add_instruction(migraphx::make_op("rsqrt"), l0); + return p; + }; +}; diff --git a/test/verify/test_scale.cpp b/test/verify/test_scale.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e01a0661c271f45788304690e9ae524814ac382d --- /dev/null +++ b/test/verify/test_scale.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_scale : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", migraphx::shape::float_type); + auto scale = + mm->add_instruction(migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), y); + mm->add_instruction(migraphx::make_op("mul"), x, scale); + return p; + } +}; diff --git a/test/verify/test_scatter0.cpp b/test/verify/test_scatter0.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b5f888a80ebc6a2c3be8191b96d3b1219cffde81 --- /dev/null +++ b/test/verify/test_scatter0.cpp @@ -0,0 +1,26 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_scatter0 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape sd{migraphx::shape::float_type, {3, 3}}; + migraphx::shape si{migraphx::shape::int32_type, {2, 3}}; + std::vector vi = {1, 0, 2, 0, 2, 1}; + migraphx::shape su{migraphx::shape::float_type, {2, 3}}; + + auto pd = mm->add_parameter("data", sd); + auto li = mm->add_literal(migraphx::literal{si, vi}); + auto pu = mm->add_parameter("update", su); + auto r = mm->add_instruction(migraphx::make_op("scatter_none", {{"axis", -1}}), pd, li, pu); + mm->add_return({r}); + + return p; + } +}; diff --git a/test/verify/test_scatter1.cpp b/test/verify/test_scatter1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..611629cd8f5b043007318317f360697adf27f4d9 --- /dev/null +++ b/test/verify/test_scatter1.cpp @@ -0,0 +1,27 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_scatter1 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape sd{migraphx::shape::float_type, {3, 3}}; + migraphx::shape si{migraphx::shape::int32_type, {2, 3}}; + std::vector vi = {-2, 0, 2, 0, -1, 1}; + migraphx::shape su{migraphx::shape::float_type, {2, 3}}; + + auto pd = mm->add_parameter("data", sd); + auto li = mm->add_literal(migraphx::literal{si, vi}); + auto pu = mm->add_parameter("update", su); + auto r = mm->add_instruction(migraphx::make_op("scatter_none", {{"axis", -2}}), pd, li, pu); + mm->add_return({r}); + + return p; + } +}; diff --git a/test/verify/test_scatternd.cpp b/test/verify/test_scatternd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6ef26f40a3d89486f9bae3be2bc63ff4c569c695 --- /dev/null +++ b/test/verify/test_scatternd.cpp @@ -0,0 +1,30 @@ +#include "verify_program.hpp" +#include +#include +#include + +struct test_scatternd : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto dtype = migraphx::shape::float_type; + auto itype = migraphx::shape::int64_type; + migraphx::shape ds{dtype, {1}}; + migraphx::shape is{itype, {4, 1}}; + migraphx::shape us{dtype, {4}}; + std::vector ind_vec{4, 3, 1, 7}; + + auto ld = mm->add_literal(migraphx::literal{ds, {1}}); + auto data = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {8}}}), ld); + auto indices = mm->add_literal(migraphx::literal{is, ind_vec}); + auto updates = mm->add_parameter("update", us); + auto scatternd = + mm->add_instruction(migraphx::make_op("scatternd_none"), data, indices, updates); + mm->add_return({scatternd}); + + return p; + } +}; diff --git a/test/verify/test_scatternd_add.cpp b/test/verify/test_scatternd_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e77bfafe17406d015443b93e97335f54b84c9718 --- /dev/null +++ b/test/verify/test_scatternd_add.cpp @@ -0,0 +1,30 @@ +#include "verify_program.hpp" +#include +#include +#include + +struct test_scatternd_add : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto dtype = migraphx::shape::float_type; + auto itype = migraphx::shape::int64_type; + migraphx::shape ds{dtype, {8}}; + migraphx::shape is{itype, {1, 4}}; + migraphx::shape us{dtype, {4}}; + std::vector ind_vec{4, 3, 1, 7}; + + auto data = mm->add_parameter("data", ds); + auto indices = mm->add_literal(migraphx::literal{is, ind_vec}); + auto t_ind = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), indices); + auto updates = mm->add_parameter("update", us); + auto scatternd = + mm->add_instruction(migraphx::make_op("scatternd_add"), data, t_ind, updates); + mm->add_return({scatternd}); + + return p; + } +}; diff --git a/test/verify/test_scatternd_mul.cpp b/test/verify/test_scatternd_mul.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a6eb0d2dfe797cfab799e7089f2711319f911e67 --- /dev/null +++ b/test/verify/test_scatternd_mul.cpp @@ -0,0 +1,28 @@ +#include "verify_program.hpp" +#include +#include +#include + +struct test_scatternd_mul : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto dtype = migraphx::shape::float_type; + auto itype = migraphx::shape::int64_type; + migraphx::shape ds{dtype, {8}}; + migraphx::shape is{itype, {4, 1}}; + migraphx::shape us{dtype, {4}}; + std::vector ind_vec{4, 3, 1, 7}; + + auto data = mm->add_parameter("data", ds); + auto indices = mm->add_literal(migraphx::literal{is, ind_vec}); + auto updates = mm->add_parameter("update", us); + auto scatternd = + mm->add_instruction(migraphx::make_op("scatternd_mul"), data, indices, updates); + mm->add_return({scatternd}); + + return p; + } +}; diff --git a/test/verify/test_sigmoid.cpp b/test/verify/test_sigmoid.cpp new file mode 100644 index 0000000000000000000000000000000000000000..afd1d5246e132a4e47e1621268f021667ed465fd --- /dev/null +++ b/test/verify/test_sigmoid.cpp @@ -0,0 +1,17 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_sigmoid : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + mm->add_instruction(migraphx::make_op("sigmoid"), x); + return p; + } +}; diff --git a/test/verify/test_sign.cpp b/test/verify/test_sign.cpp new file mode 100644 index 0000000000000000000000000000000000000000..18c2f77e42dbcd3c398c57e621648428940b5954 --- /dev/null +++ b/test/verify/test_sign.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_sign : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; + auto param = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("sign"), param); + return p; + } +}; diff --git a/test/verify/test_sin.cpp b/test/verify/test_sin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1b01755ddec2a1a7f73089d32567d13e4a49276b --- /dev/null +++ b/test/verify/test_sin.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_sin : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {10}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("sin"), x); + return p; + } +}; diff --git a/test/verify/test_sinh.cpp b/test/verify/test_sinh.cpp new file mode 100755 index 0000000000000000000000000000000000000000..845d6aaa065584eb8a9a60d9c42e998e885ee4dc --- /dev/null +++ b/test/verify/test_sinh.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_sinh : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {16}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("sinh"), x); + return p; + } +}; diff --git a/test/verify/test_slice.cpp b/test/verify/test_slice.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4534d1037ad4b1874b89b2fe9e9f9b7f59e32bb9 --- /dev/null +++ b/test/verify/test_slice.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_slice : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int32_type, {2, 2, 4}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", {migraphx::shape::int32_type, {2, 2, 2}}); + auto slice0 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), x); + mm->add_instruction(migraphx::make_op("add"), y, slice0); + + return p; + } +}; diff --git a/test/verify/test_slice_reverse.cpp b/test/verify/test_slice_reverse.cpp new file mode 100644 index 0000000000000000000000000000000000000000..703ef62b25180fd20ad288edcee71307b0967005 --- /dev/null +++ b/test/verify/test_slice_reverse.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_slice_reverse : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int32_type, {3, 5}}; + auto x = mm->add_parameter("x", s); + auto slice_out = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 2}}, {"ends", {2, -1}}}), + x); + mm->add_instruction(migraphx::make_op("reverse", {{"axes", {0}}}), slice_out); + + return p; + } +}; diff --git a/test/verify/test_slice_reverse_step.cpp b/test/verify/test_slice_reverse_step.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fabadb2b3e3c2fa39c605be3a42664513d2f4c68 --- /dev/null +++ b/test/verify/test_slice_reverse_step.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_slice_reverse_step : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int32_type, {7, 5}}; + auto x = mm->add_parameter("x", s); + auto slice_out = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 2}}, {"ends", {2, -1}}}), + x); + auto step_out = + mm->add_instruction(migraphx::make_op("reverse", {{"axes", {0, 1}}}), slice_out); + mm->add_instruction(migraphx::make_op("step", {{"axes", {0, 1}}, {"steps", {2, 2}}}), + step_out); + return p; + } +}; diff --git a/test/verify/test_slice_sin.cpp b/test/verify/test_slice_sin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8a98b6948911ebec6378b371851676e83617479f --- /dev/null +++ b/test/verify/test_slice_sin.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_slice_sin : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto l = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto t = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), l); + mm->add_instruction(migraphx::make_op("sin"), t); + + return p; + } +}; diff --git a/test/verify/test_slice_step_reverse.cpp b/test/verify/test_slice_step_reverse.cpp new file mode 100644 index 0000000000000000000000000000000000000000..db9e88fc359b9728864eaacf5bf1194771602070 --- /dev/null +++ b/test/verify/test_slice_step_reverse.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_slice_step_reverse : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int32_type, {7, 5}}; + auto x = mm->add_parameter("x", s); + auto slice_out = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 2}}, {"ends", {2, -1}}}), + x); + auto step_out = mm->add_instruction( + migraphx::make_op("step", {{"axes", {0, 1}}, {"steps", {2, 2}}}), slice_out); + mm->add_instruction(migraphx::make_op("reverse", {{"axes", {0}}}), step_out); + + return p; + } +}; diff --git a/test/verify/test_softmax.cpp b/test/verify/test_softmax.cpp new file mode 100755 index 0000000000000000000000000000000000000000..3ea2d4f0dbdae2466b087aaa04b110783becdc40 --- /dev/null +++ b/test/verify/test_softmax.cpp @@ -0,0 +1,27 @@ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_softmax : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{T, {512, 4, 1067, 6}}; + auto param = mm->add_parameter("0", s); + mm->add_instruction(migraphx::make_op("softmax", {{"axis", Axis}}), param); + + return p; + } +}; + +template struct test_softmax<0, migraphx::shape::float_type>; +template struct test_softmax<2, migraphx::shape::float_type>; +template struct test_softmax<0, migraphx::shape::half_type>; +template struct test_softmax<1, migraphx::shape::half_type>; +template struct test_softmax<2, migraphx::shape::half_type>; +template struct test_softmax<3, migraphx::shape::half_type>; diff --git a/test/verify/test_softmax1.cpp b/test/verify/test_softmax1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1b13e551fc01d41d21e728b4d2b0d6103a7097c4 --- /dev/null +++ b/test/verify/test_softmax1.cpp @@ -0,0 +1,17 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_softmax1 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 3, 4}}); + mm->add_instruction(migraphx::make_op("softmax", {{"axis", 0}}), x); + return p; + } +}; diff --git a/test/verify/test_softmax2.cpp b/test/verify/test_softmax2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..01ee5aa5941f7cecf1a5ac161b0e571b926ef94b --- /dev/null +++ b/test/verify/test_softmax2.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_softmax2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = + mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1000, 1, 1}}); + mm->add_instruction(migraphx::make_op("softmax"), x); + return p; + } +}; diff --git a/test/verify/test_softmax3.cpp b/test/verify/test_softmax3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2ed7c76731379f83825d3668d7648f82252d9cd6 --- /dev/null +++ b/test/verify/test_softmax3.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_softmax3 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 3, 4}}); + auto sx = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 3}}, {"starts", {1, 1}}, {"ends", {5, 4}}}), + x); + auto r = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 0}}), sx); + mm->add_return({r}); + return p; + } +}; diff --git a/test/verify/test_sqrt.cpp b/test/verify/test_sqrt.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cead715b934dc89bc838f8c19e94d84296e70608 --- /dev/null +++ b/test/verify/test_sqrt.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_sqrt : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; + auto param = mm->add_parameter("x", s); + auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param); + mm->add_instruction(migraphx::make_op("sqrt"), param_abs); + return p; + } +}; diff --git a/test/verify/test_sqrt_half1.cpp b/test/verify/test_sqrt_half1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9624323db6a7293788106648271be9a266b8209f --- /dev/null +++ b/test/verify/test_sqrt_half1.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +// math op on half-precision float with odd size tensor can't fit half2 packing +struct test_sqrt_half1 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::half_type, {5}}; + auto param = mm->add_parameter("x", s); + auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param); + mm->add_instruction(migraphx::make_op("sqrt"), param_abs); + return p; + } +}; diff --git a/test/verify/test_sqrt_half2.cpp b/test/verify/test_sqrt_half2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d2bd3b3b311f9482d6297f5121c66a65aa20146 --- /dev/null +++ b/test/verify/test_sqrt_half2.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +// math op on half-precision float with tensor size that's divisible by 2, +// but not divisible by 4 +struct test_sqrt_half2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::half_type, {6}}; + auto param = mm->add_parameter("x", s); + auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param); + mm->add_instruction(migraphx::make_op("sqrt"), param_abs); + return p; + } +}; diff --git a/test/verify/test_sqrt_half4.cpp b/test/verify/test_sqrt_half4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..913e2726019412e60f63eb3da18aee140d087550 --- /dev/null +++ b/test/verify/test_sqrt_half4.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +// math op on half-precision float with tensor size that fits into half4 packing +struct test_sqrt_half4 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::half_type, {8}}; + auto param = mm->add_parameter("x", s); + auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param); + mm->add_instruction(migraphx::make_op("sqrt"), param_abs); + return p; + } +}; diff --git a/test/verify/test_step.cpp b/test/verify/test_step.cpp new file mode 100644 index 0000000000000000000000000000000000000000..55c087fa0f294c30c26b5657c2c3e2a8d451a185 --- /dev/null +++ b/test/verify/test_step.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_step : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s1{migraphx::shape::float_type, {2, 1, 4, 6}}; + auto l0 = mm->add_parameter("x", s1); + auto r = mm->add_instruction( + migraphx::make_op("step", {{"axes", {0, 2, 3}}, {"steps", {2, 2, 3}}}), l0); + mm->add_return({r}); + + return p; + } +}; diff --git a/test/verify/test_step_broadcast_transpose.cpp b/test/verify/test_step_broadcast_transpose.cpp new file mode 100644 index 0000000000000000000000000000000000000000..51b2b53e887ac55f2fcd366623f2e7ddcbe0c7ec --- /dev/null +++ b/test/verify/test_step_broadcast_transpose.cpp @@ -0,0 +1,25 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_step_broadcast_transpose : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s1{migraphx::shape::float_type, {1, 1, 1, 6}}; + auto l0 = mm->add_parameter("x", s1); + auto ml = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 1, 4, 6}}}), l0); + auto tl = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), ml); + auto r = mm->add_instruction( + migraphx::make_op("step", {{"axes", {0, 1, 2}}, {"steps", {2, 2, 3}}}), tl); + mm->add_return({r}); + + return p; + } +}; diff --git a/test/verify/test_step_transpose.cpp b/test/verify/test_step_transpose.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f8da7ec86e9c979182e65159bca26dddbba7992d --- /dev/null +++ b/test/verify/test_step_transpose.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_step_transpose : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s1{migraphx::shape::float_type, {2, 1, 4, 6}}; + auto l0 = mm->add_parameter("x", s1); + auto tl = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); + auto r = mm->add_instruction( + migraphx::make_op("step", {{"axes", {0, 1, 2}}, {"steps", {2, 2, 3}}}), tl); + mm->add_return({r}); + + return p; + } +}; diff --git a/test/verify/test_sub.cpp b/test/verify/test_sub.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ab9fec85397fc97a9284b548e74fd3e13dbca6d9 --- /dev/null +++ b/test/verify/test_sub.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_sub : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto diff = mm->add_instruction(migraphx::make_op("sub"), x, y); + mm->add_instruction(migraphx::make_op("sub"), diff, z); + return p; + } +}; diff --git a/test/verify/test_sub2.cpp b/test/verify/test_sub2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3d3a8986afc7b3611c08f85616ff6b54f18dd5ba --- /dev/null +++ b/test/verify/test_sub2.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_sub2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::shape b{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", b); + auto zb = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), z); + auto diff = mm->add_instruction(migraphx::make_op("sub"), x, y); + mm->add_instruction(migraphx::make_op("sub"), diff, zb); + return p; + } +}; diff --git a/test/verify/test_sub_int.cpp b/test/verify/test_sub_int.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0929f4cb8f45de65148a120db85372d9a1994b38 --- /dev/null +++ b/test/verify/test_sub_int.cpp @@ -0,0 +1,21 @@ +#include "verify_program.hpp" +#include +#include +#include + +struct test_sub_int : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", {migraphx::shape::int16_type, {4, 5}}); + auto y = mm->add_parameter("y", {migraphx::shape::int16_type, {2, 3, 4, 5}}); + auto xb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), x); + auto diff = mm->add_instruction(migraphx::make_op("sub"), y, xb); + mm->add_return({diff}); + return p; + } +}; diff --git a/test/verify/test_tan.cpp b/test/verify/test_tan.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7216caffd9d8892d850714ddc75731a1803049c3 --- /dev/null +++ b/test/verify/test_tan.cpp @@ -0,0 +1,18 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_tan : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {16}}; + auto x = mm->add_parameter("x", s); + mm->add_instruction(migraphx::make_op("tan"), x); + return p; + } +}; diff --git a/test/verify/test_tanh.cpp b/test/verify/test_tanh.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7b445f43534ccc2dfa044af55cf35d5bb2c94e6a --- /dev/null +++ b/test/verify/test_tanh.cpp @@ -0,0 +1,17 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_tanh : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + mm->add_instruction(migraphx::make_op("tanh"), x); + return p; + } +}; diff --git a/test/verify/test_topk_0.cpp b/test/verify/test_topk_0.cpp new file mode 100644 index 0000000000000000000000000000000000000000..af6106801f8646a084ff30a4e928cddacb7a7c8d --- /dev/null +++ b/test/verify/test_topk_0.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_topk_0 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 5}}; + auto data = mm->add_parameter("data", s); + auto r = mm->add_instruction( + migraphx::make_op("topk", {{"axis", 1}, {"k", 4}, {"largest", 1}}), data); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r); + mm->add_return({r0}); + + return p; + } +}; diff --git a/test/verify/test_topk_1.cpp b/test/verify/test_topk_1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..367bc6348fa23c7becb8e11f250424d5d0e809df --- /dev/null +++ b/test/verify/test_topk_1.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_topk_1 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 5}}; + auto data = mm->add_parameter("data", s); + auto r = mm->add_instruction( + migraphx::make_op("topk", {{"axis", -2}, {"k", 3}, {"largest", 1}}), data); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), r); + mm->add_return({r0, r1}); + + return p; + } +}; diff --git a/test/verify/test_topk_2.cpp b/test/verify/test_topk_2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eb55f30fb2ce021ff00a9ded4abbef9f4bcf7301 --- /dev/null +++ b/test/verify/test_topk_2.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_topk_2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 5}}; + auto data = mm->add_parameter("data", s); + auto r = mm->add_instruction( + migraphx::make_op("topk", {{"axis", 1}, {"k", 4}, {"largest", 0}}), data); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r); + mm->add_return({r0}); + + return p; + } +}; diff --git a/test/verify/test_topk_3.cpp b/test/verify/test_topk_3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3800e002be816867cc10c2c021d3d1b27cf8e9c8 --- /dev/null +++ b/test/verify/test_topk_3.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_topk_3 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 5}}; + auto data = mm->add_parameter("data", s); + auto r = mm->add_instruction( + migraphx::make_op("topk", {{"axis", -2}, {"k", 3}, {"largest", 0}}), data); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), r); + mm->add_return({r0, r1}); + + return p; + } +}; diff --git a/test/verify/test_trans_abs.cpp b/test/verify/test_trans_abs.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ea2c27c4626b2b9e9a2127125689e7960c74378f --- /dev/null +++ b/test/verify/test_trans_abs.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_trans_abs : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto tx = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), x); + auto absx = mm->add_instruction(migraphx::make_op("abs"), tx); + auto r = mm->add_instruction(migraphx::make_op("add"), absx, absx); + mm->add_instruction(migraphx::make_op("contiguous"), r); + + return p; + } +}; diff --git a/test/verify/test_trans_ret.cpp b/test/verify/test_trans_ret.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cd43a6a2c4d267df10b3aad7f18a16d09910b5c0 --- /dev/null +++ b/test/verify/test_trans_ret.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_trans_ret : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto tx = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), x); + mm->add_return({tx}); + + return p; + } +}; diff --git a/test/verify/test_trans_tanh.cpp b/test/verify/test_trans_tanh.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9f44c672ee56b3040dabdb50323eda3c37e7821f --- /dev/null +++ b/test/verify/test_trans_tanh.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_trans_tanh : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto tx = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), x); + auto tanhx = mm->add_instruction(migraphx::make_op("tanh"), tx); + auto r = mm->add_instruction(migraphx::make_op("add"), tanhx, tanhx); + mm->add_instruction(migraphx::make_op("contiguous"), r); + + return p; + } +}; diff --git a/test/verify/test_trans_tanh1.cpp b/test/verify/test_trans_tanh1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..88927d6936a90b4f4a2995488a0f6eb872988ef3 --- /dev/null +++ b/test/verify/test_trans_tanh1.cpp @@ -0,0 +1,22 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_trans_tanh1 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto tx = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), x); + auto tanhx = mm->add_instruction(migraphx::make_op("tanh"), tx); + auto r = mm->add_instruction(migraphx::make_op("add"), tanhx, tanhx); + mm->add_return({tx, r}); + + return p; + } +}; diff --git a/test/verify/test_transpose.cpp b/test/verify/test_transpose.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2be74aab95f8fa50fc76c842706364423bf4b824 --- /dev/null +++ b/test/verify/test_transpose.cpp @@ -0,0 +1,20 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_transpose : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {4, 3, 4, 4}}; + auto x = mm->add_parameter("x", s); + std::vector perm = {0, 2, 3, 1}; + auto l = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), x); + mm->add_instruction(migraphx::make_op("contiguous"), l); + return p; + } +}; diff --git a/test/verify/test_triadd.cpp b/test/verify/test_triadd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..848c813f304ffa5696ddce34bc62929987ad05af --- /dev/null +++ b/test/verify/test_triadd.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_triadd : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto sum = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_instruction(migraphx::make_op("add"), sum, z); + return p; + } +}; diff --git a/test/verify/test_triadd2.cpp b/test/verify/test_triadd2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..24fdd518ce07c9a14e51c3b863607c76c54545cb --- /dev/null +++ b/test/verify/test_triadd2.cpp @@ -0,0 +1,24 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_triadd2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::shape b{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", b); + auto zb = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), z); + auto sum = mm->add_instruction(migraphx::make_op("add"), x, y); + mm->add_instruction(migraphx::make_op("add"), sum, zb); + return p; + } +}; diff --git a/test/verify/test_triadd_broadcast.cpp b/test/verify/test_triadd_broadcast.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a28bfc73a737ee2794dd5136b088c2e749dc1286 --- /dev/null +++ b/test/verify/test_triadd_broadcast.cpp @@ -0,0 +1,25 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +struct test_triadd_broadcast : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3}}; + auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 2, 3}}); + auto y = mm->add_parameter("y", {migraphx::shape::float_type, {2, 2}}); + auto z = mm->add_parameter("z", {migraphx::shape::float_type, {2, 2, 3}}); + auto by = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", x->get_shape().lens()}}), y); + auto sum = mm->add_instruction(migraphx::make_op("add"), x, by); + mm->add_instruction(migraphx::make_op("add"), sum, z); + return p; + } +}; diff --git a/test/verify/test_triadd_relu.cpp b/test/verify/test_triadd_relu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..66525a3e41bd7774991d9f6dbd2a7b4126e66a09 --- /dev/null +++ b/test/verify/test_triadd_relu.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_triadd_relu : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto sum = mm->add_instruction(migraphx::make_op("add"), x, y); + auto triadd = mm->add_instruction(migraphx::make_op("add"), sum, z); + mm->add_instruction(migraphx::make_op("relu"), triadd); + return p; + } +}; diff --git a/test/verify/test_triadd_sigmoid.cpp b/test/verify/test_triadd_sigmoid.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9c8279b343ddfd18fc8adb63f641c29df4ff864a --- /dev/null +++ b/test/verify/test_triadd_sigmoid.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_triadd_sigmoid : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto sum = mm->add_instruction(migraphx::make_op("add"), x, y); + auto triadd = mm->add_instruction(migraphx::make_op("add"), sum, z); + mm->add_instruction(migraphx::make_op("sigmoid"), triadd); + return p; + } +}; diff --git a/test/verify/test_triadd_tanh.cpp b/test/verify/test_triadd_tanh.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cf4bd50cac330a0396d759f8956cad9c4af9cd54 --- /dev/null +++ b/test/verify/test_triadd_tanh.cpp @@ -0,0 +1,21 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_triadd_tanh : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto sum = mm->add_instruction(migraphx::make_op("add"), x, y); + auto triadd = mm->add_instruction(migraphx::make_op("add"), sum, z); + mm->add_instruction(migraphx::make_op("tanh"), triadd); + return p; + } +}; diff --git a/test/verify/test_var_sl_gru_bidirct.cpp b/test/verify/test_var_sl_gru_bidirct.cpp new file mode 100644 index 0000000000000000000000000000000000000000..95da5321f05d15372038adecb0c7ca371bc9dac3 --- /dev/null +++ b/test/verify/test_var_sl_gru_bidirct.cpp @@ -0,0 +1,62 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_var_sl_gru_bidirct : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 3; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + std::vector sl_data{2, 1, 3}; + auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data}); + + auto hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + sql, + ih); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_return({hs, lho}); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_var_sl_gru_forward.cpp b/test/verify/test_var_sl_gru_forward.cpp new file mode 100644 index 0000000000000000000000000000000000000000..23d389fcafca9d849e3baebd67655d6a3b5a6354 --- /dev/null +++ b/test/verify/test_var_sl_gru_forward.cpp @@ -0,0 +1,62 @@ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_var_sl_gru_forward : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 3; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 3 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; + migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + std::vector sl_data{3, 2, 1}; + auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data}); + + auto hs = mm->add_instruction( + migraphx::make_op( + "gru", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + sql, + ih); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_return({lho, hs}); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_where.cpp b/test/verify/test_where.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e6af7234fa8b25d962bc5230e55f06fdb7b07390 --- /dev/null +++ b/test/verify/test_where.cpp @@ -0,0 +1,23 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_where : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape sb{migraphx::shape::bool_type, {1, 3, 4, 5}}; + migraphx::shape sx{migraphx::shape::float_type, {1, 3, 4, 5}}; + auto b = mm->add_parameter("b", sb); + auto x = mm->add_parameter("x", sx); + auto y = mm->add_parameter("y", sx); + auto r = mm->add_instruction(migraphx::make_op("where"), b, x, y); + mm->add_return({r}); + return p; + }; +}; diff --git a/test/verify/test_where2.cpp b/test/verify/test_where2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..774d3218bf741f79edbc2e834b386fe584513ae2 --- /dev/null +++ b/test/verify/test_where2.cpp @@ -0,0 +1,27 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_where2 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape sb{migraphx::shape::bool_type, {1, 3, 4, 5}}; + migraphx::shape sx{migraphx::shape::float_type, {1}}; + auto b = mm->add_parameter("b", sb); + auto x = mm->add_parameter("x", sx); + auto y = mm->add_parameter("y", sx); + auto mbx = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 3, 4, 5}}}), x); + auto mby = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 3, 4, 5}}}), y); + auto r = mm->add_instruction(migraphx::make_op("where"), b, mbx, mby); + mm->add_return({r}); + return p; + }; +}; diff --git a/test/verify/test_xor.cpp b/test/verify/test_xor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bd84233a991c64f9c8378b604c4c54679a0118f6 --- /dev/null +++ b/test/verify/test_xor.cpp @@ -0,0 +1,19 @@ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_xor : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::bool_type, {3}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + mm->add_instruction(migraphx::make_op("logical_xor"), x, y); + return p; + } +}; diff --git a/test/verify/verify_program.cpp b/test/verify/verify_program.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f77c9b467c199cbdab28e09f1daffd408f39376c --- /dev/null +++ b/test/verify/verify_program.cpp @@ -0,0 +1,10 @@ +#include "verify_program.hpp" + +std::vector& get_programs_vector() +{ + static std::vector result; // NOLINT + return result; +} + +void register_program_info(const program_info& pi) { get_programs_vector().push_back(pi); } +const std::vector& get_programs() { return get_programs_vector(); } diff --git a/test/verify/verify_program.hpp b/test/verify/verify_program.hpp new file mode 100644 index 0000000000000000000000000000000000000000..341700e5a56f1efd66f3cf3a0d2ab4df12b01ab9 --- /dev/null +++ b/test/verify/verify_program.hpp @@ -0,0 +1,41 @@ +#ifndef MIGRAPHX_GUARD_AUTO_REGISTER_VERIFY_PROGRAM_HPP +#define MIGRAPHX_GUARD_AUTO_REGISTER_VERIFY_PROGRAM_HPP + +#include +#include +#include + +struct program_info +{ + std::string name; + std::string section; + std::function get_program; +}; + +void register_program_info(const program_info& pi); +const std::vector& get_programs(); + +struct register_verify_program_action +{ + template + static void apply() + { + T x; + program_info pi; + pi.name = migraphx::get_type_name(); + pi.section = x.section(); + pi.get_program = [x] { return x.create_program(); }; + register_program_info(pi); + } +}; + +template +using auto_register_verify_program = migraphx::auto_register; + +template +struct verify_program : auto_register_verify_program +{ + std::string section() const { return "general"; }; +}; + +#endif diff --git a/tools/api.py b/tools/api.py new file mode 100755 index 0000000000000000000000000000000000000000..aba9da87cd104c44736207c9ee9bd79b53a505a9 --- /dev/null +++ b/tools/api.py @@ -0,0 +1,1212 @@ +import string, sys, re, runpy +from functools import wraps +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +type_map: Dict[str, Callable[['Parameter'], None]] = {} +cpp_type_map: Dict[str, str] = {} +functions: List['Function'] = [] +cpp_classes: List['CPPClass'] = [] +error_type = '' +success_type = '' +try_wrap = '' + +c_header_preamble: List[str] = [] +c_api_body_preamble: List[str] = [] +cpp_header_preamble: List[str] = [] + + +def bad_param_error(msg: str): + return 'throw std::runtime_error("{}")'.format(msg) + + +class Template(string.Template): + idpattern = '[_a-zA-Z0-9@]+' + + +class Type: + def __init__(self, name: str) -> None: + self.name = name.strip() + + def is_pointer(self) -> bool: + return self.name.endswith('*') + + def is_reference(self) -> bool: + return self.name.endswith('&') + + def is_const(self) -> bool: + return self.name.startswith('const ') + + def is_variadic(self): + return self.name.startswith('...') + + def add_pointer(self) -> 'Type': + return Type(self.name + '*') + + def add_reference(self): + return Type(self.name + '&') + + def add_const(self) -> 'Type': + return Type('const ' + self.name) + + def inner_type(self) -> Optional['Type']: + i = self.name.find('<') + j = self.name.rfind('>') + if i > 0 and j > 0: + return Type(self.name[i + 1:j]) + else: + return None + + def remove_generic(self) -> 'Type': + i = self.name.find('<') + j = self.name.rfind('>') + if i > 0 and j > 0: + return Type(self.name[0:i] + self.name[j + 1:]) + else: + return self + + def remove_pointer(self) -> 'Type': + if self.is_pointer(): + return Type(self.name[0:-1]) + return self + + def remove_reference(self) -> 'Type': + if self.is_reference(): + return Type(self.name[0:-1]) + return self + + def remove_const(self) -> 'Type': + if self.is_const(): + return Type(self.name[6:]) + return self + + def basic(self) -> 'Type': + return self.remove_pointer().remove_const().remove_reference() + + def decay(self) -> 'Type': + t = self.remove_reference() + if t.is_pointer(): + return t + else: + return t.remove_const() + + def const_compatible(self, t: 'Type'): + if t.is_const(): + return self.add_const() + return self + + def str(self) -> str: + return self.name + + +header_function = Template(''' +${error_type} ${name}(${params}); +''') + +function_pointer_typedef = Template(''' +typedef ${error_type} (*${fname})(${params}); +''') + +c_api_impl = Template(''' +extern "C" ${error_type} ${name}(${params}) +{ + ${va_start}auto api_error_result = ${try_wrap}([&] { + ${body}; + }); + ${va_end}return api_error_result; +} +''') + + +class CFunction: + def __init__(self, name: str) -> None: + self.name = name + self.params: List[str] = [] + self.body: List[str] = [] + self.va_start: List[str] = [] + self.va_end: List[str] = [] + + def add_param(self, type: str, pname: str) -> None: + self.params.append('{} {}'.format(type, pname)) + + def add_statement(self, stmt: str) -> None: + self.body.append(stmt) + + def add_vlist(self, name: str) -> None: + last_param = self.params[-1].split()[-1] + self.va_start = [ + 'va_list {};'.format(name), + 'va_start({}, {});'.format(name, last_param) + ] + self.va_end = ['va_end({});'.format(name)] + self.add_param('...', '') + + def substitute(self, form: Template, **kwargs) -> str: + return form.substitute(error_type=error_type, + try_wrap=try_wrap, + name=self.name, + params=', '.join(self.params), + body=";\n ".join(self.body), + va_start="\n ".join(self.va_start), + va_end="\n ".join(self.va_end), + **kwargs) + + def generate_header(self) -> str: + return self.substitute(header_function) + + def generate_function_pointer(self, name: Optional[str] = None) -> str: + return self.substitute(function_pointer_typedef, + fname=name or self.name) + + def generate_body(self) -> str: + return self.substitute(c_api_impl) + + +class BadParam: + def __init__(self, cond: str, msg: str) -> None: + self.cond = cond + self.msg = msg + + +class Parameter: + def __init__(self, + name: str, + type: str, + optional: bool = False, + returns: bool = False, + virtual: bool = False, + this: bool = False) -> None: + self.name = name + self.type = Type(type) + self.optional = optional + self.cparams: List[Tuple[str, str]] = [] + self.size_cparam = -1 + self.size_name = '' + self.read = '${name}' + self.write = ['*${name} = ${result}'] + self.cpp_read = '${name}' + self.cpp_write = '${name}' + self.returns = returns + self.virtual = virtual + self.this = this + self.bad_param_check: Optional[BadParam] = None + self.virtual_read: Optional[List[str]] = None + self.virtual_write: Optional[str] = None + + def get_name(self, prefix: Optional[str] = None) -> str: + if prefix: + return prefix + self.name + else: + return self.name + + def get_cpp_type(self) -> str: + if self.type.str() in cpp_type_map: + return cpp_type_map[self.type.basic().str()] + elif self.type.basic().str() in cpp_type_map: + return cpp_type_map[self.type.basic().str()] + elif self.returns: + return self.type.decay().str() + else: + return self.type.str() + + def substitute(self, + s: str, + prefix: Optional[str] = None, + result: Optional[str] = None) -> str: + ctype = None + if len(self.cparams) > 0: + ctype = Type(self.cparams[0][0]).basic().str() + return Template(s).safe_substitute(name=self.get_name(prefix), + type=self.type.str(), + ctype=ctype or '', + cpptype=self.get_cpp_type(), + size=self.size_name, + result=result or '') + + def add_param(self, t: Union[str, Type], + name: Optional[str] = None) -> None: + if not isinstance(t, str): + t = t.str() + self.cparams.append((t, name or self.name)) + + def add_size_param(self, name: Optional[str] = None) -> None: + self.size_cparam = len(self.cparams) + self.size_name = name or self.name + '_size' + if self.returns: + self.add_param('size_t *', self.size_name) + else: + self.add_param('size_t', self.size_name) + + def bad_param(self, cond: str, msg: str) -> None: + self.bad_param_check = BadParam(cond, msg) + + def remove_size_param(self, name): + p = None + if self.size_cparam >= 0: + p = self.cparams[self.size_cparam] + del self.cparams[self.size_cparam] + self.size_name = name + return p + + def update(self) -> None: + t = self.type.basic().str() + g = self.type.remove_generic().basic().str() + if t in type_map: + type_map[t](self) + elif g in type_map: + type_map[g](self) + else: + if self.returns: + self.add_param(self.type.remove_reference().add_pointer()) + else: + self.add_param(self.type.remove_reference()) + if isinstance(self.write, str): + raise ValueError("Error for {}: write cannot be a string".format( + self.type.str())) + + def virtual_arg(self, prefix: Optional[str] = None) -> List[str]: + read = self.virtual_read + if not read and len(self.write) >= len(self.cparams): + read = [ + Template(w.partition('=')[2]).safe_substitute(result='${name}') + for w in self.write + ] + if not read: + raise ValueError("No virtual_read parameter provided for: " + + self.type.str()) + if isinstance(read, str): + raise ValueError( + "Error for {}: virtual_read cannot be a string".format( + self.type.str())) + return [self.substitute(r, prefix=prefix) for r in read] + + def virtual_param(self, prefix: Optional[str] = None) -> str: + return self.substitute('${type} ${name}', prefix=prefix) + + def virtual_output_args(self, prefix: Optional[str] = None) -> List[str]: + return [ + '&{prefix}{n}'.format(prefix=prefix or '', n=n) + for t, n in self.cparams + ] + + def virtual_output_declarations(self, + prefix: Optional[str] = None) -> List[str]: + return [ + 'std::remove_pointer_t<{type}> {prefix}{n};'.format( + type=Type(t).str(), prefix=prefix or '', n=n) + for t, n in self.cparams + ] + + def virtual_output(self, prefix: Optional[str] = None) -> str: + write = self.virtual_write + if not write: + if '*' in self.read or '->' in self.read: + write = Template(self.read).safe_substitute(name='(&${name})') + else: + write = self.read + return self.substitute(write, prefix=prefix) + + def cpp_param(self, prefix: Optional[str] = None) -> str: + return self.substitute('${cpptype} ${name}', prefix=prefix) + + def cpp_arg(self, prefix: Optional[str] = None) -> str: + return self.substitute(self.cpp_read, prefix=prefix) + + def cpp_output_args(self, prefix: Optional[str] = None) -> List[str]: + return [ + '&{prefix}{n}'.format(prefix=prefix, n=n) for t, n in self.cparams + ] + + def output_declarations(self, prefix: Optional[str] = None) -> List[str]: + return [ + '{type} {prefix}{n};'.format(type=Type(t).remove_pointer().str(), + prefix=prefix, + n=n) for t, n in self.cparams + ] + + def output_args(self, prefix=None): + return [ + '&{prefix}{n};'.format(prefix=prefix, n=n) for t, n in self.cparams + ] + + def cpp_output(self, prefix: Optional[str] = None) -> str: + return self.substitute(self.cpp_write, prefix=prefix) + + def input(self, prefix: Optional[str] = None) -> str: + return '(' + self.substitute(self.read, prefix=prefix) + ')' + + def outputs(self, result: Optional[str] = None) -> List[str]: + return [self.substitute(w, result=result) for w in self.write] + + def add_to_cfunction(self, cfunction: CFunction) -> None: + for t, name in self.cparams: + if t.startswith('...'): + cfunction.add_vlist(name) + else: + cfunction.add_param(self.substitute(t), self.substitute(name)) + if self.bad_param_check: + msg = 'Bad parameter {name}: {msg}'.format( + name=self.name, msg=self.bad_param_check.msg) + cfunction.add_statement('if ({cond}) {body}'.format( + cond=self.substitute(self.bad_param_check.cond), + body=bad_param_error(msg))) + + +def template_var(s: str) -> str: + return '${' + s + '}' + + +def to_template_vars(params: List[Union[Any, Parameter]]) -> str: + return ', '.join([template_var(p.name) for p in params]) + + +class Function: + def __init__(self, + name: str, + params: Optional[List[Parameter]] = None, + shared_size: bool = False, + returns: Optional[str] = None, + invoke: Optional[str] = None, + fname: Optional[str] = None, + return_name: Optional[str] = None, + virtual: bool = False, + **kwargs) -> None: + self.name = name + self.params = params or [] + self.shared_size = False + self.cfunction: Optional[CFunction] = None + self.fname = fname + self.invoke = invoke or '${__fname__}($@)' + self.return_name = return_name or 'out' + self.returns = Parameter(self.return_name, returns, + returns=True) if returns else None + for p in self.params: + p.virtual = virtual + if self.returns: + self.returns.virtual = virtual + + def share_params(self) -> None: + if self.shared_size == True: + size_param_name = 'size' + size_type = Type('size_t') + for param in self.params: + p = param.remove_size_param(size_param_name) + if p: + size_type = Type(p[0]) + self.params.append(Parameter(size_param_name, size_type.str())) + + def update(self) -> None: + self.share_params() + for param in self.params: + param.update() + if self.returns: + self.returns.update() + self.create_cfunction() + + def inputs(self) -> str: + return ', '.join([p.input() for p in self.params]) + + # TODO: Shoule we remove Optional? + def input_map(self) -> Dict[str, Optional[str]]: + m: Dict[str, Optional[str]] = {} + for p in self.params: + m[p.name] = p.input() + m['return'] = self.return_name + m['@'] = self.inputs() + m['__fname__'] = self.fname + return m + + def get_invoke(self) -> str: + return Template(self.invoke).safe_substitute(self.input_map()) + + def write_to_tmp_var(self) -> bool: + if not self.returns: + return False + return len(self.returns.write) > 1 or self.returns.write[0].count( + '${result}') > 1 + + def get_cfunction(self) -> CFunction: + if self.cfunction: + return self.cfunction + raise Exception( + "self.cfunction is None: self.update() needs to be called.") + + def create_cfunction(self) -> None: + self.cfunction = CFunction(self.name) + # Add the return as a parameter + if self.returns: + self.returns.add_to_cfunction(self.cfunction) + # Add the input parameters + for param in self.params: + param.add_to_cfunction(self.cfunction) + f: Optional[str] = self.get_invoke() + # Write the assignments + assigns = [] + if self.returns: + result = f + if self.write_to_tmp_var() and f: + f = 'auto&& api_result = ' + f + result = 'api_result' + else: + f = None + assigns = self.returns.outputs(result) + if f: + self.cfunction.add_statement(f) + for assign in assigns: + self.cfunction.add_statement(assign) + + +cpp_class_template = Template(''' + +struct ${name} : handle_base<${ctype}, decltype(&${destroy}), ${destroy}> +{ + ${name}(${ctype} p, bool own = true) + : m_handle(nullptr) + { + this->set_handle(p, own); + } + ${constructors} + + ${methods} +}; +''') + +cpp_class_method_template = Template(''' + ${return_type} ${name}(${params}) const + { + ${outputs} + this->call_handle(${args}); + return ${result}; + } +''') + +cpp_class_void_method_template = Template(''' + void ${name}(${params}) const + { + this->call_handle(${args}); + } +''') + +cpp_class_constructor_template = Template(''' + ${name}(${params}) + : m_handle(nullptr) + { + m_handle = this->make_handle(${args}); + } +''') + + +class CPPMember: + def __init__(self, + name: str, + function: Function, + prefix: str, + method: bool = True) -> None: + self.name = name + self.function = function + self.prefix = prefix + self.method = method + + def get_function_params(self) -> List[Union[Any, Parameter]]: + if self.method: + return self.function.params[1:] + else: + return self.function.params + + def get_args(self) -> str: + output_args = [] + if self.function.returns: + output_args = self.function.returns.cpp_output_args(self.prefix) + if not self.function.cfunction: + raise Exception('self.function.update() must be called') + return ', '.join( + ['&{}'.format(self.function.cfunction.name)] + output_args + + [p.cpp_arg(self.prefix) for p in self.get_function_params()]) + + def get_params(self) -> str: + return ', '.join( + [p.cpp_param(self.prefix) for p in self.get_function_params()]) + + def get_return_declarations(self) -> str: + if self.function.returns: + return '\n '.join([ + d + for d in self.function.returns.output_declarations(self.prefix) + ]) + else: + return '' + + def get_result(self): + return self.function.returns.input(self.prefix) + + def generate_method(self) -> str: + if not self.function.cfunction: + raise Exception('self.function.update() must be called') + if self.function.returns: + return_type = self.function.returns.get_cpp_type() + return cpp_class_method_template.safe_substitute( + return_type=return_type, + name=self.name, + cfunction=self.function.cfunction.name, + result=self.function.returns.cpp_output(self.prefix), + params=self.get_params(), + outputs=self.get_return_declarations(), + args=self.get_args(), + success=success_type) + else: + return cpp_class_void_method_template.safe_substitute( + name=self.name, + cfunction=self.function.cfunction.name, + params=self.get_params(), + args=self.get_args(), + success=success_type) + + def generate_constructor(self, name: str) -> str: + if not self.function.cfunction: + raise Exception('self.function.update() must be called') + return cpp_class_constructor_template.safe_substitute( + name=name, + cfunction=self.function.cfunction.name, + params=self.get_params(), + args=self.get_args(), + success=success_type) + + +class CPPClass: + def __init__(self, name: str, ctype: str) -> None: + self.name = name + self.ctype = ctype + self.constructors: List[CPPMember] = [] + self.methods: List[CPPMember] = [] + self.prefix = 'p' + + def add_method(self, name: str, f: Function) -> None: + self.methods.append(CPPMember(name, f, self.prefix, method=True)) + + def add_constructor(self, name: str, f: Function) -> None: + self.constructors.append(CPPMember(name, f, self.prefix, method=True)) + + def generate_methods(self) -> str: + return '\n '.join([m.generate_method() for m in self.methods]) + + def generate_constructors(self) -> str: + return '\n '.join( + [m.generate_constructor(self.name) for m in self.constructors]) + + def substitute(self, s: Union[string.Template, str], **kwargs) -> str: + t = string.Template(s) if isinstance(s, str) else s + destroy = self.ctype + '_destroy' + return t.safe_substitute(name=self.name, + ctype=self.ctype, + destroy=destroy, + **kwargs) + + def generate(self) -> str: + return self.substitute( + cpp_class_template, + constructors=self.substitute(self.generate_constructors()), + methods=self.substitute(self.generate_methods())) + + +def params(virtual: Optional[Dict[str, str]] = None, + **kwargs) -> List[Parameter]: + result = [] + v: Dict[str, str] = virtual or {} + for name in v: + result.append(Parameter(name, v[name])) + for name in kwargs: + result.append(Parameter(name, kwargs[name])) + return result + + +gparams = params + + +def add_function(name: str, *args, **kwargs) -> Function: + f = Function(name, *args, **kwargs) + functions.append(f) + return f + + +def once(f: Callable) -> Any: + @wraps(f) + def decorated(*args, **kwargs): + if not decorated.has_run: + decorated.has_run = True + return f(*args, **kwargs) + + d: Any = decorated + d.has_run = False + return d + + +@once +def process_functions() -> None: + for f in functions: + f.update() + + +def generate_lines(p: List[str]) -> str: + return '\n'.join(p) + + +def generate_c_header() -> str: + process_functions() + return generate_lines( + c_header_preamble + + [f.get_cfunction().generate_header() for f in functions]) + + +def generate_c_api_body() -> str: + process_functions() + return generate_lines( + c_api_body_preamble + + [f.get_cfunction().generate_body() for f in functions]) + + +def generate_cpp_header() -> str: + process_functions() + return generate_lines(cpp_header_preamble + + [c.generate() for c in cpp_classes]) + + +def cwrap(name: str) -> Callable: + def with_cwrap(f): + type_map[name] = f + + @wraps(f) + def decorated(*args, **kwargs): + return f(*args, **kwargs) + + return decorated + + return with_cwrap + + +handle_typedef = Template(''' +typedef struct ${ctype} * ${ctype}_t; +typedef const struct ${ctype} * const_${ctype}_t; +''') + +handle_definition = Template(''' +extern "C" struct ${ctype}; +struct ${ctype} { + template + ${ctype}(Ts&&... xs) + : object(std::forward(xs)...) // NOLINT(readability-redundant-member-init) + {} + ${cpptype} object; +}; +''') + +handle_preamble = ''' +template> +Target* object_cast(U* x) +{ + return reinterpret_cast(x); +} +template> +const Target* object_cast(const U* x) +{ + return reinterpret_cast(x); +} + +template> +Target* allocate(Ts&&... xs) +{ + return new Target(std::forward(xs)...); // NOLINT +} + +template +void destroy(T* x) +{ + delete x; // NOLINT +} +// TODO: Move to interface preamble +template +struct manage_generic_ptr +{ + manage_generic_ptr() = default; + + manage_generic_ptr(std::nullptr_t) + { + } + + manage_generic_ptr(void* pdata, C pcopier, D pdeleter) + : data(nullptr), copier(pcopier), deleter(pdeleter) + { + copier(&data, pdata); + } + + manage_generic_ptr(const manage_generic_ptr& rhs) + : data(nullptr), copier(rhs.copier), deleter(rhs.deleter) + { + if(copier) + copier(&data, rhs.data); + } + + manage_generic_ptr(manage_generic_ptr&& other) noexcept + : data(other.data), copier(other.copier), deleter(other.deleter) + { + other.data = nullptr; + other.copier = nullptr; + other.deleter = nullptr; + } + + manage_generic_ptr& operator=(manage_generic_ptr rhs) + { + std::swap(data, rhs.data); + std::swap(copier, rhs.copier); + std::swap(deleter, rhs.deleter); + return *this; + } + + ~manage_generic_ptr() + { + if(data != nullptr) + deleter(data); + } + + void* data = nullptr; + C copier = nullptr; + D deleter = nullptr; +}; +''' + +cpp_handle_preamble = ''' +template +struct handle_base +{ + + template + void make_handle(F f, Ts&&... xs) + { + T* result = nullptr; + auto e = F(&result, std::forward(xs)...); + if (e != ${success}) + throw std::runtime_error("Failed to call function"); + set_handle(result); + } + + template + void call_handle(F f, Ts&&... xs) + { + auto e = F(this->get_handle_ptr(), std::forward(xs)...); + if (e != ${success}) + throw std::runtime_error("Failed to call function"); + } + + const std::shared_ptr& get_handle() const + { + return m_handle; + } + + T* get_handle_ptr() const + { + assert(m_handle != nullptr); + return get_handle().get(); + } + + void set_handle(T* ptr, bool own = true) + { + if (own) + m_handle = std::shared_ptr{ptr, deleter}; + else + m_handle = std::shared_ptr{ptr, [](T*) {}}; + } + +protected: + std::shared_ptr m_handle; +}; + +''' + + +@once +def add_handle_preamble() -> None: + c_api_body_preamble.append(handle_preamble) + cpp_header_preamble.append( + string.Template(cpp_handle_preamble).substitute(success=success_type)) + + +def add_handle(name: str, + ctype: str, + cpptype: str, + destroy: Optional[str] = None, + ref=False, + skip_def=False) -> None: + opaque_type = ctype + '_t' + const_opaque_type = 'const_' + opaque_type + + def handle_wrap(p: Parameter): + t = Type(opaque_type) + if p.type.is_const(): + t = Type('const_' + opaque_type) + # p.read = 'object_cast<${ctype}>(&(${name}))' + if p.virtual: + p.add_param(t) + elif p.returns: + p.add_param(t.add_pointer()) + else: + p.add_param(t) + p.bad_param('${name} == nullptr', 'Null pointer') + if p.type.is_reference(): + p.virtual_read = ['object_cast<${ctype}>(&(${name}))'] + p.cpp_write = '${cpptype}(${name}, false)' + p.write = ['*${name} = object_cast<${ctype}>(&(${result}))'] + elif p.type.is_pointer(): + p.virtual_read = ['object_cast<${ctype}>(${result})'] + p.cpp_write = '${cpptype}(${name}, false)' + p.write = ['*${name} = object_cast<${ctype}>(${result})'] + else: + p.virtual_read = ['object_cast<${ctype}>(&(${name}))'] + p.cpp_write = '${cpptype}(${name})' + p.write = ['*${name} = allocate<${ctype}>(${result})'] + if skip_def: + p.read = '*${name}' + else: + p.read = '${name}->object' + p.cpp_read = '${name}.get_handle_ptr()' + + type_map[cpptype] = handle_wrap + if not ref: + add_function(destroy or ctype + '_' + 'destroy', + params({name: opaque_type}), + fname='destroy') + add_function(ctype + '_' + 'assign_to', + params(output=opaque_type, input=const_opaque_type), + invoke='*output = *input') + add_handle_preamble() + c_header_preamble.append(handle_typedef.substitute(locals())) + if not skip_def: + c_api_body_preamble.append(handle_definition.substitute(locals())) + + +@cwrap('std::vector') +def vector_c_wrap(p: Parameter) -> None: + inner = p.type.inner_type() + # Not a generic type + if not inner: + return + t = inner.add_pointer() + if p.type.is_reference(): + if p.type.is_const(): + t = t.add_const() + if p.returns: + if p.type.is_reference(): + p.add_param(t.add_pointer()) + p.add_size_param() + p.bad_param('${name} == nullptr or ${size} == nullptr', + 'Null pointer') + else: + p.add_param(t) + p.bad_param('${name} == nullptr', 'Null pointer') + else: + p.add_param(t) + p.add_size_param() + p.bad_param('${name} == nullptr and ${size} != 0', 'Null pointer') + + p.read = '${type}(${name}, ${name}+${size})' + p.cpp_write = '${type}(${name}, ${name}+${size})' + p.virtual_read = ['${name}.data()', '${name}.size()'] + if p.type.is_reference(): + p.write = [ + '*${name} = ${result}.data()', '*${size} = ${result}.size()' + ] + else: + p.write = ['std::copy(${result}.begin(), ${result}.end(), ${name})'] + + +@cwrap('std::string') +def string_c_wrap(p: Parameter) -> None: + t = Type('char*') + if p.returns: + if p.type.is_reference(): + p.add_param(t.add_pointer()) + p.bad_param('${name} == nullptr', 'Null pointer') + else: + p.add_param(t) + p.add_param('size_t', p.name + '_size') + p.bad_param('${name} == nullptr', 'Null pointer') + else: + p.add_param(t) + p.bad_param('${name} == nullptr', 'Null pointer') + + p.read = '${type}(${name})' + p.cpp_write = '${type}(${name})' + p.virtual_read = ['${name}.c_str()'] + if p.type.is_reference(): + p.write = ['*${name} = ${result}.c_str()'] + else: + p.write = [ + 'auto* it = std::copy_n(${result}.begin(), std::min(${result}.size(), ${name}_size - 1), ${name});' + '*it = \'\\0\'' + ] + + +class Handle: + def __init__(self, name: str, ctype: str, cpptype: str, **kwargs) -> None: + self.name = name + self.ctype = ctype + self.cpptype = cpptype + self.opaque_type = self.ctype + '_t' + self.cpp_class = CPPClass(name, ctype) + add_handle(name, ctype, cpptype, **kwargs) + cpp_type_map[cpptype] = name + + def cname(self, name: str) -> str: + return self.ctype + '_' + name + + def substitute(self, s: str, **kwargs) -> str: + return Template(s).safe_substitute(name=self.name, + ctype=self.ctype, + cpptype=self.cpptype, + opaque_type=self.opaque_type, + **kwargs) + + def constructor(self, + name: str, + params: Optional[List[Parameter]] = None, + fname: Optional[str] = None, + invoke: Optional[str] = None, + **kwargs) -> 'Handle': + create = self.substitute('allocate<${cpptype}>($@)') + if fname: + create = self.substitute('allocate<${cpptype}>(${fname}($@))', + fname=fname) + + f = add_function(self.cname(name), + params=params, + invoke=invoke or create, + returns=self.cpptype + '*', + return_name=self.name, + **kwargs) + self.cpp_class.add_constructor(name, f) + return self + + def method(self, + name: str, + params: Optional[List[Parameter]] = None, + fname: Optional[str] = None, + invoke: Optional[str] = None, + cpp_name: Optional[str] = None, + const: Optional[bool] = None, + **kwargs) -> 'Handle': + cpptype = self.cpptype + if const: + cpptype = Type(cpptype).add_const().str() + p = Parameter(self.name, cpptype) + args = to_template_vars(params or []) + f = add_function(self.cname(name), + params=[p] + (params or []), + invoke=invoke + or self.substitute('${var}.${fname}(${args})', + var=template_var(self.name), + fname=fname or name, + args=args), + **kwargs) + self.cpp_class.add_method(cpp_name or name, f) + return self + + def function(self, name, params=None, **kwargs): + add_function(self.cname(name), params=params, **kwargs) + return self + + def add_cpp_class(self) -> None: + cpp_classes.append(self.cpp_class) + + +interface_handle_definition = Template(''' +extern "C" struct ${ctype}; +struct ${ctype} { + template + ${ctype}(void* p, ${copier} c, ${deleter} d, Ts&&... xs) + : object_ptr(p, c, d), xobject(std::forward(xs)...) + {} + manage_generic_ptr<${copier}, ${deleter}> object_ptr = nullptr; + ${cpptype} xobject; + ${functions} +}; +''') + +c_api_virtual_impl = Template(''' +${return_type} ${name}(${params}) const +{ + ${output_decls} + if (${fname} == nullptr) + throw std::runtime_error("${name} function is missing."); + auto api_error_result = ${fname}(${args}); + if (api_error_result != ${success}) + throw std::runtime_error("Error in ${name}."); + return ${output}; +} +''') + + +def generate_virtual_impl(f: Function, fname: str) -> str: + success = success_type + name = f.name + return_type = 'void' + output_decls = '' + output = '' + largs = [] + lparams = [] + if f.returns: + return_type = f.returns.type.str() + output_decls = '\n'.join(f.returns.virtual_output_declarations()) + largs += f.returns.virtual_output_args() + output = f.returns.virtual_output() + largs += [arg for p in f.params for arg in p.virtual_arg()] + lparams += [p.virtual_param() for p in f.params if not p.this] + args = ', '.join(largs) + params = ', '.join(lparams) + return c_api_virtual_impl.substitute(locals()) + + +class Interface(Handle): + def __init__(self, name: str, ctype: str, cpptype: str) -> None: + super().__init__(name, ctype, cpptype, skip_def=True) + self.ifunctions: List[Function] = [] + self.members: List[str] = [] + + def mname(self, name: str) -> str: + return name + "_f" + + def constructor( # type: ignore + self, + name: str, + params: Optional[List[Parameter]] = None, + **kwargs) -> 'Interface': + create = self.substitute('allocate<${opaque_type}>($@)') + + initial_params = gparams(obj='void*', + c=self.cname('copy'), + d=self.cname('delete')) + + add_function(self.cname(name), + params=initial_params + (params or []), + invoke=create, + returns=self.opaque_type, + return_name=self.name, + **kwargs) + return self + + def method(self, *args, **kwargs) -> 'Interface': + super().method(*args, **kwargs) + return self + + def virtual(self, + name: str, + params: Optional[List[Parameter]] = None, + const: Optional[bool] = None, + **kwargs) -> 'Interface': + + # Add this parameter to the function + this = Parameter('obj', 'void*', this=True) + this.virtual_read = ['object_ptr.data'] + f = Function(name, + params=[this] + (params or []), + virtual=True, + **kwargs) + self.ifunctions.append(f) + + add_function(self.cname('set_' + name), + params=gparams(obj=self.opaque_type, + input=self.cname(name)), + invoke='${{obj}}->{name} = ${{input}}'.format( + name=self.mname(name))) + return self + + def generate_function(self, f: Function): + cname = self.cname(f.name) + mname = self.mname(f.name) + function = generate_virtual_impl(f, fname=mname) + return f"{cname} {mname} = nullptr;{function}" + + def generate(self): + required_functions = [ + Function('copy', + params=gparams(out='void**', input='void*'), + virtual=True), + Function('delete', params=gparams(input='void*'), virtual=True) + ] + for f in self.ifunctions + required_functions: + f.update() + c_header_preamble.extend([ + f.get_cfunction().generate_function_pointer(self.cname(f.name)) + for f in self.ifunctions + required_functions + ]) + function_list = [self.generate_function(f) for f in self.ifunctions] + ctype = self.ctype + cpptype = self.cpptype + copier = self.cname('copy') + deleter = self.cname('delete') + functions = '\n'.join(function_list) + + c_api_body_preamble.append( + interface_handle_definition.substitute(locals())) + + +def handle(ctype: str, + cpptype: str, + name: Optional[str] = None, + ref: Optional[bool] = None) -> Callable: + def with_handle(f): + n = name or f.__name__ + h = Handle(n, ctype, cpptype, ref=ref) + f(h) + h.add_cpp_class() + + @wraps(f) + def decorated(*args, **kwargs): + return f(*args, **kwargs) + + return decorated + + return with_handle + + +def interface(ctype: str, cpptype: str, + name: Optional[str] = None) -> Callable: + def with_interface(f): + n = name or f.__name__ + h = Interface(n, ctype, cpptype) + f(h) + h.generate() + + @wraps(f) + def decorated(*args, **kwargs): + return f(*args, **kwargs) + + return decorated + + return with_interface + + +def template_eval(template, **kwargs): + start = '<%' + end = '%>' + escaped = (re.escape(start), re.escape(end)) + mark = re.compile('%s(.*?)%s' % escaped, re.DOTALL) + for key in kwargs: + exec('%s = %s' % (key, kwargs[key])) + for item in mark.findall(template): + e = eval(item.strip()) + template = template.replace(start + item + end, str(e)) + return template + + +def run(args: List[str]) -> None: + runpy.run_path(args[0]) + if len(args) > 1: + f = open(args[1]).read() + r = template_eval(f) + sys.stdout.write(r) + else: + sys.stdout.write(generate_c_header()) + sys.stdout.write(generate_c_api_body()) + # sys.stdout.write(generate_cpp_header()) + + +if __name__ == "__main__": + sys.modules['api'] = sys.modules['__main__'] + run(sys.argv[1:]) diff --git a/tools/api/api.cpp b/tools/api/api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5dceb465e842539aa06f95d0e0177d407f9128f8 --- /dev/null +++ b/tools/api/api.cpp @@ -0,0 +1,253 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { + +template +migraphx_status try_(F f, bool output = true) // NOLINT +{ + try + { + f(); + } + catch(const migraphx::exception& ex) + { + if(output) + std::cerr << "MIGraphX Error: " << ex.what() << std::endl; + if(ex.error > 0) + return migraphx_status(ex.error); + else + return migraphx_status_unknown_error; + } + catch(const std::exception& ex) + { + if(output) + std::cerr << "MIGraphX Error: " << ex.what() << std::endl; + return migraphx_status_unknown_error; + } + catch(...) + { + return migraphx_status_unknown_error; + } + return migraphx_status_success; +} + +shape::type_t to_shape_type(migraphx_shape_datatype_t t) +{ + switch(t) + { + case migraphx_shape_tuple_type: return shape::tuple_type; +#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \ + case migraphx_shape_##x: return shape::x; + MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT) +#undef MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT + } + MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type"); +} + +migraphx_shape_datatype_t to_shape_type(shape::type_t t) +{ + switch(t) + { + case shape::tuple_type: return migraphx_shape_tuple_type; +#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \ + case shape::x: return migraphx_shape_##x; + MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT) +#undef MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT + } + MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type"); +} + +template +auto to_obj_vector(const T* x, std::size_t n) +{ + std::vectorobject)> result; + std::transform(x, x + n, std::back_inserter(result), [&](auto&& y) { return y->object; }); + return result; +} + +template +auto to_objptr_vector(const U* x, std::size_t n) +{ + std::vector result; + std::transform( + x, x + n, std::back_inserter(result), [&](auto&& y) { return std::addressof(y->object); }); + return result; +} + +target get_target(const std::string& name) { return make_target(name); } + +void set_offload_copy(compile_options& options, bool value) { options.offload_copy = value; } + +void set_fast_math(compile_options& options, bool value) { options.fast_math = value; } + +void set_file_format(file_options& options, const char* format) { options.format = format; } + +void set_default_dim_value(onnx_options& options, size_t value) +{ + options.default_dim_value = value; +} + +void set_default_loop_iterations(onnx_options& options, int64_t value) +{ + options.max_loop_iterations = value; +} + +void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; } + +void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; } + +void set_input_parameter_shape(onnx_options& options, + const char* name, + std::vector dims) +{ + options.map_input_dims[std::string(name)] = std::move(dims); +} + +void set_input_parameter_shape(tf_options& options, const char* name, std::vector dims) +{ + options.map_input_dims[std::string(name)] = std::move(dims); +} + +void set_output_names(tf_options& options, std::vector names) +{ + options.output_node_names = std::vector(names.begin(), names.end()); +} + +template +std::vector get_names(const std::unordered_map& m) +{ + std::vector result; + std::transform( + m.begin(), m.end(), std::back_inserter(result), [](auto&& p) { return p.first.c_str(); }); + return result; +} + +void quantize_fp16_with_op_names(program& prog, std::vector& names) +{ + if(names.empty()) + { + names = {"all"}; + } + + migraphx::quantize_fp16(prog, names); +} + +struct quantize_int8_options +{ + std::vector calibration = {}; + std::vector op_names = {}; +}; + +void add_op_name(quantize_int8_options& options, const char* name) +{ + options.op_names.push_back(name); +} + +void add_calibration_data(quantize_int8_options& options, parameter_map& data) +{ + options.calibration.push_back(data); +} + +void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& options) +{ + if(options.op_names.empty()) + { + options.op_names = {"dot", "convolution"}; + } + + migraphx::quantize_int8(prog, t, options.calibration, options.op_names); +} + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wformat-nonliteral" +#endif + +operation create_op(const char* name, const char* attributes, va_list vlist) +{ + std::string sattributes = attributes == nullptr ? "" : attributes; + std::vector buffer(sattributes.size() * 2); + std::vsnprintf(buffer.data(), buffer.size(), sattributes.c_str(), vlist); + value v = value::object{}; + if(attributes != nullptr) + { + v = from_json_string(convert_to_json(std::string(buffer.data()))); + } + auto op = make_op(name, v); + + return op; +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +template +bool equal(const T& x, const T& y) +{ + return x == y; +} + +std::vector run(program& p, const parameter_map& params) { return p.eval(params); } + +std::vector get_output_shapes(program& p) { return p.get_output_shapes(); } + +void print_program(const program& p) { std::cout << p << std::endl; } + +void print_module(const module& m) { std::cout << m << std::endl; } + +struct experimental_custom_op +{ + std::string name; + experimental_custom_op() = default; + + experimental_custom_op(std::string pname) : name(std::move(pname)) {} +}; + +template +struct custom_operation +{ + template + static auto reflect(Self&, F) + { + return pack(); + } + CustomOp op; + std::string name() const { return op.xobject.name; } + + shape compute_shape(std::vector inputs) const + { + return op.compute_shape(std::move(inputs)); + } + + argument compute(const std::vector&) const { MIGRAPHX_THROW("Not computable"); } +}; + +template +void register_custom_op(const CustomOp& op) +{ + register_op(custom_operation{op}); +} + +migraphx::context get_context(const program& p) { return p.get_context(); } + +} // namespace migraphx + +<% generate_c_api_body() %> diff --git a/tools/api/migraphx.h b/tools/api/migraphx.h new file mode 100644 index 0000000000000000000000000000000000000000..f41d2b91ef18887ae99da526498bae719dbc086a --- /dev/null +++ b/tools/api/migraphx.h @@ -0,0 +1,52 @@ +#ifndef MIGRAPHX_GUARD_C_API_MIGRAPHX_H +#define MIGRAPHX_GUARD_C_API_MIGRAPHX_H + +#include + +// Add new types here +// clang-format off +#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \ + m(bool_type, bool) \ + m(half_type, half) \ + m(float_type, float) \ + m(double_type, double) \ + m(uint8_type, uint8_t) \ + m(int8_type, int8_t) \ + m(uint16_type, uint16_t) \ + m(int16_type, int16_t) \ + m(int32_type, int32_t) \ + m(int64_type, int64_t) \ + m(uint32_type, uint32_t) \ + m(uint64_type, uint64_t) +// clang-format on + +#ifdef __cplusplus +extern "C" { +#endif + +// return code, more to be added later +typedef enum +{ + migraphx_status_success = 0, + migraphx_status_bad_param = 1, + migraphx_status_unknown_target = 3, + migraphx_status_unknown_error = 4, + +} migraphx_status; + +#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x, +/// An enum to represent the different data type inputs +typedef enum +{ + migraphx_shape_tuple_type, + MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) +} migraphx_shape_datatype_t; +#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES + +<% generate_c_header() %> + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/tools/build_and_test_onnxrt.sh b/tools/build_and_test_onnxrt.sh new file mode 100755 index 0000000000000000000000000000000000000000..faacb082c483fdb740016d058bb844bd09465f7f --- /dev/null +++ b/tools/build_and_test_onnxrt.sh @@ -0,0 +1,7 @@ +cd /onnxruntime +pip3 install -r requirements.txt +# Add newer cmake to the path +export PATH="/opt/cmake/bin:$PATH" +export CXXFLAGS="-D__HIP_PLATFORM_HCC__=1 -w" +./build.sh --config Release --update --build --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --test --use_migraphx +# pip3 install /code/onnxruntime/build/Linux/Release/dist/*.whl diff --git a/tools/download_models.sh b/tools/download_models.sh new file mode 100755 index 0000000000000000000000000000000000000000..27c473a8d6df81e4b301b85cc5cc98c71fc0543c --- /dev/null +++ b/tools/download_models.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +if [ -z "$ONNX_HOME" ] +then + ONNX_HOME=$HOME +fi + +model_dir=$ONNX_HOME/.onnx/models +tmp_dir=$ONNX_HOME/tmp/ +mkdir -p $model_dir +mkdir -p $tmp_dir +models="bvlc_alexnet \ + densenet121 \ + inception_v2 \ + shufflenet \ + vgg19 \ + zfnet512" + +for name in $models +do +curl https://s3.amazonaws.com/download.onnx/models/opset_9/$name.tar.gz --output $tmp_dir/$name.tar.gz +tar -xzvf $tmp_dir/$name.tar.gz --directory $model_dir && rm $tmp_dir/$name.tar.gz +done + diff --git a/tools/generate.sh b/tools/generate.sh index 6c864018e6ba18a0a973519dff4e507e5f8d8601..8b5b4a6bd79271f9a2bd0131a090e58e286a3980 100755 --- a/tools/generate.sh +++ b/tools/generate.sh @@ -1,2 +1,19 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "python3.6 $DIR/te.py $DIR/include/{} | clang-format-5.0 -style=file > $DIR/../src/include/migraphx/{}" +SRC_DIR=$DIR/../src +PYTHON=python3 +if type -p python3.6 > /dev/null ; then + PYTHON=python3.6 +fi +if type -p python3.8 > /dev/null ; then + PYTHON=python3.8 +fi +ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "$PYTHON $DIR/te.py $DIR/include/{} | clang-format-10 -style=file > $SRC_DIR/include/migraphx/{}" + +function api { + $PYTHON $DIR/api.py $SRC_DIR/api/migraphx.py $1 | clang-format-10 -style=file > $2 +} + +api $DIR/api/migraphx.h $SRC_DIR/api/include/migraphx/migraphx.h +echo "Finished generating header migraphx.h" +api $DIR/api/api.cpp $SRC_DIR/api/api.cpp +echo "Finished generating source api.cpp " diff --git a/tools/include/allocation_model.hpp b/tools/include/allocation_model.hpp new file mode 100755 index 0000000000000000000000000000000000000000..a0973f18805cc866a7c5825ef7c6f5871f64b668 --- /dev/null +++ b/tools/include/allocation_model.hpp @@ -0,0 +1,52 @@ +#ifndef MIGRAPHX_GUARD_ALLOCATION_MODEL_HPP +#define MIGRAPHX_GUARD_ALLOCATION_MODEL_HPP + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +#ifdef DOXYGEN + +/// An interface for target-dependent allocation +struct allocation_model +{ + /// A name of the target-dependent allocate operator + std::string name() const; + /// A name of the target-dependent copy operator + std::string copy() const; + /// Create an allocation operator for the given shape + operation allocate(const shape& s) const; + /// Create a preallocated operator for the given shape + operation preallocate(const shape& s, const std::string& id) const; + /// Check if outputs are to be inserted + bool needs_out_params() const; +}; + +#else + +<% +interface('allocation_model', + virtual('name', returns='std::string', const=True), + virtual('copy', returns='std::string', const=True), + virtual('allocate', s='const shape&', returns='operation', const=True), + virtual('preallocate', s='const shape&', id='std::string', returns='operation', const=True), + virtual('needs_out_params', returns='bool', const=True) +) +%> + +#endif + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/tools/include/concat_opt.hpp b/tools/include/concat_opt.hpp old mode 100644 new mode 100755 index 2cccd675a2d78ca746a2e021d676c5b6cc36d22f..f136e679239881e91aa07f84e61c8aa14389f7be --- a/tools/include/concat_opt.hpp +++ b/tools/include/concat_opt.hpp @@ -15,8 +15,6 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; - #ifdef DOXYGEN /// An interface for target-dependent optimization for the concat instruction diff --git a/tools/include/context.hpp b/tools/include/context.hpp index 51647979e47eadaa42901a0833ce30649c75c225..34846582a3eecde5d312305096cbb83d0205d046 100644 --- a/tools/include/context.hpp +++ b/tools/include/context.hpp @@ -8,6 +8,8 @@ #include #include #include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -25,11 +27,35 @@ struct context #else +template +value to_value_context(const T&) +{ + return value{}; +} + +template +void from_value_context(T&, const value&) +{ +} + +template +any_ptr get_queue_context(T&) +{ + return {}; +} + <% -interface('context', - virtual('finish', returns='void', const=True) -) -%> + interface('context', + virtual('to_value', returns = 'value', const = True, default = 'to_value_context'), + virtual('from_value', v = 'const value&', default = 'from_value_context'), + virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'), + virtual('finish', returns = 'void', const = True)) %> + + inline void migraphx_to_value(value& v, const context& ctx) +{ + v = ctx.to_value(); +} +inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); } #endif diff --git a/tools/include/marker.hpp b/tools/include/marker.hpp new file mode 100755 index 0000000000000000000000000000000000000000..49088fac40b7f2262de54a52cc5961b861da62e5 --- /dev/null +++ b/tools/include/marker.hpp @@ -0,0 +1,35 @@ +#ifndef MIGRAPHX_GUARD_MARKER_HPP +#define MIGRAPHX_GUARD_MARKER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +#ifdef DOXYGEN + +/// Marker is an interface to general marking functions, such as rocTX markers. + +#else + +<% +interface('marker', + virtual('mark_start', ins_ref = 'instruction_ref', returns = 'void'), + virtual('mark_start', prog = 'const program&', returns = 'void'), + virtual('mark_stop', ins = 'instruction_ref', returns = 'void'), + virtual('mark_stop', prog = 'const program&', returns = 'void') + ) %> +#endif + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/tools/include/operation.hpp b/tools/include/operation.hpp index 1b1ca05a10c1877f8fcb05254ce3b82e006ee883..7891eee55ab1af34651b6be6d34db03ecb9f3ba8 100644 --- a/tools/include/operation.hpp +++ b/tools/include/operation.hpp @@ -7,10 +7,15 @@ #include #include #include +#include #include #include +#include #include +#include +#include #include +#include #include namespace migraphx { @@ -57,6 +62,8 @@ struct operation /// Returns true if operation does not require a context to run compute bool is_context_free(const operation& x); +/// Returns true if operation needs normalization before running compute +bool need_normalization(const operation& x); /// Returns true if the operation has a finalize method bool has_finalize(const operation& x); @@ -96,7 +103,73 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) } // namespace operation_operators template -auto compute_op(rank<2>, +auto compute_shape_op(rank<3>, const T& x, const std::vector& inputs) + -> decltype(x.compute_shape(inputs)) +{ + return x.compute_shape(inputs); +} + +template +auto compute_shape_op(rank<2>, const T& x, const std::vector& inputs) + -> decltype(x.normalize_compute_shape(inputs)) +{ + dependent_type y = x; + normalize_attributes(y, inputs[0].lens()); + return any_cast(y).normalize_compute_shape(inputs); +} + +template +auto compute_shape_op(rank<1>, const T& x, const std::vector& inputs) + -> decltype(x.compute_shape(inputs, {})) +{ + return x.compute_shape(inputs, {}); +} + +template +shape compute_shape_op(rank<0>, const T& x, const std::vector&) +{ + std::string name = x.name(); + MIGRAPHX_THROW("Shape not computable: " + name); +} + +template +shape compute_shape_op(const T& x, const std::vector& inputs) +{ + return compute_shape_op(rank<3>{}, x, inputs); +} + +template +auto mod_compute_shape_op(rank<1>, + const T& x, + const std::vector& inputs, + const std::vector& mod_args) + -> decltype(x.compute_shape(inputs, mod_args)) +{ + return x.compute_shape(inputs, mod_args); +} + +template +shape mod_compute_shape_op(rank<0>, + const T& x, + const std::vector& inputs, + const std::vector& mod_args) +{ + if(mod_args.empty()) + return compute_shape_op(x, inputs); + std::string name = x.name(); + MIGRAPHX_THROW("Shape not computable: " + name); +} + +template +shape mod_compute_shape_op(const T& x, + const std::vector& inputs, + const std::vector& mod_args) +{ + return mod_compute_shape_op(rank<1>{}, x, inputs, mod_args); +} + +template +auto compute_op(rank<1>, const T& x, context& ctx, const shape& output_shape, @@ -106,14 +179,6 @@ auto compute_op(rank<2>, return x.compute(auto_any_cast(ctx), output_shape, input); } -template -auto compute_op( - rank<1>, const T& x, context&, const shape& output_shape, const std::vector& input) - -> decltype(x.compute(output_shape, input)) -{ - return x.compute(output_shape, input); -} - template argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector&) { @@ -125,35 +190,132 @@ template argument compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector& input) { - return compute_op(rank<2>{}, x, ctx, output_shape, input); + return compute_op(rank<1>{}, x, ctx, output_shape, input); } template -auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector& input) +auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector& input) -> decltype(x.compute(output_shape, input)) { return x.compute(output_shape, input); } template -auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector& input) - -> decltype(x.compute(auto_any_cast(std::declval()), output_shape, input)) +argument compute_op(rank<0>, const T& x, const shape&, const std::vector&) { std::string name = x.name(); - MIGRAPHX_THROW("Not computable without a context: " + name); + MIGRAPHX_THROW("Not computable: " + name); } template -argument compute_op(rank<0>, const T& x, const shape&, const std::vector&) +argument compute_op(const T& x, const shape& output_shape, const std::vector& input) +{ + return compute_op(rank<1>{}, x, output_shape, input); +} + +template +auto compute_op(rank<1>, + const T& x, + const shape& output, + const std::vector& inputs, + const std::vector& module_args, + F f) -> decltype(x.compute(output, inputs, module_args, f)) +{ + return x.compute(output, inputs, module_args, f); +} + +template +argument compute_op(rank<0>, + const T& x, + const shape&, + const std::vector&, + const std::vector&, + F) { std::string name = x.name(); MIGRAPHX_THROW("Not computable: " + name); } -template -argument compute_op(const T& x, const shape& output_shape, const std::vector& input) +template +argument compute_op(const T& x, + const shape& output, + const std::vector& inputs, + const std::vector& module_args, + F f) +{ + return compute_op(rank<1>{}, x, output, inputs, module_args, f); +} + +template +auto compute_op(rank<4>, + const T& x, + context& ctx, + const shape& output, + const std::vector& inputs, + const std::vector& module_args, + F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f)) +{ + return x.compute(auto_any_cast(ctx), output, inputs, module_args, f); +} + +template +auto compute_op(rank<3>, + const T& x, + context&, + const shape& output, + const std::vector& inputs, + const std::vector& module_args, + F f) -> decltype(x.compute(output, inputs, module_args, f)) +{ + return x.compute(output, inputs, module_args, f); +} + +template +auto compute_op(rank<2>, + const T& x, + context&, + const shape& output, + const std::vector& inputs, + const std::vector&, + F) -> decltype(x.compute(output, inputs)) +{ + return x.compute(output, inputs); +} + +template +auto compute_op(rank<1>, + const T& x, + context& ctx, + const shape& output, + const std::vector& inputs, + const std::vector&, + F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs)) +{ + return x.compute(auto_any_cast(ctx), output, inputs); +} + +template +argument compute_op(rank<0>, + const T& x, + context&, + const shape&, + const std::vector&, + const std::vector&, + F) +{ + std::string name = x.name(); + MIGRAPHX_THROW("Not computable: " + name); +} + +template +argument compute_op(const T& x, + context& ctx, + const shape& output, + const std::vector& inputs, + const std::vector& module_args, + F f) { - return compute_op(rank<2>{}, x, output_shape, input); + return compute_op(rank<4>{}, x, ctx, output, inputs, module_args, f); } template @@ -174,6 +336,20 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op( return {}; } +template +auto need_normalization_op(rank<1>, const T& x, const std::vector& inputs) + -> decltype(x.normalize_compute_shape(inputs), std::true_type{}); + +template +auto need_normalization_op(rank<0>, const T&, const std::vector&) -> std::false_type; + +template +auto need_normalization_op(const T& x) + -> decltype(need_normalization_op(rank<1>{}, x, std::declval>())) +{ + return {}; +} + template std::ptrdiff_t output_alias_op(const T&, const std::vector&) { @@ -218,6 +394,55 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, return {}; } +template +auto compile_op( + rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector& input) + -> decltype(x.compile(auto_any_cast(ctx), output_shape, input)) +{ + return x.compile(auto_any_cast(ctx), output_shape, input); +} + +template +value compile_op(rank<0>, T&, context&, const shape&, const std::vector&) +{ + return value::object{}; +} + +template +value compile_op(const T& x, + context& ctx, + const shape& output_shape, + const std::vector& input) +{ + return compile_op(rank<1>{}, x, ctx, output_shape, input); +} + +template +value attributes_op(const T&) +{ + return value::object{}; +} + +template +value to_value_op(const T& x) +{ + return migraphx::to_value(x); +} + +template +void from_value_op(T& x, const value& v) +{ + if(not(v.is_object() or (v.empty() and v.is_array()))) + MIGRAPHX_THROW("Value is not an object"); + return migraphx::from_value(v, x); +} + +template +lifetime get_lifetime_op(const T&) +{ + return lifetime::local; +} + } // namespace detail <% @@ -226,18 +451,40 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, virtual('name', returns = 'std::string', const = True), virtual( 'is_context_free', returns = 'bool', const = True, default = 'detail::is_context_free_op'), + virtual('need_normalization', + returns = 'bool', + const = True, + default = 'detail::need_normalization_op'), virtual('has_finalize', returns = 'bool', const = True, default = 'detail::has_finalize_op'), + virtual( + 'get_lifetime', returns = 'lifetime', const = True, default = 'detail::get_lifetime_op'), virtual('output_alias', returns = 'std::ptrdiff_t', input = 'const std::vector&', const = True, default = 'detail::output_alias_op'), + virtual('compile', + returns = 'value', + ctx = 'context&', + output = 'const shape&', + input = 'const std::vector&', + default = 'detail::compile_op'), virtual('finalize', ctx = 'context&', output = 'const shape&', input = 'const std::vector&', default = 'detail::finalize_op'), - virtual('compute_shape', returns = 'shape', input = 'const std::vector&', const = True), + virtual('compute_shape', + returns = 'shape', + input = 'const std::vector&', + const = True, + default = 'detail::compute_shape_op'), + virtual('compute_shape', + returns = 'shape', + inputs = 'const std::vector&', + mod_args = 'const std::vector&', + const = True, + default = 'detail::mod_compute_shape_op'), virtual('compute', returns = 'argument', ctx = 'context&', @@ -251,6 +498,30 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, input = 'const std::vector&', const = True, default = 'detail::compute_op'), + virtual( + 'compute', + returns = 'argument', + output = 'const shape&', + input = 'const std::vector&', + module_args = 'const std::vector&', + run = + 'std::function(module_ref&, const std::unordered_map&)>', + const = True, + default = 'detail::compute_op'), + virtual( + 'compute', + returns = 'argument', + ctx = 'context&', + output = 'const shape&', + input = 'const std::vector&', + module_args = 'const std::vector&', + run = + 'std::function(module_ref&, const std::unordered_map&)>', + const = True, + default = 'detail::compute_op'), + virtual('to_value', returns = 'value', const = True, default = 'detail::to_value_op'), + virtual('from_value', v = 'const value&', default = 'detail::from_value_op'), + virtual('attributes', returns = 'value', const = True, default = 'detail::attributes_op'), friend('operator<<', returns = 'std::ostream &', os = 'std::ostream &', @@ -267,6 +538,68 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{}, return !(x == y); } +inline value +compile(operation& op, context& ctx, const shape& output_shape, const std::vector& input) +{ + return op.compile(ctx, output_shape, input); +} +template +inline value +compile(operation& op, Context& ctx, const shape& output_shape, const std::vector& input) +{ + dependent_type ctx2 = std::ref(ctx); + return compile(op, ctx2, output_shape, input); +} +template +inline auto compile(T& op, Context& ctx, const shape& output_shape, const std::vector& input) + -> decltype(op.compile(ctx, ctx, output_shape, input)) +{ + return op.compile(ctx, ctx, output_shape, input); +} +inline shape compute_shape(const operation& op, const std::vector& inputs) +{ + return op.compute_shape(inputs); +} + +template +inline auto compute_shape(const T& op, const std::vector& inputs) + -> decltype(op.compute_shape(inputs)) +{ + return op.compute_shape(inputs); +} + +template +inline auto compute_shape(const T& op, const std::vector& inputs) + -> decltype(op.normalize_compute_shape(inputs)) +{ + return detail::compute_shape_op(op, inputs); +} + +inline shape compute_shape(const operation& op, + const std::vector& inputs, + const std::vector& mod_args) +{ + return op.compute_shape(inputs, mod_args); +} + +template +inline auto compute_shape(const T& op, + const std::vector& inputs, + const std::vector& mod_args) + -> decltype(op.compute_shape(inputs, mod_args)) +{ + return op.compute_shape(inputs, mod_args); +} + +template +inline auto compute_shape(const T& op, + const std::vector& inputs, + const std::vector& mod_args) + -> decltype(op.normalize_compute_shape(inputs, mod_args)) +{ + return detail::compute_shape_op(op, inputs, mod_args); +} + inline bool is_context_free(const operation& op) { return op.is_context_free(); } template @@ -275,6 +608,14 @@ bool is_context_free(const T& x) return detail::is_context_free_op(x); } +inline bool need_normalization(const operation& op) { return op.need_normalization(); } + +template +bool need_normalization(const T& x) +{ + return detail::need_normalization_op(x); +} + inline bool has_finalize(const operation& op) { return op.has_finalize(); } template @@ -283,6 +624,9 @@ bool has_finalize(const T& x) return detail::has_finalize_op(x); } +void migraphx_to_value(value& v, const operation& op); +void migraphx_from_value(const value& v, operation& op); + #endif } // namespace MIGRAPHX_INLINE_NS diff --git a/tools/include/pass.hpp b/tools/include/pass.hpp index a6a982b3b11b08c1abd4f0a9eeff8d5a209d78e9..63b5a9954136249e4ec602216ceb76d1b22d32c3 100644 --- a/tools/include/pass.hpp +++ b/tools/include/pass.hpp @@ -3,16 +3,19 @@ #include #include -#include #include #include #include +#include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { struct program; +struct module; +struct module_pass_manager; #ifdef DOXYGEN @@ -22,16 +25,44 @@ struct pass { /// A unique name used to identify the pass std::string name() const; + /// Run the pass on the module + void apply(module_pass_manager& mpm) const; + void apply(module& m) const; /// Run the pass on the program void apply(program& p) const; }; #else +module& get_module(module_pass_manager& mpm); + +namespace detail { + +template +auto module_pass_manager_apply(rank<1>, const T& x, module_pass_manager& mpm) + -> decltype(x.apply(get_module(mpm))) +{ + return x.apply(get_module(mpm)); +} + +template +void module_pass_manager_apply(rank<0>, const T&, module_pass_manager&) +{ +} + +template +void module_pass_manager_apply(const T& x, module_pass_manager& mpm) +{ + module_pass_manager_apply(rank<1>{}, x, mpm); +} + +} // namespace detail + <% interface('pass', virtual('name', returns='std::string', const=True), - virtual('apply', returns='void', p='program &', const=True) + virtual('apply', returns='void', mpm='module_pass_manager &', const=True, default='migraphx::detail::module_pass_manager_apply'), + virtual('apply', returns='void', p='program &', const=True, default='migraphx::nop') ) %> diff --git a/tools/include/schedule_model.hpp b/tools/include/schedule_model.hpp index e43bf6322ba998f68aacaa48971f3eb0c092ba0c..6a1935767a28727075c55ef50c58698ef7155cbf 100644 --- a/tools/include/schedule_model.hpp +++ b/tools/include/schedule_model.hpp @@ -15,7 +15,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -struct program; +struct module; struct operation; #ifdef DOXYGEN @@ -26,11 +26,11 @@ struct schedule_model /// Get the number of concurrent instruction allowed std::size_t concurrency() const; /// Schedule a concurrent instruction - void sched(program& p, instruction_ref ins, std::size_t n) const; + void sched(module& m, instruction_ref ins, std::size_t n) const; // Insert necessary waits before an instruction - void wait(program& p, instruction_ref ins, std::size_t wait_id) const; + void wait(module& m, instruction_ref ins, std::size_t wait_id) const; // Insert necessary records after an instruction - void record(program& p, instruction_ref ins, std::size_t wait_id) const; + void record(module& m, instruction_ref ins, std::size_t wait_id) const; /// Compute weights for an operation std::size_t weight(const operation& op) const; }; @@ -40,9 +40,9 @@ struct schedule_model <% interface('schedule_model', virtual('concurrency', returns='std::size_t', const=True), - virtual('sched', p='program&', ins='instruction_ref', n='std::size_t', const=True), - virtual('wait', p='program&', ins='instruction_ref', wait_id='std::size_t', const=True), - virtual('record', p='program&', ins='instruction_ref', wait_id='std::size_t', const=True), + virtual('sched', m='module&', ins='instruction_ref', n='std::size_t', const=True), + virtual('wait', m='module&', ins='instruction_ref', wait_id='std::size_t', const=True), + virtual('record', m='module&', ins='instruction_ref', wait_id='std::size_t', const=True), virtual('weight', returns='std::size_t', op='const operation&', const=True) ) %> diff --git a/tools/include/stream_model.hpp b/tools/include/stream_model.hpp new file mode 100644 index 0000000000000000000000000000000000000000..56fe8739dc1e1a1e731fdde894254f73a0c34105 --- /dev/null +++ b/tools/include/stream_model.hpp @@ -0,0 +1,55 @@ +#ifndef MIGRAPHX_GUARD_STREAM_MODEL_HPP +#define MIGRAPHX_GUARD_STREAM_MODEL_HPP + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +#ifdef DOXYGEN + +/// An interface for target-dependent model for the scheduler +struct stream_model +{ + /// Get the number of streams used in the program + std::size_t get_nstream() const; + /// Get stream for instruction + std::size_t get_stream(instruction_ref ins) const; + /// Get unique event id for instruction + std::size_t get_event_id(instruction_ref ins) const; + /// Returns true if instruction has a stream assignment + bool has_stream(instruction_ref ins) const; + /// Returns true if the instruction records the event + bool is_record(instruction_ref ins) const; + /// Returns true if the instruction wait on the event + bool is_wait(instruction_ref ins) const; +}; + +#else + +<% +interface('stream_model', + virtual('get_nstream', returns='std::size_t', const=True), + virtual('get_stream', ins='instruction_ref', returns='std::size_t', const=True), + virtual('get_event_id', ins='instruction_ref', returns='std::size_t', const=True), + virtual('has_stream', ins='instruction_ref', returns='bool', const=True), + virtual('is_record', ins='instruction_ref', returns='bool', const=True), + virtual('is_wait', ins='instruction_ref', returns='bool', const=True) +) +%> + +#endif + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/tools/install_prereqs.sh b/tools/install_prereqs.sh new file mode 100755 index 0000000000000000000000000000000000000000..65b5f74c2af047fe8f3f83e65e26f1292322c312 --- /dev/null +++ b/tools/install_prereqs.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# +# Build MIGraphX prerequisites for docker container + +set -e + +export LC_ALL=C.UTF-8 +export LANG=C.UTF-8 + + +# Need pip3 and Python headers to build dependencies +apt update && apt install -y python3-pip python3-dev cmake rocm-cmake rocblas miopen-hip openmp-extras + +# Needed for cmake to build various pip packages +pip3 install setuptools wheel + +# install rbuild to build dependencies +pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz + + +PREFIX=/usr/local +REQ_FILE_DIR="" +if [ "$#" -ge 2 ]; then + PREFIX=$1 + cd $2 +elif [ "$#" -eq 1 ]; then + PREFIX=$1 +fi + +echo "Dependencies are installed at $PREFIX" + +# Install deps with rbuild +rbuild prepare -d $PREFIX -s develop + +# install onnx package for unit tests +pip3 install onnx==1.8.1 numpy==1.18.5 typing==3.7.4 pytest==6.0.1 packaging==16.8 + +# pin version of protobuf in Python for onnx runtime unit tests +pip3 install protobuf==3.20.0 diff --git a/tools/roctx.py b/tools/roctx.py new file mode 100755 index 0000000000000000000000000000000000000000..e7e0d9f7d4bbaeef3556f83e13887171034ae784 --- /dev/null +++ b/tools/roctx.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 + +import json +import argparse +import os +from sys import argv as sysargs +from sys import version_info as python_version +from sys import exit as sys_exit +import pandas as pd +from datetime import datetime +import venv +import shutil + +if (python_version[0] < 3) or (python_version[0] < 3 + and python_version[1] < 6): + raise Exception("Please utilize Python version 3.6 and above. Exiting...") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Parser for MIGraphX ROCTX Markers") + parser.add_argument('--json-path', + type=str, + metavar='json-path', + help='Path to json file') + parser.add_argument('--out', + type=str, + metavar='out', + help='Output directory for run.') + parser.add_argument( + '--study-name', + type=str, + metavar='study-name', + help='Study-name is used for naming the output CSV file.') + parser.add_argument('--repeat', + type=int, + metavar='repeat', + help='Defines number of runs.', + default=2) + parser.add_argument('--parse', + default=False, + action='store_true', + help='Parses given JSON file.') + parser.add_argument('--clean', + default=False, + action='store_true', + help='Removes temporary paths') + parser.add_argument('--run', + type=str, + metavar='run', + help='Enables run and fetches run configs.') + parser.add_argument('--debug', default=False, action='store_true') + + args = parser.parse_args() + return args + + +args = parse_args() +if not len(sysargs) > 1: + raise Exception("No arg is passed. Exiting...") + + +def parse(file): + with open(file, "r") as read_file: + data = json.load(read_file) + + #Get marker names and first marker's time + list_names = [] + first_marker = True + first_marker_time = 0 + for i in data: + if (i): + if ("Marker start:" in i['name']) and ( + i['name'] not in list_names): + list_names.append(i['name']) + if first_marker: + first_marker_time = i['ts'] + first_marker = False + + if (args.debug): + print(f"FIRST MARKER TIME DETERMINED: {first_marker_time}") + + if (first_marker_time == 0): + raise ("FIRST MARKER TIME IS ZERO. EXITING...") + + kernel_launch_info = [] #kernel description + kernel_launch_list = [] #kernel launch details + kernel_launch_time = [] #kernel execution time + for i in data: + if (i and i.get('args')): + try: + if (("KernelExecution" in i['args']['desc']) + and (i['ts'] >= first_marker_time)): + kernel_launch_info.append(i['args']['desc']) + kernel_launch_list.append(i) + kernel_launch_time.append(int(i['dur'])) + except: + continue + + max_index = kernel_launch_time.index(max(kernel_launch_time)) + max_kernel_info = kernel_launch_list[max_index] + + if (args.debug): + with open('rocTX_kernel_launch_list.txt', 'w') as f: + for i in kernel_launch_list: + f.write(f'{i}') + + # Get timing information for each marker name + list_times_per_names = [] + for name in list_names: + temp_list = [] + for entry in data: + if (entry) and ( + name == entry['name'] + ): # name can match on gpu or cpu side, for gpu, we need data from gpu markers. + if (("gpu::" in name) + and ("UserMarker frame:" in entry['args']['desc']) + ): #gpu side information + temp_list.append(int(entry.get('dur'))) + elif (("gpu::" not in name) + and ("Marker start:" in entry['args']['desc']) + ): #cpu side information + temp_list.append(int(entry.get('dur'))) + list_times_per_names.append(temp_list) + + if (args.debug): + print(list_times_per_names) + + sum_per_name = [] #TODO: refactor stat collection + for list in list_times_per_names: + sum_per_name.append(sum(list)) + + count_per_name = [] + for list in list_times_per_names: + try: + count_per_name.append(len(list)) + except: + count_per_name.append(0) + + max_per_name = [] + for list in list_times_per_names: + try: + max_per_name.append(max(list)) + except: + max_per_name.append(0) + + min_per_name = [] + for list in list_times_per_names: + try: + min_per_name.append(min(list)) + except: + min_per_name.append(0) + + max_index_per_name = [] + for list in list_times_per_names: + try: + max_index_per_name.append(list.index(max(list))) + except: + max_index_per_name.append(0) + + max_occur_per_name = [] + for list in list_times_per_names: + try: + max_occur_per_name.append(list.count(max(list))) + except: + max_occur_per_name.append(0) + + total_time = sum(sum_per_name) + + d = { + 'SUM': sum_per_name, + 'MIN': min_per_name, + 'MAX': max_per_name, + 'COUNT': count_per_name, + 'MAX_INDEX': max_index_per_name, + 'MAX_OCCUR': max_occur_per_name + } + df2 = pd.DataFrame(d) + df2.index = list_names + df2.sort_values(by=['SUM'], inplace=True, ascending=False) + + if (args.debug): + print(df2) + print(f"\nTOTAL TIME: {total_time} us") + return df2, total_time, max_kernel_info + + +def run(): + repeat_count = args.repeat + if (repeat_count == 0 or repeat_count == float('inf') or not repeat_count): + raise Exception("REPEAT COUNT CANNOT BE ZERO/INFINITY/NULL") + run_args = args.run + #configurations + configs = '--hip-trace --roctx-trace --flush-rate 10ms --timestamp on' + output_dir = f"-d {args.out}" + executable = f"/opt/rocm/bin/migraphx-driver roctx {run_args}" + process_args = configs + ' ' + output_dir + ' ' + executable + for i in range(repeat_count): + os.system('rocprof ' + process_args) + print("RUN COMPLETE.") + + +def clean(): + shutil.rmtree('/tmp/rocm-profile-data/', ignore_errors=False) + + +def main(): + + if (args.clean): + clean() + sys_exit() + + print("Initiating virtual environment...") + builder = venv.EnvBuilder(clear=True, with_pip=True) + builder.create('/tmp/rocm-profile-data/py/') + python_bin = '/tmp/rocm-profile-data/py' + '/bin/python' + file = args.json_path + + if (args.study_name): + filename = args.study_name + ".csv" + else: + filename = "output" + datetime.now().strftime( + "%Y_%m_%d-%I:%M:%S_%p") + ".csv" + + with open(filename, 'a') as f: + f.write(f"{args.run}\n") + + if (args.run): + curr = os.path.abspath(os.getcwd()) + rpd_path = '/tmp/rocm-profile-data/rocmProfileData/' + if not os.path.exists(rpd_path): + print("rocmProfileData DOES NOT EXIST. CLONING...") + os.system( + f"git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData.git {rpd_path}" + ) + os.chdir(rpd_path + "rocpd_python/") + os.system(python_bin + ' -m pip install --upgrade pip') + os.system(python_bin + ' setup.py install') + os.chdir(curr) + run() + os.chdir(curr + f"/{args.out}/") + out_path = os.popen(f"ls -td $PWD/*/*/ | head -{args.repeat}").read() + print(f"\nFOLLOWING PATHS WILL BE PARSED:\n{out_path}") + out_path = out_path.splitlines() + df_tot = pd.DataFrame() + tot_time = [] + max_kernel_info_list = [] + for path in out_path: + path = path.strip('\n') + print("\nPARSING OUTPUT PATH: " + path) + os.chdir(path) + os.system( + f"{python_bin} -m rocpd.rocprofiler_import --ops_input_file hcc_ops_trace.txt --api_input_file hip_api_trace.txt --roctx_input_file roctx_trace.txt trace.rpd" + ) + os.system( + f"{python_bin} {rpd_path}/rpd2tracing.py trace.rpd trace.json") + os.chdir(curr) + df, total_time, path_max_kernel_info = parse(path + "trace.json") + max_kernel_info_list.append(path_max_kernel_info) + tot_time.append(total_time) + df_tot = pd.merge(df_tot, + df, + how='outer', + left_index=True, + right_index=True) + if (args.debug): + print("JSON FILE PATH: " + path + "trace.json") + + df_tot.to_csv("rocTX_runs_dataframe.csv") + if (args.debug): + print(df_tot) + + tmp_sum = df_tot.loc[:, df_tot.columns.str.contains('SUM')].astype(int) + tmp_min = df_tot.loc[:, df_tot.columns.str.contains('MIN')].astype(int) + tmp_max = df_tot.loc[:, df_tot.columns.str.match("^MAX_.$")].astype( + int) + tmp_count = df_tot.loc[:, df_tot.columns.str.match("COUNT")].astype( + int) + + tmp_sum['SUM_avg'] = tmp_sum.mean(axis=1).astype(int) + tmp_min['MIN_avg'] = tmp_min.mean(axis=1).astype(int) + tmp_max['MAX_avg'] = tmp_max.mean(axis=1).astype(int) + + df2 = tmp_sum['SUM_avg'].copy() + df2 = pd.merge(df2, + tmp_min['MIN_avg'], + how='outer', + left_index=True, + right_index=True) + df2 = pd.merge(df2, + tmp_max['MAX_avg'], + how='outer', + left_index=True, + right_index=True) + df2 = pd.merge(df2, + tmp_count['COUNT_x'], + how='outer', + left_index=True, + right_index=True) + df2.rename(columns={'COUNT_x': 'COUNT'}, inplace=True) + df2 = df2.loc[:, ~df2.columns.duplicated( + )] #there will be many COUNT_x in df2 + df2.sort_values(by=['SUM_avg'], inplace=True, ascending=False) + + if (args.debug): + pd.set_option('display.max_columns', None) + print(df_tot) #all data from all runs + + print("\n*** RESULTS ***") + print(df2) + out_time = sum(tot_time) / len(tot_time) + print(f"\nAVG TOTAL TIME: {out_time} us\n") + + df2.to_csv(filename, mode='a') + with open(filename, 'a') as f: + f.write(f"AVG TOTAL TIME: {out_time} us\n") + print(f"OUTPUT CSV FILE:\t{filename}") + + if (args.debug): + #kernels that took the longest time printed + for item in max_kernel_info_list: + print(f"KERNEL NAME: {item['name']}\t\t{item['dur']}") + + with open('rocTX_kernel_timing_details.txt', 'w') as f: + f.write( + "MOST TIME CONSUMING KERNELS IN EACH ITERATION (EXPECTED TO BE SAME KERNEL):\n" + ) + for i in max_kernel_info_list: + f.write(f"KERNEL NAME: {i['name']}\t\t{i['dur']}\n") + print("KERNEL TIMING DETAILS:\trocTX_kernel_timing_details.txt") + print("ALL DATA FROM ALL RUNS:\trocTX_runs_dataframe.csv") + + elif (args.parse): + if not (file): + raise Exception("JSON PATH IS NOT PROVIDED FOR PARSING.") + parse(file) + else: + raise Exception("PLEASE PROVIDE A COMMAND: RUN, PARSE, CLEAN") + + +if __name__ == "__main__": + main() diff --git a/tools/te.py b/tools/te.py old mode 100644 new mode 100755 index 4c51075b1b690e44919d1ae70e14f2bf2a822a02..f06612baa97b1b77ecc888ec24b8c9d80212f3f7 --- a/tools/te.py +++ b/tools/te.py @@ -1,4 +1,4 @@ -import string, sys, re, os +import string, sys, re trivial = ['std::size_t', 'instruction_ref'] @@ -12,16 +12,15 @@ headers = ''' ''' form = string.Template(''' +#ifdef TYPE_ERASED_DECLARATION -/* -* Type-erased interface for: -* -* struct ${struct_name} -* { -${comment_members} -* }; -* -*/ +// Type-erased interface for: +struct ${struct_name} +{ +${decl_members} +}; + +#else struct ${struct_name} { @@ -41,10 +40,17 @@ struct ${struct_name} template ${struct_name} & operator= (PrivateDetailTypeErasedT value) { - if (private_detail_te_handle_mem_var.unique()) - *private_detail_te_handle_mem_var = std::forward(value); - else if (!private_detail_te_handle_mem_var) - private_detail_te_handle_mem_var = std::make_shared(std::forward(value)); + using std::swap; + auto * derived = this->any_cast(); + if(derived and private_detail_te_handle_mem_var.unique()) + { + *derived = std::forward(value); + } + else + { + ${struct_name} rhs(value); + swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var); + } return *this; } @@ -52,7 +58,7 @@ struct ${struct_name} template PrivateDetailTypeErasedT * any_cast() { - return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) ? + return this->type_id() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type> &>(private_detail_te_get_handle()).private_detail_te_value) : nullptr; } @@ -60,7 +66,7 @@ struct ${struct_name} template const typename std::remove_cv::type * any_cast() const { - return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) ? + return this->type_id() == typeid(PrivateDetailTypeErasedT) ? std::addressof(static_cast::type> &>(private_detail_te_get_handle()).private_detail_te_value) : nullptr; } @@ -182,6 +188,7 @@ inline const ValueType & any_cast(const ${struct_name} & x) if (y == nullptr) throw std::bad_cast(); return *y; } +#endif ''') nonvirtual_member = string.Template(''' @@ -207,6 +214,10 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override comment_member = string.Template( '''* ${friend} ${return_type} ${name}(${params}) ${const};''') +decl_member = string.Template(''' ${comment} + ${friend} ${return_type} ${name}(${params}) ${const}; +''') + default_member = string.Template(''' template static auto private_detail_te_default_${name}(char, T&& private_detail_te_self ${comma} ${member_params}) @@ -272,7 +283,8 @@ def convert_member(d, struct_name): 'this': '(*this)', 'using': '', 'brief': '', - 'return_': '' + 'return_': '', + 'comment': '// ' } args = [] params = [] @@ -299,6 +311,7 @@ def convert_member(d, struct_name): member['friend'] = 'friend' elif x == 'default': member['default'] = t + member['comment'] = member['comment'] + '(optional)' elif x == 'using': member['using'] = 'using {};'.format(d[name]['using']) elif x == '__brief__': @@ -340,18 +353,21 @@ def generate_form(name, members): virtual_members = [] comment_members = [] default_members = [] + decl_members = [] for member in members: m = convert_member(member, name) nonvirtual_members.append(nonvirtual_member.substitute(m)) pure_virtual_members.append(pure_virtual_member.substitute(m)) virtual_members.append(virtual_member.substitute(m)) comment_members.append(comment_member.substitute(m)) + decl_members.append(decl_member.substitute(m)) if 'default' in m: default_members.append(default_member.substitute(m)) return form.substitute(nonvirtual_members=''.join(nonvirtual_members), pure_virtual_members=''.join(pure_virtual_members), virtual_members=''.join(virtual_members), default_members=''.join(default_members), + decl_members=''.join(decl_members), comment_members='\n'.join(comment_members), struct_name=name) @@ -379,7 +395,7 @@ def template_eval(template, **kwargs): escaped = (re.escape(start), re.escape(end)) mark = re.compile('%s(.*?)%s' % escaped, re.DOTALL) for key in kwargs: - exec ('%s = %s' % (key, kwargs[key])) + exec('%s = %s' % (key, kwargs[key])) for item in mark.findall(template): template = template.replace(start + item + end, str(eval(item.strip()))) diff --git a/tools/test_runner.py b/tools/test_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..6170a6f1199a6f228268e8ba1bc7358a6d9bd087 --- /dev/null +++ b/tools/test_runner.py @@ -0,0 +1,275 @@ +import os, sys +import numpy as np +import argparse +import onnx +from onnx import numpy_helper +import migraphx + + +def parse_args(): + parser = argparse.ArgumentParser(description="MIGraphX test runner") + parser.add_argument('test_dir', + type=str, + metavar='test_loc', + help='folder where the test is stored') + parser.add_argument('--target', + type=str, + default='gpu', + help='Specify where the tests execute (ref, gpu)') + args = parser.parse_args() + + return args + + +def get_sub_folders(dir_name): + dir_contents = os.listdir(dir_name) + folders = [] + for item in dir_contents: + tmp_item = dir_name + '/' + item + if os.path.isdir(tmp_item): + folders.append(item) + folders.sort() + + return folders + + +def get_test_cases(dir_name): + return get_sub_folders(dir_name) + + +def get_model_name(dir_name): + dir_contents = os.listdir(dir_name) + for item in dir_contents: + file_name = dir_name + '/' + item + if os.path.isfile(file_name) and file_name.endswith('.onnx'): + return item + + return '' + + +def read_pb_file(filename): + with open(filename, 'rb') as pfile: + data_str = pfile.read() + tensor = onnx.TensorProto() + tensor.ParseFromString(data_str) + np_array = numpy_helper.to_array(tensor) + + return tensor.name, np_array + + +def wrapup_inputs(io_folder, param_names): + param_map = {} + data_array = [] + name_array = [] + for i in range(len(param_names)): + file_name = io_folder + '/input_' + str(i) + '.pb' + name, data = read_pb_file(file_name) + param_map[name] = data + data_array.append(data) + if name: + name_array.append(name) + + if len(name_array) < len(data_array): + param_map = {} + for i in range(len(param_names)): + param_map[param_names[i]] = data_array[i] + + return param_map + + for name in param_names: + if not name in param_map.keys(): + print("Input {} does not exist!".format(name)) + sys.exit() + + return param_map + + +def read_outputs(io_folder, out_names): + outputs = [] + data_array = [] + name_array = [] + for i in range(len(out_names)): + file_name = io_folder + '/output_' + str(i) + '.pb' + name, data = read_pb_file(file_name) + data_array.append(data) + if name: + name_array.append(name) + + if len(name_array) < len(data_array): + return data_array + + for name in out_names: + index = name_array.index(name) + outputs.append(data_array[index]) + + return outputs + + +def model_parameter_names(model_file_name): + with open(model_file_name, 'rb') as pfile: + data_str = pfile.read() + model_proto = onnx.ModelProto() + model_proto.ParseFromString(data_str) + init_names = set([(i.name) for i in model_proto.graph.initializer]) + param_names = [ + input.name for input in model_proto.graph.input + if input.name not in init_names + ] + + return param_names + + +def model_output_names(model_file_name): + with open(model_file_name, 'rb') as pfile: + data_str = pfile.read() + model_proto = onnx.ModelProto() + model_proto.ParseFromString(data_str) + output_names = [out.name for out in model_proto.graph.output] + + return output_names + + +def get_input_shapes(sample_case, param_names): + param_shape_map = {} + name_array = [] + shape_array = [] + for i in range(len(param_names)): + file_name = sample_case + '/input_' + str(i) + '.pb' + name, data = read_pb_file(file_name) + param_shape_map[name] = data.shape + shape_array.append(data.shape) + if name: + name_array.append(name) + + if len(name_array) < len(shape_array): + param_shape_map = {} + for i in range(len(param_names)): + param_shape_map[param_names[i]] = shape_array[i] + + return param_shape_map + + for name in param_names: + if not name in param_shape_map: + print("Input {} does not exist!".format(name)) + sys.exit() + + return param_shape_map + + +def run_one_case(model, param_map): + # convert np array to model argument + pp = {} + for key, val in param_map.items(): + pp[key] = migraphx.argument(val) + + # run the model + model_outputs = model.run(param_map) + + # convert argument to np array + outputs = [] + for output in model_outputs: + outputs.append(np.array(output)) + + return outputs + + +def check_correctness(gold_outputs, outputs, rtol=1e-3, atol=1e-3): + if len(gold_outputs) != len(outputs): + print("Number of outputs {} is not equal to expected number {}".format( + len(outputs), len(gold_outputs))) + return False + + out_num = len(gold_outputs) + ret = True + for i in range(out_num): + if not np.allclose(gold_outputs[i], outputs[i], rtol, atol): + print("\nOutput {} is incorrect ...".format(i)) + print("Expected value: \n{}".format(gold_outputs[i])) + print("......") + print("Actual value: \n{}\n".format(outputs[i])) + ret = False + + return ret + + +def tune_input_shape(model, input_data): + param_shapes = model.get_parameter_shapes() + input_shapes = {} + for name, s in param_shapes.items(): + assert name in input_data + data_shape = list(input_data[name].shape) + if not np.array_equal(data_shape, s.lens()): + input_shapes[name] = data_shape + + return input_shapes + + +def main(): + args = parse_args() + test_loc = args.test_dir + target = args.target + + test_name = os.path.basename(os.path.normpath(test_loc)) + + print("Running test \"{}\" on target \"{}\" ...\n".format( + test_name, target)) + + # get model full path + model_name = get_model_name(test_loc) + model_path_name = test_loc + '/' + model_name + + # get param names + param_names = model_parameter_names(model_path_name) + + # get output names + output_names = model_output_names(model_path_name) + + # get test cases + cases = get_test_cases(test_loc) + sample_case = test_loc + '/' + cases[0] + param_shapes = get_input_shapes(sample_case, param_names) + for name, dims in param_shapes.items(): + print("Input: {}, shape: {}".format(name, dims)) + print() + + # read and compile model + model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes) + model.compile(migraphx.get_target(target)) + + # get test cases + case_num = len(cases) + correct_num = 0 + for case_name in cases: + io_folder = test_loc + '/' + case_name + input_data = wrapup_inputs(io_folder, param_names) + gold_outputs = read_outputs(io_folder, output_names) + + # if input shape is different from model shape, reload and recompile + # model + input_shapes = tune_input_shape(model, input_data) + if not len(input_shapes) == 0: + model = migraphx.parse_onnx(model_path_name, + map_input_dims=input_shapes) + model.compile(migraphx.get_target(target)) + + # run the model and return outputs + output_data = run_one_case(model, input_data) + + # check output correctness + ret = check_correctness(gold_outputs, output_data) + if ret: + correct_num += 1 + + output_str = "PASSED" if ret else "FAILED" + print("\tCase {}: {}".format(case_name, output_str)) + + print("\nTest \"{}\" has {} cases:".format(test_name, case_num)) + print("\t Passed: {}".format(correct_num)) + print("\t Failed: {}".format(case_num - correct_num)) + if case_num > correct_num: + error_num = case_num - correct_num + raise ValueError(str(error_num) + " cases failed!") + + +if __name__ == "__main__": + main()