Commit b878f78f authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into rewrite-fast-gelu

parents 3b414cc2 55cb7d3a
...@@ -46,7 +46,7 @@ else ...@@ -46,7 +46,7 @@ else
fi fi
# do the formatting # do the formatting
for file in $(git diff-index --cached --name-only $against | grep -E '\.h$|\.hpp$|\.cpp$|\.cl$|\.h\.in$|\.hpp\.in$|\.cpp\.in$|\.py$') for file in $(git diff-index --cached --name-only $against | grep -E '\.h$|\.hpp$|\.cpp$|\.cl$|\.c$|\.h\.in$|\.hpp\.in$|\.cpp\.in$|\.py$')
do do
if [ -e "$file" ] if [ -e "$file" ]
then then
......
name: MiGraphX Benchmark
on:
workflow_dispatch:
jobs:
benchmark:
uses: ROCmSoftwarePlatform/actions/.github/workflows/benchmarks.yml@main
with:
rocm_version: 5.2
script_repo: migraphx-benchmark/benchmark-utils
result_path: /usr/share/migraphx/test-results
result_repo: ROCmSoftwarePlatform/comparison-results
secrets:
gh_token: ${{ secrets.MIGRAPHX_BOT_TOKEN }}
...@@ -133,6 +133,7 @@ jobs: ...@@ -133,6 +133,7 @@ jobs:
-o -iname '*.hpp.in' \ -o -iname '*.hpp.in' \
-o -iname '*.cpp.in' \ -o -iname '*.cpp.in' \
-o -iname '*.cl' \ -o -iname '*.cl' \
-o -iname '*.c' \
| grep -v 'build/' \ | grep -v 'build/' \
| xargs -n 1 -P 1 -I{} -t sh -c 'clang-format-10 -style=file {} | diff - {}' | xargs -n 1 -P 1 -I{} -t sh -c 'clang-format-10 -style=file {} | diff - {}'
find . -iname '*.py' \ find . -iname '*.py' \
...@@ -269,4 +270,98 @@ jobs: ...@@ -269,4 +270,98 @@ jobs:
curl -s https://codecov.io/bash | bash curl -s https://codecov.io/bash | bash
echo "Uploaded" echo "Uploaded"
linux-fpga:
continue-on-error: true
runs-on: ${{ matrix.os }}
env:
CCACHE_COMPRESSLEVEL: 10
CCACHE_DIR: ${{github.workspace}}/ccache
CCACHE_NOHASHDIR: true
CCACHE_BASEDIR: ${{github.workspace}}
CCACHE_MAXSIZE: 1
strategy:
matrix:
os:
- ubuntu-18.04
- ubuntu-20.04
configuration:
- debug
#- release Uncomment when ready to test release builds
#- codecov Uncomment when ready for codecov
steps:
- name: Free space
run: sudo rm -rf /usr/local/android /usr/share/dotnet /usr/local/share/boost /opt/ghc /usr/local/share/chrom* /usr/share/swift /usr/local/julia* /usr/local/lib/android
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.6
- name: Cache dependencies
# Ignore the failure of a step and avoid terminating the job.
continue-on-error: true
uses: actions/cache@v2
with:
# This path is specific to Ubuntu
path: ${{ github.workspace }}/cget
# Look to see if there is a cache hit for the corresponding requirements file
key:
${{ matrix.os }}-cget-4-${{ hashFiles('requirements.txt', 'dev-requirements.txt') }}
${{ matrix.os }}-cget-4-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz
rbuild prepare -d cget -s gh
- name: Prepare timestamp
id: cache_timestamp
shell: cmake -P {0}
run: |
string(TIMESTAMP current_date "%Y-%m-%d-%H;%M;%S" UTC)
message("::set-output name=timestamp::${current_date}")
- name: Cache files for ccache
# Ignore the failure of a step and avoid terminating the job.
continue-on-error: true
uses: pat-s/always-upload-cache@v2.1.3
with:
path: ccache
key: ${{ matrix.os }}-${{ matrix.configuration }}-ccache-${{ steps.cache_timestamp.outputs.timestamp }}
restore-keys: |
${{ matrix.os }}-${{ matrix.configuration }}-ccache-${{ steps.cache_timestamp.outputs.timestamp }}
${{ matrix.os }}-${{ matrix.configuration }}-ccache-
- name: Build and test
env:
CMAKE_PREFIX_PATH: ${{ github.workspace }}/cget
CCACHE_LOGFILE: /tmp/ccache.log
CXXFLAGS: -Werror -pthread --param ggc-min-expand=5 --param ggc-min-heapsize=8192
run: |
echo "leak:dnnl::impl::malloc" > suppressions.txt
export LSAN_OPTIONS="suppressions=$(pwd)/suppressions.txt"
rbuild build -d cget -s gh -T check \
-DCMAKE_BUILD_TYPE=${{matrix.configuration}} \
-DMIGRAPHX_ENABLE_PYTHON=${{matrix.configuration == 'release' && 'On' || 'Off'}} \
-DCMAKE_CXX_FLAGS_DEBUG="-g1 -Os -fdebug-prefix-map=$PWD=. -fdebug-types-section -fno-omit-frame-pointer -fsanitize=undefined -fno-sanitize-recover=undefined" \
-DCMAKE_CXX_FLAGS_CODECOV="-g1 -Og -fdebug-prefix-map=$PWD=. -fdebug-types-section -fprofile-arcs -ftest-coverage -fno-omit-frame-pointer" \
-DCMAKE_EXE_LINKER_FLAGS='-fuse-ld=gold' \
-DCMAKE_SHARED_LINKER_FLAGS='-fuse-ld=gold' \
-DMIGRAPHX_ENABLE_FPGA=On
${{ github.workspace }}/cget/bin/ccache -s
#- name: Upload code coverage
# if: "matrix.configuration == 'codecov'"
# env:
# CODECOV_TOKEN: "8545af1c-f90b-4345-92a5-0d075503ca56"
# run: |
# sudo apt-get install -y lcov
# cd build
# lcov --directory . --capture --output-file $(pwd)/coverage.info
# lcov --remove $(pwd)/coverage.info '/usr/*' --output-file $(pwd)/coverage.info
# lcov --list $(pwd)/coverage.info
# curl -s https://codecov.io/bash | bash
# echo "Uploaded"
\ No newline at end of file
name: MIGraphX Performance Tests
on:
push:
branches: [develop]
pull_request:
branches: [develop]
schedule:
- cron: "0 5 * * 1-6"
workflow_dispatch:
inputs:
rocm_release:
description: ROCm Version
required: true
default: '5.2'
performance_reports_repo:
description: Result repository
required: true
default: 'ROCmSoftwarePlatform/migraphx-reports'
result_number:
description: Last N results
required: true
default: '10'
flags:
description: -m for Max value; -s for Std dev; -r for Threshold file
required: true
default: '-s'
concurrency: benchmark
jobs:
release:
uses: rocmsoftwareplatform/migraphx-benchmark/.github/workflows/perf-test.yml@main
with:
rocm_release: ${{ github.event.inputs.rocm_release || '5.2' }}
result_number: ${{ github.event.inputs.result_number || '10' }}
flags: ${{ github.event.inputs.flags || '-s' }}
performance_reports_repo: ${{ github.event.inputs.performance_reports_repo || 'ROCmSoftwarePlatform/migraphx-reports' }}
secrets:
gh_token: ${{ secrets.MIGRAPHX_BOT_TOKEN }}
mail_user: ${{ secrets.MAIL_USERNAME }}
mail_pass: ${{ secrets.MAIL_PASSWORD }}
name: ROCM Docker image build
on:
workflow_dispatch:
inputs:
rocm_release:
description: ROCm release version
required: true
jobs:
release:
uses: ROCmSoftwarePlatform/actions/.github/workflows/rocm-release.yml@main
with:
rocm_release: ${{ github.event.inputs.rocm_release }}
secrets:
gh_token: ${{ secrets.MIGRAPHX_BOT_TOKEN }}
...@@ -61,8 +61,6 @@ check_type_size("half_float::detail::expr" HALF_EXPR LANGUAGE CXX) ...@@ -61,8 +61,6 @@ check_type_size("half_float::detail::expr" HALF_EXPR LANGUAGE CXX)
set(CMAKE_REQUIRED_INCLUDES) set(CMAKE_REQUIRED_INCLUDES)
set(CMAKE_EXTRA_INCLUDE_FILES) set(CMAKE_EXTRA_INCLUDE_FILES)
find_package(nlohmann_json 3.8.0 REQUIRED)
include(ROCMSetupVersion) include(ROCMSetupVersion)
rocm_setup_version(VERSION 2.3) rocm_setup_version(VERSION 2.3)
...@@ -82,8 +80,11 @@ endif() ...@@ -82,8 +80,11 @@ endif()
# Disable cpu backend by default # Disable cpu backend by default
set(MIGRAPHX_ENABLE_CPU Off CACHE BOOL "") set(MIGRAPHX_ENABLE_CPU Off CACHE BOOL "")
# Disable fpga backend by default
set(MIGRAPHX_ENABLE_FPGA Off CACHE BOOL "")
set(CMAKE_CXX_STANDARD_DEFAULT "") set(CMAKE_CXX_STANDARD_DEFAULT "")
add_compile_options(-std=c++17) add_compile_options($<$<COMPILE_LANGUAGE:CXX>:-std=c++17>)
if(${CMAKE_VERSION} VERSION_LESS "3.12.0") if(${CMAKE_VERSION} VERSION_LESS "3.12.0")
set(CONFIGURE_DEPENDS) set(CONFIGURE_DEPENDS)
...@@ -253,14 +254,6 @@ rocm_enable_cppcheck( ...@@ -253,14 +254,6 @@ rocm_enable_cppcheck(
enable_testing() enable_testing()
include(ROCMCreatePackage) include(ROCMCreatePackage)
rocm_create_package(
NAME MIGraphX
DESCRIPTION "AMD's graph optimizer"
MAINTAINER "Paul Fultz II <paul.fultz@amd.com>"
LDCONFIG
PTH
DEPENDS miopen-hip rocblas hip-rocclr hip-base half
)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
...@@ -277,3 +270,12 @@ foreach(py_file ${backend_files}) ...@@ -277,3 +270,12 @@ foreach(py_file ${backend_files})
configure_file(${py_file} ${DEST_DIR}/lib/onnx_migraphx/. COPYONLY) configure_file(${py_file} ${DEST_DIR}/lib/onnx_migraphx/. COPYONLY)
endforeach(py_file) endforeach(py_file)
configure_file(${CMAKE_SOURCE_DIR}/test/py/onnx_backend_test.py ${DEST_DIR}/onnx_backend_test.py COPYONLY) configure_file(${CMAKE_SOURCE_DIR}/test/py/onnx_backend_test.py ${DEST_DIR}/onnx_backend_test.py COPYONLY)
rocm_create_package(
NAME MIGraphX
DESCRIPTION "AMD's graph optimizer"
MAINTAINER "AMDMIGraphX Maintainer <migraphx-lib.support@amd.com>"
LDCONFIG
PTH
DEPENDS miopen-hip rocblas hip-rocclr hip-base half
)
...@@ -77,7 +77,7 @@ RUN cget -p $PREFIX install ccache@v4.1 ...@@ -77,7 +77,7 @@ RUN cget -p $PREFIX install ccache@v4.1
RUN cget -p /opt/cmake install kitware/cmake@v3.13.4 RUN cget -p /opt/cmake install kitware/cmake@v3.13.4
ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG ONNXRUNTIME_BRANCH=master ARG ONNXRUNTIME_BRANCH=main
ARG ONNXRUNTIME_COMMIT=24f1bd6156cf5968bbc76dfb0e801a9b9c56b9fc ARG ONNXRUNTIME_COMMIT=24f1bd6156cf5968bbc76dfb0e801a9b9c56b9fc
RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime && \ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime && \
cd onnxruntime && \ cd onnxruntime && \
...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR ...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@26a4b3cfc0a1a15181490f24ae461608fef1b04e -DBUILD_MIXR_TARGET=On RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@d2cb9e580550e92ab75a0a417e7a4abd02a24edf -DBUILD_MIXR_TARGET=On
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
...@@ -93,7 +93,7 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build -> ...@@ -93,7 +93,7 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
stage('Hip Clang Debug') { stage('Hip Clang Debug') {
def sanitizers = "undefined" def sanitizers = "undefined"
def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'") cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'")
} }
}, clang_release: rocmnode('vega') { cmake_build -> }, clang_release: rocmnode('vega') { cmake_build ->
stage('Hip Clang Release') { stage('Hip Clang Release') {
...@@ -104,13 +104,13 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build -> ...@@ -104,13 +104,13 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
stage('MLIR Debug') { stage('MLIR Debug') {
def sanitizers = "undefined" def sanitizers = "undefined"
def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'") cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'")
} }
}, clang_asan: rocmnode('nogpu') { cmake_build -> }, clang_asan: rocmnode('nogpu') { cmake_build ->
stage('Clang ASAN') { stage('Clang ASAN') {
def sanitizers = "undefined,address" def sanitizers = "undefined,address"
def debug_flags = "-g -O2 -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" def debug_flags = "-g -O2 -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=Off -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'") cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=Off -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'")
} }
}//, clang_release_navi: rocmnode('navi21') { cmake_build -> }//, clang_release_navi: rocmnode('navi21') { cmake_build ->
// stage('HIP Clang Release Navi') { // stage('HIP Clang Release Navi') {
......
...@@ -46,6 +46,7 @@ The following is a list of prerequisites required to build MIGraphX source. ...@@ -46,6 +46,7 @@ The following is a list of prerequisites required to build MIGraphX source.
* [pybind11](https://pybind11.readthedocs.io/en/stable/) - for python bindings * [pybind11](https://pybind11.readthedocs.io/en/stable/) - for python bindings
* [JSON](https://github.com/nlohmann/json) - for model serialization to json string format * [JSON](https://github.com/nlohmann/json) - for model serialization to json string format
* [MessagePack](https://msgpack.org/index.html) - for model serialization to binary format * [MessagePack](https://msgpack.org/index.html) - for model serialization to binary format
* [SQLite3](https://www.sqlite.org/index.html) - to create database of kernels' tuning information or execute queries on existing database
#### Use the ROCm build tool [rbuild](https://github.com/RadeonOpenCompute/rbuild). #### Use the ROCm build tool [rbuild](https://github.com/RadeonOpenCompute/rbuild).
......
File mode changed from 100755 to 100644
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <algorithm> #include <algorithm>
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <rocblas.h> #include <rocblas/rocblas.h>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> // MIGraphX's C++ API #include <migraphx/migraphx.hpp> // MIGraphX's C++ API
#include <numeric> #include <numeric>
...@@ -56,11 +56,13 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base ...@@ -56,11 +56,13 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
migraphx::arguments args) const override migraphx::arguments args) const override
{ {
// create rocblas stream handle // create rocblas stream handle
auto rocblas_handle = create_rocblas_handle_ptr(ctx); auto rb_handle = create_rocblas_handle_ptr(ctx);
MIGRAPHX_ROCBLAS_ASSERT(rocblas_set_pointer_mode(rb_handle, rocblas_pointer_mode_device));
rocblas_int n = args[1].get_shape().lengths()[0]; rocblas_int n = args[1].get_shape().lengths()[0];
float* alpha = reinterpret_cast<float*>(args[0].data()); float* alpha = reinterpret_cast<float*>(args[0].data());
float* vec_ptr = reinterpret_cast<float*>(args[1].data()); float* vec_ptr = reinterpret_cast<float*>(args[1].data());
MIGRAPHX_ROCBLAS_ASSERT(rocblas_sscal(rocblas_handle, n, alpha, vec_ptr, 1)); MIGRAPHX_ROCBLAS_ASSERT(rocblas_sscal(rb_handle, n, alpha, vec_ptr, 1));
MIGRAPHX_ROCBLAS_ASSERT(rocblas_destroy_handle(rb_handle));
return args[1]; return args[1];
} }
......
...@@ -23,7 +23,8 @@ ...@@ -23,7 +23,8 @@
##################################################################################### #####################################################################################
google/protobuf@v3.11.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off google/protobuf@v3.11.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off
nlohmann/json@v3.8.0 nlohmann/json@v3.8.0
blaze,https://bitbucket.org/blaze-lib/blaze/get/f0755dea0e03.tar.gz -X header -DHEADER_DIR=blaze live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212
half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969 half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
...@@ -65,6 +65,7 @@ add_library(migraphx ...@@ -65,6 +65,7 @@ add_library(migraphx
operation.cpp operation.cpp
opt/memory_coloring.cpp opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp opt/memory_coloring_impl.cpp
pad_calc.cpp
pass_manager.cpp pass_manager.cpp
permutation.cpp permutation.cpp
preallocate_param.cpp preallocate_param.cpp
...@@ -79,6 +80,7 @@ add_library(migraphx ...@@ -79,6 +80,7 @@ add_library(migraphx
register_target.cpp register_target.cpp
replace_allocate.cpp replace_allocate.cpp
simplify_qdq.cpp simplify_qdq.cpp
sqlite.cpp
rewrite_batchnorm.cpp rewrite_batchnorm.cpp
rewrite_gelu.cpp rewrite_gelu.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
...@@ -135,6 +137,7 @@ register_migraphx_ops( ...@@ -135,6 +137,7 @@ register_migraphx_ops(
exp exp
flatten flatten
floor floor
fmod
gather gather
gathernd gathernd
get_tuple_elem get_tuple_elem
...@@ -157,6 +160,7 @@ register_migraphx_ops( ...@@ -157,6 +160,7 @@ register_migraphx_ops(
lstm lstm
max max
min min
mod
mul mul
multibroadcast multibroadcast
multinomial multinomial
...@@ -240,6 +244,13 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU ...@@ -240,6 +244,13 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU
find_package(Threads) find_package(Threads)
target_link_libraries(migraphx PUBLIC Threads::Threads) target_link_libraries(migraphx PUBLIC Threads::Threads)
find_package(nlohmann_json 3.8.0 REQUIRED)
target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json)
find_package(PkgConfig)
pkg_check_modules(SQLITE3 REQUIRED IMPORTED_TARGET sqlite3)
target_link_libraries(migraphx PRIVATE PkgConfig::SQLITE3)
find_package(msgpack REQUIRED) find_package(msgpack REQUIRED)
target_link_libraries(migraphx PRIVATE msgpackc-cxx) target_link_libraries(migraphx PRIVATE msgpackc-cxx)
# Make this available to the tests # Make this available to the tests
...@@ -268,6 +279,11 @@ add_subdirectory(targets/gpu) ...@@ -268,6 +279,11 @@ add_subdirectory(targets/gpu)
target_link_libraries(migraphx_all_targets INTERFACE migraphx_gpu) target_link_libraries(migraphx_all_targets INTERFACE migraphx_gpu)
target_compile_definitions(migraphx_all_targets INTERFACE -DHAVE_GPU) target_compile_definitions(migraphx_all_targets INTERFACE -DHAVE_GPU)
endif() endif()
if(MIGRAPHX_ENABLE_FPGA)
add_subdirectory(targets/fpga)
target_link_libraries(migraphx_all_targets INTERFACE migraphx_fpga)
target_compile_definitions(migraphx_all_targets INTERFACE -DHAVE_FPGA)
endif()
if(HAVE_HALF_EXPR) if(HAVE_HALF_EXPR)
target_compile_definitions(migraphx PUBLIC -DHAS_HALF_V1) target_compile_definitions(migraphx PUBLIC -DHAS_HALF_V1)
......
...@@ -39,12 +39,24 @@ ...@@ -39,12 +39,24 @@
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <algorithm> #include <algorithm>
#include <cstdarg> #include <cstdarg>
namespace migraphx { namespace migraphx {
static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" void migraphx_test_private_disable_exception_catch(bool b)
{
disable_exception_catch = b;
}
template <class F> template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT migraphx_status try_(F f, bool output = true) // NOLINT
{ {
if(disable_exception_catch)
{
f();
}
else
{
try try
{ {
f(); f();
...@@ -68,6 +80,7 @@ migraphx_status try_(F f, bool output = true) // NOLINT ...@@ -68,6 +80,7 @@ migraphx_status try_(F f, bool output = true) // NOLINT
{ {
return migraphx_status_unknown_error; return migraphx_status_unknown_error;
} }
}
return migraphx_status_success; return migraphx_status_success;
} }
...@@ -305,6 +318,7 @@ void destroy(T* x) ...@@ -305,6 +318,7 @@ void destroy(T* x)
{ {
delete x; // NOLINT delete x; // NOLINT
} }
// TODO: Move to interface preamble // TODO: Move to interface preamble
template <class C, class D> template <class C, class D>
struct manage_generic_ptr struct manage_generic_ptr
...@@ -313,23 +327,27 @@ struct manage_generic_ptr ...@@ -313,23 +327,27 @@ struct manage_generic_ptr
manage_generic_ptr(std::nullptr_t) {} manage_generic_ptr(std::nullptr_t) {}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter) manage_generic_ptr(void* pdata, const char* obj_tname, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter) : data(nullptr), obj_typename(obj_tname), copier(pcopier), deleter(pdeleter)
{ {
copier(&data, pdata); copier(&data, pdata);
} }
manage_generic_ptr(const manage_generic_ptr& rhs) manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter) : data(nullptr), obj_typename(rhs.obj_typename), copier(rhs.copier), deleter(rhs.deleter)
{ {
if(copier) if(copier)
copier(&data, rhs.data); copier(&data, rhs.data);
} }
manage_generic_ptr(manage_generic_ptr&& other) noexcept manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter) : data(other.data),
obj_typename(other.obj_typename),
copier(other.copier),
deleter(other.deleter)
{ {
other.data = nullptr; other.data = nullptr;
other.obj_typename = "";
other.copier = nullptr; other.copier = nullptr;
other.deleter = nullptr; other.deleter = nullptr;
} }
...@@ -337,6 +355,7 @@ struct manage_generic_ptr ...@@ -337,6 +355,7 @@ struct manage_generic_ptr
manage_generic_ptr& operator=(manage_generic_ptr rhs) manage_generic_ptr& operator=(manage_generic_ptr rhs)
{ {
std::swap(data, rhs.data); std::swap(data, rhs.data);
std::swap(obj_typename, rhs.obj_typename);
std::swap(copier, rhs.copier); std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter); std::swap(deleter, rhs.deleter);
return *this; return *this;
...@@ -349,6 +368,7 @@ struct manage_generic_ptr ...@@ -349,6 +368,7 @@ struct manage_generic_ptr
} }
void* data = nullptr; void* data = nullptr;
const char* obj_typename = "";
C copier = nullptr; C copier = nullptr;
D deleter = nullptr; D deleter = nullptr;
}; };
...@@ -580,8 +600,9 @@ struct migraphx_experimental_custom_op ...@@ -580,8 +600,9 @@ struct migraphx_experimental_custom_op
migraphx_experimental_custom_op(void* p, migraphx_experimental_custom_op(void* p,
migraphx_experimental_custom_op_copy c, migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d, migraphx_experimental_custom_op_delete d,
const char* obj_typename,
Ts&&... xs) Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...) : object_ptr(p, obj_typename, c, d), xobject(std::forward<Ts>(xs)...)
{ {
} }
manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete> manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete>
...@@ -595,13 +616,21 @@ struct migraphx_experimental_custom_op ...@@ -595,13 +616,21 @@ struct migraphx_experimental_custom_op
std::remove_pointer_t<migraphx_argument_t> out; std::remove_pointer_t<migraphx_argument_t> out;
if(compute_f == nullptr) if(compute_f == nullptr)
throw std::runtime_error("compute function is missing."); throw std::runtime_error("compute function is missing.");
std::array<char, 256> exception_msg;
exception_msg.front() = '\0';
auto api_error_result = compute_f(&out, auto api_error_result = compute_f(&out,
object_ptr.data, object_ptr.data,
exception_msg.data(),
exception_msg.size(),
object_cast<migraphx_context_t>(&(ctx)), object_cast<migraphx_context_t>(&(ctx)),
object_cast<migraphx_shape_t>(&(output)), object_cast<migraphx_shape_t>(&(output)),
object_cast<migraphx_arguments_t>(&(inputs))); object_cast<migraphx_arguments_t>(&(inputs)));
if(api_error_result != migraphx_status_success) if(api_error_result != migraphx_status_success)
throw std::runtime_error("Error in compute."); {
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in compute of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return (&out)->object; return (&out)->object;
} }
...@@ -611,10 +640,19 @@ struct migraphx_experimental_custom_op ...@@ -611,10 +640,19 @@ struct migraphx_experimental_custom_op
std::remove_pointer_t<migraphx_shape_t> out; std::remove_pointer_t<migraphx_shape_t> out;
if(compute_shape_f == nullptr) if(compute_shape_f == nullptr)
throw std::runtime_error("compute_shape function is missing."); throw std::runtime_error("compute_shape function is missing.");
auto api_error_result = std::array<char, 256> exception_msg;
compute_shape_f(&out, object_ptr.data, object_cast<migraphx_shapes_t>(&(inputs))); exception_msg.front() = '\0';
auto api_error_result = compute_shape_f(&out,
object_ptr.data,
exception_msg.data(),
exception_msg.size(),
object_cast<migraphx_shapes_t>(&(inputs)));
if(api_error_result != migraphx_status_success) if(api_error_result != migraphx_status_success)
throw std::runtime_error("Error in compute_shape."); {
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in compute_shape of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return (&out)->object; return (&out)->object;
} }
}; };
...@@ -743,6 +781,16 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha ...@@ -743,6 +781,16 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).standard();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument) extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{ {
auto api_error_result = migraphx::try_([&] { destroy((argument)); }); auto api_error_result = migraphx::try_([&] { destroy((argument)); });
...@@ -1806,11 +1854,12 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi ...@@ -1806,11 +1854,12 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
void* obj, void* obj,
migraphx_experimental_custom_op_copy c, migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d, migraphx_experimental_custom_op_delete d,
const char* obj_typename,
const char* name) const char* name)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*experimental_custom_op = *experimental_custom_op =
allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (name)); allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (obj_typename), (name));
}); });
return api_error_result; return api_error_result;
} }
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_C_API_MIGRAPHX_H #define MIGRAPHX_GUARD_C_API_MIGRAPHX_H
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h>
// Add new types here // Add new types here
// clang-format off // clang-format off
...@@ -131,12 +132,16 @@ typedef const struct migraphx_experimental_custom_op* const_migraphx_experimenta ...@@ -131,12 +132,16 @@ typedef const struct migraphx_experimental_custom_op* const_migraphx_experimenta
typedef migraphx_status (*migraphx_experimental_custom_op_compute)(migraphx_argument_t out, typedef migraphx_status (*migraphx_experimental_custom_op_compute)(migraphx_argument_t out,
void* obj, void* obj,
char* exception_msg,
size_t exception_msg_size,
migraphx_context_t ctx, migraphx_context_t ctx,
migraphx_shape_t output, migraphx_shape_t output,
migraphx_arguments_t inputs); migraphx_arguments_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out, typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out,
void* obj, void* obj,
char* exception_msg,
size_t exception_msg_size,
migraphx_shapes_t inputs); migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input); typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input);
...@@ -175,6 +180,8 @@ migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape); ...@@ -175,6 +180,8 @@ migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x); migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x);
migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape);
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument); migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output, migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
...@@ -485,6 +492,7 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi ...@@ -485,6 +492,7 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
void* obj, void* obj,
migraphx_experimental_custom_op_copy c, migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d, migraphx_experimental_custom_op_delete d,
const char* obj_typename,
const char* name); const char* name);
migraphx_status migraphx_status
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP #define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h" #include "migraphx.h"
#include <cstring>
#include <initializer_list> #include <initializer_list>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <memory> #include <memory>
...@@ -58,6 +59,42 @@ struct rank<0> ...@@ -58,6 +59,42 @@ struct rank<0>
{ {
}; };
template <class PrivateMigraphTypeNameProbe>
std::string compute_type_name()
{
std::string name;
#ifdef _MSC_VER
name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7);
#else
const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT
name = __PRETTY_FUNCTION__;
auto begin = name.find(parameter_name) + sizeof(parameter_name);
#if(defined(__GNUC__) && !defined(__clang__) && __GNUC__ == 4 && __GNUC_MINOR__ < 7)
auto length = name.find_last_of(",") - begin;
#else
auto length = name.find_first_of("];", begin) - begin;
#endif
name = name.substr(begin, length);
#endif
return name;
}
template <class T>
const std::string& get_type_name()
{
static const std::string name = compute_type_name<T>();
return name;
}
template <class T>
const std::string& get_type_name(const T&)
{
return get_type_name<T>();
}
template <class T, class F, class... Ts> template <class T, class F, class... Ts>
T* make(F f, Ts&&... xs) T* make(F f, Ts&&... xs)
{ {
...@@ -310,13 +347,22 @@ struct interface_base : Base ...@@ -310,13 +347,22 @@ struct interface_base : Base
protected: protected:
template <class F> template <class F>
static migraphx_status try_(F f) // NOLINT static migraphx_status try_(F f, char* ex_msg = nullptr, size_t ex_msg_size = 0) // NOLINT
{ {
try try
{ {
f(); f();
return migraphx_status_success; return migraphx_status_success;
} }
catch(const std::exception& ex)
{
if(ex_msg)
{
std::strncpy(ex_msg, ex.what(), ex_msg_size);
ex_msg[ex_msg_size - 1] = '\0';
}
return migraphx_status_unknown_error;
}
catch(...) catch(...)
{ {
return migraphx_status_unknown_error; return migraphx_status_unknown_error;
...@@ -349,8 +395,12 @@ struct interface_base : Base ...@@ -349,8 +395,12 @@ struct interface_base : Base
{ {
static F f = pf; static F f = pf;
(void)f; // avoid warning on gcc (void)f; // avoid warning on gcc
call(setter, this->get_handle_ptr(), [](auto... xs) -> migraphx_status { call(setter,
return try_([&] { call_cast_arg<T>(rank<1>{}, f, xs...); }); this->get_handle_ptr(),
[](auto out, void* obj, char* ex_msg, size_t ex_msg_size, auto... xs)
-> migraphx_status {
return try_(
[&] { call_cast_arg<T>(rank<1>{}, f, out, obj, xs...); }, ex_msg, ex_msg_size);
}); });
} }
...@@ -524,6 +574,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -524,6 +574,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout; return pout;
} }
bool standard() const
{
bool result = false;
call(&migraphx_shape_standard, &result, this->get_handle_ptr());
return result;
}
friend bool operator==(const shape& px, const shape& py) friend bool operator==(const shape& px, const shape& py)
{ {
bool pout; bool pout;
...@@ -1206,7 +1263,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental ...@@ -1206,7 +1263,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
template <class T> template <class T>
experimental_custom_op(T& obj) experimental_custom_op(T& obj)
{ {
this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str()); this->make_interface(&migraphx_experimental_custom_op_create,
obj,
get_type_name(obj).c_str(),
obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape); MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape);
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute); MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute);
} }
......
...@@ -121,6 +121,7 @@ def shape(h): ...@@ -121,6 +121,7 @@ def shape(h):
invoke='migraphx::equal($@)', invoke='migraphx::equal($@)',
returns='bool', returns='bool',
const=True) const=True)
h.method('standard', returns='bool', const=True)
@auto_handle() @auto_handle()
...@@ -439,7 +440,8 @@ def context(h): ...@@ -439,7 +440,8 @@ def context(h):
@api.interface('migraphx_experimental_custom_op', @api.interface('migraphx_experimental_custom_op',
'migraphx::experimental_custom_op') 'migraphx::experimental_custom_op')
def experimental_custom_op(h): def experimental_custom_op(h):
h.constructor('create', api.params(name='const char*')) h.constructor('create',
api.params(obj_typename='const char*', name='const char*'))
h.virtual('compute', h.virtual('compute',
api.params(ctx='migraphx::context', api.params(ctx='migraphx::context',
output='migraphx::shape', output='migraphx::shape',
......
...@@ -63,7 +63,7 @@ void auto_contiguous::apply(module& m) const ...@@ -63,7 +63,7 @@ void auto_contiguous::apply(module& m) const
if(ins->outputs().empty() and ins != last) if(ins->outputs().empty() and ins != last)
continue; continue;
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0) if(not s.dynamic() and not s.standard() and s.elements() != 0)
{ {
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins); auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c); m.replace_instruction(ins, c);
......
...@@ -48,9 +48,10 @@ void dead_code_elimination::apply(module& m) const ...@@ -48,9 +48,10 @@ void dead_code_elimination::apply(module& m) const
// Skip the last instruction // Skip the last instruction
if(i == last) if(i == last)
break; break;
// Skip instruction with empty shape as output unless its a builtin, undefined, identity, or // Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// allocate // identity, allocate]
if(i->get_shape().elements() == 0 and i->name().front() != '@' and if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and
i->name().front() != '@' and
not contains({"undefined", "identity", "allocate"}, i->name())) not contains({"undefined", "identity", "allocate"}, i->name()))
continue; continue;
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last)); assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
......
...@@ -27,11 +27,13 @@ ...@@ -27,11 +27,13 @@
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <list>
#include <set> #include <set>
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -39,9 +41,16 @@ ...@@ -39,9 +41,16 @@
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#ifndef _WIN32
#include <unistd.h>
#endif
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -74,6 +83,65 @@ template <class T> ...@@ -74,6 +83,65 @@ template <class T>
using is_multi_value = using is_multi_value =
std::integral_constant<bool, (is_container<T>{} and not std::is_convertible<T, std::string>{})>; std::integral_constant<bool, (is_container<T>{} and not std::is_convertible<T, std::string>{})>;
enum class color
{
reset = 0,
bold = 1,
underlined = 4,
fg_red = 31,
fg_green = 32,
fg_yellow = 33,
fg_blue = 34,
fg_default = 39,
bg_red = 41,
bg_green = 42,
bg_yellow = 43,
bg_blue = 44,
bg_default = 49
};
inline std::ostream& operator<<(std::ostream& os, const color& c)
{
#ifndef _WIN32
static const bool use_color = isatty(STDOUT_FILENO) != 0;
if(use_color)
return os << "\033[" << static_cast<std::size_t>(c) << "m";
#endif
return os;
}
inline std::string colorize(color c, const std::string& s)
{
std::stringstream ss;
ss << c << s << color::reset;
return ss.str();
}
template <class T>
struct type_name
{
static const std::string& apply() { return migraphx::get_type_name<T>(); }
};
template <>
struct type_name<std::string>
{
static const std::string& apply()
{
static const std::string name = "std::string";
return name;
}
};
template <class T>
struct type_name<std::vector<T>>
{
static const std::string& apply()
{
static const std::string name = "std::vector<" + type_name<T>::apply() + ">";
return name;
}
};
template <class T> template <class T>
struct value_parser struct value_parser
{ {
...@@ -85,7 +153,7 @@ struct value_parser ...@@ -85,7 +153,7 @@ struct value_parser
ss.str(x); ss.str(x);
ss >> result; ss >> result;
if(ss.fail()) if(ss.fail())
throw std::runtime_error("Failed to parse: " + x); throw std::runtime_error("Failed to parse '" + x + "' as " + type_name<T>::apply());
return result; return result;
} }
...@@ -97,7 +165,7 @@ struct value_parser ...@@ -97,7 +165,7 @@ struct value_parser
ss.str(x); ss.str(x);
ss >> i; ss >> i;
if(ss.fail()) if(ss.fail())
throw std::runtime_error("Failed to parse: " + x); throw std::runtime_error("Failed to parse '" + x + "' as " + type_name<T>::apply());
return static_cast<T>(i); return static_cast<T>(i);
} }
...@@ -115,13 +183,42 @@ struct argument_parser ...@@ -115,13 +183,42 @@ struct argument_parser
{ {
struct argument struct argument
{ {
using action_function =
std::function<bool(argument_parser&, const std::vector<std::string>&)>;
using validate_function =
std::function<void(const argument_parser&, const std::vector<std::string>&)>;
std::vector<std::string> flags; std::vector<std::string> flags;
std::function<bool(argument_parser&, const std::vector<std::string>&)> action{}; action_function action{};
std::string type = ""; std::string type = "";
std::string help = ""; std::string help = "";
std::string metavar = ""; std::string metavar = "";
std::string default_value = ""; std::string default_value = "";
std::string group = "";
unsigned nargs = 1; unsigned nargs = 1;
bool required = false;
std::vector<validate_function> validations{};
std::string usage(const std::string& flag) const
{
std::stringstream ss;
if(flag.empty())
{
ss << metavar;
}
else
{
ss << flag;
if(not type.empty())
ss << " [" << type << "]";
}
return ss.str();
}
std::string usage() const
{
if(flags.empty())
return usage("");
return usage(flags.front());
}
}; };
template <class T, MIGRAPHX_REQUIRES(is_multi_value<T>{})> template <class T, MIGRAPHX_REQUIRES(is_multi_value<T>{})>
...@@ -154,12 +251,14 @@ struct argument_parser ...@@ -154,12 +251,14 @@ struct argument_parser
arguments.push_back({flags, [&](auto&&, const std::vector<std::string>& params) { arguments.push_back({flags, [&](auto&&, const std::vector<std::string>& params) {
if(params.empty()) if(params.empty())
throw std::runtime_error("Flag with no value."); throw std::runtime_error("Flag with no value.");
if(not is_multi_value<T>{} and params.size() > 1)
throw std::runtime_error("Too many arguments passed.");
x = value_parser<T>::apply(params.back()); x = value_parser<T>::apply(params.back());
return false; return false;
}}); }});
argument& arg = arguments.back(); argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>(); arg.type = type_name<T>::apply();
migraphx::each_args([&](auto f) { f(x, arg); }, fs...); migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
if(not arg.default_value.empty() and arg.nargs > 0) if(not arg.default_value.empty() and arg.nargs > 0)
arg.default_value = as_string_value(x); arg.default_value = as_string_value(x);
...@@ -181,6 +280,11 @@ struct argument_parser ...@@ -181,6 +280,11 @@ struct argument_parser
return [=](auto&&, auto& arg) { arg.nargs = n; }; return [=](auto&&, auto& arg) { arg.nargs = n; };
} }
MIGRAPHX_DRIVER_STATIC auto required()
{
return [=](auto&&, auto& arg) { arg.required = true; };
}
template <class F> template <class F>
MIGRAPHX_DRIVER_STATIC auto write_action(F f) MIGRAPHX_DRIVER_STATIC auto write_action(F f)
{ {
...@@ -215,13 +319,141 @@ struct argument_parser ...@@ -215,13 +319,141 @@ struct argument_parser
}); });
} }
MIGRAPHX_DRIVER_STATIC auto show_help(const std::string& msg = "") template <class F>
MIGRAPHX_DRIVER_STATIC auto validate(F f)
{
return [=](const auto& x, auto& arg) {
arg.validations.push_back(
[&, f](auto& self, const std::vector<std::string>& params) { f(self, x, params); });
};
}
MIGRAPHX_DRIVER_STATIC auto file_exist()
{
return validate([](auto&, auto&, auto& params) {
if(params.empty())
throw std::runtime_error("No argument passed.");
if(not fs::exists(params.back()))
throw std::runtime_error("Path does not exists: " + params.back());
});
}
template <class F>
argument* find_argument(F f)
{
auto it = std::find_if(arguments.begin(), arguments.end(), f);
if(it == arguments.end())
return nullptr;
return std::addressof(*it);
}
template <class F>
bool has_argument(F f)
{
return find_argument(f) != nullptr;
}
template <class F>
std::vector<argument*> find_arguments(F f)
{
std::vector<argument*> result;
for(auto& arg : arguments)
{
if(not f(arg))
continue;
result.push_back(&arg);
}
return result;
}
std::vector<argument*> get_group_arguments(const std::string& group)
{
return find_arguments([&](const auto& arg) { return arg.group == group; });
}
std::vector<argument*> get_required_arguments()
{
return find_arguments([&](const auto& arg) { return arg.required; });
}
template <class SequenceContainer>
std::vector<std::string> get_argument_usages(SequenceContainer args)
{
std::vector<std::string> usage_flags;
std::unordered_set<std::string> found_groups;
// Remove arguments that belong to a group
auto it = std::remove_if(args.begin(), args.end(), [&](const argument* arg) {
if(arg->group.empty())
return false;
found_groups.insert(arg->group);
return true;
});
args.erase(it, args.end());
transform(found_groups, std::back_inserter(usage_flags), [&](auto&& group) {
std::vector<std::string> either_flags;
transform(get_group_arguments(group), std::back_inserter(either_flags), [](auto* arg) {
return arg->usage();
});
return "(" + join_strings(either_flags, "|") + ")";
});
transform(args, std::back_inserter(usage_flags), [&](auto* arg) { return arg->usage(); });
return usage_flags;
}
auto show_help(const std::string& msg = "")
{ {
return do_action([=](auto& self) { return do_action([=](auto& self) {
argument* input_argument =
self.find_argument([](const auto& arg) { return arg.flags.empty(); });
auto required_usages = get_argument_usages(get_required_arguments());
if(required_usages.empty() && input_argument)
required_usages.push_back(input_argument->metavar);
required_usages.insert(required_usages.begin(), "<options>");
print_usage(required_usages);
std::cout << std::endl;
if(self.find_argument([](const auto& arg) { return arg.nargs == 0; }))
{
std::cout << color::fg_yellow << "FLAGS:" << color::reset << std::endl;
std::cout << std::endl;
for(auto&& arg : self.arguments) for(auto&& arg : self.arguments)
{
if(arg.nargs != 0)
continue;
const int col_align = 35;
std::string prefix = " ";
int len = 0;
std::cout << color::fg_green;
for(const std::string& a : arg.flags)
{
len += prefix.length() + a.length();
std::cout << prefix;
std::cout << a;
prefix = ", ";
}
std::cout << color::reset;
int spaces = col_align - len;
if(spaces < 0)
{ {
std::cout << std::endl; std::cout << std::endl;
}
else
{
for(int i = 0; i < spaces; i++)
std::cout << " ";
}
std::cout << arg.help << std::endl;
}
std::cout << std::endl;
}
if(self.find_argument([](const auto& arg) { return arg.nargs != 0; }))
{
std::cout << color::fg_yellow << "OPTIONS:" << color::reset << std::endl;
for(auto&& arg : self.arguments)
{
if(arg.nargs == 0)
continue;
std::cout << std::endl;
std::string prefix = " "; std::string prefix = " ";
std::cout << color::fg_green;
if(arg.flags.empty()) if(arg.flags.empty())
{ {
std::cout << prefix; std::cout << prefix;
...@@ -233,9 +465,10 @@ struct argument_parser ...@@ -233,9 +465,10 @@ struct argument_parser
std::cout << a; std::cout << a;
prefix = ", "; prefix = ", ";
} }
std::cout << color::reset;
if(not arg.type.empty()) if(not arg.type.empty())
{ {
std::cout << " [" << arg.type << "]"; std::cout << " [" << color::fg_blue << arg.type << color::reset << "]";
if(not arg.default_value.empty()) if(not arg.default_value.empty())
std::cout << " (Default: " << arg.default_value << ")"; std::cout << " (Default: " << arg.default_value << ")";
} }
...@@ -243,6 +476,7 @@ struct argument_parser ...@@ -243,6 +476,7 @@ struct argument_parser
std::cout << " " << arg.help << std::endl; std::cout << " " << arg.help << std::endl;
} }
std::cout << std::endl; std::cout << std::endl;
}
if(not msg.empty()) if(not msg.empty())
std::cout << msg << std::endl; std::cout << msg << std::endl;
}); });
...@@ -263,6 +497,11 @@ struct argument_parser ...@@ -263,6 +497,11 @@ struct argument_parser
return [=](auto&, auto& arg) { arg.type = type; }; return [=](auto&, auto& arg) { arg.type = type; };
} }
MIGRAPHX_DRIVER_STATIC auto group(const std::string& group)
{
return [=](auto&, auto& arg) { arg.group = group; };
}
template <class T> template <class T>
MIGRAPHX_DRIVER_STATIC auto set_value(T value) MIGRAPHX_DRIVER_STATIC auto set_value(T value)
{ {
...@@ -276,6 +515,109 @@ struct argument_parser ...@@ -276,6 +515,109 @@ struct argument_parser
}; };
} }
template <class T>
void set_exe_name_to(T& x)
{
actions.push_back([&](const auto& self) { x = self.exe_name; });
}
void print_try_help()
{
if(has_argument([](const auto& a) { return contains(a.flags, "--help"); }))
{
std::cout << std::endl;
std::cout << "For more information try '" << color::fg_green << "--help" << color::reset
<< "'" << std::endl;
}
}
void print_usage(const std::vector<std::string>& flags) const
{
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " " << exe_name << " ";
std::cout << join_strings(flags, " ") << std::endl;
}
auto spellcheck(const std::vector<std::string>& inputs)
{
struct result_t
{
const argument* arg = nullptr;
std::string correct = "";
std::string incorrect = "";
std::ptrdiff_t distance = std::numeric_limits<std::ptrdiff_t>::max();
};
result_t result;
for(const auto& input : inputs)
{
if(input.empty())
continue;
if(input[0] != '-')
continue;
for(const auto& arg : arguments)
{
for(const auto& flag : arg.flags)
{
if(flag.empty())
continue;
if(flag[0] != '-')
continue;
auto d =
levenshtein_distance(flag.begin(), flag.end(), input.begin(), input.end());
if(d < result.distance)
result = result_t{&arg, flag, input, d};
}
}
}
return result;
}
bool
run_action(const argument& arg, const std::string& flag, const std::vector<std::string>& inputs)
{
std::string msg = "";
try
{
for(const auto& v : arg.validations)
v(*this, inputs);
return arg.action(*this, inputs);
}
catch(const std::exception& e)
{
msg = e.what();
}
catch(...)
{
msg = "unknown exception";
}
std::cout << color::fg_red << color::bold << "error: " << color::reset;
auto sc = spellcheck(inputs);
if(sc.distance < 5)
{
std::cout << "Found argument '" << color::fg_yellow << sc.incorrect << color::reset
<< "'"
<< " which wasn't expected, or isn't valid in this context" << std::endl;
std::cout << " "
<< "Did you mean " << color::fg_green << sc.correct << color::reset << "?"
<< std::endl;
std::cout << std::endl;
print_usage({sc.arg->usage(sc.correct)});
}
else
{
const auto& flag_name = flag.empty() ? arg.metavar : flag;
std::cout << "Invalid input to '" << color::fg_yellow;
std::cout << arg.usage(flag_name);
std::cout << color::reset << "'" << std::endl;
std::cout << " " << msg << std::endl;
std::cout << std::endl;
print_usage({arg.usage()});
}
std::cout << std::endl;
print_try_help();
return true;
}
bool parse(std::vector<std::string> args) bool parse(std::vector<std::string> args)
{ {
std::unordered_map<std::string, unsigned> keywords; std::unordered_map<std::string, unsigned> keywords;
...@@ -286,8 +628,11 @@ struct argument_parser ...@@ -286,8 +628,11 @@ struct argument_parser
} }
auto arg_map = auto arg_map =
generic_parse(std::move(args), [&](const std::string& x) { return keywords[x]; }); generic_parse(std::move(args), [&](const std::string& x) { return keywords[x]; });
std::list<const argument*> missing_arguments;
std::unordered_set<std::string> groups_used;
for(auto&& arg : arguments) for(auto&& arg : arguments)
{ {
bool used = false;
auto flags = arg.flags; auto flags = arg.flags;
if(flags.empty()) if(flags.empty())
flags = {""}; flags = {""};
...@@ -295,14 +640,41 @@ struct argument_parser ...@@ -295,14 +640,41 @@ struct argument_parser
{ {
if(arg_map.count(flag) > 0) if(arg_map.count(flag) > 0)
{ {
if(arg.action(*this, arg_map[flag])) if(run_action(arg, flag, arg_map[flag]))
return true; return true;
used = true;
} }
} }
if(used and not arg.group.empty())
groups_used.insert(arg.group);
if(arg.required and not used)
missing_arguments.push_back(&arg);
} }
// Remove arguments from a group that is being used
missing_arguments.remove_if(
[&](const argument* arg) { return groups_used.count(arg->group); });
if(not missing_arguments.empty())
{
std::cout << color::fg_red << color::bold << "error: " << color::reset;
std::cout << "The following required arguments were not provided:" << std::endl;
std::cout << " " << color::fg_red
<< join_strings(get_argument_usages(std::move(missing_arguments)), " ")
<< color::reset << std::endl;
std::cout << std::endl;
auto required_usages = get_argument_usages(get_required_arguments());
print_usage(required_usages);
print_try_help();
return true;
}
for(auto&& action : actions)
action(*this);
return false; return false;
} }
void set_exe_name(const std::string& s) { exe_name = s; }
const std::string& get_exe_name() const { return exe_name; }
using string_map = std::unordered_map<std::string, std::vector<std::string>>; using string_map = std::unordered_map<std::string, std::vector<std::string>>;
template <class IsKeyword> template <class IsKeyword>
static string_map generic_parse(std::vector<std::string> as, IsKeyword is_keyword) static string_map generic_parse(std::vector<std::string> as, IsKeyword is_keyword)
...@@ -337,7 +709,9 @@ struct argument_parser ...@@ -337,7 +709,9 @@ struct argument_parser
} }
private: private:
std::vector<argument> arguments; std::list<argument> arguments;
std::string exe_name = "";
std::vector<std::function<void(argument_parser&)>> actions;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment