diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..5ad181a5e210976483e8f875966ebb6ab4f8dfae --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +**/.idea +*~ +*.swp +*.o +*.so +*.pyc +build +build2 +dist +*.egg-info/ +docs/release/ +docs/docs/web/ +docs/docs/chm/ +docs/docs/cache/ +docs/docs/git-logs.xml +docs/docs/python/classes.txt +docs/docs/python/functions.txt +docs/docs/python/constants.txt +**/.vscode diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..da92454aabfacc784d269452e918cbd43c2bf6ff --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,37 @@ +cmake_minimum_required(VERSION 3.8.0) + +project(dlib_project) + + + +############################################################################# +# # +# READ examples/CMakeLists.txt TO SEE HOW TO USE DLIB FROM C++ WITH CMAKE # +# # +############################################################################# + + + + + +get_directory_property(has_parent PARENT_DIRECTORY) +if(NOT has_parent) + # When you call add_subdirectory(dlib) from a parent CMake project dlib's + # CMake scripts will assume you want to statically compile dlib into + # whatever you are building rather than create a standalone copy of dlib. + # This means CMake will build dlib as a static library, disable dlib's + # install targets so they don't clutter your project, and adjust a few other + # minor things that are convenient when statically building dlib as part of + # your own projects. + # + # On the other hand, if there is no parent CMake project or if + # DLIB_IN_PROJECT_BUILD is set to false, CMake will compile dlib as a normal + # standalone library (either shared or static, based on the state of CMake's + # BUILD_SHARED_LIBS flag), and include the usual install targets so you can + # install dlib on your computer via `make install`. Since the only reason + # to build this CMakeLists.txt (the one you are reading right now) by itself + # is if you want to install dlib, we indicate as such by setting + # DLIB_IN_PROJECT_BUILD to false. + set(DLIB_IN_PROJECT_BUILD false) +endif() +add_subdirectory(dlib) diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 9f358a4addefcab294b83e4282bfef1f9625a249..0000000000000000000000000000000000000000 --- a/LICENSE +++ /dev/null @@ -1 +0,0 @@ -123456 diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..127a5bc39ba030c7cb99cc0aedc4f280ffe27310 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,23 @@ +Boost Software License - Version 1.0 - August 17th, 2003 + +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..2ba72dc45b2e613cf6b2d40afb52883c66dcf5f0 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,19 @@ +# +# MANIFEST.in +# +# Manifest template for creating the dlib source distribution. + +include MANIFEST.in +include setup.py +include README.md + +# sources +recursive-include dlib ** +recursive-include python_examples *.txt *.py +recursive-include tools/python ** + +prune tools/python/build* +prune dlib/cmake_utils/*/build* +prune dlib/test + +global-exclude *.pyc diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..da1ec8700855263c2fb74ceb363c6de1e125aeea --- /dev/null +++ b/README.md @@ -0,0 +1,67 @@ +# dlib C++ library [![GitHub Actions C++ Status](https://github.com/davisking/dlib/actions/workflows/build_cpp.yml/badge.svg)](https://github.com/davisking/dlib/actions/workflows/build_cpp.yml) [![GitHub Actions Python Status](https://github.com/davisking/dlib/actions/workflows/build_python.yml/badge.svg)](https://github.com/davisking/dlib/actions/workflows/build_python.yml) + +Dlib is a modern C++ toolkit containing machine learning algorithms and tools for creating complex software in C++ to solve real world problems. See [http://dlib.net](http://dlib.net) for the main project documentation and API reference. + + + +## Compiling dlib C++ example programs + +Go into the examples folder and type: + +```bash +mkdir build; cd build; cmake .. ; cmake --build . +``` + +That will build all the examples. +If you have a CPU that supports AVX instructions then turn them on like this: + +```bash +mkdir build; cd build; cmake .. -DUSE_AVX_INSTRUCTIONS=1; cmake --build . +``` + +Doing so will make some things run faster. + +Finally, Visual Studio users should usually do everything in 64bit mode. By default Visual Studio is 32bit, both in its outputs and its own execution, so you have to explicitly tell it to use 64bits. Since it's not the 1990s anymore you probably want to use 64bits. Do that with a cmake invocation like this: +```bash +cmake .. -G "Visual Studio 14 2015 Win64" -T host=x64 +``` + +## Compiling your own C++ programs that use dlib + +The examples folder has a [CMake tutorial](https://github.com/davisking/dlib/blob/master/examples/CMakeLists.txt) that tells you what to do. There are also additional instructions on the [dlib web site](http://dlib.net/compile.html). + +Alternatively, if you are using the [vcpkg](https://github.com/Microsoft/vcpkg/) dependency manager you can download and install dlib with CMake integration in a single command: +```bash +vcpkg install dlib +``` + +## Compiling dlib Python API + +Before you can run the Python example programs you must compile dlib. Type: + +```bash +python setup.py install +``` + + +## Running the unit test suite + +Type the following to compile and run the dlib unit test suite: + +```bash +cd dlib/test +mkdir build +cd build +cmake .. +cmake --build . --config Release +./dtest --runall +``` + +Note that on windows your compiler might put the test executable in a subfolder called `Release`. If that's the case then you have to go to that folder before running the test. + +This library is licensed under the Boost Software License, which can be found in [dlib/LICENSE.txt](https://github.com/davisking/dlib/blob/master/dlib/LICENSE.txt). The long and short of the license is that you can use dlib however you like, even in closed source commercial software. + +## dlib sponsors + +This research is based in part upon work supported by the Office of the Director of National Intelligence (ODNI), Intelligence Advanced Research Projects Activity (IARPA) under contract number 2014-14071600010. The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies or endorsements, either expressed or implied, of ODNI, IARPA, or the U.S. Government. + diff --git a/dlib/CMakeLists.txt b/dlib/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7c3159d0d2f644ca436e33acae883e47a2eff69b --- /dev/null +++ b/dlib/CMakeLists.txt @@ -0,0 +1,1043 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# + + +cmake_minimum_required(VERSION 3.8.0) + +set(CMAKE_DISABLE_SOURCE_CHANGES ON) +set(CMAKE_DISABLE_IN_SOURCE_BUILD ON) + +if (POLICY CMP0077) + cmake_policy(SET CMP0077 NEW) +endif () + +project(dlib) + +set(CPACK_PACKAGE_NAME "dlib") +set(CPACK_PACKAGE_VERSION_MAJOR "19") +set(CPACK_PACKAGE_VERSION_MINOR "24") +set(CPACK_PACKAGE_VERSION_PATCH "99") +set(VERSION ${CPACK_PACKAGE_VERSION_MAJOR}.${CPACK_PACKAGE_VERSION_MINOR}.${CPACK_PACKAGE_VERSION_PATCH}) +# Only print these messages once, even if dlib is added multiple times via add_subdirectory() +if (NOT TARGET dlib) + message(STATUS "Using CMake version: ${CMAKE_VERSION}") + message(STATUS "Compiling dlib version: ${VERSION}") +endif () + + +include(cmake_utils/set_compiler_specific_options.cmake) + + +# Adhere to GNU filesystem layout conventions +include(GNUInstallDirs) + +if (POLICY CMP0075) + cmake_policy(SET CMP0075 NEW) +endif () + +# default to a Release build (except if CMAKE_BUILD_TYPE is set) +include(cmake_utils/release_build_by_default) + +# Set DLIB_VERSION in the including CMake file so they can use it to do whatever they want. +get_directory_property(has_parent PARENT_DIRECTORY) +if (has_parent) + set(DLIB_VERSION ${VERSION} PARENT_SCOPE) + if (NOT DEFINED DLIB_IN_PROJECT_BUILD) + set(DLIB_IN_PROJECT_BUILD true) + endif () +endif () + + +if (COMMAND pybind11_add_module AND MSVC) + # True when building a python extension module using Visual Studio. We care + # about this because a huge number of windows users have broken systems, and + # in particular, they have broken or incompatibly installed copies of things + # like libjpeg or libpng. So if we detect we are in this mode we will never + # ever link to those libraries. Instead, we link to the copy included with + # dlib. + set(BUILDING_PYTHON_IN_MSVC true) +else () + set(BUILDING_PYTHON_IN_MSVC false) +endif () + +if (DLIB_IN_PROJECT_BUILD) + + # Check if we are being built as part of a pybind11 module. + if (COMMAND pybind11_add_module) + set(CMAKE_POSITION_INDEPENDENT_CODE True) + if (CMAKE_COMPILER_IS_GNUCXX) + # Just setting CMAKE_POSITION_INDEPENDENT_CODE should be enough to set + # -fPIC for GCC but sometimes it still doesn't get set, so make sure it + # does. + add_definitions("-fPIC") + endif () + # Make DLIB_ASSERT statements not abort the python interpreter, but just return an error. + list(APPEND active_preprocessor_switches "-DDLIB_NO_ABORT_ON_2ND_FATAL_ERROR") + endif () + + # DLIB_IN_PROJECT_BUILD==true means you are using dlib by invoking + # add_subdirectory(dlib) in the parent project. In this case, we always want + # to build dlib as a static library so the parent project doesn't need to + # deal with some random dlib shared library file. It is much better to + # statically compile dlib into the parent project. So the following bit of + # CMake ensures that happens. However, we have to take care to compile dlib + # with position independent code if appropriate (i.e. if the parent project + # is a shared library). + if (BUILD_SHARED_LIBS) + if (CMAKE_COMPILER_IS_GNUCXX) + # Just setting CMAKE_POSITION_INDEPENDENT_CODE should be enough to set + # -fPIC for GCC but sometimes it still doesn't get set, so make sure it + # does. + add_definitions("-fPIC") + endif () + set(CMAKE_POSITION_INDEPENDENT_CODE true) + endif () + + # Tell cmake to build dlib as a static library + set(BUILD_SHARED_LIBS false) + +elseif (BUILD_SHARED_LIBS) + if (MSVC) + message(FATAL_ERROR "Building dlib as a standalone dll is not supported when using Visual Studio. You are highly encouraged to use static linking instead. See https://github.com/davisking/dlib/issues/1483 for a discussion.") + endif () +endif () + + +if (CMAKE_VERSION VERSION_LESS "3.9.0") + # Set only because there are old target_link_libraries() statements in the + # FindCUDA.cmake file that comes with CMake that error out if the new behavior + # is used. In newer versions of CMake we can instead set CUDA_LINK_LIBRARIES_KEYWORD which fixes this issue. + cmake_policy(SET CMP0023 OLD) +else () + set(CUDA_LINK_LIBRARIES_KEYWORD PUBLIC) +endif () + + +macro(enable_preprocessor_switch option_name) + list(APPEND active_preprocessor_switches "-D${option_name}") +endmacro() + +macro(disable_preprocessor_switch option_name) + if (active_preprocessor_switches) + list(REMOVE_ITEM active_preprocessor_switches "-D${option_name}") + endif () +endmacro() + +macro(toggle_preprocessor_switch option_name) + if (${option_name}) + enable_preprocessor_switch(${option_name}) + else () + disable_preprocessor_switch(${option_name}) + endif () +endmacro() + + +# Suppress superfluous randlib warnings about libdlib.a having no symbols on MacOSX. +if (APPLE) + set(CMAKE_C_ARCHIVE_CREATE " Scr ") + set(CMAKE_CXX_ARCHIVE_CREATE " Scr ") + set(CMAKE_C_ARCHIVE_FINISH " -no_warning_for_no_symbols -c ") + set(CMAKE_CXX_ARCHIVE_FINISH " -no_warning_for_no_symbols -c ") +endif () + +# Don't try to call add_library(dlib) and setup dlib's stuff if it has already +# been done by some other part of the current cmake project. We do this +# because it avoids getting warnings/errors about cmake policy CMP0002. This +# happens when a project tries to call add_subdirectory() on dlib more than +# once. This most often happens when the top level of a project depends on two +# or more other things which both depend on dlib. +if (NOT TARGET dlib) + + set(DLIB_ISO_CPP_ONLY_STR + "Enable this if you don't want to compile any non-ISO C++ code (i.e. you don't use any of the API Wrappers)") + set(DLIB_NO_GUI_SUPPORT_STR + "Enable this if you don't want to compile any of the dlib GUI code") + set(DLIB_ENABLE_STACK_TRACE_STR + "Enable this if you want to turn on the DLIB_STACK_TRACE macros") + set(DLIB_USE_BLAS_STR + "Disable this if you don't want to use a BLAS library") + set(DLIB_USE_LAPACK_STR + "Disable this if you don't want to use a LAPACK library") + set(DLIB_USE_CUDA_STR + "Disable this if you don't want to use NVIDIA CUDA") + set(DLIB_USE_CUDA_COMPUTE_CAPABILITIES_STR + "Set this to a comma-separated list of CUDA compute capabilities") + set(DLIB_USE_MKL_SEQUENTIAL_STR + "Enable this if you have MKL installed and want to use the sequential version instead of the multi-core version.") + set(DLIB_USE_MKL_WITH_TBB_STR + "Enable this if you have MKL installed and want to use the tbb version instead of the openmp version.") + set(DLIB_PNG_SUPPORT_STR + "Disable this if you don't want to link against libpng") + set(DLIB_GIF_SUPPORT_STR + "Disable this if you don't want to link against libgif") + set(DLIB_JPEG_SUPPORT_STR + "Disable this if you don't want to link against libjpeg") + set(DLIB_WEBP_SUPPORT_STR + "Disable this if you don't want to link against libwebp") + set(DLIB_LINK_WITH_SQLITE3_STR + "Disable this if you don't want to link against sqlite3") + #set (DLIB_USE_FFTW_STR "Disable this if you don't want to link against fftw" ) + set(DLIB_USE_MKL_FFT_STR + "Disable this is you don't want to use the MKL DFTI FFT implementation") + set(DLIB_ENABLE_ASSERTS_STR + "Enable this if you want to turn on the DLIB_ASSERT macro") + set(DLIB_USE_FFMPEG_STR + "Disable this if you don't want to use the FFMPEG library") + + option(DLIB_ENABLE_ASSERTS ${DLIB_ENABLE_ASSERTS_STR} OFF) + option(DLIB_ISO_CPP_ONLY ${DLIB_ISO_CPP_ONLY_STR} OFF) + toggle_preprocessor_switch(DLIB_ISO_CPP_ONLY) + option(DLIB_NO_GUI_SUPPORT ${DLIB_NO_GUI_SUPPORT_STR} OFF) + toggle_preprocessor_switch(DLIB_NO_GUI_SUPPORT) + option(DLIB_ENABLE_STACK_TRACE ${DLIB_ENABLE_STACK_TRACE_STR} OFF) + toggle_preprocessor_switch(DLIB_ENABLE_STACK_TRACE) + option(DLIB_USE_MKL_SEQUENTIAL ${DLIB_USE_MKL_SEQUENTIAL_STR} OFF) + option(DLIB_USE_MKL_WITH_TBB ${DLIB_USE_MKL_WITH_TBB_STR} OFF) + + if (DLIB_ENABLE_ASSERTS) + # Set these variables so they are set in the config.h.in file when dlib + # is installed. + set(DLIB_DISABLE_ASSERTS false) + set(ENABLE_ASSERTS true) + enable_preprocessor_switch(ENABLE_ASSERTS) + disable_preprocessor_switch(DLIB_DISABLE_ASSERTS) + else () + # Set these variables so they are set in the config.h.in file when dlib + # is installed. + set(DLIB_DISABLE_ASSERTS true) + set(ENABLE_ASSERTS false) + disable_preprocessor_switch(ENABLE_ASSERTS) + # Never force the asserts off when doing an in project build. The only + # time this matters is when using visual studio. The visual studio IDE + # has a drop down that lets the user select either release or debug + # builds. The DLIB_ASSERT macro is setup to enable/disable automatically + # based on this drop down (via preprocessor magic). However, if + # DLIB_DISABLE_ASSERTS is defined it permanently disables asserts no + # matter what, which would defeat the visual studio drop down. So here + # we make a point to not do that kind of severe disabling when in a + # project build. It should also be pointed out that DLIB_DISABLE_ASSERTS + # is only needed when building and installing dlib as a separately + # installed library. It doesn't matter when doing an in project build. + if (NOT DLIB_IN_PROJECT_BUILD) + enable_preprocessor_switch(DLIB_DISABLE_ASSERTS) + endif () + endif () + + if (DLIB_ISO_CPP_ONLY) + option(DLIB_JPEG_SUPPORT ${DLIB_JPEG_SUPPORT_STR} OFF) + option(DLIB_LINK_WITH_SQLITE3 ${DLIB_LINK_WITH_SQLITE3_STR} OFF) + option(DLIB_USE_BLAS ${DLIB_USE_BLAS_STR} OFF) + option(DLIB_USE_LAPACK ${DLIB_USE_LAPACK_STR} OFF) + option(DLIB_USE_CUDA ${DLIB_USE_CUDA_STR} OFF) + option(DLIB_PNG_SUPPORT ${DLIB_PNG_SUPPORT_STR} OFF) + option(DLIB_GIF_SUPPORT ${DLIB_GIF_SUPPORT_STR} OFF) + option(DLIB_WEBP_SUPPORT ${DLIB_WEBP_SUPPORT_STR} OFF) + #option(DLIB_USE_FFTW ${DLIB_USE_FFTW_STR} OFF) + option(DLIB_USE_MKL_FFT ${DLIB_USE_MKL_FFT_STR} OFF) + option(DLIB_USE_FFMPEG ${DLIB_USE_FFMPEG_STR} OFF) + else () + option(DLIB_JPEG_SUPPORT ${DLIB_JPEG_SUPPORT_STR} ON) + option(DLIB_WEBP_SUPPORT ${DLIB_WEBP_SUPPORT_STR} ON) + option(DLIB_LINK_WITH_SQLITE3 ${DLIB_LINK_WITH_SQLITE3_STR} ON) + option(DLIB_USE_BLAS ${DLIB_USE_BLAS_STR} ON) + option(DLIB_USE_LAPACK ${DLIB_USE_LAPACK_STR} ON) + option(DLIB_USE_CUDA ${DLIB_USE_CUDA_STR} ON) + set(DLIB_USE_CUDA_COMPUTE_CAPABILITIES 50 CACHE STRING ${DLIB_USE_CUDA_COMPUTE_CAPABILITIES_STR}) + option(DLIB_PNG_SUPPORT ${DLIB_PNG_SUPPORT_STR} ON) + option(DLIB_GIF_SUPPORT ${DLIB_GIF_SUPPORT_STR} ON) + #option(DLIB_USE_FFTW ${DLIB_USE_FFTW_STR} ON) + option(DLIB_USE_MKL_FFT ${DLIB_USE_MKL_FFT_STR} ON) + option(DLIB_USE_FFMPEG ${DLIB_USE_FFMPEG_STR} ON) + endif () + toggle_preprocessor_switch(DLIB_JPEG_SUPPORT) + toggle_preprocessor_switch(DLIB_WEBP_SUPPORT) + toggle_preprocessor_switch(DLIB_USE_BLAS) + toggle_preprocessor_switch(DLIB_USE_LAPACK) + toggle_preprocessor_switch(DLIB_USE_CUDA) + toggle_preprocessor_switch(DLIB_PNG_SUPPORT) + toggle_preprocessor_switch(DLIB_GIF_SUPPORT) + #toggle_preprocessor_switch(DLIB_USE_FFTW) + toggle_preprocessor_switch(DLIB_USE_MKL_FFT) + toggle_preprocessor_switch(DLIB_USE_FFMPEG) + + + set(source_files + base64/base64_kernel_1.cpp + bigint/bigint_kernel_1.cpp + bigint/bigint_kernel_2.cpp + bit_stream/bit_stream_kernel_1.cpp + entropy_decoder/entropy_decoder_kernel_1.cpp + entropy_decoder/entropy_decoder_kernel_2.cpp + entropy_encoder/entropy_encoder_kernel_1.cpp + entropy_encoder/entropy_encoder_kernel_2.cpp + md5/md5_kernel_1.cpp + tokenizer/tokenizer_kernel_1.cpp + unicode/unicode.cpp + test_for_odr_violations.cpp + ) + + set(dlib_needed_public_libraries) + set(dlib_needed_public_includes) + set(dlib_needed_public_cflags) + set(dlib_needed_public_ldflags) + set(dlib_needed_private_libraries) + set(dlib_needed_private_includes) + + if (DLIB_ISO_CPP_ONLY) + add_library(dlib ${source_files}) + else () + + set(source_files ${source_files} + sockets/sockets_kernel_1.cpp + bsp/bsp.cpp + dir_nav/dir_nav_kernel_1.cpp + dir_nav/dir_nav_kernel_2.cpp + dir_nav/dir_nav_extensions.cpp + gui_widgets/fonts.cpp + linker/linker_kernel_1.cpp + logger/extra_logger_headers.cpp + logger/logger_kernel_1.cpp + logger/logger_config_file.cpp + misc_api/misc_api_kernel_1.cpp + misc_api/misc_api_kernel_2.cpp + sockets/sockets_extensions.cpp + sockets/sockets_kernel_2.cpp + sockstreambuf/sockstreambuf.cpp + sockstreambuf/sockstreambuf_unbuffered.cpp + server/server_kernel.cpp + server/server_iostream.cpp + server/server_http.cpp + threads/multithreaded_object_extension.cpp + threads/threaded_object_extension.cpp + threads/threads_kernel_1.cpp + threads/threads_kernel_2.cpp + threads/threads_kernel_shared.cpp + threads/thread_pool_extension.cpp + threads/async.cpp + timer/timer.cpp + stack_trace.cpp + cuda/cpu_dlib.cpp + cuda/tensor_tools.cpp + data_io/image_dataset_metadata.cpp + data_io/mnist.cpp + data_io/cifar.cpp + global_optimization/global_function_search.cpp + filtering/kalman_filter.cpp + svm/auto.cpp + ) + + if (UNIX) + set(CMAKE_THREAD_PREFER_PTHREAD ON) + find_package(Threads REQUIRED) + list(APPEND dlib_needed_private_libraries ${CMAKE_THREAD_LIBS_INIT}) + endif () + + # we want to link to the right stuff depending on our platform. + if (WIN32 AND NOT CYGWIN) ############################################################################### + if (DLIB_NO_GUI_SUPPORT) + list(APPEND dlib_needed_private_libraries ws2_32 winmm) + else () + list(APPEND dlib_needed_private_libraries ws2_32 winmm comctl32 gdi32 imm32) + endif () + elseif (APPLE) ############################################################################ + set(CMAKE_MACOSX_RPATH 1) + if (NOT DLIB_NO_GUI_SUPPORT) + find_package(X11 QUIET) + if (X11_FOUND) + # If both X11 and anaconda are installed, it's possible for the + # anaconda path to appear before /opt/X11, so we remove anaconda. + foreach (ITR ${X11_INCLUDE_DIR}) + if ("${ITR}" MATCHES "(.*)(Ana|ana|mini)conda(.*)") + list(REMOVE_ITEM X11_INCLUDE_DIR ${ITR}) + endif () + endforeach (ITR) + list(APPEND dlib_needed_public_includes ${X11_INCLUDE_DIR}) + list(APPEND dlib_needed_public_libraries ${X11_LIBRARIES}) + else () + find_library(xlib X11) + # Make sure X11 is in the include path. Note that we look for + # Xlocale.h rather than Xlib.h because it avoids finding a partial + # copy of the X11 headers on systems with anaconda installed. + find_path(xlib_path Xlocale.h + PATHS + /Developer/SDKs/MacOSX10.4u.sdk/usr/X11R6/include + /opt/local/include + PATH_SUFFIXES X11 + ) + if (xlib AND xlib_path) + get_filename_component(x11_path ${xlib_path} PATH CACHE) + list(APPEND dlib_needed_public_includes ${x11_path}) + list(APPEND dlib_needed_public_libraries ${xlib}) + set(X11_FOUND 1) + endif () + endif () + if (NOT X11_FOUND) + message(" *****************************************************************************") + message(" *** DLIB GUI SUPPORT DISABLED BECAUSE X11 DEVELOPMENT LIBRARIES NOT FOUND ***") + message(" *** Make sure XQuartz is installed if you want GUI support. ***") + message(" *** You can download XQuartz from: https://www.xquartz.org/ ***") + message(" *****************************************************************************") + set(DLIB_NO_GUI_SUPPORT ON CACHE STRING ${DLIB_NO_GUI_SUPPORT_STR} FORCE) + enable_preprocessor_switch(DLIB_NO_GUI_SUPPORT) + endif () + endif () + + mark_as_advanced(xlib xlib_path x11_path) + else () ################################################################################## + # link to the socket library if it exists. this is something you need on solaris + find_library(socketlib socket) + if (socketlib) + list(APPEND dlib_needed_private_libraries ${socketlib}) + endif () + + if (NOT DLIB_NO_GUI_SUPPORT) + include(FindX11) + if (X11_FOUND) + list(APPEND dlib_needed_private_includes ${X11_INCLUDE_DIR}) + list(APPEND dlib_needed_private_libraries ${X11_LIBRARIES}) + else () + message(" *****************************************************************************") + message(" *** DLIB GUI SUPPORT DISABLED BECAUSE X11 DEVELOPMENT LIBRARIES NOT FOUND ***") + message(" *** Make sure libx11-dev is installed if you want GUI support. ***") + message(" *** On Ubuntu run: sudo apt-get install libx11-dev ***") + message(" *****************************************************************************") + set(DLIB_NO_GUI_SUPPORT ON CACHE STRING ${DLIB_NO_GUI_SUPPORT_STR} FORCE) + enable_preprocessor_switch(DLIB_NO_GUI_SUPPORT) + endif () + endif () + + mark_as_advanced(nsllib socketlib) + endif () ################################################################################## + + if (NOT DLIB_NO_GUI_SUPPORT) + set(source_files ${source_files} + gui_widgets/widgets.cpp + gui_widgets/drawable.cpp + gui_widgets/canvas_drawing.cpp + gui_widgets/style.cpp + gui_widgets/base_widgets.cpp + gui_core/gui_core_kernel_1.cpp + gui_core/gui_core_kernel_2.cpp + ) + endif () + + INCLUDE(CheckFunctionExists) + + if (DLIB_GIF_SUPPORT) + find_package(GIF QUIET) + if (GIF_FOUND) + list(APPEND dlib_needed_public_includes ${GIF_INCLUDE_DIR}) + list(APPEND dlib_needed_public_libraries ${GIF_LIBRARY}) + else () + set(DLIB_GIF_SUPPORT OFF CACHE STRING ${DLIB_GIF_SUPPORT_STR} FORCE) + toggle_preprocessor_switch(DLIB_GIF_SUPPORT) + endif () + endif () + + if (DLIB_PNG_SUPPORT) + include(cmake_utils/find_libpng.cmake) + if (PNG_FOUND) + list(APPEND dlib_needed_private_includes ${PNG_INCLUDE_DIR}) + list(APPEND dlib_needed_private_libraries ${PNG_LIBRARIES}) + else () + # If we can't find libpng then statically compile it in. + include_directories(external/libpng external/zlib) + set(source_files ${source_files} + external/libpng/arm/arm_init.c + external/libpng/arm/filter_neon_intrinsics.c + external/libpng/arm/palette_neon_intrinsics.c + external/libpng/png.c + external/libpng/pngerror.c + external/libpng/pngget.c + external/libpng/pngmem.c + external/libpng/pngpread.c + external/libpng/pngread.c + external/libpng/pngrio.c + external/libpng/pngrtran.c + external/libpng/pngrutil.c + external/libpng/pngset.c + external/libpng/pngtrans.c + external/libpng/pngwio.c + external/libpng/pngwrite.c + external/libpng/pngwtran.c + external/libpng/pngwutil.c + external/zlib/adler32.c + external/zlib/compress.c + external/zlib/crc32.c + external/zlib/deflate.c + external/zlib/gzclose.c + external/zlib/gzlib.c + external/zlib/gzread.c + external/zlib/gzwrite.c + external/zlib/infback.c + external/zlib/inffast.c + external/zlib/inflate.c + external/zlib/inftrees.c + external/zlib/trees.c + external/zlib/uncompr.c + external/zlib/zutil.c + ) + + include(cmake_utils/check_if_neon_available.cmake) + if (ARM_NEON_IS_AVAILABLE) + message(STATUS "NEON instructions will be used for libpng.") + enable_language(ASM) + set(source_files ${source_files} + external/libpng/arm/arm_init.c + external/libpng/arm/filter_neon_intrinsics.c + external/libpng/arm/filter_neon.S + ) + set_source_files_properties(external/libpng/arm/filter_neon.S PROPERTIES COMPILE_FLAGS "${CMAKE_ASM_FLAGS} ${CMAKE_CXX_FLAGS} -x assembler-with-cpp") + endif () + endif () + set(source_files ${source_files} + image_loader/png_loader.cpp + image_saver/save_png.cpp + ) + endif () + + if (DLIB_JPEG_SUPPORT) + include(cmake_utils/find_libjpeg.cmake) + if (JPEG_FOUND) + list(APPEND dlib_needed_private_includes ${JPEG_INCLUDE_DIR}) + list(APPEND dlib_needed_private_libraries ${JPEG_LIBRARY}) + else () + # If we can't find libjpeg then statically compile it in. + add_definitions(-DDLIB_JPEG_STATIC) + set(source_files ${source_files} + external/libjpeg/jaricom.c + external/libjpeg/jcapimin.c + external/libjpeg/jcapistd.c + external/libjpeg/jcarith.c + external/libjpeg/jccoefct.c + external/libjpeg/jccolor.c + external/libjpeg/jcdctmgr.c + external/libjpeg/jchuff.c + external/libjpeg/jcinit.c + external/libjpeg/jcmainct.c + external/libjpeg/jcmarker.c + external/libjpeg/jcmaster.c + external/libjpeg/jcomapi.c + external/libjpeg/jcparam.c + external/libjpeg/jcprepct.c + external/libjpeg/jcsample.c + external/libjpeg/jdapimin.c + external/libjpeg/jdapistd.c + external/libjpeg/jdarith.c + external/libjpeg/jdatadst.c + external/libjpeg/jdatasrc.c + external/libjpeg/jdcoefct.c + external/libjpeg/jdcolor.c + external/libjpeg/jddctmgr.c + external/libjpeg/jdhuff.c + external/libjpeg/jdinput.c + external/libjpeg/jdmainct.c + external/libjpeg/jdmarker.c + external/libjpeg/jdmaster.c + external/libjpeg/jdmerge.c + external/libjpeg/jdpostct.c + external/libjpeg/jdsample.c + external/libjpeg/jerror.c + external/libjpeg/jfdctflt.c + external/libjpeg/jfdctfst.c + external/libjpeg/jfdctint.c + external/libjpeg/jidctflt.c + external/libjpeg/jidctfst.c + external/libjpeg/jidctint.c + external/libjpeg/jmemmgr.c + external/libjpeg/jmemnobs.c + external/libjpeg/jquant1.c + external/libjpeg/jquant2.c + external/libjpeg/jutils.c + ) + endif () + set(source_files ${source_files} + image_loader/jpeg_loader.cpp + image_saver/save_jpeg.cpp + ) + endif () + if (DLIB_WEBP_SUPPORT) + include(cmake_utils/find_libwebp.cmake) + if (WEBP_FOUND) + list(APPEND dlib_needed_private_includes ${WEBP_INCLUDE_DIR}) + list(APPEND dlib_needed_private_libraries ${WEBP_LIBRARY}) + set(source_files ${source_files} + image_loader/webp_loader.cpp + image_saver/save_webp.cpp + ) + else () + set(DLIB_WEBP_SUPPORT OFF CACHE BOOL ${DLIB_WEBP_SUPPORT_STR} FORCE) + toggle_preprocessor_switch(DLIB_WEBP_SUPPORT) + endif () + endif () + + + if (DLIB_USE_BLAS OR DLIB_USE_LAPACK OR DLIB_USE_MKL_FFT) + if (DLIB_USE_MKL_WITH_TBB AND DLIB_USE_MKL_SEQUENTIAL) + set(DLIB_USE_MKL_SEQUENTIAL OFF CACHE STRING ${DLIB_USE_MKL_SEQUENTIAL_STR} FORCE) + toggle_preprocessor_switch(DLIB_USE_MKL_SEQUENTIAL) + message(STATUS "Disabling DLIB_USE_MKL_SEQUENTIAL. It cannot be used simultaneously with DLIB_USE_MKL_WITH_TBB.") + endif () + + + # Try to find BLAS, LAPACK and MKL + include(cmake_utils/find_blas.cmake) + + if (DLIB_USE_BLAS) + if (blas_found) + list(APPEND dlib_needed_private_libraries ${blas_libraries}) + else () + set(DLIB_USE_BLAS OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE) + toggle_preprocessor_switch(DLIB_USE_BLAS) + endif () + endif () + + if (DLIB_USE_LAPACK) + if (lapack_found) + list(APPEND dlib_needed_private_libraries ${lapack_libraries}) + if (lapack_with_underscore) + set(LAPACK_FORCE_UNDERSCORE 1) + enable_preprocessor_switch(LAPACK_FORCE_UNDERSCORE) + elseif (lapack_without_underscore) + set(LAPACK_FORCE_NOUNDERSCORE 1) + enable_preprocessor_switch(LAPACK_FORCE_NOUNDERSCORE) + endif () + else () + set(DLIB_USE_LAPACK OFF CACHE STRING ${DLIB_USE_LAPACK_STR} FORCE) + toggle_preprocessor_switch(DLIB_USE_LAPACK) + endif () + endif () + + if (DLIB_USE_MKL_FFT) + if (found_intel_mkl AND found_intel_mkl_headers) + list(APPEND dlib_needed_public_includes ${mkl_include_dir}) + list(APPEND dlib_needed_public_libraries ${mkl_libraries}) + else () + set(DLIB_USE_MKL_FFT OFF CACHE STRING ${DLIB_USE_MKL_FFT_STR} FORCE) + toggle_preprocessor_switch(DLIB_USE_MKL_FFT) + endif () + endif () + endif () + + + if (DLIB_USE_CUDA) + find_package(CUDA 7.5) + + if (CUDA_VERSION VERSION_GREATER 9.1 AND CMAKE_VERSION VERSION_LESS 3.12.2) + # This bit of weirdness is to work around a bug in cmake + list(REMOVE_ITEM CUDA_CUBLAS_LIBRARIES "CUDA_cublas_device_LIBRARY-NOTFOUND") + endif () + + message("CUDA_VERSION: " ${CUDA_VERSION}) + + if (CUDA_FOUND AND MSVC AND NOT CUDA_CUBLAS_LIBRARIES AND "${CMAKE_SIZEOF_VOID_P}" EQUAL "4") + message(WARNING "You have CUDA installed, but we can't use it unless you put visual studio in 64bit mode.") + set(CUDA_FOUND 0) + endif () + message("11 CUDA_FOUND: " ${CUDA_VERSION}) + + + if (NOT CUDA_CUBLAS_LIBRARIES) + message(STATUS "Found CUDA, but CMake was unable to find the cuBLAS libraries that should be part of every basic CUDA " + "install. Your CUDA install is somehow broken or incomplete. Since cuBLAS is required for dlib to use CUDA we won't use CUDA.") + set(CUDA_FOUND 0) + endif () + message("22 CUDA_FOUND: " ${CUDA_VERSION}) + + if (CUDA_FOUND) + # There is some bug in cmake that causes it to mess up the + # -std=c++11 option if you let it propagate it to nvcc in some + # cases. So instead we disable this and manually include + # things from CMAKE_CXX_FLAGS in the CUDA_NVCC_FLAGS list below. + if (APPLE) + set(CUDA_PROPAGATE_HOST_FLAGS OFF) + # Grab all the -D flags from CMAKE_CXX_FLAGS so we can pass them + # to nvcc. + string(REGEX MATCHALL "-D[^ ]*" FLAGS_FOR_NVCC "${CMAKE_CXX_FLAGS}") + + # Check if we are being built as part of a pybind11 module. + if (COMMAND pybind11_add_module) + # Don't export unnecessary symbols. + list(APPEND FLAGS_FOR_NVCC "-Xcompiler=-fvisibility=hidden") + endif () + endif () + + set(CUDA_HOST_COMPILATION_CPP ON) + string(REPLACE "," ";" DLIB_CUDA_COMPUTE_CAPABILITIES ${DLIB_USE_CUDA_COMPUTE_CAPABILITIES}) + message("DLIB_CUDA_COMPUTE_CAPABILITIES: " ${DLIB_CUDA_COMPUTE_CAPABILITIES}) + + foreach (CAP ${DLIB_CUDA_COMPUTE_CAPABILITIES}) + list(APPEND CUDA_NVCC_FLAGS "-gencode arch=compute_${CAP},code=[sm_${CAP},compute_${CAP}]") + endforeach () + # Note that we add __STRICT_ANSI__ to avoid freaking out nvcc with gcc specific + # magic in the standard C++ header files (since nvcc uses gcc headers on linux). + list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__;-D_MWAITXINTRIN_H_INCLUDED;-D_FORCE_INLINES;${FLAGS_FOR_NVCC}") + list(APPEND CUDA_NVCC_FLAGS ${active_preprocessor_switches}) + if (NOT DLIB_IN_PROJECT_BUILD) + LIST(APPEND CUDA_NVCC_FLAGS -DDLIB__CMAKE_GENERATED_A_CONFIG_H_FILE) + endif () + if (NOT MSVC) + list(APPEND CUDA_NVCC_FLAGS "-std=c++14") + endif () + if (CMAKE_POSITION_INDEPENDENT_CODE) + # sometimes this setting isn't propagated to NVCC, which then causes the + # compile to fail. So make sure it's propagated. + if (NOT MSVC) # Visual studio doesn't have -fPIC so don't do it in that case. + list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC") + endif () + endif () + + message("CUDA_NVCC_FLAGS: " ${CUDA_NVCC_FLAGS}) + include(cmake_utils/test_for_cudnn/find_cudnn.txt) + + if (cudnn AND cudnn_include AND NOT DEFINED cuda_test_compile_worked AND NOT DEFINED cudnn_test_compile_worked) + # make sure cuda is really working by doing a test compile + message(STATUS "Building a CUDA test project to see if your compiler is compatible with CUDA...") + + set(CUDA_TEST_CMAKE_FLAGS + "-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" + "-DCMAKE_INCLUDE_PATH=${CMAKE_INCLUDE_PATH}" + "-DCMAKE_LIBRARY_PATH=${CMAKE_LIBRARY_PATH}") + + if (NOT MSVC) # see https://github.com/davisking/dlib/issues/363 + list(APPEND CUDA_TEST_CMAKE_FLAGS "-DCUDA_HOST_COMPILER=${CUDA_HOST_COMPILER}") + endif () + + try_compile(cuda_test_compile_worked + ${PROJECT_BINARY_DIR}/cuda_test_build + ${PROJECT_SOURCE_DIR}/cmake_utils/test_for_cuda cuda_test + CMAKE_FLAGS ${CUDA_TEST_CMAKE_FLAGS} + OUTPUT_VARIABLE try_compile_output_message + ) + if (NOT cuda_test_compile_worked) + string(REPLACE "\n" "\n *** " try_compile_output_message "${try_compile_output_message}") + message(STATUS "*****************************************************************************************************************") + message(STATUS "*** CUDA was found but your compiler failed to compile a simple CUDA program so dlib isn't going to use CUDA. ") + message(STATUS "*** The output of the failed CUDA test compile is shown below: ") + message(STATUS "*** ") + message(STATUS "*** ${try_compile_output_message}") + message(STATUS "*****************************************************************************************************************") + else () + message(STATUS "Building a cuDNN test project to check if you have the right version of cuDNN installed...") + try_compile(cudnn_test_compile_worked + ${PROJECT_BINARY_DIR}/cudnn_test_build + ${PROJECT_SOURCE_DIR}/cmake_utils/test_for_cudnn cudnn_test + CMAKE_FLAGS ${CUDA_TEST_CMAKE_FLAGS} + OUTPUT_VARIABLE try_compile_output_message + ) + if (NOT cudnn_test_compile_worked) + string(REPLACE "\n" "\n *** " try_compile_output_message "${try_compile_output_message}") + message(STATUS "*****************************************************************************************************") + message(STATUS "*** Found cuDNN, but we failed to compile the dlib/cmake_utils/test_for_cudnn project. ") + message(STATUS "*** You either have an unsupported version of cuDNN or something is wrong with your cudDNN install.") + message(STATUS "*** Since a functional cuDNN is not found DLIB WILL NOT USE CUDA. ") + message(STATUS "*** The output of the failed test_for_cudnn build is: ") + message(STATUS "*** ") + message(STATUS "*** ${try_compile_output_message}") + message(STATUS "*****************************************************************************************************") + endif () + endif () + endif () + + message("CUDA_cusolver_LIBRARY=" ${CUDA_cusolver_LIBRARY}) + # Find where cuSOLVER is since the FindCUDA cmake package doesn't + # bother to look for it in older versions of cmake. + if (NOT CUDA_cusolver_LIBRARY) + get_filename_component(cuda_blas_path "${CUDA_CUBLAS_LIBRARIES}" DIRECTORY) + find_library(CUDA_cusolver_LIBRARY cusolver HINTS ${cuda_blas_path}) + # CUDA 10.1 doesn't install symbolic links to libcusolver.so in + # the usual place. This is probably a bug in the cuda + # installer. In any case, If we haven't found cusolver yet go + # look in the cuda install folder for it. New versions of cmake + # do this correctly, but older versions need help. + if (NOT CUDA_cusolver_LIBRARY) + find_library(CUDA_cusolver_LIBRARY cusolver HINTS + /usr/local/cuda/lib64/ + ) + endif () + mark_as_advanced(CUDA_cusolver_LIBRARY) + endif () + # Also find OpenMP since cuSOLVER needs it. Importantly, we only + # look for one to link to if our use of BLAS, specifically the + # Intel MKL, hasn't already decided what to use. This is because + # it makes the MKL bug out if you link to another openmp lib other + # than Intel's when you use the MKL. I'm also not really sure when + # explicit linking to openmp became unnecessary, but for + # sufficiently older versions of cuda it was needed. Then in + # versions of cmake newer than 3.11 linking to openmp started to + # mess up the switches passed to nvcc, so you can't just leave + # these "try to link to openmp" statements here going forward. Fun + # times. + if (CUDA_VERSION VERSION_LESS "9.1" AND NOT openmp_libraries AND NOT MSVC AND NOT XCODE AND NOT APPLE) + find_package(OpenMP) + if (OPENMP_FOUND) + set(openmp_libraries ${OpenMP_CXX_FLAGS}) + else () + message(STATUS "*** Didn't find OpenMP, which is required to use CUDA. ***") + set(CUDA_FOUND 0) + endif () + endif () + endif () + + if (CUDA_FOUND AND cudnn AND cuda_test_compile_worked AND cudnn_test_compile_worked AND cudnn_include) + set(source_files ${source_files} + cuda/cuda_dlib.cu + cuda/cudnn_dlibapi.cpp + cuda/cublas_dlibapi.cpp + cuda/cusolver_dlibapi.cu + cuda/curand_dlibapi.cpp + cuda/cuda_data_ptr.cpp + cuda/gpu_data.cpp + ) + list(APPEND dlib_needed_private_libraries ${CUDA_CUBLAS_LIBRARIES}) + list(APPEND dlib_needed_private_libraries ${cudnn}) + list(APPEND dlib_needed_private_libraries ${CUDA_curand_LIBRARY}) + list(APPEND dlib_needed_private_libraries ${CUDA_cusolver_LIBRARY}) + list(APPEND dlib_needed_private_libraries ${CUDA_CUDART_LIBRARY}) + if (openmp_libraries) + list(APPEND dlib_needed_private_libraries ${openmp_libraries}) + endif () + + include_directories(${cudnn_include}) + message(STATUS "Enabling CUDA support for dlib. DLIB WILL USE CUDA, compute capabilities: ${DLIB_CUDA_COMPUTE_CAPABILITIES}") + else () + set(DLIB_USE_CUDA OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE) + toggle_preprocessor_switch(DLIB_USE_CUDA) + if (NOT CUDA_FOUND) + message(STATUS "DID NOT FIND CUDA") + endif () + message(STATUS "Disabling CUDA support for dlib. DLIB WILL NOT USE CUDA") + endif () + endif () + + + if (DLIB_LINK_WITH_SQLITE3) + find_library(sqlite sqlite3) + # make sure sqlite3.h is in the include path + find_path(sqlite_path sqlite3.h) + if (sqlite AND sqlite_path) + list(APPEND dlib_needed_public_includes ${sqlite_path}) + list(APPEND dlib_needed_public_libraries ${sqlite}) + else () + set(DLIB_LINK_WITH_SQLITE3 OFF CACHE STRING ${DLIB_LINK_WITH_SQLITE3_STR} FORCE) + endif () + mark_as_advanced(sqlite sqlite_path) + endif () + + + if (DLIB_USE_FFTW) + find_library(fftw fftw3) + # make sure fftw3.h is in the include path + find_path(fftw_path fftw3.h) + if (fftw AND fftw_path) + list(APPEND dlib_needed_private_includes ${fftw_path}) + list(APPEND dlib_needed_private_libraries ${fftw}) + else () + set(DLIB_USE_FFTW OFF CACHE STRING ${DLIB_USE_FFTW_STR} FORCE) + toggle_preprocessor_switch(DLIB_USE_FFTW) + endif () + mark_as_advanced(fftw fftw_path) + endif () + + if (DLIB_USE_FFMPEG) + include(cmake_utils/find_ffmpeg.cmake) + if (FFMPEG_FOUND) + list(APPEND dlib_needed_public_includes ${FFMPEG_INCLUDE_DIRS}) + list(APPEND dlib_needed_public_libraries ${FFMPEG_LINK_LIBRARIES}) + list(APPEND dlib_needed_public_cflags ${FFMPEG_CFLAGS}) + list(APPEND dlib_needed_public_ldflags ${FFMPEG_LDFLAGS}) + enable_preprocessor_switch(DLIB_USE_FFMPEG) + else () + set(DLIB_USE_FFMPEG OFF CACHE BOOL ${DLIB_USE_FFMPEG_STR} FORCE) + disable_preprocessor_switch(DLIB_USE_FFMPEG) + endif () + endif () + + # Tell CMake to build dlib via add_library()/cuda_add_library() + message("source_files: ${source_files}") + if (DLIB_USE_CUDA) + message("dlib_needed_public_includes: " ${dlib_needed_public_includes}) + # The old cuda_add_library() command doesn't support CMake's newer dependency + # stuff, so we have to set the include path manually still, which we do here. + include_directories(${dlib_needed_public_includes}) + cuda_add_library(dlib ${source_files}) + else () + add_library(dlib ${source_files}) + endif () + + endif () ##### end of if NOT DLIB_ISO_CPP_ONLY ########################################################## + + + target_include_directories(dlib + INTERFACE $ + INTERFACE $ + PUBLIC ${dlib_needed_public_includes} + PRIVATE ${dlib_needed_private_includes} + ) + target_link_libraries(dlib PUBLIC ${dlib_needed_public_libraries} ${dlib_needed_public_ldflags}) + target_link_libraries(dlib PRIVATE ${dlib_needed_private_libraries}) + target_compile_options(dlib PUBLIC ${dlib_needed_public_cflags}) + if (DLIB_IN_PROJECT_BUILD) + target_compile_options(dlib PUBLIC ${active_preprocessor_switches}) + else () + # These are private in this case because they will be controlled by the + # contents of dlib/config.h once it's installed. But for in project + # builds, there is no real config.h so they are public in the above case. + target_compile_options(dlib PRIVATE ${active_preprocessor_switches}) + # Do this so that dlib/config.h won't set DLIB_NOT_CONFIGURED. This will then allow + # the code in dlib/threads_kernel_shared.cpp to emit a linker error for users who + # don't use the configured config.h file generated by cmake. + target_compile_options(dlib PRIVATE -DDLIB__CMAKE_GENERATED_A_CONFIG_H_FILE) + + # Do this so that dlib/config.h can record the version of dlib it's configured with + # and ultimately issue a linker error to people who try to use a binary dlib that is + # the wrong version. + set(DLIB_CHECK_FOR_VERSION_MISMATCH + DLIB_VERSION_MISMATCH_CHECK__EXPECTED_VERSION_${CPACK_PACKAGE_VERSION_MAJOR}_${CPACK_PACKAGE_VERSION_MINOR}_${CPACK_PACKAGE_VERSION_PATCH}) + target_compile_options(dlib PRIVATE "-DDLIB_CHECK_FOR_VERSION_MISMATCH=${DLIB_CHECK_FOR_VERSION_MISMATCH}") + endif () + + + # Allow the unit tests to ask us to compile the all/source.cpp file just to make sure it compiles. + if (DLIB_TEST_COMPILE_ALL_SOURCE_CPP) + add_library(dlib_all_source_cpp STATIC all/source.cpp) + target_link_libraries(dlib_all_source_cpp dlib) + target_compile_options(dlib_all_source_cpp PUBLIC ${active_preprocessor_switches}) + target_compile_features(dlib_all_source_cpp PUBLIC cxx_std_14) + endif () + + target_compile_features(dlib PUBLIC cxx_std_14) + if ((MSVC AND CMAKE_VERSION VERSION_LESS 3.11)) + target_compile_options(dlib PUBLIC ${active_compile_opts}) + target_compile_options(dlib PRIVATE ${active_compile_opts_private}) + else () + target_compile_options(dlib PUBLIC $<$:${active_compile_opts}>) + target_compile_options(dlib PRIVATE $<$:${active_compile_opts_private}>) + endif () + + # Install the library + if (NOT DLIB_IN_PROJECT_BUILD) + string(REPLACE ";" " " pkg_config_dlib_needed_libraries "${dlib_needed_public_libraries}") + # Make the -I include options for pkg-config + foreach (ITR ${dlib_needed_public_includes}) + set(pkg_config_dlib_needed_includes "${pkg_config_dlib_needed_includes} -I${ITR}") + endforeach () + set_target_properties(dlib PROPERTIES + VERSION ${VERSION}) + install(TARGETS dlib + EXPORT dlib + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} # Windows considers .dll to be runtime artifacts + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) + + install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dlib + FILES_MATCHING + PATTERN "*.h" + PATTERN "*.cmake" + PATTERN "*_tutorial.txt" + PATTERN "cassert" + PATTERN "cstring" + PATTERN "fstream" + PATTERN "iomanip" + PATTERN "iosfwd" + PATTERN "iostream" + PATTERN "istream" + PATTERN "locale" + PATTERN "ostream" + PATTERN "sstream" + REGEX "${CMAKE_CURRENT_BINARY_DIR}" EXCLUDE) + + + configure_file(${PROJECT_SOURCE_DIR}/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/config.h) + # overwrite config.h with the configured one + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/config.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dlib) + + configure_file(${PROJECT_SOURCE_DIR}/revision.h.in ${CMAKE_CURRENT_BINARY_DIR}/revision.h) + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/revision.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dlib) + + ## Config.cmake generation and installation + + set(ConfigPackageLocation "${CMAKE_INSTALL_LIBDIR}/cmake/dlib") + install(EXPORT dlib + NAMESPACE dlib:: + DESTINATION ${ConfigPackageLocation}) + + configure_file(cmake_utils/dlibConfig.cmake.in "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfig.cmake" @ONLY) + + include(CMakePackageConfigHelpers) + write_basic_package_version_file( + "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfigVersion.cmake" + VERSION ${VERSION} + COMPATIBILITY AnyNewerVersion + ) + + install(FILES + "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfig.cmake" + "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfigVersion.cmake" + DESTINATION ${ConfigPackageLocation}) + + ## dlib-1.pc generation and installation + + configure_file("cmake_utils/dlib.pc.in" "dlib-1.pc" @ONLY) + install(FILES "${CMAKE_CURRENT_BINARY_DIR}/dlib-1.pc" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") + + # Add a cpack "package" target. This will create an archive containing + # the built library file, the header files, and cmake and pkgconfig + # configuration files. + include(CPack) + + endif () + +endif () + +if (MSVC) + # Give the output library files names that are unique functions of the + # visual studio mode that compiled them. We do this so that people who + # compile dlib and then copy the .lib files around (which they shouldn't be + # doing in the first place!) will hopefully be slightly less confused by + # what happens since, at the very least, the filenames will indicate what + # visual studio runtime they go with. + math(EXPR numbits ${CMAKE_SIZEOF_VOID_P}*8) + set_target_properties(dlib PROPERTIES DEBUG_POSTFIX "${VERSION}_debug_${numbits}bit_msvc${MSVC_VERSION}") + set_target_properties(dlib PROPERTIES RELEASE_POSTFIX "${VERSION}_release_${numbits}bit_msvc${MSVC_VERSION}") + set_target_properties(dlib PROPERTIES MINSIZEREL_POSTFIX "${VERSION}_minsizerel_${numbits}bit_msvc${MSVC_VERSION}") + set_target_properties(dlib PROPERTIES RELWITHDEBINFO_POSTFIX "${VERSION}_relwithdebinfo_${numbits}bit_msvc${MSVC_VERSION}") +endif () + +# Check if we are being built as part of a pybind11 module. +if (COMMAND pybind11_add_module) + # Don't export unnecessary symbols. + set_target_properties(dlib PROPERTIES CXX_VISIBILITY_PRESET "hidden") + set_target_properties(dlib PROPERTIES CUDA_VISIBILITY_PRESET "hidden") +endif () + +if (WIN32 AND mkl_iomp_dll) + # If we are using the Intel MKL on windows then try and copy the iomp dll + # file to the output folder. We do this since a very large number of + # windows users don't understand that they need to add the Intel MKL's + # folders to their PATH to use the Intel MKL. They then complain on the + # dlib forums. Copying the Intel MKL dlls to the output directory removes + # the need to add the Intel MKL to the PATH. + if (CMAKE_LIBRARY_OUTPUT_DIRECTORY) + add_custom_command(TARGET dlib POST_BUILD + # In some newer versions of windows/visual studio the output Config folder doesn't + # exist at first, so you can't copy to it unless you make it yourself. So make + # sure the target folder exists first. + COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/" + COMMAND ${CMAKE_COMMAND} -E copy "${mkl_iomp_dll}" "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/" + ) + else () + add_custom_command(TARGET dlib POST_BUILD + # In some newer versions of windows/visual studio the output Config folder doesn't + # exist at first, so you can't copy to it unless you make it yourself. So make + # sure the target folder exists first. + COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/$/" + COMMAND ${CMAKE_COMMAND} -E copy "${mkl_iomp_dll}" "${CMAKE_BINARY_DIR}/$/" + ) + endif () +endif () + +add_library(dlib::dlib ALIAS dlib) diff --git a/dlib/LICENSE.txt b/dlib/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..127a5bc39ba030c7cb99cc0aedc4f280ffe27310 --- /dev/null +++ b/dlib/LICENSE.txt @@ -0,0 +1,23 @@ +Boost Software License - Version 1.0 - August 17th, 2003 + +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/dlib/algs.h b/dlib/algs.h new file mode 100644 index 0000000000000000000000000000000000000000..faa5ca1fe884247438a7bce11c7db73dc0386fbf --- /dev/null +++ b/dlib/algs.h @@ -0,0 +1,919 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifdef DLIB_ALL_SOURCE_END +#include "dlib_basic_cpp_build_tutorial.txt" +#endif + +#ifndef DLIB_ALGs_ +#define DLIB_ALGs_ + +// this file contains miscellaneous stuff + +// Give people who forget the -std=c++14 option a reminder +#if (defined(__GNUC__) && ((__GNUC__ >= 5 && __GNUC_MINOR__ >= 0) || (__GNUC__ > 5))) || \ + (defined(__clang__) && ((__clang_major__ >= 3 && __clang_minor__ >= 4) || (__clang_major__ >= 3))) + #if __cplusplus < 201402L + #error "Dlib requires C++14 support. Give your compiler the -std=c++14 option to enable it." + #endif +#endif + +#if defined __NVCC__ + // Disable the "statement is unreachable" message since it will go off on code that is + // actually reachable but just happens to not be reachable sometimes during certain + // template instantiations. + #ifdef __NVCC_DIAG_PRAGMA_SUPPORT__ + #pragma nv_diag_suppress code_is_unreachable + #else + #pragma diag_suppress code_is_unreachable + #endif +#endif + + +#ifdef _MSC_VER + +#if _MSC_VER < 1900 +#error "dlib versions newer than v19.1 use C++11 and therefore require Visual Studio 2015 or newer." +#endif + +// Disable the following warnings for Visual Studio + +// this is to disable the "'this' : used in base member initializer list" +// warning you get from some of the GUI objects since all the objects +// require that their parent class be passed into their constructor. +// In this case though it is totally safe so it is ok to disable this warning. +#pragma warning(disable : 4355) + +// This is a warning you get sometimes when Visual Studio performs a Koenig Lookup. +// This is a bug in visual studio. It is a totally legitimate thing to +// expect from a compiler. +#pragma warning(disable : 4675) + +// This is a warning you get from visual studio 2005 about things in the standard C++ +// library being "deprecated." I checked the C++ standard and it doesn't say jack +// about any of them (I checked the searchable PDF). So this warning is total Bunk. +#pragma warning(disable : 4996) + +// This is a warning you get from visual studio 2003: +// warning C4345: behavior change: an object of POD type constructed with an initializer +// of the form () will be default-initialized. +// I love it when this compiler gives warnings about bugs in previous versions of itself. +#pragma warning(disable : 4345) + + +// Disable warnings about conversion from size_t to unsigned long and long. +#pragma warning(disable : 4267) + +// Disable warnings about conversion from double to float +#pragma warning(disable : 4244) +#pragma warning(disable : 4305) + +// Disable "warning C4180: qualifier applied to function type has no meaning; ignored". +// This warning happens often in generic code that works with functions and isn't useful. +#pragma warning(disable : 4180) + +// Disable "warning C4290: C++ exception specification ignored except to indicate a function is not __declspec(nothrow)" +#pragma warning(disable : 4290) + + +// DNN module uses template-based network declaration that leads to very long +// type names. Visual Studio will produce Warning C4503 in such cases. https://msdn.microsoft.com/en-us/library/074af4b6.aspx says +// that correct binaries are still produced even when this warning happens, but linker errors from visual studio, if they occur could be confusing. +#pragma warning( disable: 4503 ) + + +#endif + +#ifdef __BORLANDC__ +// Disable the following warnings for the Borland Compilers +// +// These warnings just say that the compiler is refusing to inline functions with +// loops or try blocks in them. +// +#pragma option -w-8027 +#pragma option -w-8026 +#endif + +#include // for the exceptions + +#ifdef __CYGWIN__ +namespace std +{ + typedef std::basic_string wstring; +} +#endif + +#include "platform.h" +#include "windows_magic.h" + + +#include // for std::swap +#include // for std::bad_alloc +#include +#include +#include +#include // for std::isfinite for is_finite() +#include "assert.h" +#include "error.h" +#include "noncopyable.h" +#include "enable_if.h" +#include "uintn.h" +#include "numeric_constants.h" +#include "memory_manager_stateless/memory_manager_stateless_kernel_1.h" // for the default memory manager +#include "type_traits.h" + +// ---------------------------------------------------------------------------------------- +/*!A _dT !*/ + +template +inline charT _dTcast (const char a, const wchar_t b); +template <> +inline char _dTcast (const char a, const wchar_t ) { return a; } +template <> +inline wchar_t _dTcast (const char , const wchar_t b) { return b; } + +template +inline const charT* _dTcast ( const char* a, const wchar_t* b); +template <> +inline const char* _dTcast ( const char* a, const wchar_t* ) { return a; } +template <> +inline const wchar_t* _dTcast ( const char* , const wchar_t* b) { return b; } + + +#define _dT(charT,str) _dTcast(str,L##str) +/*! + requires + - charT == char or wchar_t + - str == a string or character literal + ensures + - returns the literal in the form of a charT type literal. +!*/ + +// ---------------------------------------------------------------------------------------- + + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /*!A default_memory_manager + + This memory manager just calls new and delete directly. + + !*/ + typedef memory_manager_stateless_kernel_1 default_memory_manager; + +// ---------------------------------------------------------------------------------------- + + /*!A swap !*/ + // make swap available in the dlib namespace + using std::swap; + +// ---------------------------------------------------------------------------------------- + + /*! + Here is where I define my return codes. It is + important that they all be < 0. + !*/ + + enum general_return_codes + { + TIMEOUT = -1, + WOULDBLOCK = -2, + OTHER_ERROR = -3, + SHUTDOWN = -4, + PORTINUSE = -5 + }; + +// ---------------------------------------------------------------------------------------- + + inline unsigned long square_root ( + unsigned long value + ) + /*! + requires + - value <= 2^32 - 1 + ensures + - returns the square root of value. if the square root is not an + integer then it will be rounded up to the nearest integer. + !*/ + { + unsigned long x; + + // set the initial guess for what the root is depending on + // how big value is + if (value < 3) + return value; + else if (value < 4096) // 12 + x = 45; + else if (value < 65536) // 16 + x = 179; + else if (value < 1048576) // 20 + x = 717; + else if (value < 16777216) // 24 + x = 2867; + else if (value < 268435456) // 28 + x = 11469; + else // 32 + x = 45875; + + + + // find the root + x = (x + value/x)>>1; + x = (x + value/x)>>1; + x = (x + value/x)>>1; + x = (x + value/x)>>1; + + + + if (x*x < value) + return x+1; + else + return x; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void median ( + T& one, + T& two, + T& three + ); + /*! + requires + - T implements operator< + - T is swappable by a global swap() + ensures + - #one is the median + - #one, #two, and #three is some permutation of one, two, and three. + !*/ + + + template < + typename T + > + void median ( + T& one, + T& two, + T& three + ) + { + using std::swap; + using dlib::swap; + + if ( one < two ) + { + // one < two + if ( two < three ) + { + // one < two < three : two + swap(one,two); + + } + else + { + // one < two >= three + if ( one < three) + { + // three + swap(three,one); + } + } + + } + else + { + // one >= two + if ( three < one ) + { + // three <= one >= two + if ( three < two ) + { + // two + swap(two,one); + } + else + { + // three + swap(three,one); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + namespace relational_operators + { + template < + typename A, + typename B + > + constexpr bool operator> ( + const A& a, + const B& b + ) { return b < a; } + + // --------------------------------- + + template < + typename A, + typename B + > + constexpr bool operator!= ( + const A& a, + const B& b + ) { return !(a == b); } + + // --------------------------------- + + template < + typename A, + typename B + > + constexpr bool operator<= ( + const A& a, + const B& b + ) { return !(b < a); } + + // --------------------------------- + + template < + typename A, + typename B + > + constexpr bool operator>= ( + const A& a, + const B& b + ) { return !(a < b); } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void exchange ( + T& a, + T& b + ) + /*! + This function does the exact same thing that global swap does and it does it by + just calling swap. But a lot of compilers have problems doing a Koenig Lookup + and the fact that this has a different name (global swap has the same name as + the member functions called swap) makes them compile right. + + So this is a workaround but not too ugly of one. But hopefully I can get + rid of this in a few years. So this function is already deprecated. + + This also means you should NOT use this function in your own code unless + you have to support an old buggy compiler that benefits from this hack. + !*/ + { + using std::swap; + using dlib::swap; + swap(a,b); + } + +// ---------------------------------------------------------------------------------------- + + struct general_ {}; + struct special_ : general_ {}; + template struct int_ { typedef int type; }; + +// ---------------------------------------------------------------------------------------- + + + /*!A is_same_object + + This is a templated function which checks if both of its arguments are actually + references to the same object. It returns true if they are and false otherwise. + + !*/ + + // handle the case where T and U are unrelated types. + template < typename T, typename U > + std::enable_if_t::value && !std::is_convertible::value, bool> + is_same_object ( + const T& a, + const U& b + ) + { + return ((void*)&a == (void*)&b); + } + + // handle the case where T and U are related types because their pointers can be + // implicitly converted into one or the other. E.g. a derived class and its base class. + // Or where both T and U are just the same type. This way we make sure that if there is a + // valid way to convert between these two pointer types then we will take that route rather + // than the void* approach used otherwise. + template < typename T, typename U > + std::enable_if_t::value || std::is_convertible::value, bool> + is_same_object ( + const T& a, + const U& b + ) + { + return (&a == &b); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class copy_functor + { + public: + void operator() ( + const T& source, + T& destination + ) const + { + destination = source; + } + }; + +// ---------------------------------------------------------------------------------------- + + /*!A static_switch + + To use this template you give it some number of boolean expressions and it + tells you which one of them is true. If more than one of them is true then + it causes a compile time error. + + for example: + static_switch<1 + 1 == 2, 4 - 1 == 4>::value == 1 // because the first expression is true + static_switch<1 + 1 == 3, 4 == 4>::value == 2 // because the second expression is true + static_switch<1 + 1 == 3, 4 == 5>::value == 0 // 0 here because none of them are true + static_switch<1 + 1 == 2, 4 == 4>::value == compiler error // because more than one expression is true + !*/ + + template < bool v1 = 0, bool v2 = 0, bool v3 = 0, bool v4 = 0, bool v5 = 0, + bool v6 = 0, bool v7 = 0, bool v8 = 0, bool v9 = 0, bool v10 = 0, + bool v11 = 0, bool v12 = 0, bool v13 = 0, bool v14 = 0, bool v15 = 0 > + struct static_switch; + + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 0; }; + template <> struct static_switch<1,0,0,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 1; }; + template <> struct static_switch<0,1,0,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 2; }; + template <> struct static_switch<0,0,1,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 3; }; + template <> struct static_switch<0,0,0,1,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 4; }; + template <> struct static_switch<0,0,0,0,1,0,0,0,0,0,0,0,0,0,0> { const static int value = 5; }; + template <> struct static_switch<0,0,0,0,0,1,0,0,0,0,0,0,0,0,0> { const static int value = 6; }; + template <> struct static_switch<0,0,0,0,0,0,1,0,0,0,0,0,0,0,0> { const static int value = 7; }; + template <> struct static_switch<0,0,0,0,0,0,0,1,0,0,0,0,0,0,0> { const static int value = 8; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,1,0,0,0,0,0,0> { const static int value = 9; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,1,0,0,0,0,0> { const static int value = 10; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,1,0,0,0,0> { const static int value = 11; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,1,0,0,0> { const static int value = 12; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,1,0,0> { const static int value = 13; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,0,1,0> { const static int value = 14; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,0,0,1> { const static int value = 15; }; + +// ---------------------------------------------------------------------------------------- + + template + std::enable_if_t::value, bool> is_finite(T value) + /*! + requires + - value must be some kind of scalar type such as int or double + ensures + - returns true if value is a finite value (e.g. not infinity or NaN) and false + otherwise. + !*/ + { + return std::isfinite(value); + } + + template + std::enable_if_t::value, bool> is_finite(T value) + { + return std::isfinite((double)value); + } + +// ---------------------------------------------------------------------------------------- + + /*!A promote + + This is a template that takes one of the built in scalar types and gives you another + scalar type that should be big enough to hold sums of values from the original scalar + type. The new scalar type will also always be signed. + + For example, promote::type == int32 + !*/ + + template struct promote; + template struct promote { typedef int32 type; }; + template struct promote { typedef int32 type; }; + template struct promote { typedef int64 type; }; + template struct promote { typedef int64 type; }; + + template <> struct promote { typedef double type; }; + template <> struct promote { typedef double type; }; + template <> struct promote { typedef long double type; }; + +// ---------------------------------------------------------------------------------------- + + /*!A assign_zero_if_built_in_scalar_type + + This function assigns its argument the value of 0 if it is a built in scalar + type according to the is_built_in_scalar_type<> template. If it isn't a + built in scalar type then it does nothing. + !*/ + + template inline typename disable_if,void>::type assign_zero_if_built_in_scalar_type (T&){} + template inline typename enable_if,void>::type assign_zero_if_built_in_scalar_type (T& a){a=0;} + +// ---------------------------------------------------------------------------------------- + + template + T put_in_range ( + const T& a, + const T& b, + const T& val + ) + /*! + requires + - T is a type that looks like double, float, int, or so forth + ensures + - if (val is within the range [a,b]) then + - returns val + - else + - returns the end of the range [a,b] that is closest to val + !*/ + { + if (a < b) + { + if (val < a) + return a; + else if (val > b) + return b; + } + else + { + if (val < b) + return b; + else if (val > a) + return a; + } + + return val; + } + + // overload for double + inline double put_in_range(const double& a, const double& b, const double& val) + { return put_in_range(a,b,val); } + +// ---------------------------------------------------------------------------------------- + + /*!A tabs + + This is a template to compute the absolute value a number at compile time. + + For example, + abs<-4>::value == 4 + abs<4>::value == 4 + !*/ + + template + struct tabs { const static long value = x; }; + template + struct tabs::type> { const static long value = -x; }; + +// ---------------------------------------------------------------------------------------- + + /*!A tmax + + This is a template to compute the max of two values at compile time + + For example, + abs<4,7>::value == 7 + !*/ + + template + struct tmax { const static long value = x; }; + template + struct tmax x)>::type> { const static long value = y; }; + +// ---------------------------------------------------------------------------------------- + + /*!A tmin + + This is a template to compute the min of two values at compile time + + For example, + abs<4,7>::value == 4 + !*/ + + template + struct tmin { const static long value = x; }; + template + struct tmin::type> { const static long value = y; }; + +// ---------------------------------------------------------------------------------------- + +#define DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(testname, returnT, funct_name, args) \ + struct _two_bytes_##testname { char a[2]; }; \ + template < typename T, returnT (T::*funct)args > \ + struct _helper_##testname { typedef char type; }; \ + template \ + static char _has_##testname##_helper( typename _helper_##testname::type ) { return 0;} \ + template \ + static _two_bytes_##testname _has_##testname##_helper(int) { return _two_bytes_##testname();} \ + template struct _##testname##workaroundbug { \ + const static unsigned long U = sizeof(_has_##testname##_helper('a')); }; \ + template ::U > \ + struct testname { static const bool value = false; }; \ + template \ + struct testname { static const bool value = true; }; + /*!A DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST + + The DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST() macro is used to define traits templates + that tell you if a class has a certain member function. For example, to make a + test to see if a class has a public method with the signature void print(int) you + would say: + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, (int)) + + Then you can check if a class, T, has this method by looking at the boolean value: + has_print::value + which will be true if the member function is in the T class. + + Note that you can test for member functions taking no arguments by simply passing + in empty () like so: + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, ()) + This would test for a member of the form: + void print(). + + To test for const member functions you would use a statement such as this: + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, ()const) + This would test for a member of the form: + void print() const. + + To test for const templated member functions you would use a statement such as this: + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, template print, ()) + This would test for a member of the form: + template void print(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template class funct_wrap0 + { + public: + funct_wrap0(T (&f_)()):f(f_){} + T operator()() const { return f(); } + private: + T (&f)(); + }; + template class funct_wrap1 + { + public: + funct_wrap1(T (&f_)(A0)):f(f_){} + T operator()(A0 a0) const { return f(a0); } + private: + T (&f)(A0); + }; + template class funct_wrap2 + { + public: + funct_wrap2(T (&f_)(A0,A1)):f(f_){} + T operator()(A0 a0, A1 a1) const { return f(a0,a1); } + private: + T (&f)(A0,A1); + }; + template class funct_wrap3 + { + public: + funct_wrap3(T (&f_)(A0,A1,A2)):f(f_){} + T operator()(A0 a0, A1 a1, A2 a2) const { return f(a0,a1,a2); } + private: + T (&f)(A0,A1,A2); + }; + template class funct_wrap4 + { + public: + funct_wrap4(T (&f_)(A0,A1,A2,A3)):f(f_){} + T operator()(A0 a0, A1 a1, A2 a2, A3 a3) const { return f(a0,a1,a2,a3); } + private: + T (&f)(A0,A1,A2,A3); + }; + template class funct_wrap5 + { + public: + funct_wrap5(T (&f_)(A0,A1,A2,A3,A4)):f(f_){} + T operator()(A0 a0, A1 a1, A2 a2, A3 a3, A4 a4) const { return f(a0,a1,a2,a3,a4); } + private: + T (&f)(A0,A1,A2,A3,A4); + }; + + /*!A wrap_function + + This is a template that allows you to turn a global function into a + function object. The reason for this template's existence is so you can + do stuff like this: + + template + void call_funct(const T& funct) + { cout << funct(); } + + std::string test() { return "asdfasf"; } + + int main() + { + call_funct(wrap_function(test)); + } + + The above code doesn't work right on some compilers if you don't + use wrap_function. + !*/ + + template + funct_wrap0 wrap_function(T (&f)()) { return funct_wrap0(f); } + template + funct_wrap1 wrap_function(T (&f)(A0)) { return funct_wrap1(f); } + template + funct_wrap2 wrap_function(T (&f)(A0, A1)) { return funct_wrap2(f); } + template + funct_wrap3 wrap_function(T (&f)(A0, A1, A2)) { return funct_wrap3(f); } + template + funct_wrap4 wrap_function(T (&f)(A0, A1, A2, A3)) { return funct_wrap4(f); } + template + funct_wrap5 wrap_function(T (&f)(A0, A1, A2, A3, A4)) { return funct_wrap5(f); } + +// ---------------------------------------------------------------------------------------- + + template + class stack_based_memory_block : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a simple container for a block of memory + of bSIZE bytes. This memory block is located on the stack + and properly aligned to hold any kind of object. + !*/ + public: + static const unsigned long size = bSIZE; + + stack_based_memory_block(): data(mem.data) {} + + void* get () { return data; } + /*! + ensures + - returns a pointer to the block of memory contained in this object + !*/ + + const void* get () const { return data; } + /*! + ensures + - returns a pointer to the block of memory contained in this object + !*/ + + private: + + // You obviously can't have a block of memory that has zero bytes in it. + COMPILE_TIME_ASSERT(bSIZE > 0); + + union mem_block + { + // All of this garbage is to make sure this union is properly aligned + // (a union is always aligned such that everything in it would be properly + // aligned. So the assumption here is that one of these objects has + // a large enough alignment requirement to satisfy any object this + // block of memory might be cast into). + void* void_ptr; + int integer; + struct { + void (stack_based_memory_block::*callback)(); + stack_based_memory_block* o; + } stuff; + long double more_stuff; + + uint64 var1; + uint32 var2; + double var3; + + char data[size]; + } mem; + + // The reason for having this variable is that doing it this way avoids + // warnings from gcc about violations of strict-aliasing rules. + void* const data; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename F + > + auto max_scoring_element( + const T& container, + F score_func + ) -> decltype(std::make_pair(*container.begin(), 0.0)) + /*! + requires + - container has .begin() and .end(), allowing it to be enumerated. + - score_func() is a function that takes an element of the container and returns a double. + ensures + - This function finds the element of container that has the largest score, + according to score_func(), and returns a std::pair containing that maximal + element along with the score. + - If the container is empty then make_pair(a default initialized object, -infinity) is returned. + !*/ + { + double best_score = -std::numeric_limits::infinity(); + auto best_i = container.begin(); + for (auto i = container.begin(); i != container.end(); ++i) + { + auto score = score_func(*i); + if (score > best_score) + { + best_score = score; + best_i = i; + } + } + + using item_type = typename std::remove_reference::type; + + if (best_i == container.end()) + return std::make_pair(item_type(), best_score); + else + return std::make_pair(*best_i, best_score); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename F + > + auto min_scoring_element( + const T& container, + F score_func + ) -> decltype(std::make_pair(*container.begin(), 0.0)) + /*! + requires + - container has .begin() and .end(), allowing it to be enumerated. + - score_func() is a function that takes an element of the container and returns a double. + ensures + - This function finds the element of container that has the smallest score, + according to score_func(), and returns a std::pair containing that minimal + element along with the score. + - If the container is empty then make_pair(a default initialized object, infinity) is returned. + !*/ + { + double best_score = std::numeric_limits::infinity(); + auto best_i = container.begin(); + for (auto i = container.begin(); i != container.end(); ++i) + { + auto score = score_func(*i); + if (score < best_score) + { + best_score = score; + best_i = i; + } + } + + using item_type = typename std::remove_reference::type; + + if (best_i == container.end()) + return std::make_pair(item_type(), best_score); + else + return std::make_pair(*best_i, best_score); + } + +// ---------------------------------------------------------------------------------------- + + namespace detail + { + template + constexpr void for_each_impl(Tuple&& t, F&& f, std::index_sequence) + { +#ifdef __cpp_fold_expressions + (std::forward(f)(std::get(std::forward(t))),...); +#else + (void)std::initializer_list{(std::forward(f)(std::get(std::forward(t))),0)...}; +#endif + } + } + + template + constexpr void for_each_in_tuple(Tuple&& t, F&& f) + { + detail::for_each_impl(std::forward(t), std::forward(f), + std::make_index_sequence>::value>{}); + } + +// ---------------------------------------------------------------------------------------- +} + +#endif // DLIB_ALGs_ + diff --git a/dlib/all/source.cpp b/dlib/all/source.cpp new file mode 100644 index 0000000000000000000000000000000000000000..04756d19cddb624a052513ff40617513b490702c --- /dev/null +++ b/dlib/all/source.cpp @@ -0,0 +1,100 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ALL_SOURCe_ +#define DLIB_ALL_SOURCe_ + +#if defined(DLIB_ALGs_) || defined(DLIB_PLATFORm_) +#include "../dlib_basic_cpp_build_tutorial.txt" +#endif + +// ISO C++ code +#include "../base64/base64_kernel_1.cpp" +#include "../bigint/bigint_kernel_1.cpp" +#include "../bigint/bigint_kernel_2.cpp" +#include "../bit_stream/bit_stream_kernel_1.cpp" +#include "../entropy_decoder/entropy_decoder_kernel_1.cpp" +#include "../entropy_decoder/entropy_decoder_kernel_2.cpp" +#include "../entropy_encoder/entropy_encoder_kernel_1.cpp" +#include "../entropy_encoder/entropy_encoder_kernel_2.cpp" +#include "../md5/md5_kernel_1.cpp" +#include "../tokenizer/tokenizer_kernel_1.cpp" +#include "../unicode/unicode.cpp" +#include "../test_for_odr_violations.cpp" + + + + +#ifndef DLIB_ISO_CPP_ONLY +// Code that depends on OS specific APIs + +// include this first so that it can disable the older version +// of the winsock API when compiled in windows. +#include "../sockets/sockets_kernel_1.cpp" +#include "../bsp/bsp.cpp" + +#include "../dir_nav/dir_nav_kernel_1.cpp" +#include "../dir_nav/dir_nav_kernel_2.cpp" +#include "../dir_nav/dir_nav_extensions.cpp" +#include "../linker/linker_kernel_1.cpp" +#include "../logger/extra_logger_headers.cpp" +#include "../logger/logger_kernel_1.cpp" +#include "../logger/logger_config_file.cpp" +#include "../misc_api/misc_api_kernel_1.cpp" +#include "../misc_api/misc_api_kernel_2.cpp" +#include "../sockets/sockets_extensions.cpp" +#include "../sockets/sockets_kernel_2.cpp" +#include "../sockstreambuf/sockstreambuf.cpp" +#include "../sockstreambuf/sockstreambuf_unbuffered.cpp" +#include "../server/server_kernel.cpp" +#include "../server/server_iostream.cpp" +#include "../server/server_http.cpp" +#include "../threads/multithreaded_object_extension.cpp" +#include "../threads/threaded_object_extension.cpp" +#include "../threads/threads_kernel_1.cpp" +#include "../threads/threads_kernel_2.cpp" +#include "../threads/threads_kernel_shared.cpp" +#include "../threads/thread_pool_extension.cpp" +#include "../threads/async.cpp" +#include "../timer/timer.cpp" +#include "../stack_trace.cpp" + +#ifdef DLIB_PNG_SUPPORT +#include "../image_loader/png_loader.cpp" +#include "../image_saver/save_png.cpp" +#endif + +#ifdef DLIB_JPEG_SUPPORT +#include "../image_loader/jpeg_loader.cpp" +#include "../image_saver/save_jpeg.cpp" +#endif + +#ifndef DLIB_NO_GUI_SUPPORT +#include "../gui_widgets/fonts.cpp" +#include "../gui_widgets/widgets.cpp" +#include "../gui_widgets/drawable.cpp" +#include "../gui_widgets/canvas_drawing.cpp" +#include "../gui_widgets/style.cpp" +#include "../gui_widgets/base_widgets.cpp" +#include "../gui_core/gui_core_kernel_1.cpp" +#include "../gui_core/gui_core_kernel_2.cpp" +#endif // DLIB_NO_GUI_SUPPORT + +#include "../cuda/cpu_dlib.cpp" +#include "../cuda/tensor_tools.cpp" +#include "../data_io/image_dataset_metadata.cpp" +#include "../data_io/mnist.cpp" +#include "../data_io/cifar.cpp" +#include "../svm/auto.cpp" +#include "../global_optimization/global_function_search.cpp" +#include "../filtering/kalman_filter.cpp" + +#endif // DLIB_ISO_CPP_ONLY + + + + + +#define DLIB_ALL_SOURCE_END + +#endif // DLIB_ALL_SOURCe_ + diff --git a/dlib/any.h b/dlib/any.h new file mode 100644 index 0000000000000000000000000000000000000000..01f047066782b037bbcbb58c6212bdb32fec1007 --- /dev/null +++ b/dlib/any.h @@ -0,0 +1,13 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AnY_ +#define DLIB_AnY_ + +#include "any/any.h" +#include "any/any_trainer.h" +#include "any/any_decision_function.h" +#include "any/any_function.h" + +#endif // DLIB_AnY_ + + diff --git a/dlib/any/any.h b/dlib/any/any.h new file mode 100644 index 0000000000000000000000000000000000000000..f2db32bcaf43a5c2a465ec6ff17f1441f1949cc1 --- /dev/null +++ b/dlib/any/any.h @@ -0,0 +1,72 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AnY_H_ +#define DLIB_AnY_H_ + +#include "any_abstract.h" +#include +#include "storage.h" + +namespace dlib +{ +// ---------------------------------------------------------------------------------------- + + class any + { + public: + any() = default; + any(const any& other) = default; + any& operator=(const any& other) = default; + any(any&& other) = default; + any& operator=(any&& other) = default; + + template< + typename T, + std::enable_if_t, any>::value, bool> = true + > + any(T&& item) + : storage{std::forward(item)} + { + } + + template< + typename T, + typename T_ = std::decay_t, + std::enable_if_t::value, bool> = true + > + any& operator=(T&& item) + { + if (contains()) + storage.unsafe_get() = std::forward(item); + else + *this = std::move(any{std::forward(item)}); + return *this; + } + + bool is_empty() const { return storage.is_empty(); } + void clear() { storage.clear(); } + void swap (any& item) { std::swap(*this, item); } + + template bool contains() const { return storage.contains();} + template T& cast_to() { return storage.cast_to(); } + template const T& cast_to() const { return storage.cast_to(); } + template T& get() { return storage.get(); } + + private: + te::storage_heap storage; + }; + +// ---------------------------------------------------------------------------------------- + + template T& any_cast(any& a) { return a.cast_to(); } + template const T& any_cast(const any& a) { return a.cast_to(); } + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_AnY_H_ + + + diff --git a/dlib/any/any_abstract.h b/dlib/any/any_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..3dcc069e1c28cbc4fa2bb6a22fb95eb53af76e2d --- /dev/null +++ b/dlib/any/any_abstract.h @@ -0,0 +1,209 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AnY_ABSTRACT_H_ +#ifdef DLIB_AnY_ABSTRACT_H_ + +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class bad_any_cast : public std::bad_cast + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is the exception class used by the any object. + It is used to indicate when someone attempts to cast an any + object into a type which isn't contained in the any object. + !*/ + + public: + virtual const char* what() const throw() { return "bad_any_cast"; } + }; + +// ---------------------------------------------------------------------------------------- + + class any + { + /*! + INITIAL VALUE + - is_empty() == true + - for all T: contains() == false + + WHAT THIS OBJECT REPRESENTS + This object is basically a type-safe version of a void*. In particular, + it is a container which can contain only one object but the object may + be of any type. + + It is somewhat like the type_safe_union except you don't have to declare + the set of possible content types beforehand. So in some sense this is + like a less type-strict version of the type_safe_union. + !*/ + + public: + + any( + ); + /*! + ensures + - this object is properly initialized + !*/ + + any ( + const any& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + any_function ( + any_function&& item + ); + /*! + ensures + - #item.is_empty() == true + - moves item into *this. + !*/ + + template < typename T > + any ( + const T& item + ); + /*! + ensures + - #contains() == true + - #cast_to() == item + (i.e. a copy of item will be stored in *this) + !*/ + + void clear ( + ); + /*! + ensures + - #*this will have its default value. I.e. #is_empty() == true + !*/ + + template + bool contains ( + ) const; + /*! + ensures + - if (this object currently contains an object of type T) then + - returns true + - else + - returns false + !*/ + + bool is_empty( + ) const; + /*! + ensures + - if (this object contains any kind of object) then + - returns false + - else + - returns true + !*/ + + template + T& cast_to( + ); + /*! + ensures + - if (contains() == true) then + - returns a non-const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + const T& cast_to( + ) const; + /*! + ensures + - if (contains() == true) then + - returns a const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + T& get( + ); + /*! + ensures + - #is_empty() == false + - #contains() == true + - if (contains() == true) + - returns a non-const reference to the object contained in *this. + - else + - Constructs an object of type T inside *this + - Any previous object stored in this any object is destructed and its + state is lost. + - returns a non-const reference to the newly created T object. + !*/ + + any& operator= ( + const any& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + void swap ( + any& item + ); + /*! + ensures + - swaps *this and item + - does not invalidate pointers or references to the object contained + inside *this or item. Moreover, a pointer or reference to the object in + *this will now refer to the contents of #item and vice versa. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + T& any_cast( + any& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const T& any_cast( + const any& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AnY_ABSTRACT_H_ + + diff --git a/dlib/any/any_decision_function.h b/dlib/any/any_decision_function.h new file mode 100644 index 0000000000000000000000000000000000000000..27eeb4ff35d94b68635c8d3777b622994fdd0c56 --- /dev/null +++ b/dlib/any/any_decision_function.h @@ -0,0 +1,23 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AnY_DECISION_FUNCTION_Hh_ +#define DLIB_AnY_DECISION_FUNCTION_Hh_ + +#include "any_decision_function_abstract.h" +#include "any_function.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + using any_decision_function = any_function; + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_AnY_DECISION_FUNCTION_Hh_ diff --git a/dlib/any/any_decision_function_abstract.h b/dlib/any/any_decision_function_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..756690ccfa15eba7ca5d878d3afc803ba33b6c6f --- /dev/null +++ b/dlib/any/any_decision_function_abstract.h @@ -0,0 +1,24 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AnY_DECISION_FUNCTION_ABSTRACT_H_ +#ifdef DLIB_AnY_DECISION_FUNCTION_ABSTRACT_H_ + +#include "any_function_abstract.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + using any_decision_function = any_function; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AnY_DECISION_FUNCTION_ABSTRACT_H_ + + + diff --git a/dlib/any/any_function.h b/dlib/any/any_function.h new file mode 100644 index 0000000000000000000000000000000000000000..b6c9a590f8f505be5dec6ebeff376bd900497411 --- /dev/null +++ b/dlib/any/any_function.h @@ -0,0 +1,126 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AnY_FUNCTION_Hh_ +#define DLIB_AnY_FUNCTION_Hh_ + +#include "../assert.h" +#include "../functional.h" +#include "any.h" +#include "any_function_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + class Storage, + class F + > + class any_function_basic; + + template < + class Storage, + class R, + class... Args + > + class any_function_basic + { + private: + template + using is_valid = std::enable_if_t, any_function_basic>::value && + dlib::is_invocable_r::value, + bool>; + + template + static auto make_invoker() + { + return [](void* self, Args... args) -> R { + return dlib::invoke(*reinterpret_cast>(self), + std::forward(args)...); + }; + } + + Storage str; + R (*func)(void*, Args...) = nullptr; + + public: + + using result_type = R; + + constexpr any_function_basic(std::nullptr_t) noexcept {} + constexpr any_function_basic() = default; + constexpr any_function_basic(const any_function_basic& other) = default; + constexpr any_function_basic& operator=(const any_function_basic& other) = default; + + constexpr any_function_basic(any_function_basic&& other) + : str{std::move(other.str)}, + func{std::exchange(other.func, nullptr)} + { + } + + constexpr any_function_basic& operator=(any_function_basic&& other) + { + if (this != &other) + { + str = std::move(other.str); + func = std::exchange(other.func, nullptr); + } + + return *this; + } + + template = true> + any_function_basic( + F&& f + ) : str{std::forward(f)}, + func{make_invoker()} + { + } + + template = true> + any_function_basic( + F* f + ) : str{f}, + func{make_invoker()} + { + } + + R operator()(Args... args) const { + return func(const_cast(str.get_ptr()), std::forward(args)...); + } + + void clear() { str.clear(); } + void swap (any_function_basic& item) { std::swap(*this, item); } + bool is_empty() const noexcept { return str.is_empty() || func == nullptr; } + bool is_set() const noexcept { return !is_empty(); } + explicit operator bool() const noexcept { return is_set(); } + + template bool contains() const { return str.template contains();} + template T& cast_to() { return str.template cast_to(); } + template const T& cast_to() const { return str.template cast_to(); } + template T& get() { return str.template get(); } + }; + +// ---------------------------------------------------------------------------------------- + + template + T& any_cast(any_function_basic& a) { return a.template cast_to(); } + + template + const T& any_cast(const any_function_basic& a) { return a.template cast_to(); } + +// ---------------------------------------------------------------------------------------- + + template + using any_function = any_function_basic, F>; + + template + using any_function_view = any_function_basic; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AnY_FUNCTION_Hh_ + diff --git a/dlib/any/any_function_abstract.h b/dlib/any/any_function_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..8c0da073659a394a6bb2ae3d4e69ccf54af6b099 --- /dev/null +++ b/dlib/any/any_function_abstract.h @@ -0,0 +1,274 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AnY_FUNCTION_ABSTRACT_H_ +#ifdef DLIB_AnY_FUNCTION_ABSTRACT_H_ + +#include "any_abstract.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + class Storage, + typename function_type + > + class any_function_basic + { + /*! + REQUIREMENTS ON Storage + This must be one of the storage types from dlib/any/storage.hh + E.g. storage_heap, storage_stack, etc. + + It determines the method by which any_function_basic holds onto the function it uses. + + REQUIREMENTS ON function_type + This type should be a function signature. Some examples are: + void (int,int) // a function returning nothing and taking two ints + void () // a function returning nothing and taking no arguments + char (double&) // a function returning a char and taking a reference to a double + + The number of arguments in the function must be no greater than 10. + + INITIAL VALUE + - is_empty() == true + - for all T: contains() == false + + WHAT THIS OBJECT REPRESENTS + This object is a version of dlib::any that is restricted to containing + elements which are some kind of function object with an operator() which + matches the function signature defined by function_type. + + + Here is an example: + #include + #include + #include "dlib/any.h" + using namespace std; + void print_message(string str) { cout << str << endl; } + + int main() + { + dlib::any_function f; + f = print_message; + f("hello world"); // calls print_message("hello world") + } + + Note that any_function_basic objects can be used to store general function + objects (i.e. defined by a class with an overloaded operator()) in + addition to regular global functions. + !*/ + + public: + + // This is the type of object returned by function_type functions. + typedef result_type_for_function_type result_type; + + any_function_basic( + ); + /*! + ensures + - this object is properly initialized + !*/ + + any_function_basic ( + const any_function_basic& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + any_function_basic ( + any_function_basic&& item + ); + /*! + ensures + - moves item into *this. + - The exact move semantics are determined by which Storage type is used. E.g. + storage_heap will result in #item.is_empty()==true but storage_view would result + in #item.is_empty() == false + !*/ + + template < typename Funct > + any_function_basic ( + Funct&& funct + ); + /*! + ensures + - #contains() == true + - #cast_to() == item + (i.e. calling operator() will invoke funct()) + !*/ + + void clear ( + ); + /*! + ensures + - #*this will have its default value. I.e. #is_empty() == true + !*/ + + template + bool contains ( + ) const; + /*! + ensures + - if (this object currently contains an object of type T) then + - returns true + - else + - returns false + !*/ + + bool is_empty( + ) const; + /*! + ensures + - if (this object contains any kind of object) then + - returns false + - else + - returns true + !*/ + + bool is_set ( + ) const; + /*! + ensures + - returns !is_empty() + !*/ + + explicit operator bool( + ) const; + /*! + ensures + - returns is_set() + !*/ + + result_type operator(Args... args) ( + ) const; + /*! + requires + - is_empty() == false + - the signature defined by function_type takes no arguments + ensures + - Let F denote the function object contained within *this. Then + this function performs: + return F(std::forward(args)...) + !*/ + + template + T& cast_to( + ); + /*! + ensures + - if (contains() == true) then + - returns a non-const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + const T& cast_to( + ) const; + /*! + ensures + - if (contains() == true) then + - returns a const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + T& get( + ); + /*! + ensures + - #is_empty() == false + - #contains() == true + - if (contains() == true) + - returns a non-const reference to the object contained in *this. + - else + - Constructs an object of type T inside *this + - Any previous object stored in this any_function_basic object is destructed and its + state is lost. + - returns a non-const reference to the newly created T object. + !*/ + + any_function_basic& operator= ( + const any_function_basic& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that the type of copy is determined by the Storage template argument. E.g. + storage_sbo will result in a deep copy, while storage_view would result in *this + and item referring to the same underlying function. + !*/ + + void swap ( + any_function_basic& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename function_type + > + T& any_cast( + any_function_basic& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename function_type + > + const T& any_cast( + const any_function_basic& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + + /*!A any_function + + A version of any_function_basic (defined above) that owns the function it contains. Uses + the small buffer optimization to make working with small lambdas faster. + !*/ + template + using any_function = any_function_basic, F>; + + /*!A any_function_view + + A version of any_function_basic (defined above) that *DOES NOT* own the function it + contains. It merely holds a pointer to the function given to its constructor. + !*/ + template + using any_function_view = any_function_basic; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AnY_FUNCTION_ABSTRACT_H_ + diff --git a/dlib/any/any_trainer.h b/dlib/any/any_trainer.h new file mode 100644 index 0000000000000000000000000000000000000000..cda7e02890d687d3a4b2d323729ebc158b430c10 --- /dev/null +++ b/dlib/any/any_trainer.h @@ -0,0 +1,119 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AnY_TRAINER_H_ +#define DLIB_AnY_TRAINER_H_ + +#include "any.h" +#include "any_decision_function.h" +#include "any_trainer_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type_, + typename scalar_type_ = double + > + class any_trainer + { + public: + using sample_type = sample_type_; + using scalar_type = scalar_type_; + using mem_manager_type = default_memory_manager; + using trained_function_type = any_decision_function; + + any_trainer() = default; + any_trainer(const any_trainer& other) = default; + any_trainer& operator=(const any_trainer& other) = default; + any_trainer(any_trainer&& other) = default; + any_trainer& operator=(any_trainer&& other) = default; + + template < + class T, + class T_ = std::decay_t, + std::enable_if_t::value, bool> = true + > + any_trainer ( + T&& item + ) : storage{std::forward(item)}, + train_func{[]( + const void* ptr, + const std::vector& samples, + const std::vector& labels + ) -> trained_function_type { + const T_& f = *reinterpret_cast(ptr); + return f.train(samples, labels); + }} + { + } + + template < + class T, + class T_ = std::decay_t, + std::enable_if_t::value, bool> = true + > + any_trainer& operator= ( + T&& item + ) + { + if (contains()) + storage.unsafe_get() = std::forward(item); + else + *this = std::move(any_trainer{std::forward(item)}); + return *this; + } + + trained_function_type train ( + const std::vector& samples, + const std::vector& labels + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_empty() == false, + "\t trained_function_type any_trainer::train()" + << "\n\t You can't call train() on an empty any_trainer" + << "\n\t this: " << this + ); + + return train_func(storage.get_ptr(), samples, labels); + } + + bool is_empty() const { return storage.is_empty(); } + void clear() { storage.clear(); } + void swap (any_trainer& item) { std::swap(*this, item); } + + template bool contains() const { return storage.contains();} + template T& cast_to() { return storage.cast_to(); } + template const T& cast_to() const { return storage.cast_to(); } + template T& get() { return storage.get(); } + + private: + te::storage_heap storage; + trained_function_type (*train_func) ( + const void* self, + const std::vector& samples, + const std::vector& labels + ) = nullptr; + }; + +// ---------------------------------------------------------------------------------------- + + template + T& any_cast(any_trainer& a) { return a.template cast_to(); } + + template + const T& any_cast(const any_trainer& a) { return a.template cast_to(); } + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_AnY_TRAINER_H_ + + + + diff --git a/dlib/any/any_trainer_abstract.h b/dlib/any/any_trainer_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..30f013c00eb8d145d7e058b4698972ace6dc903e --- /dev/null +++ b/dlib/any/any_trainer_abstract.h @@ -0,0 +1,243 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AnY_TRAINER_ABSTRACT_H_ +#ifdef DLIB_AnY_TRAINER_ABSTRACT_H_ + +#include "any_abstract.h" +#include "../algs.h" +#include "any_decision_function_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type_, + typename scalar_type_ = double + > + class any_trainer + { + /*! + INITIAL VALUE + - is_empty() == true + - for all T: contains() == false + + WHAT THIS OBJECT REPRESENTS + This object is a version of dlib::any that is restricted to containing + elements which are some kind of object with a .train() method compatible + with the following signature: + + decision_function train( + const std::vector& samples, + const std::vector& labels + ) const + + Where decision_function is a type capable of being stored in an + any_decision_function object. + + any_trainer is intended to be used to contain objects such as the svm_nu_trainer + and other similar types which represent supervised machine learning algorithms. + It allows you to write code which contains and processes these trainer objects + without needing to know the specific types of trainer objects used. + !*/ + + public: + + typedef sample_type_ sample_type; + typedef scalar_type_ scalar_type; + typedef default_memory_manager mem_manager_type; + typedef any_decision_function trained_function_type; + + any_trainer( + ); + /*! + ensures + - this object is properly initialized + !*/ + + any_trainer ( + const any_trainer& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + any_trainer ( + any_trainer&& item + ); + /*! + ensures + - #item.is_empty() == true + - moves item into *this. + !*/ + + template < typename T > + any_trainer ( + const T& item + ); + /*! + ensures + - #contains() == true + - #cast_to() == item + (i.e. a copy of item will be stored in *this) + !*/ + + void clear ( + ); + /*! + ensures + - #*this will have its default value. I.e. #is_empty() == true + !*/ + + template + bool contains ( + ) const; + /*! + ensures + - if (this object currently contains an object of type T) then + - returns true + - else + - returns false + !*/ + + bool is_empty( + ) const; + /*! + ensures + - if (this object contains any kind of object) then + - returns false + - else + - returns true + !*/ + + trained_function_type train ( + const std::vector& samples, + const std::vector& labels + ) const + /*! + requires + - is_empty() == false + ensures + - Let TRAINER denote the object contained within *this. Then + this function performs: + return TRAINER.train(samples, labels) + !*/ + + template + T& cast_to( + ); + /*! + ensures + - if (contains() == true) then + - returns a non-const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + const T& cast_to( + ) const; + /*! + ensures + - if (contains() == true) then + - returns a const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + T& get( + ); + /*! + ensures + - #is_empty() == false + - #contains() == true + - if (contains() == true) + - returns a non-const reference to the object contained in *this. + - else + - Constructs an object of type T inside *this + - Any previous object stored in this any_trainer object is destructed and its + state is lost. + - returns a non-const reference to the newly created T object. + !*/ + + any_trainer& operator= ( + const any_trainer& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + void swap ( + any_trainer& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type, + typename scalar_type + > + inline void swap ( + any_trainer& a, + any_trainer& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename sample_type, + typename scalar_type + > + T& any_cast( + any_trainer& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename sample_type, + typename scalar_type + > + const T& any_cast( + const any_trainer& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AnY_TRAINER_ABSTRACT_H_ + + diff --git a/dlib/any/storage.h b/dlib/any/storage.h new file mode 100644 index 0000000000000000000000000000000000000000..6e3a073ca6494c6c62e5153efcc7d2c1a2c86e79 --- /dev/null +++ b/dlib/any/storage.h @@ -0,0 +1,966 @@ +// Copyright (C) 2022 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TYPE_ERASURE_H_ +#define DLIB_TYPE_ERASURE_H_ + +#include +#include +#include +#include +#include + +namespace dlib +{ + +// ----------------------------------------------------------------------------------------------------- + + class bad_any_cast : public std::bad_cast + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is the exception class used by the storage objects. + It is used to indicate when someone attempts to cast a storage + object into a type which isn't contained in the object. + !*/ + + public: + virtual const char * what() const throw() + { + return "bad_any_cast"; + } + }; + +// ----------------------------------------------------------------------------------------------------- + + namespace te + { + + /*! + This is used as a SFINAE tool to prevent a function taking a universal reference from + binding to some undesired type. For example: + template < + typename T, + T_is_not_this_type = true + > + void foo(T&&); + prevents foo() from binding to an object of type SomeExcludedType. + !*/ + template + using T_is_not_this_type = std::enable_if_t, Storage>::value, bool>; + +// ----------------------------------------------------------------------------------------------------- + + template + class storage_base + { + /*! + WHAT THIS OBJECT REPRESENTS + This class defines functionality common to all type erasure storage objects + (defined below in this file). These objects are essentially type-safe versions of + a void*. In particular, they are containers which can contain only one object + but the object may be of any type. + + Each storage object implements a different way of storing the underlying object. + E.g. on the heap or stack or some other more specialized method. + !*/ + + public: + + bool is_empty() const + /*! + ensures + - if (this object contains any kind of object) then + - returns false + - else + - returns true + !*/ + { + const Storage& me = *static_cast(this); + return me.get_ptr() == nullptr; + } + + template + bool contains() const + /*! + ensures + - if (this object currently contains an object of type T) then + - returns true + - else + - returns false + !*/ + { + const Storage& me = *static_cast(this); + return !is_empty() && me.type_id() == std::type_index{typeid(T)}; + } + + template + T& unsafe_get() + /*! + requires + - contains() == true + ensures + - returns a reference to the object contained within *this. + !*/ + { + Storage& me = *static_cast(this); + return *reinterpret_cast(me.get_ptr()); + } + + template + const T& unsafe_get() const + /*! + requires + - contains() == true + ensures + - returns a const reference to the object contained within *this. + !*/ + { + const Storage& me = *static_cast(this); + return *reinterpret_cast(me.get_ptr()); + } + + template + T& get( + ) + /*! + ensures + - #is_empty() == false + - #contains() == true + - if (contains() == true) + - returns a non-const reference to the object contained in *this. + - else + - Constructs an object of type T inside *this + - Any previous object stored in this any object is destructed and its + state is lost. + - returns a non-const reference to the newly created T object. + !*/ + { + Storage& me = *static_cast(this); + + if (!contains()) + me = T{}; + return unsafe_get(); + } + + template + T& cast_to( + ) + /*! + ensures + - if (contains() == true) then + - returns a non-const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + { + if (!contains()) + throw bad_any_cast{}; + return unsafe_get(); + } + + template + const T& cast_to( + ) const + /*! + ensures + - if (contains() == true) then + - returns a const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + { + if (!contains()) + throw bad_any_cast{}; + return unsafe_get(); + } + }; + +// ----------------------------------------------------------------------------------------------------- + + class storage_heap : public storage_base + { + public: + /*! + WHAT THIS OBJECT REPRESENTS + This object is a storage type that uses type erasure to erase any type. + + This particular storage type uses heap allocation only. + !*/ + + storage_heap() = default; + /*! + ensures + - #is_empty() == true + - for all T: #contains() == false + !*/ + + template < + class T, + class T_ = std::decay_t, + T_is_not_this_type = true + > + storage_heap(T &&t) noexcept(std::is_nothrow_constructible::value) + /*! + ensures + - copies or moves the incoming object (depending on the forwarding reference) + - #is_empty() == false + - #contains>() == true + - #unsafe_get() will yield the provided t. + !*/ + : ptr{new T_{std::forward(t)}}, + del{[](void *self) { + delete reinterpret_cast(self); + }}, + copy{[](const void *self) -> void * { + return new T_{*reinterpret_cast(self)}; + }}, + type_id_{[] { + return std::type_index{typeid(T_)}; + }} + { + } + + storage_heap(const storage_heap& other) + /*! + ensures + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - underlying object of other is copied using erased type's copy constructor. + !*/ + : ptr{other.ptr ? other.copy(other.ptr) : nullptr}, + del{other.del}, + copy{other.copy}, + type_id_{other.type_id_} + { + } + + storage_heap& operator=(const storage_heap& other) + /*! + ensures + - if is_empty() == false then + - destructs the object contained in this class. + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - underlying object of other is copied using erased type's copy constructor. + !*/ + { + if (this != &other) + *this = std::move(storage_heap{other}); + return *this; + } + + storage_heap(storage_heap&& other) noexcept + /*! + ensures + - The state of other is moved into *this. + - #other.is_empty() == true + !*/ + : ptr{std::exchange(other.ptr, nullptr)}, + del{std::exchange(other.del, nullptr)}, + copy{std::exchange(other.copy, nullptr)}, + type_id_{std::exchange(other.type_id_, nullptr)} + { + } + + storage_heap& operator=(storage_heap&& other) noexcept + /*! + ensures + - The state of other is moved into *this. + - #other.is_empty() == true + - returns *this + !*/ + { + if (this != &other) + { + clear(); + ptr = std::exchange(other.ptr, nullptr); + del = std::exchange(other.del, nullptr); + copy = std::exchange(other.copy, nullptr); + type_id_ = std::exchange(other.type_id_, nullptr); + } + return *this; + } + + ~storage_heap() + /*! + ensures + - destructs the object contained in *this if one exists. + !*/ + { + if (ptr) + del(ptr); + } + + void clear() + /*! + ensures + - #is_empty() == true + !*/ + { + storage_heap{std::move(*this)}; + } + + void* get_ptr() + /*! + ensures + - returns a pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr; + } + + const void* get_ptr() const + /*! + ensures + - returns a const pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr; + } + + std::type_index type_id() const + /*! + requires + - is_empty() == false + ensures + - returns the std::type_index of the type contained within this object. + I.e. if this object contains the type T then this returns std::type_index{typeid(T)}. + !*/ + { + return type_id_(); + } + + private: + void* ptr = nullptr; + void (*del)(void*) = nullptr; + void* (*copy)(const void*) = nullptr; + std::type_index (*type_id_)() = nullptr; + }; + +// ----------------------------------------------------------------------------------------------------- + + template + class storage_stack : public storage_base> + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a storage type that uses type erasure to erase any type. + + This particular storage type uses stack allocation using a template size and alignment. + Therefore, only objects whose size and alignment fits the template parameters can be + erased and absorbed into this object. Attempting to store a type not + representable on the stack with those settings will result in a build error. + !*/ + + public: + storage_stack() = default; + /*! + ensures + - #is_empty() == true + - for all T: #contains() == false + !*/ + + template < + class T, + class T_ = std::decay_t, + T_is_not_this_type = true + > + storage_stack(T &&t) noexcept(std::is_nothrow_constructible::value) + /*! + ensures + - copies or moves the incoming object (depending on the forwarding reference) + - #is_empty() == false + - #contains>() == true + !*/ + : del{[](storage_stack& self) { + reinterpret_cast(&self.data)->~T_(); + self.del = nullptr; + self.copy = nullptr; + self.move = nullptr; + self.type_id_ = nullptr; + }}, + copy{[](const storage_stack& src, storage_stack& dst) { + new (&dst.data) T_{*reinterpret_cast(&src.data)}; + dst.del = src.del; + dst.copy = src.copy; + dst.move = src.move; + dst.type_id_ = src.type_id_; + }}, + move{[](storage_stack& src, storage_stack& dst) { + new (&dst.data) T_{std::move(*reinterpret_cast(&src.data))}; + dst.del = src.del; + dst.copy = src.copy; + dst.move = src.move; + dst.type_id_ = src.type_id_; + }}, + type_id_{[] { + return std::type_index{typeid(T_)}; + }} + { + static_assert(sizeof(T_) <= Size, "insufficient size"); + static_assert(Alignment % alignof(T_) == 0, "bad alignment"); + new (&data) T_{std::forward(t)}; + } + + storage_stack(const storage_stack& other) + /*! + ensures + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - underlying object of other is copied using erased type's copy constructor. + !*/ + { + if (other.copy) + other.copy(other, *this); + } + + storage_stack& operator=(const storage_stack& other) + /*! + ensures + - #is_empty() == other.is_empty() + - if is_empty() == false then + - destructs the object contained in this class. + - if other.is_empty() == false then + - underlying object of other is copied using erased type's copy constructor + !*/ + { + if (this != &other) + { + clear(); + if (other.copy) + other.copy(other, *this); + } + return *this; + } + + storage_stack(storage_stack&& other) + /*! + ensures + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - underlying object of other is moved using erased type's moved constructor + !*/ + { + if (other.move) + other.move(other, *this); + } + + storage_stack& operator=(storage_stack&& other) + /*! + ensures + - if is_empty() == false then + - destructs the object contained in this class. + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - underlying object of other is moved using erased type's moved constructor. + This does not make other empty. It will still contain a moved from object + of the underlying type in whatever that object's moved from state is. + - #other.is_empty() == false + !*/ + { + if (this != &other) + { + clear(); + if (other.move) + other.move(other, *this); + } + return *this; + } + + ~storage_stack() + /*! + ensures + - destructs the object contained in *this if one exists. + !*/ + { + clear(); + } + + void clear() + /*! + ensures + - #is_empty() == true + !*/ + { + if (del) + del(*this); + } + + void* get_ptr() + /*! + ensures + - returns a pointer to the underlying object or nullptr if is_empty() + !*/ + { + return del ? (void*)&data : nullptr; + } + + const void* get_ptr() const + /*! + ensures + - returns a const pointer to the underlying object or nullptr if is_empty() + !*/ + { + return del ? (const void*)&data : nullptr; + } + + std::type_index type_id() const + /*! + requires + - is_empty() == false + ensures + - returns the std::type_index of the type contained within this object. + I.e. if this object contains the type T then this returns std::type_index{typeid(T)}. + !*/ + { + return type_id_(); + } + + private: + std::aligned_storage_t data; + void (*del)(storage_stack&) = nullptr; + void (*copy)(const storage_stack&, storage_stack&) = nullptr; + void (*move)(storage_stack&, storage_stack&) = nullptr; + std::type_index (*type_id_)() = nullptr; + }; + +// ----------------------------------------------------------------------------------------------------- + + template + class storage_sbo : public storage_base> + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a storage type that uses type erasure to erase any type. + + This particular storage type uses small buffer optimization (SBO), i.e. optional + stack allocation if the erased type has sizeof <= Size and alignment + requirements no greater than the given Alignment template value. If not it + allocates the object on the heap. + !*/ + + public: + // type_fits::value tells us if our SBO can hold T. + template + struct type_fits : std::integral_constant{}; + + storage_sbo() = default; + /*! + ensures + - #is_empty() == true + - for all T: #contains() == false + !*/ + + template < + class T, + class T_ = std::decay_t, + T_is_not_this_type = true, + std::enable_if_t::value, bool> = true + > + storage_sbo(T &&t) noexcept(std::is_nothrow_constructible::value) + /*! + ensures + - copies or moves the incoming object (depending on the forwarding reference) + - #is_empty() == false + - #contains>() == true + - stack allocation is used + !*/ + : ptr{new (&data) T_{std::forward(t)}}, + del{[](storage_sbo& self) { + reinterpret_cast(&self.data)->~T_(); + self.ptr = nullptr; + self.del = nullptr; + self.copy = nullptr; + self.move = nullptr; + self.type_id_ = nullptr; + }}, + copy{[](const storage_sbo& src, storage_sbo& dst) { + dst.ptr = new (&dst.data) T_{*reinterpret_cast(src.ptr)}; + dst.del = src.del; + dst.copy = src.copy; + dst.move = src.move; + dst.type_id_ = src.type_id_; + }}, + move{[](storage_sbo& src, storage_sbo& dst) { + dst.ptr = new (&dst.data) T_{std::move(*reinterpret_cast(src.ptr))}; + dst.del = src.del; + dst.copy = src.copy; + dst.move = src.move; + dst.type_id_ = src.type_id_; + }}, + type_id_{[] { + return std::type_index{typeid(T_)}; + }} + { + } + + template < + class T, + class T_ = std::decay_t, + T_is_not_this_type = true, + std::enable_if_t::value, bool> = true + > + storage_sbo(T &&t) noexcept(std::is_nothrow_constructible::value) + /*! + ensures + - copies or moves the incoming object (depending on the forwarding reference) + - #is_empty() == false + - #contains>() == true + - heap allocation is used + !*/ + : ptr{new T_{std::forward(t)}}, + del{[](storage_sbo& self) { + delete reinterpret_cast(self.ptr); + self.ptr = nullptr; + self.del = nullptr; + self.copy = nullptr; + self.move = nullptr; + self.type_id_ = nullptr; + }}, + copy{[](const storage_sbo& src, storage_sbo& dst) { + dst.ptr = new T_{*reinterpret_cast(src.ptr)}; + dst.del = src.del; + dst.copy = src.copy; + dst.move = src.move; + dst.type_id_ = src.type_id_; + }}, + move{[](storage_sbo& src, storage_sbo& dst) { + dst.ptr = std::exchange(src.ptr, nullptr); + dst.del = std::exchange(src.del, nullptr); + dst.copy = std::exchange(src.copy, nullptr); + dst.move = std::exchange(src.move, nullptr); + dst.type_id_ = std::exchange(src.type_id_, nullptr); + }}, + type_id_{[] { + return std::type_index{typeid(T_)}; + }} + { + } + + storage_sbo(const storage_sbo& other) + /*! + ensures + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - underlying object of other is copied using erased type's copy constructor + !*/ + { + if (other.copy) + other.copy(other, *this); + } + + storage_sbo& operator=(const storage_sbo& other) + /*! + ensures + - if is_empty() == false then + - destructs the object contained in this class. + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - underlying object of other is copied using erased type's copy constructor + !*/ + { + if (this != &other) + { + clear(); + if (other.copy) + other.copy(other, *this); + } + return *this; + } + + storage_sbo(storage_sbo&& other) + /*! + ensures + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - if underlying object of other is allocated on stack then + - underlying object of other is moved using erased type's moved constructor. + This does not make other empty. It will still contain a moved from + object of the underlying type in whatever that object's moved from + state is. + - #other.is_empty() == false + - else + - storage heap pointer is moved. + - #other.is_empty() == true + !*/ + { + if (other.move) + other.move(other, *this); + } + + storage_sbo& operator=(storage_sbo&& other) + /*! + ensures + - underlying object is destructed if is_empty() == false + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - if underlying object of other is allocated on stack then + - underlying object of other is moved using erased type's moved constructor. + This does not make other empty. It will still contain a moved from + object of the underlying type in whatever that object's moved from + state is. + - #other.is_empty() == false + - else + - storage heap pointer is moved. + - #other.is_empty() == true + !*/ + { + if (this != &other) + { + clear(); + if (other.move) + other.move(other, *this); + } + return *this; + } + + ~storage_sbo() + /*! + ensures + - destructs the object contained in *this if one exists. + !*/ + { + clear(); + } + + void clear() + /*! + ensures + - #is_empty() == true + !*/ + { + if (ptr) + del(*this); + } + + void* get_ptr() + /*! + ensures + - returns a pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr; + } + + const void* get_ptr() const + /*! + ensures + - returns a const pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr; + } + + std::type_index type_id() const + /*! + requires + - is_empty() == false + ensures + - returns the std::type_index of the type contained within this object. + I.e. if this object contains the type T then this returns std::type_index{typeid(T)}. + !*/ + { + return type_id_(); + } + + private: + std::aligned_storage_t data; + void* ptr = nullptr; + void (*del)(storage_sbo&) = nullptr; + void (*copy)(const storage_sbo&, storage_sbo&) = nullptr; + void (*move)(storage_sbo&, storage_sbo&) = nullptr; + std::type_index (*type_id_)() = nullptr; + }; + +// ----------------------------------------------------------------------------------------------------- + + class storage_shared : public storage_base + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a storage type that uses type erasure to erase any type. + + This particular storage type uses std::shared_ptr to store and erase + incoming objects. Therefore, it uses heap allocation and reference counting. + Moreover, it has the same copying and move semantics as std::shared_ptr. I.e. + it results in the underlying object being held by reference rather than by + value. + !*/ + + public: + storage_shared() = default; + /*! + ensures + - #is_empty() == true + - for all T: #contains() == false + !*/ + + template < + class T, + class T_ = std::decay_t, + T_is_not_this_type = true + > + storage_shared(T &&t) noexcept(std::is_nothrow_constructible::value) + /*! + ensures + - copies or moves the incoming object (depending on the forwarding reference) + - #is_empty() == true + - #contains>() == true + !*/ + : ptr{std::make_shared(std::forward(t))}, + type_id_{[] { + return std::type_index{typeid(T_)}; + }} + { + } + + // This object has the same copy/move semantics as a std::shared_ptr + storage_shared(const storage_shared& other) = default; + storage_shared& operator=(const storage_shared& other) = default; + storage_shared(storage_shared&& other) noexcept = default; + storage_shared& operator=(storage_shared&& other) noexcept = default; + + void clear() + /*! + ensures + - #is_empty() == true + !*/ + { + ptr = nullptr; + type_id_ = nullptr; + } + + void* get_ptr() + /*! + ensures + - returns a pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr.get(); + } + + const void* get_ptr() const + /*! + ensures + - returns a const pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr.get(); + } + + std::type_index type_id() const + /*! + requires + - is_empty() == false + ensures + - returns the std::type_index of the type contained within this object. + I.e. if this object contains the type T then this returns std::type_index{typeid(T)}. + !*/ + { + return type_id_(); + } + + private: + std::shared_ptr ptr = nullptr; + std::type_index (*type_id_)() = nullptr; + }; + +// ----------------------------------------------------------------------------------------------------- + + class storage_view : public storage_base + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a storage type that uses type erasure to erase any type. + + This particular storage type is a view type, similar to std::string_view or + std::span. So underlying objects are only ever referenced, not copied, moved or + destructed. That is, instances of this object take no ownership of the objects + they contain. So they are only valid as long as the contained object exists. + So storage_view merely holds a pointer to the underlying object. + !*/ + + public: + storage_view() = default; + /*! + ensures + - #is_empty() == true + - for all T: #contains() == false + !*/ + + template < + class T, + class T_ = std::decay_t, + T_is_not_this_type = true + > + storage_view(T &&t) noexcept + /*! + ensures + - #get_ptr() == &t + - #is_empty() == false + - #contains>() == true + !*/ + : ptr{&t}, + type_id_{[] { + return std::type_index{typeid(T_)}; + }} + { + } + + // This object has the same copy/move semantics as a void*. + storage_view(const storage_view& other) = default; + storage_view& operator=(const storage_view& other) = default; + storage_view(storage_view&& other) noexcept = default; + storage_view& operator=(storage_view&& other) noexcept = default; + + void clear() + /*! + ensures + - #is_empty() == true + !*/ + { + ptr = nullptr; + type_id_ = nullptr; + } + + void* get_ptr() + /*! + ensures + - returns a pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr; + } + + const void* get_ptr() const + /*! + ensures + - returns a const pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr; + } + + std::type_index type_id() const + /*! + requires + - is_empty() == false + ensures + - returns the std::type_index of the type contained within this object. + I.e. if this object contains the type T then this returns std::type_index{typeid(T)}. + !*/ + { + return type_id_(); + } + + private: + void* ptr = nullptr; + std::type_index (*type_id_)() = nullptr; + }; + +// ----------------------------------------------------------------------------------------------------- + + } +} + +#endif //DLIB_TYPE_ERASURE_H_ diff --git a/dlib/array.h b/dlib/array.h new file mode 100644 index 0000000000000000000000000000000000000000..ecdafc497952fdb4865886a46b4978d1eaebd3ea --- /dev/null +++ b/dlib/array.h @@ -0,0 +1,10 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAy_ +#define DLIB_ARRAy_ + +#include "array/array_kernel.h" +#include "array/array_tools.h" + +#endif // DLIB_ARRAy_ + diff --git a/dlib/array/array_kernel.h b/dlib/array/array_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..0d18d1f45f4c4902fb7f2832cd154779ba2cee8f --- /dev/null +++ b/dlib/array/array_kernel.h @@ -0,0 +1,809 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY_KERNEl_2_ +#define DLIB_ARRAY_KERNEl_2_ + +#include "array_kernel_abstract.h" +#include "../interfaces/enumerable.h" +#include "../algs.h" +#include "../serialize.h" +#include "../sort.h" +#include "../is_kind.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class array : public enumerable + { + + /*! + INITIAL VALUE + - array_size == 0 + - max_array_size == 0 + - array_elements == 0 + - pos == 0 + - last_pos == 0 + - _at_start == true + + CONVENTION + - array_size == size() + - max_array_size == max_size() + - if (max_array_size > 0) + - array_elements == pointer to max_array_size elements of type T + - else + - array_elements == 0 + + - if (array_size > 0) + - last_pos == array_elements + array_size - 1 + - else + - last_pos == 0 + + + - at_start() == _at_start + - current_element_valid() == pos != 0 + - if (current_element_valid()) then + - *pos == element() + !*/ + + public: + + // These typedefs are here for backwards compatibility with old versions of dlib. + typedef array kernel_1a; + typedef array kernel_1a_c; + typedef array kernel_2a; + typedef array kernel_2a_c; + typedef array sort_1a; + typedef array sort_1a_c; + typedef array sort_1b; + typedef array sort_1b_c; + typedef array sort_2a; + typedef array sort_2a_c; + typedef array sort_2b; + typedef array sort_2b_c; + typedef array expand_1a; + typedef array expand_1a_c; + typedef array expand_1b; + typedef array expand_1b_c; + typedef array expand_1c; + typedef array expand_1c_c; + typedef array expand_1d; + typedef array expand_1d_c; + + + + + typedef T type; + typedef T value_type; + typedef mem_manager mem_manager_type; + + array ( + ) : + array_size(0), + max_array_size(0), + array_elements(0), + pos(0), + last_pos(0), + _at_start(true) + {} + + array(const array&) = delete; + array& operator=(array&) = delete; + + array( + array&& item + ) : array() + { + swap(item); + } + + array& operator=( + array&& item + ) + { + swap(item); + return *this; + } + + explicit array ( + size_t new_size + ) : + array_size(0), + max_array_size(0), + array_elements(0), + pos(0), + last_pos(0), + _at_start(true) + { + resize(new_size); + } + + ~array ( + ); + + void clear ( + ); + + inline const T& operator[] ( + size_t pos + ) const; + + inline T& operator[] ( + size_t pos + ); + + void set_size ( + size_t size + ); + + inline size_t max_size( + ) const; + + void set_max_size( + size_t max + ); + + void swap ( + array& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + inline bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + inline const T& element ( + ) const; + + inline T& element ( + ); + + bool move_next ( + ) const; + + void sort ( + ); + + void resize ( + size_t new_size + ); + + const T& back ( + ) const; + + T& back ( + ); + + void pop_back ( + ); + + void pop_back ( + T& item + ); + + void push_back ( + T& item + ); + + void push_back ( + T&& item + ); + + typedef T* iterator; + typedef const T* const_iterator; + iterator begin() { return array_elements; } + const_iterator begin() const { return array_elements; } + iterator end() { return array_elements+array_size; } + const_iterator end() const { return array_elements+array_size; } + + private: + + typename mem_manager::template rebind::other pool; + + // data members + size_t array_size; + size_t max_array_size; + T* array_elements; + + mutable T* pos; + T* last_pos; + mutable bool _at_start; + + }; + + template < + typename T, + typename mem_manager + > + inline void swap ( + array& a, + array& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void serialize ( + const array& item, + std::ostream& out + ) + { + try + { + serialize(item.max_size(),out); + serialize(item.size(),out); + + for (size_t i = 0; i < item.size(); ++i) + serialize(item[i],out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type array"); + } + } + + template < + typename T, + typename mem_manager + > + void deserialize ( + array& item, + std::istream& in + ) + { + try + { + size_t max_size, size; + deserialize(max_size,in); + deserialize(size,in); + item.set_max_size(max_size); + item.set_size(size); + for (size_t i = 0; i < size; ++i) + deserialize(item[i],in); + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + array:: + ~array ( + ) + { + if (array_elements) + { + pool.deallocate_array(array_elements); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + clear ( + ) + { + reset(); + last_pos = 0; + array_size = 0; + if (array_elements) + { + pool.deallocate_array(array_elements); + } + array_elements = 0; + max_array_size = 0; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& array:: + operator[] ( + size_t pos + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( pos < this->size() , + "\tconst T& array::operator[]" + << "\n\tpos must < size()" + << "\n\tpos: " << pos + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + return array_elements[pos]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& array:: + operator[] ( + size_t pos + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( pos < this->size() , + "\tT& array::operator[]" + << "\n\tpos must be < size()" + << "\n\tpos: " << pos + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + return array_elements[pos]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + set_size ( + size_t size + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(( size <= this->max_size() ), + "\tvoid array::set_size" + << "\n\tsize must be <= max_size()" + << "\n\tsize: " << size + << "\n\tmax size: " << this->max_size() + << "\n\tthis: " << this + ); + + reset(); + array_size = size; + if (size > 0) + last_pos = array_elements + size - 1; + else + last_pos = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + size_t array:: + size ( + ) const + { + return array_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + set_max_size( + size_t max + ) + { + reset(); + array_size = 0; + last_pos = 0; + if (max != 0) + { + // if new max size is different + if (max != max_array_size) + { + if (array_elements) + { + pool.deallocate_array(array_elements); + } + // try to get more memroy + try { array_elements = pool.allocate_array(max); } + catch (...) { array_elements = 0; max_array_size = 0; throw; } + max_array_size = max; + } + + } + // if the array is being made to be zero + else + { + if (array_elements) + pool.deallocate_array(array_elements); + max_array_size = 0; + array_elements = 0; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + size_t array:: + max_size ( + ) const + { + return max_array_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + swap ( + array& item + ) + { + auto array_size_temp = item.array_size; + auto max_array_size_temp = item.max_array_size; + T* array_elements_temp = item.array_elements; + + item.array_size = array_size; + item.max_array_size = max_array_size; + item.array_elements = array_elements; + + array_size = array_size_temp; + max_array_size = max_array_size_temp; + array_elements = array_elements_temp; + + exchange(_at_start,item._at_start); + exchange(pos,item.pos); + exchange(last_pos,item.last_pos); + pool.swap(item.pool); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool array:: + at_start ( + ) const + { + return _at_start; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + reset ( + ) const + { + _at_start = true; + pos = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool array:: + current_element_valid ( + ) const + { + return pos != 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& array:: + element ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(this->current_element_valid(), + "\tconst T& array::element()" + << "\n\tThe current element must be valid if you are to access it." + << "\n\tthis: " << this + ); + + return *pos; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& array:: + element ( + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(this->current_element_valid(), + "\tT& array::element()" + << "\n\tThe current element must be valid if you are to access it." + << "\n\tthis: " << this + ); + + return *pos; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool array:: + move_next ( + ) const + { + if (!_at_start) + { + if (pos < last_pos) + { + ++pos; + return true; + } + else + { + pos = 0; + return false; + } + } + else + { + _at_start = false; + if (array_size > 0) + { + pos = array_elements; + return true; + } + else + { + return false; + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Yet more functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + sort ( + ) + { + if (this->size() > 1) + { + // call the quick sort function for arrays that is in algs.h + dlib::qsort_array(*this,0,this->size()-1); + } + this->reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + resize ( + size_t new_size + ) + { + if (this->max_size() < new_size) + { + array temp; + temp.set_max_size(new_size); + temp.set_size(new_size); + for (size_t i = 0; i < this->size(); ++i) + { + exchange((*this)[i],temp[i]); + } + temp.swap(*this); + } + else + { + this->set_size(new_size); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& array:: + back ( + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( this->size() > 0 , + "\tT& array::back()" + << "\n\tsize() must be bigger than 0" + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + return (*this)[this->size()-1]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& array:: + back ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( this->size() > 0 , + "\tconst T& array::back()" + << "\n\tsize() must be bigger than 0" + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + return (*this)[this->size()-1]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + pop_back ( + T& item + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( this->size() > 0 , + "\tvoid array::pop_back()" + << "\n\tsize() must be bigger than 0" + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + exchange(item,(*this)[this->size()-1]); + this->set_size(this->size()-1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + pop_back ( + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( this->size() > 0 , + "\tvoid array::pop_back()" + << "\n\tsize() must be bigger than 0" + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + this->set_size(this->size()-1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + push_back ( + T& item + ) + { + if (this->max_size() == this->size()) + { + // double the size of the array + array temp; + temp.set_max_size(this->size()*2 + 1); + temp.set_size(this->size()+1); + for (size_t i = 0; i < this->size(); ++i) + { + exchange((*this)[i],temp[i]); + } + exchange(item,temp[temp.size()-1]); + temp.swap(*this); + } + else + { + this->set_size(this->size()+1); + exchange(item,(*this)[this->size()-1]); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + push_back ( + T&& item + ) { push_back(item); } + +// ---------------------------------------------------------------------------------------- + + template + struct is_array > + { + const static bool value = true; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ARRAY_KERNEl_2_ + diff --git a/dlib/array/array_kernel_abstract.h b/dlib/array/array_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..5cfdd483ac9da2e4b27916d05c5403f87dc6977b --- /dev/null +++ b/dlib/array/array_kernel_abstract.h @@ -0,0 +1,360 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ARRAY_KERNEl_ABSTRACT_ +#ifdef DLIB_ARRAY_KERNEl_ABSTRACT_ + +#include "../interfaces/enumerable.h" +#include "../serialize.h" +#include "../algs.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class array : public enumerable + { + + /*! + REQUIREMENTS ON T + T must have a default constructor. + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + front(), back(), swap(), max_size(), set_size(), and operator[] + functions do not invalidate pointers or references to internal data. + All other functions have no such guarantee. + + INITIAL VALUE + size() == 0 + max_size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the elements of the array in the + order (*this)[0], (*this)[1], (*this)[2], ... + + WHAT THIS OBJECT REPRESENTS + This object represents an ordered 1-dimensional array of items, + each item is associated with an integer value. The items are + numbered from 0 though size() - 1 and the operator[] functions + run in constant time. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + !*/ + + public: + + typedef T type; + typedef T value_type; + typedef mem_manager mem_manager_type; + + array ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by T's constructor + !*/ + + explicit array ( + size_t new_size + ); + /*! + ensures + - #*this is properly initialized + - #size() == new_size + - #max_size() == new_size + - All elements of the array will have initial values for their type. + throws + - std::bad_alloc or any exception thrown by T's constructor + !*/ + + ~array ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + array( + array&& item + ); + /*! + ensures + - move constructs *this from item. Therefore, the state of item is + moved into *this and #item has a valid but unspecified state. + !*/ + + array& operator=( + array&& item + ); + /*! + ensures + - move assigns *this from item. Therefore, the state of item is + moved into *this and #item has a valid but unspecified state. + - returns a reference to #*this + !*/ + + void clear ( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by T's constructor + if this exception is thrown then the array object is unusable + until clear() is called and succeeds + !*/ + + const T& operator[] ( + size_t pos + ) const; + /*! + requires + - pos < size() + ensures + - returns a const reference to the element at position pos + !*/ + + T& operator[] ( + size_t pos + ); + /*! + requires + - pos < size() + ensures + - returns a non-const reference to the element at position pos + !*/ + + void set_size ( + size_t size + ); + /*! + requires + - size <= max_size() + ensures + - #size() == size + - any element with index between 0 and size - 1 which was in the + array before the call to set_size() retains its value and index. + All other elements have undetermined (but valid for their type) + values. (e.g. this object might buffer old T objects and reuse + them without reinitializing them between calls to set_size()) + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + may throw this exception if there is not enough memory and + if it does throw then the call to set_size() has no effect + !*/ + + size_t max_size( + ) const; + /*! + ensures + - returns the maximum size of *this + !*/ + + void set_max_size( + size_t max + ); + /*! + ensures + - #max_size() == max + - #size() == 0 + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + may throw this exception if there is not enough + memory and if it does throw then max_size() == 0 + !*/ + + void swap ( + array& item + ); + /*! + ensures + - swaps *this and item + !*/ + + void sort ( + ); + /*! + requires + - T must be a type with that is comparable via operator< + ensures + - for all elements in #*this the ith element is <= the i+1 element + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + data may be lost if sort() throws + !*/ + + void resize ( + size_t new_size + ); + /*! + ensures + - #size() == new_size + - #max_size() == max(new_size,max_size()) + - for all i < size() && i < new_size: + - #(*this)[i] == (*this)[i] + (i.e. All the original elements of *this which were at index + values less than new_size are unmodified.) + - for all valid i >= size(): + - #(*this)[i] has an undefined value + (i.e. any new elements of the array have an undefined value) + throws + - std::bad_alloc or any exception thrown by T's constructor. + If an exception is thrown then it has no effect on *this. + !*/ + + + const T& back ( + ) const; + /*! + requires + - size() != 0 + ensures + - returns a const reference to (*this)[size()-1] + !*/ + + T& back ( + ); + /*! + requires + - size() != 0 + ensures + - returns a non-const reference to (*this)[size()-1] + !*/ + + void pop_back ( + T& item + ); + /*! + requires + - size() != 0 + ensures + - #size() == size() - 1 + - swaps (*this)[size()-1] into item + - All elements with an index less than size()-1 are + unmodified by this operation. + !*/ + + void pop_back ( + ); + /*! + requires + - size() != 0 + ensures + - #size() == size() - 1 + - All elements with an index less than size()-1 are + unmodified by this operation. + !*/ + + void push_back ( + T& item + ); + /*! + ensures + - #size() == size()+1 + - swaps item into (*this)[#size()-1] + - #back() == item + - #item has some undefined value (whatever happens to + get swapped out of the array) + throws + - std::bad_alloc or any exception thrown by T's constructor. + If an exception is thrown then it has no effect on *this. + !*/ + + void push_back (T&& item) { push_back(item); } + /*! + enable push_back from rvalues + !*/ + + typedef T* iterator; + typedef const T* const_iterator; + + iterator begin( + ); + /*! + ensures + - returns an iterator that points to the first element in this array or + end() if the array is empty. + !*/ + + const_iterator begin( + ) const; + /*! + ensures + - returns a const iterator that points to the first element in this + array or end() if the array is empty. + !*/ + + iterator end( + ); + /*! + ensures + - returns an iterator that points to one past the end of the array. + !*/ + + const_iterator end( + ) const; + /*! + ensures + - returns a const iterator that points to one past the end of the + array. + !*/ + + private: + + // restricted functions + array(array&); // copy constructor + array& operator=(array&); // assignment operator + + }; + + template < + typename T + > + inline void swap ( + array& a, + array& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T + > + void serialize ( + const array& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template < + typename T + > + void deserialize ( + array& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +} + +#endif // DLIB_ARRAY_KERNEl_ABSTRACT_ + diff --git a/dlib/array/array_tools.h b/dlib/array/array_tools.h new file mode 100644 index 0000000000000000000000000000000000000000..fce6343968da48774e9de4e3aada851026e1de8c --- /dev/null +++ b/dlib/array/array_tools.h @@ -0,0 +1,38 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY_tOOLS_H_ +#define DLIB_ARRAY_tOOLS_H_ + +#include "../assert.h" +#include "array_tools_abstract.h" + +namespace dlib +{ + template + void split_array ( + T& a, + T& b, + double frac + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 <= frac && frac <= 1, + "\t void split_array()" + << "\n\t frac must be between 0 and 1." + << "\n\t frac: " << frac + ); + + const unsigned long asize = static_cast(a.size()*frac); + const unsigned long bsize = a.size()-asize; + + b.resize(bsize); + for (unsigned long i = 0; i < b.size(); ++i) + { + swap(b[i], a[i+asize]); + } + a.resize(asize); + } +} + +#endif // DLIB_ARRAY_tOOLS_H_ + diff --git a/dlib/array/array_tools_abstract.h b/dlib/array/array_tools_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..e9b95751891abaee53eeb871fe5255f21c301e31 --- /dev/null +++ b/dlib/array/array_tools_abstract.h @@ -0,0 +1,33 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ARRAY_tOOLS_ABSTRACT_H_ +#ifdef DLIB_ARRAY_tOOLS_ABSTRACT_H_ + +#include "array_kernel_abstract.h" + +namespace dlib +{ + template + void split_array ( + T& a, + T& b, + double frac + ); + /*! + requires + - 0 <= frac <= 1 + - T must be an array type such as dlib::array or std::vector + ensures + - This function takes the elements of a and splits them into two groups. The + first group remains in a and the second group is put into b. The ordering of + elements in a is preserved. In particular, concatenating #a with #b will + reproduce the original contents of a. + - The elements in a are moved around using global swap(). So they must be + swappable, but do not need to be copyable. + - #a.size() == floor(a.size()*frac) + - #b.size() == a.size()-#a.size() + !*/ +} + +#endif // DLIB_ARRAY_tOOLS_ABSTRACT_H_ + diff --git a/dlib/array2d.h b/dlib/array2d.h new file mode 100644 index 0000000000000000000000000000000000000000..f5325e4a2f2720e5d01c71cf77fea4ee18598bb8 --- /dev/null +++ b/dlib/array2d.h @@ -0,0 +1,12 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY2d_ +#define DLIB_ARRAY2d_ + + +#include "array2d/array2d_kernel.h" +#include "array2d/serialize_pixel_overloads.h" +#include "array2d/array2d_generic_image.h" + +#endif // DLIB_ARRAY2d_ + diff --git a/dlib/array2d/array2d_generic_image.h b/dlib/array2d/array2d_generic_image.h new file mode 100644 index 0000000000000000000000000000000000000000..a96f5e3c2548046a195c0045f73ffffda20f8323 --- /dev/null +++ b/dlib/array2d/array2d_generic_image.h @@ -0,0 +1,67 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY2D_GENERIC_iMAGE_Hh_ +#define DLIB_ARRAY2D_GENERIC_iMAGE_Hh_ + +#include "array2d_kernel.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + template + struct image_traits > + { + typedef T pixel_type; + }; + template + struct image_traits > + { + typedef T pixel_type; + }; + + template + inline long num_rows( const array2d& img) { return img.nr(); } + template + inline long num_columns( const array2d& img) { return img.nc(); } + + template + inline void set_image_size( + array2d& img, + long rows, + long cols + ) { img.set_size(rows,cols); } + + template + inline void* image_data( + array2d& img + ) + { + if (img.size() != 0) + return &img[0][0]; + else + return 0; + } + + template + inline const void* image_data( + const array2d& img + ) + { + if (img.size() != 0) + return &img[0][0]; + else + return 0; + } + + template + inline long width_step( + const array2d& img + ) + { + return img.width_step(); + } + +} + +#endif // DLIB_ARRAY2D_GENERIC_iMAGE_Hh_ + diff --git a/dlib/array2d/array2d_kernel.h b/dlib/array2d/array2d_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..b4736ed18e66ab1e6a334d5bb3e99a240b372863 --- /dev/null +++ b/dlib/array2d/array2d_kernel.h @@ -0,0 +1,524 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY2D_KERNEl_1_ +#define DLIB_ARRAY2D_KERNEl_1_ + +#include "array2d_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/enumerable.h" +#include "../serialize.h" +#include "../geometry/rectangle.h" + +namespace dlib +{ + template < + typename T, + typename mem_manager = default_memory_manager + > + class array2d : public enumerable + { + + /*! + INITIAL VALUE + - nc_ == 0 + - nr_ == 0 + - data == 0 + - at_start_ == true + - cur == 0 + - last == 0 + + CONVENTION + - nc_ == nc() + - nr_ == nc() + - if (data != 0) then + - last == a pointer to the last element in the data array + - data == pointer to an array of nc_*nr_ T objects + - else + - nc_ == 0 + - nr_ == 0 + - data == 0 + - last == 0 + + + - nr_ * nc_ == size() + - if (cur == 0) then + - current_element_valid() == false + - else + - current_element_valid() == true + - *cur == element() + + - at_start_ == at_start() + !*/ + + + class row_helper; + public: + + // These typedefs are here for backwards compatibility with older versions of dlib. + typedef array2d kernel_1a; + typedef array2d kernel_1a_c; + + typedef T type; + typedef mem_manager mem_manager_type; + typedef T* iterator; + typedef const T* const_iterator; + + + // ----------------------------------- + + class row + { + /*! + CONVENTION + - nc_ == nc() + - for all x < nc_: + - (*this)[x] == data[x] + !*/ + + friend class array2d; + friend class row_helper; + + public: + long nc ( + ) const { return nc_; } + + const T& operator[] ( + long column + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(column < nc() && column >= 0, + "\tconst T& array2d::operator[](long column) const" + << "\n\tThe column index given must be less than the number of columns." + << "\n\tthis: " << this + << "\n\tcolumn: " << column + << "\n\tnc(): " << nc() + ); + + return data[column]; + } + + T& operator[] ( + long column + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(column < nc() && column >= 0, + "\tT& array2d::operator[](long column)" + << "\n\tThe column index given must be less than the number of columns." + << "\n\tthis: " << this + << "\n\tcolumn: " << column + << "\n\tnc(): " << nc() + ); + + return data[column]; + } + + private: + + row(T* data_, long cols) : data(data_), nc_(cols) {} + row(row&& r) = default; + row& operator=(row&& r) = default; + + T* data = nullptr; + long nc_ = 0; + + + // restricted functions + row(const row&) = delete; + row& operator=(const row&) = delete; + }; + + // ----------------------------------- + + array2d ( + ) : + data(0), + nc_(0), + nr_(0), + cur(0), + last(0), + at_start_(true) + { + } + + array2d( + long rows, + long cols + ) : + data(0), + nc_(0), + nr_(0), + cur(0), + last(0), + at_start_(true) + { + // make sure requires clause is not broken + DLIB_ASSERT((cols >= 0 && rows >= 0), + "\t array2d::array2d(long rows, long cols)" + << "\n\t The array2d can't have negative rows or columns." + << "\n\t this: " << this + << "\n\t cols: " << cols + << "\n\t rows: " << rows + ); + + set_size(rows,cols); + } + + array2d(const array2d&) = delete; // copy constructor + array2d& operator=(const array2d&) = delete; // assignment operator + +#ifdef DLIB_HAS_RVALUE_REFERENCES + array2d(array2d&& item) : array2d() + { + swap(item); + } + + array2d& operator= ( + array2d&& rhs + ) + { + swap(rhs); + return *this; + } +#endif + + virtual ~array2d ( + ) { clear(); } + + long nc ( + ) const { return nc_; } + + long nr ( + ) const { return nr_; } + + row operator[] ( + long row_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(row_ < nr() && row_ >= 0, + "\trow array2d::operator[](long row_)" + << "\n\tThe row index given must be less than the number of rows." + << "\n\tthis: " << this + << "\n\trow_: " << row_ + << "\n\tnr(): " << nr() + ); + + return row(data+row_*nc_, nc_); + } + + const row operator[] ( + long row_ + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(row_ < nr() && row_ >= 0, + "\tconst row array2d::operator[](long row_) const" + << "\n\tThe row index given must be less than the number of rows." + << "\n\tthis: " << this + << "\n\trow_: " << row_ + << "\n\tnr(): " << nr() + ); + + return row(data+row_*nc_, nc_); + } + + void swap ( + array2d& item + ) + { + exchange(data,item.data); + exchange(nr_,item.nr_); + exchange(nc_,item.nc_); + exchange(at_start_,item.at_start_); + exchange(cur,item.cur); + exchange(last,item.last); + pool.swap(item.pool); + } + + void clear ( + ) + { + if (data != 0) + { + pool.deallocate_array(data); + nc_ = 0; + nr_ = 0; + data = 0; + at_start_ = true; + cur = 0; + last = 0; + } + } + + void set_size ( + long rows, + long cols + ); + + bool at_start ( + ) const { return at_start_; } + + void reset ( + ) const { at_start_ = true; cur = 0; } + + bool current_element_valid ( + ) const { return (cur != 0); } + + const T& element ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tconst T& array2d::element()()" + << "\n\tYou can only call element() when you are at a valid one." + << "\n\tthis: " << this + ); + + return *cur; + } + + T& element ( + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tT& array2d::element()()" + << "\n\tYou can only call element() when you are at a valid one." + << "\n\tthis: " << this + ); + + return *cur; + } + + bool move_next ( + ) const + { + if (cur != 0) + { + if (cur != last) + { + ++cur; + return true; + } + cur = 0; + return false; + } + else if (at_start_) + { + cur = data; + at_start_ = false; + return (data != 0); + } + else + { + return false; + } + } + + size_t size ( + ) const { return static_cast(nc_) * static_cast(nr_); } + + long width_step ( + ) const + { + return nc_*sizeof(T); + } + + iterator begin() + { + return data; + } + + iterator end() + { + return data+size(); + } + + const_iterator begin() const + { + return data; + } + + const_iterator end() const + { + return data+size(); + } + + + private: + + + T* data; + long nc_; + long nr_; + + typename mem_manager::template rebind::other pool; + mutable T* cur; + T* last; + mutable bool at_start_; + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + inline void swap ( + array2d& a, + array2d& b + ) { a.swap(b); } + + + template < + typename T, + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + item.reset(); + while (item.move_next()) + serialize(item.element(),out); + item.reset(); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename T, + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + item.set_size(nr,nc); + + while (item.move_next()) + deserialize(item.element(),in); + item.reset(); + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array2d:: + set_size ( + long rows, + long cols + ) + { + // make sure requires clause is not broken + DLIB_ASSERT((cols >= 0 && rows >= 0) , + "\tvoid array2d::set_size(long rows, long cols)" + << "\n\tThe array2d can't have negative rows or columns." + << "\n\tthis: " << this + << "\n\tcols: " << cols + << "\n\trows: " << rows + ); + + // set the enumerator back at the start + at_start_ = true; + cur = 0; + + // don't do anything if we are already the right size. + if (nc_ == cols && nr_ == rows) + { + return; + } + + nc_ = cols; + nr_ = rows; + + // free any existing memory + if (data != 0) + { + pool.deallocate_array(data); + data = 0; + } + + // now setup this object to have the new size + try + { + if (nr_ > 0) + { + data = pool.allocate_array(nr_*nc_); + last = data + nr_*nc_ - 1; + } + } + catch (...) + { + if (data) + pool.deallocate_array(data); + + data = 0; + nc_ = 0; + nr_ = 0; + last = 0; + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template + struct is_array2d > + { + const static bool value = true; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ARRAY2D_KERNEl_1_ + diff --git a/dlib/array2d/array2d_kernel_abstract.h b/dlib/array2d/array2d_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..daccfc600436f2a6f59ea58f37fc67537a3d94dd --- /dev/null +++ b/dlib/array2d/array2d_kernel_abstract.h @@ -0,0 +1,339 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ARRAY2D_KERNEl_ABSTRACT_ +#ifdef DLIB_ARRAY2D_KERNEl_ABSTRACT_ + +#include "../interfaces/enumerable.h" +#include "../serialize.h" +#include "../algs.h" +#include "../geometry/rectangle_abstract.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class array2d : public enumerable + { + + /*! + REQUIREMENTS ON T + T must have a default constructor. + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + No member functions in this object will invalidate pointers + or references to internal data except for the set_size() + and clear() member functions. + + INITIAL VALUE + nr() == 0 + nc() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the elements of the array starting + with row 0 and then proceeding to row 1 and so on. Each row will be + fully enumerated before proceeding on to the next row and the elements + in a row will be enumerated beginning with the 0th column, then the 1st + column and so on. + + WHAT THIS OBJECT REPRESENTS + This object represents a 2-Dimensional array of objects of + type T. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + + + Finally, note that this object stores its data contiguously and in + row major order. Moreover, there is no padding at the end of each row. + This means that its width_step() value is always equal to sizeof(type)*nc(). + !*/ + + + public: + + // ---------------------------------------- + + typedef T type; + typedef mem_manager mem_manager_type; + typedef T* iterator; + typedef const T* const_iterator; + + // ---------------------------------------- + + class row + { + /*! + POINTERS AND REFERENCES TO INTERNAL DATA + No member functions in this object will invalidate pointers + or references to internal data. + + WHAT THIS OBJECT REPRESENTS + This object represents a row of Ts in an array2d object. + !*/ + public: + long nc ( + ) const; + /*! + ensures + - returns the number of columns in this row + !*/ + + const T& operator[] ( + long column + ) const; + /*! + requires + - 0 <= column < nc() + ensures + - returns a const reference to the T in the given column + !*/ + + T& operator[] ( + long column + ); + /*! + requires + - 0 <= column < nc() + ensures + - returns a non-const reference to the T in the given column + !*/ + + private: + // restricted functions + row(); + row& operator=(row&); + }; + + // ---------------------------------------- + + array2d ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + array2d(const array2d&) = delete; // copy constructor + array2d& operator=(const array2d&) = delete; // assignment operator + + array2d( + array2d&& item + ); + /*! + ensures + - Moves the state of item into *this. + - #item is in a valid but unspecified state. + !*/ + + array2d ( + long rows, + long cols + ); + /*! + requires + - rows >= 0 && cols >= 0 + ensures + - #nc() == cols + - #nr() == rows + - #at_start() == true + - all elements in this array have initial values for their type + throws + - std::bad_alloc + !*/ + + virtual ~array2d ( + ); + /*! + ensures + - all resources associated with *this has been released + !*/ + + void clear ( + ); + /*! + ensures + - #*this has an initial value for its type + !*/ + + long nc ( + ) const; + /*! + ensures + - returns the number of elements there are in a row. i.e. returns + the number of columns in *this + !*/ + + long nr ( + ) const; + /*! + ensures + - returns the number of rows in *this + !*/ + + void set_size ( + long rows, + long cols + ); + /*! + requires + - rows >= 0 && cols >= 0 + ensures + - #nc() == cols + - #nr() == rows + - #at_start() == true + - if (the call to set_size() doesn't change the dimensions of this array) then + - all elements in this array retain their values from before this function was called + - else + - all elements in this array have initial values for their type + throws + - std::bad_alloc + If this exception is thrown then #*this will have an initial + value for its type. + !*/ + + row operator[] ( + long row_index + ); + /*! + requires + - 0 <= row_index < nr() + ensures + - returns a non-const row of nc() elements that represents the + given row_index'th row in *this. + !*/ + + const row operator[] ( + long row_index + ) const; + /*! + requires + - 0 <= row_index < nr() + ensures + - returns a const row of nc() elements that represents the + given row_index'th row in *this. + !*/ + + void swap ( + array2d& item + ); + /*! + ensures + - swaps *this and item + !*/ + + array2d& operator= ( + array2d&& rhs + ); + /*! + ensures + - Moves the state of item into *this. + - #item is in a valid but unspecified state. + - returns #*this + !*/ + + long width_step ( + ) const; + /*! + ensures + - returns the size of one row of the image, in bytes. + More precisely, return a number N such that: + (char*)&item[0][0] + N == (char*)&item[1][0]. + - for dlib::array2d objects, the returned value + is always equal to sizeof(type)*nc(). However, + other objects which implement dlib::array2d style + interfaces might have padding at the ends of their + rows and therefore might return larger numbers. + An example of such an object is the dlib::cv_image. + !*/ + + iterator begin( + ); + /*! + ensures + - returns a random access iterator pointing to the first element in this + object. + - The iterator will iterate over the elements of the object in row major + order. + !*/ + + iterator end( + ); + /*! + ensures + - returns a random access iterator pointing to one past the end of the last + element in this object. + !*/ + + const_iterator begin( + ) const; + /*! + ensures + - returns a random access iterator pointing to the first element in this + object. + - The iterator will iterate over the elements of the object in row major + order. + !*/ + + const_iterator end( + ) const; + /*! + ensures + - returns a random access iterator pointing to one past the end of the last + element in this object. + !*/ + + }; + + template < + typename T, + typename mem_manager + > + inline void swap ( + array2d& a, + array2d& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T, + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ); + /*! + Provides serialization support. Note that the serialization formats used by the + dlib::matrix and dlib::array2d objects are compatible. That means you can load the + serialized data from one into another and it will work properly. + !*/ + + template < + typename T, + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +} + +#endif // DLIB_ARRAY2D_KERNEl_ABSTRACT_ + diff --git a/dlib/array2d/serialize_pixel_overloads.h b/dlib/array2d/serialize_pixel_overloads.h new file mode 100644 index 0000000000000000000000000000000000000000..91383a66654553b2356a677e848f2c772a72aa22 --- /dev/null +++ b/dlib/array2d/serialize_pixel_overloads.h @@ -0,0 +1,371 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_ +#define DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_ + +#include "array2d_kernel.h" +#include "../pixel.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /* + This file contains overloads of the serialize functions for array2d object + for the case where they contain simple 8bit POD pixel types. In these + cases we can perform a much faster serialization by writing data in chunks + instead of one pixel at a time (this avoids a lot of function call overhead + inside the iostreams). + */ + +// ---------------------------------------------------------------------------------------- + + template < + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + COMPILE_TIME_ASSERT(sizeof(rgb_pixel) == 3); + + if (item.size() != 0) + out.write((char*)&item[0][0], sizeof(rgb_pixel)*item.size()); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + COMPILE_TIME_ASSERT(sizeof(rgb_pixel) == 3); + + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + item.set_size(nr,nc); + + if (item.size() != 0) + in.read((char*)&item[0][0], sizeof(rgb_pixel)*item.size()); + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + COMPILE_TIME_ASSERT(sizeof(bgr_pixel) == 3); + + if (item.size() != 0) + out.write((char*)&item[0][0], sizeof(bgr_pixel)*item.size()); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + COMPILE_TIME_ASSERT(sizeof(bgr_pixel) == 3); + + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + + item.set_size(nr,nc); + + if (item.size() != 0) + in.read((char*)&item[0][0], sizeof(bgr_pixel)*item.size()); + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + COMPILE_TIME_ASSERT(sizeof(hsi_pixel) == 3); + + if (item.size() != 0) + out.write((char*)&item[0][0], sizeof(hsi_pixel)*item.size()); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + COMPILE_TIME_ASSERT(sizeof(hsi_pixel) == 3); + + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + + item.set_size(nr,nc); + + if (item.size() != 0) + in.read((char*)&item[0][0], sizeof(hsi_pixel)*item.size()); + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + COMPILE_TIME_ASSERT(sizeof(rgb_alpha_pixel) == 4); + + if (item.size() != 0) + out.write((char*)&item[0][0], sizeof(rgb_alpha_pixel)*item.size()); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + COMPILE_TIME_ASSERT(sizeof(rgb_alpha_pixel) == 4); + + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + + item.set_size(nr,nc); + + if (item.size() != 0) + in.read((char*)&item[0][0], sizeof(rgb_alpha_pixel)*item.size()); + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + if (item.size() != 0) + out.write((char*)&item[0][0], sizeof(unsigned char)*item.size()); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + + item.set_size(nr,nc); + + if (item.size() != 0) + in.read((char*)&item[0][0], sizeof(unsigned char)*item.size()); + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_ + diff --git a/dlib/assert.h b/dlib/assert.h new file mode 100644 index 0000000000000000000000000000000000000000..67dc634144130757bd61fa5e9668ff822c1f7e11 --- /dev/null +++ b/dlib/assert.h @@ -0,0 +1,217 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ASSERt_ +#define DLIB_ASSERt_ + +#include "config.h" +#include +#include +#include "error.h" + +// ----------------------------- + +// Use some stuff from boost here +// (C) Copyright John Maddock 2001 - 2003. +// (C) Copyright Darin Adler 2001. +// (C) Copyright Peter Dimov 2001. +// (C) Copyright Bill Kempf 2002. +// (C) Copyright Jens Maurer 2002. +// (C) Copyright David Abrahams 2002 - 2003. +// (C) Copyright Gennaro Prota 2003. +// (C) Copyright Eric Friedman 2003. +// License: Boost Software License See LICENSE.txt for the full license. +// +#ifndef DLIB_BOOST_JOIN +#define DLIB_BOOST_JOIN( X, Y ) DLIB_BOOST_DO_JOIN( X, Y ) +#define DLIB_BOOST_DO_JOIN( X, Y ) DLIB_BOOST_DO_JOIN2(X,Y) +#define DLIB_BOOST_DO_JOIN2( X, Y ) X##Y +#endif + +// figure out if the compiler has rvalue references. +#if defined(__clang__) +# if __has_feature(cxx_rvalue_references) +# define DLIB_HAS_RVALUE_REFERENCES +# endif +# if __has_feature(cxx_generalized_initializers) +# define DLIB_HAS_INITIALIZER_LISTS +# endif +#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) && defined(__GXX_EXPERIMENTAL_CXX0X__) +# define DLIB_HAS_RVALUE_REFERENCES +# define DLIB_HAS_INITIALIZER_LISTS +#elif defined(_MSC_VER) && _MSC_VER >= 1800 +# define DLIB_HAS_INITIALIZER_LISTS +# define DLIB_HAS_RVALUE_REFERENCES +#elif defined(_MSC_VER) && _MSC_VER >= 1600 +# define DLIB_HAS_RVALUE_REFERENCES +#elif defined(__INTEL_COMPILER) && defined(BOOST_INTEL_STDCXX0X) +# define DLIB_HAS_RVALUE_REFERENCES +# define DLIB_HAS_INITIALIZER_LISTS +#endif + +#if defined(__APPLE__) && defined(__GNUC_LIBSTD__) && ((__GNUC_LIBSTD__-0) * 100 + __GNUC_LIBSTD_MINOR__-0 <= 402) + // Apple has not updated libstdc++ in some time and anything under 4.02 does not have for sure. +# undef DLIB_HAS_INITIALIZER_LISTS +#endif + +// figure out if the compiler has static_assert. +#if defined(__clang__) +# if __has_feature(cxx_static_assert) +# define DLIB_HAS_STATIC_ASSERT +# endif +#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) && defined(__GXX_EXPERIMENTAL_CXX0X__) +# define DLIB_HAS_STATIC_ASSERT +#elif defined(_MSC_VER) && _MSC_VER >= 1600 +# define DLIB_HAS_STATIC_ASSERT +#elif defined(__INTEL_COMPILER) && defined(BOOST_INTEL_STDCXX0X) +# define DLIB_HAS_STATIC_ASSERT +#endif + + +// ----------------------------- + +namespace dlib +{ + template struct compile_time_assert; + template <> struct compile_time_assert { enum {value=1}; }; + + template struct assert_are_same_type; + template struct assert_are_same_type {enum{value=1};}; + template struct assert_are_not_same_type {enum{value=1}; }; + template struct assert_are_not_same_type {}; + + template struct assert_types_match {enum{value=0};}; + template struct assert_types_match {enum{value=1};}; +} + + +// gcc 4.8 will warn about unused typedefs. But we use typedefs in some of the compile +// time assert macros so we need to make it not complain about them "not being used". +#ifdef __GNUC__ +#define DLIB_NO_WARN_UNUSED __attribute__ ((unused)) +#else +#define DLIB_NO_WARN_UNUSED +#endif + +// Use the newer static_assert if it's available since it produces much more readable error +// messages. +#ifdef DLIB_HAS_STATIC_ASSERT + #define COMPILE_TIME_ASSERT(expression) static_assert(expression, "Failed assertion") + #define ASSERT_ARE_SAME_TYPE(type1, type2) static_assert(::dlib::assert_types_match::value, "These types should be the same but aren't.") + #define ASSERT_ARE_NOT_SAME_TYPE(type1, type2) static_assert(!::dlib::assert_types_match::value, "These types should NOT be the same.") +#else + #define COMPILE_TIME_ASSERT(expression) \ + DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DLIB_CTA, __LINE__)[::dlib::compile_time_assert<(bool)(expression)>::value] + + #define ASSERT_ARE_SAME_TYPE(type1, type2) \ + DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DLIB_AAST, __LINE__)[::dlib::assert_are_same_type::value] + + #define ASSERT_ARE_NOT_SAME_TYPE(type1, type2) \ + DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DLIB_AANST, __LINE__)[::dlib::assert_are_not_same_type::value] +#endif + +// ----------------------------- + +#if defined DLIB_DISABLE_ASSERTS + // if DLIB_DISABLE_ASSERTS is on then never enable DLIB_ASSERT no matter what. + #undef ENABLE_ASSERTS +#endif + +#if !defined(DLIB_DISABLE_ASSERTS) && ( defined DEBUG || defined _DEBUG) + // make sure ENABLE_ASSERTS is defined if we are indeed using them. + #ifndef ENABLE_ASSERTS + #define ENABLE_ASSERTS + #endif +#endif + +// ----------------------------- + +#ifdef __GNUC__ +// There is a bug in version 4.4.5 of GCC on Ubuntu which causes GCC to segfault +// when __PRETTY_FUNCTION__ is used within certain templated functions. So just +// don't use it with this version of GCC. +# if !(__GNUC__ == 4 && __GNUC_MINOR__ == 4 && __GNUC_PATCHLEVEL__ == 5) +# define DLIB_FUNCTION_NAME __PRETTY_FUNCTION__ +# else +# define DLIB_FUNCTION_NAME "unknown function" +# endif +#elif defined(_MSC_VER) +#define DLIB_FUNCTION_NAME __FUNCSIG__ +#else +#define DLIB_FUNCTION_NAME "unknown function" +#endif + +#define DLIBM_CASSERT(_exp,_message) \ + {if ( !(_exp) ) \ + { \ + dlib_assert_breakpoint(); \ + std::ostringstream dlib_o_out; \ + dlib_o_out << "\n\nError detected at line " << __LINE__ << ".\n"; \ + dlib_o_out << "Error detected in file " << __FILE__ << ".\n"; \ + dlib_o_out << "Error detected in function " << DLIB_FUNCTION_NAME << ".\n\n"; \ + dlib_o_out << "Failing expression was " << #_exp << ".\n"; \ + dlib_o_out << std::boolalpha << _message << "\n"; \ + throw dlib::fatal_error(dlib::EBROKEN_ASSERT,dlib_o_out.str()); \ + }} + +// This macro is not needed if you have a real C++ compiler. It's here to work around bugs in Visual Studio's preprocessor. +#define DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(x) x +// Make it so the 2nd argument of DLIB_CASSERT is optional. That is, you can call it like +// DLIB_CASSERT(exp) or DLIB_CASSERT(exp,message). +#define DLIBM_CASSERT_1_ARGS(exp) DLIBM_CASSERT(exp,"") +#define DLIBM_CASSERT_2_ARGS(exp,message) DLIBM_CASSERT(exp,message) +#define DLIBM_GET_3TH_ARG(arg1, arg2, arg3, ...) arg3 +#define DLIBM_CASSERT_CHOOSER(...) DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(DLIBM_GET_3TH_ARG(__VA_ARGS__, DLIBM_CASSERT_2_ARGS, DLIBM_CASSERT_1_ARGS, DLIB_CASSERT_NEVER_USED)) +#define DLIB_CASSERT(...) DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(DLIBM_CASSERT_CHOOSER(__VA_ARGS__)(__VA_ARGS__)) + + +#ifdef ENABLE_ASSERTS + #define DLIB_ASSERT(...) DLIB_CASSERT(__VA_ARGS__) + #define DLIB_IF_ASSERT(exp) exp +#else + #define DLIB_ASSERT(...) {} + #define DLIB_IF_ASSERT(exp) +#endif + +// ---------------------------------------------------------------------------------------- + + /*!A DLIB_ASSERT_HAS_STANDARD_LAYOUT + + This macro is meant to cause a compiler error if a type doesn't have a simple + memory layout (like a C struct). In particular, types with simple layouts are + ones which can be copied via memcpy(). + + + This was called a POD type in C++03 and in C++0x we are looking to check if + it is a "standard layout type". Once we can use C++0x we can change this macro + to something that uses the std::is_standard_layout type_traits class. + See: http://www2.research.att.com/~bs/C++0xFAQ.html#PODs + !*/ + // Use the fact that in C++03 you can't put non-PODs into a union. +#define DLIB_ASSERT_HAS_STANDARD_LAYOUT(type) \ + union DLIB_BOOST_JOIN(DAHSL_,__LINE__) { type TYPE_NOT_STANDARD_LAYOUT; }; \ + DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DAHSL2_,__LINE__)[sizeof(DLIB_BOOST_JOIN(DAHSL_,__LINE__))]; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +// breakpoints +extern "C" +{ + inline void dlib_assert_breakpoint( + ) {} + /*! + ensures + - this function does nothing + It exists just so you can put breakpoints on it in a debugging tool. + It is called only when an DLIB_ASSERT or DLIB_CASSERT fails and is about to + throw an exception. + !*/ +} + +// ----------------------------- + +#include "stack_trace.h" + +#endif // DLIB_ASSERt_ + diff --git a/dlib/base64.h b/dlib/base64.h new file mode 100644 index 0000000000000000000000000000000000000000..8308920d6e133a403a08160cde5bd5025408f2c9 --- /dev/null +++ b/dlib/base64.h @@ -0,0 +1,9 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BASe64_ +#define DLIB_BASe64_ + +#include "base64/base64_kernel_1.h" + +#endif // DLIB_BASe64_ + diff --git a/dlib/base64/base64_kernel_1.cpp b/dlib/base64/base64_kernel_1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b48c789e8baa9e15681e02d35b28cdf64e711be --- /dev/null +++ b/dlib/base64/base64_kernel_1.cpp @@ -0,0 +1,403 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BASE64_KERNEL_1_CPp_ +#define DLIB_BASE64_KERNEL_1_CPp_ + +#include "base64_kernel_1.h" +#include +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + base64::line_ending_type base64:: + line_ending ( + ) const + { + return eol_style; + } + +// ---------------------------------------------------------------------------------------- + + void base64:: + set_line_ending ( + line_ending_type eol_style_ + ) + { + eol_style = eol_style_; + } + +// ---------------------------------------------------------------------------------------- + + base64:: + base64 ( + ) : + encode_table(0), + decode_table(0), + bad_value(100), + eol_style(LF) + { + try + { + encode_table = new char[64]; + decode_table = new unsigned char[UCHAR_MAX]; + } + catch (...) + { + if (encode_table) delete [] encode_table; + if (decode_table) delete [] decode_table; + throw; + } + + // now set up the tables with the right stuff + encode_table[0] = 'A'; + encode_table[17] = 'R'; + encode_table[34] = 'i'; + encode_table[51] = 'z'; + + encode_table[1] = 'B'; + encode_table[18] = 'S'; + encode_table[35] = 'j'; + encode_table[52] = '0'; + + encode_table[2] = 'C'; + encode_table[19] = 'T'; + encode_table[36] = 'k'; + encode_table[53] = '1'; + + encode_table[3] = 'D'; + encode_table[20] = 'U'; + encode_table[37] = 'l'; + encode_table[54] = '2'; + + encode_table[4] = 'E'; + encode_table[21] = 'V'; + encode_table[38] = 'm'; + encode_table[55] = '3'; + + encode_table[5] = 'F'; + encode_table[22] = 'W'; + encode_table[39] = 'n'; + encode_table[56] = '4'; + + encode_table[6] = 'G'; + encode_table[23] = 'X'; + encode_table[40] = 'o'; + encode_table[57] = '5'; + + encode_table[7] = 'H'; + encode_table[24] = 'Y'; + encode_table[41] = 'p'; + encode_table[58] = '6'; + + encode_table[8] = 'I'; + encode_table[25] = 'Z'; + encode_table[42] = 'q'; + encode_table[59] = '7'; + + encode_table[9] = 'J'; + encode_table[26] = 'a'; + encode_table[43] = 'r'; + encode_table[60] = '8'; + + encode_table[10] = 'K'; + encode_table[27] = 'b'; + encode_table[44] = 's'; + encode_table[61] = '9'; + + encode_table[11] = 'L'; + encode_table[28] = 'c'; + encode_table[45] = 't'; + encode_table[62] = '+'; + + encode_table[12] = 'M'; + encode_table[29] = 'd'; + encode_table[46] = 'u'; + encode_table[63] = '/'; + + encode_table[13] = 'N'; + encode_table[30] = 'e'; + encode_table[47] = 'v'; + + encode_table[14] = 'O'; + encode_table[31] = 'f'; + encode_table[48] = 'w'; + + encode_table[15] = 'P'; + encode_table[32] = 'g'; + encode_table[49] = 'x'; + + encode_table[16] = 'Q'; + encode_table[33] = 'h'; + encode_table[50] = 'y'; + + + + // we can now fill out the decode_table by using the encode_table + for (int i = 0; i < UCHAR_MAX; ++i) + { + decode_table[i] = bad_value; + } + for (unsigned char i = 0; i < 64; ++i) + { + decode_table[(unsigned char)encode_table[i]] = i; + } + } + +// ---------------------------------------------------------------------------------------- + + base64:: + ~base64 ( + ) + { + delete [] encode_table; + delete [] decode_table; + } + +// ---------------------------------------------------------------------------------------- + + void base64:: + encode ( + std::istream& in_, + std::ostream& out_ + ) const + { + using namespace std; + streambuf& in = *in_.rdbuf(); + streambuf& out = *out_.rdbuf(); + + unsigned char inbuf[3]; + unsigned char outbuf[4]; + streamsize status = in.sgetn(reinterpret_cast(&inbuf),3); + + unsigned char c1, c2, c3, c4, c5, c6; + + int counter = 19; + + // while we haven't hit the end of the input stream + while (status != 0) + { + if (counter == 0) + { + counter = 19; + // write a newline + char ch; + switch (eol_style) + { + case CR: + ch = '\r'; + if (out.sputn(&ch,1)!=1) + throw std::ios_base::failure("error occurred in the base64 object"); + break; + case LF: + ch = '\n'; + if (out.sputn(&ch,1)!=1) + throw std::ios_base::failure("error occurred in the base64 object"); + break; + case CRLF: + ch = '\r'; + if (out.sputn(&ch,1)!=1) + throw std::ios_base::failure("error occurred in the base64 object"); + ch = '\n'; + if (out.sputn(&ch,1)!=1) + throw std::ios_base::failure("error occurred in the base64 object"); + break; + default: + DLIB_CASSERT(false,"this should never happen"); + } + } + --counter; + + if (status == 3) + { + // encode the bytes in inbuf to base64 and write them to the output stream + c1 = inbuf[0]&0xfc; + c2 = inbuf[0]&0x03; + c3 = inbuf[1]&0xf0; + c4 = inbuf[1]&0x0f; + c5 = inbuf[2]&0xc0; + c6 = inbuf[2]&0x3f; + + outbuf[0] = c1>>2; + outbuf[1] = (c2<<4)|(c3>>4); + outbuf[2] = (c4<<2)|(c5>>6); + outbuf[3] = c6; + + + outbuf[0] = encode_table[outbuf[0]]; + outbuf[1] = encode_table[outbuf[1]]; + outbuf[2] = encode_table[outbuf[2]]; + outbuf[3] = encode_table[outbuf[3]]; + + // write the encoded bytes to the output stream + if (out.sputn(reinterpret_cast(&outbuf),4)!=4) + { + throw std::ios_base::failure("error occurred in the base64 object"); + } + + // get 3 more input bytes + status = in.sgetn(reinterpret_cast(&inbuf),3); + continue; + } + else if (status == 2) + { + // we are at the end of the input stream and need to add some padding + + // encode the bytes in inbuf to base64 and write them to the output stream + c1 = inbuf[0]&0xfc; + c2 = inbuf[0]&0x03; + c3 = inbuf[1]&0xf0; + c4 = inbuf[1]&0x0f; + c5 = 0; + + outbuf[0] = c1>>2; + outbuf[1] = (c2<<4)|(c3>>4); + outbuf[2] = (c4<<2)|(c5>>6); + outbuf[3] = '='; + + outbuf[0] = encode_table[outbuf[0]]; + outbuf[1] = encode_table[outbuf[1]]; + outbuf[2] = encode_table[outbuf[2]]; + + // write the encoded bytes to the output stream + if (out.sputn(reinterpret_cast(&outbuf),4)!=4) + { + throw std::ios_base::failure("error occurred in the base64 object"); + } + + + break; + } + else // in this case status must be 1 + { + // we are at the end of the input stream and need to add some padding + + // encode the bytes in inbuf to base64 and write them to the output stream + c1 = inbuf[0]&0xfc; + c2 = inbuf[0]&0x03; + c3 = 0; + + outbuf[0] = c1>>2; + outbuf[1] = (c2<<4)|(c3>>4); + outbuf[2] = '='; + outbuf[3] = '='; + + outbuf[0] = encode_table[outbuf[0]]; + outbuf[1] = encode_table[outbuf[1]]; + + + // write the encoded bytes to the output stream + if (out.sputn(reinterpret_cast(&outbuf),4)!=4) + { + throw std::ios_base::failure("error occurred in the base64 object"); + } + + break; + } + } // while (status != 0) + + + // make sure the stream buffer flushes to its I/O channel + out.pubsync(); + } + +// ---------------------------------------------------------------------------------------- + + void base64:: + decode ( + std::istream& in_, + std::ostream& out_ + ) const + { + using namespace std; + streambuf& in = *in_.rdbuf(); + streambuf& out = *out_.rdbuf(); + + unsigned char inbuf[4]; + unsigned char outbuf[3]; + int inbuf_pos = 0; + streamsize status = in.sgetn(reinterpret_cast(inbuf),1); + + // only count this character if it isn't some kind of filler + if (status == 1 && decode_table[inbuf[0]] != bad_value ) + ++inbuf_pos; + + unsigned char c1, c2, c3, c4, c5, c6; + streamsize outsize; + + // while we haven't hit the end of the input stream + while (status != 0) + { + // if we have 4 valid characters + if (inbuf_pos == 4) + { + inbuf_pos = 0; + + // this might be the end of the encoded data so we need to figure out if + // there was any padding applied. + outsize = 3; + if (inbuf[3] == '=') + { + if (inbuf[2] == '=') + outsize = 1; + else + outsize = 2; + } + + // decode the incoming characters + inbuf[0] = decode_table[inbuf[0]]; + inbuf[1] = decode_table[inbuf[1]]; + inbuf[2] = decode_table[inbuf[2]]; + inbuf[3] = decode_table[inbuf[3]]; + + + // now pack these guys into bytes rather than 6 bit chunks + c1 = inbuf[0]<<2; + c2 = inbuf[1]>>4; + c3 = inbuf[1]<<4; + c4 = inbuf[2]>>2; + c5 = inbuf[2]<<6; + c6 = inbuf[3]; + + outbuf[0] = c1|c2; + outbuf[1] = c3|c4; + outbuf[2] = c5|c6; + + + // write the encoded bytes to the output stream + if (out.sputn(reinterpret_cast(&outbuf),outsize)!=outsize) + { + throw std::ios_base::failure("error occurred in the base64 object"); + } + } + + // get more input characters + status = in.sgetn(reinterpret_cast(inbuf + inbuf_pos),1); + // only count this character if it isn't some kind of filler + if ((decode_table[inbuf[inbuf_pos]] != bad_value || inbuf[inbuf_pos] == '=') && + status != 0) + ++inbuf_pos; + } // while (status != 0) + + if (inbuf_pos != 0) + { + ostringstream sout; + sout << inbuf_pos << " extra characters were found at the end of the encoded data." + << " This may indicate that the data stream has been truncated."; + // this happens if we hit EOF in the middle of decoding a 24bit block. + throw decode_error(sout.str()); + } + + // make sure the stream buffer flushes to its I/O channel + out.pubsync(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BASE64_KERNEL_1_CPp_ + diff --git a/dlib/base64/base64_kernel_1.h b/dlib/base64/base64_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..d8f49b1b8cad201499aead19ca759bec2fe8e919 --- /dev/null +++ b/dlib/base64/base64_kernel_1.h @@ -0,0 +1,92 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BASE64_KERNEl_1_ +#define DLIB_BASE64_KERNEl_1_ + +#include "../algs.h" +#include "base64_kernel_abstract.h" +#include + +namespace dlib +{ + + class base64 + { + /*! + INITIAL VALUE + - bad_value == 100 + - encode_table == a pointer to an array of 64 chars + - where x is a 6 bit value the following is true: + - encode_table[x] == the base64 encoding of x + - decode_table == a pointer to an array of UCHAR_MAX chars + - where x is any char value: + - if (x is a valid character in the base64 coding scheme) then + - decode_table[x] == the 6 bit value that x encodes + - else + - decode_table[x] == bad_value + + CONVENTION + - The state of this object never changes so just refer to its + initial value. + + + !*/ + + public: + // this is here for backwards compatibility with older versions of dlib. + typedef base64 kernel_1a; + + class decode_error : public dlib::error { public: + decode_error( const std::string& e) : error(e) {}}; + + base64 ( + ); + + virtual ~base64 ( + ); + + enum line_ending_type + { + CR, // i.e. "\r" + LF, // i.e. "\n" + CRLF // i.e. "\r\n" + }; + + line_ending_type line_ending ( + ) const; + + void set_line_ending ( + line_ending_type eol_style_ + ); + + void encode ( + std::istream& in, + std::ostream& out + ) const; + + void decode ( + std::istream& in, + std::ostream& out + ) const; + + private: + + char* encode_table; + unsigned char* decode_table; + const unsigned char bad_value; + line_ending_type eol_style; + + // restricted functions + base64(base64&); // copy constructor + base64& operator=(base64&); // assignment operator + + }; + +} + +#ifdef NO_MAKEFILE +#include "base64_kernel_1.cpp" +#endif + +#endif // DLIB_BASE64_KERNEl_1_ + diff --git a/dlib/base64/base64_kernel_abstract.h b/dlib/base64/base64_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..0a63d3b87be54f88dcfbf122850e81d9987034a2 --- /dev/null +++ b/dlib/base64/base64_kernel_abstract.h @@ -0,0 +1,121 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BASE64_KERNEl_ABSTRACT_ +#ifdef DLIB_BASE64_KERNEl_ABSTRACT_ + +#include "../algs.h" +#include + +namespace dlib +{ + + class base64 + { + /*! + INITIAL VALUE + - line_ending() == LF + + WHAT THIS OBJECT REPRESENTS + This object consists of the two functions encode and decode. + These functions allow you to encode and decode data to and from + the Base64 Content-Transfer-Encoding defined in section 6.8 of + rfc2045. + !*/ + + public: + + class decode_error : public dlib::error {}; + + base64 ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~base64 ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + enum line_ending_type + { + CR, // i.e. "\r" + LF, // i.e. "\n" + CRLF // i.e. "\r\n" + }; + + line_ending_type line_ending ( + ) const; + /*! + ensures + - returns the type of end of line bytes the encoder + will use when encoding data to base64 blocks. Note that + the ostream object you use might apply some sort of transform + to line endings as well. For example, C++ ofstream objects + usually convert '\n' into whatever a normal newline is for + your platform unless you open a file in binary mode. But + aside from file streams the ostream objects usually don't + modify the data you pass to them. + !*/ + + void set_line_ending ( + line_ending_type eol_style + ); + /*! + ensures + - #line_ending() == eol_style + !*/ + + void encode ( + std::istream& in, + std::ostream& out + ) const; + /*! + ensures + - reads all data from in (until EOF is reached) and encodes it + and writes it to out + throws + - std::ios_base::failure + if there was a problem writing to out then this exception will + be thrown. + - any other exception + this exception may be thrown if there is any other problem + !*/ + + void decode ( + std::istream& in, + std::ostream& out + ) const; + /*! + ensures + - reads data from in (until EOF is reached), decodes it, + and writes it to out. + throws + - std::ios_base::failure + if there was a problem writing to out then this exception will + be thrown. + - decode_error + if an error was detected in the encoded data that prevented + it from being correctly decoded then this exception is + thrown. + - any other exception + this exception may be thrown if there is any other problem + !*/ + + private: + + // restricted functions + base64(base64&); // copy constructor + base64& operator=(base64&); // assignment operator + + }; + +} + +#endif // DLIB_BASE64_KERNEl_ABSTRACT_ + diff --git a/dlib/bayes_utils.h b/dlib/bayes_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..51ef6d2ed3d02cd1fe56b0bb2a6d2a7f7682e8ff --- /dev/null +++ b/dlib/bayes_utils.h @@ -0,0 +1,11 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BAYES_UTILs_H_ +#define DLIB_BAYES_UTILs_H_ + +#include "bayes_utils/bayes_utils.h" + +#endif // DLIB_BAYES_UTILs_H_ + + + diff --git a/dlib/bayes_utils/bayes_utils.h b/dlib/bayes_utils/bayes_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..c5ebc19b92a71e836ed787ad054dd376139f23e3 --- /dev/null +++ b/dlib/bayes_utils/bayes_utils.h @@ -0,0 +1,1677 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BAYES_UTILs_ +#define DLIB_BAYES_UTILs_ + +#include "bayes_utils_abstract.h" + +#include +#include +#include +#include + +#include "../string.h" +#include "../map.h" +#include "../matrix.h" +#include "../rand.h" +#include "../array.h" +#include "../set.h" +#include "../algs.h" +#include "../noncopyable.h" +#include "../graph.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class assignment + { + public: + + assignment() + { + } + + assignment( + const assignment& a + ) + { + a.reset(); + while (a.move_next()) + { + unsigned long idx = a.element().key(); + unsigned long value = a.element().value(); + vals.add(idx,value); + } + } + + assignment& operator = ( + const assignment& rhs + ) + { + if (this == &rhs) + return *this; + + assignment(rhs).swap(*this); + return *this; + } + + void clear() + { + vals.clear(); + } + + bool operator < ( + const assignment& item + ) const + { + if (size() < item.size()) + return true; + else if (size() > item.size()) + return false; + + reset(); + item.reset(); + while (move_next()) + { + item.move_next(); + if (element().key() < item.element().key()) + return true; + else if (element().key() > item.element().key()) + return false; + else if (element().value() < item.element().value()) + return true; + else if (element().value() > item.element().value()) + return false; + } + + return false; + } + + bool has_index ( + unsigned long idx + ) const + { + return vals.is_in_domain(idx); + } + + void add ( + unsigned long idx, + unsigned long value = 0 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( has_index(idx) == false , + "\tvoid assignment::add(idx)" + << "\n\tYou can't add the same index to an assignment object more than once" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + + vals.add(idx, value); + } + + unsigned long& operator[] ( + const long idx + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( has_index(idx) == true , + "\tunsigned long assignment::operator[](idx)" + << "\n\tYou can't access an index value if it isn't already in the object" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + + return vals[idx]; + } + + const unsigned long& operator[] ( + const long idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( has_index(idx) == true , + "\tunsigned long assignment::operator[](idx)" + << "\n\tYou can't access an index value if it isn't already in the object" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + + return vals[idx]; + } + + void swap ( + assignment& item + ) + { + vals.swap(item.vals); + } + + void remove ( + unsigned long idx + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( has_index(idx) == true , + "\tunsigned long assignment::remove(idx)" + << "\n\tYou can't remove an index value if it isn't already in the object" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + + vals.destroy(idx); + } + + unsigned long size() const { return vals.size(); } + + void reset() const { vals.reset(); } + + bool move_next() const { return vals.move_next(); } + + map_pair& element() + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tmap_pair& assignment::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + return vals.element(); + } + + const map_pair& element() const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tconst map_pair& assignment::element() const" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return vals.element(); + } + + bool at_start() const { return vals.at_start(); } + + bool current_element_valid() const { return vals.current_element_valid(); } + + friend inline void serialize ( + const assignment& item, + std::ostream& out + ) + { + serialize(item.vals, out); + } + + friend inline void deserialize ( + assignment& item, + std::istream& in + ) + { + deserialize(item.vals, in); + } + + private: + mutable dlib::map::kernel_1b_c vals; + }; + + inline std::ostream& operator << ( + std::ostream& out, + const assignment& a + ) + { + a.reset(); + out << "("; + if (a.move_next()) + out << a.element().key() << ":" << a.element().value(); + + while (a.move_next()) + { + out << ", " << a.element().key() << ":" << a.element().value(); + } + + out << ")"; + return out; + } + + + inline void swap ( + assignment& a, + assignment& b + ) + { + a.swap(b); + } + + +// ------------------------------------------------------------------------ + + class joint_probability_table + { + /*! + INITIAL VALUE + - table.size() == 0 + + CONVENTION + - size() == table.size() + - probability(a) == table[a] + !*/ + public: + + joint_probability_table ( + const joint_probability_table& t + ) + { + t.reset(); + while (t.move_next()) + { + assignment a = t.element().key(); + double p = t.element().value(); + set_probability(a,p); + } + } + + joint_probability_table() {} + + joint_probability_table& operator= ( + const joint_probability_table& rhs + ) + { + if (this == &rhs) + return *this; + joint_probability_table(rhs).swap(*this); + return *this; + } + + void set_probability ( + const assignment& a, + double p + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0.0 <= p && p <= 1.0, + "\tvoid& joint_probability_table::set_probability(a,p)" + << "\n\tyou have given an invalid probability value" + << "\n\tp: " << p + << "\n\ta: " << a + << "\n\tthis: " << this + ); + + if (table.is_in_domain(a)) + { + table[a] = p; + } + else + { + assignment temp(a); + table.add(temp,p); + } + } + + bool has_entry_for ( + const assignment& a + ) const + { + return table.is_in_domain(a); + } + + void add_probability ( + const assignment& a, + double p + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0.0 <= p && p <= 1.0, + "\tvoid& joint_probability_table::add_probability(a,p)" + << "\n\tyou have given an invalid probability value" + << "\n\tp: " << p + << "\n\ta: " << a + << "\n\tthis: " << this + ); + + if (table.is_in_domain(a)) + { + table[a] += p; + if (table[a] > 1.0) + table[a] = 1.0; + } + else + { + assignment temp(a); + table.add(temp,p); + } + } + + double probability ( + const assignment& a + ) const + { + return table[a]; + } + + void clear() + { + table.clear(); + } + + size_t size () const { return table.size(); } + bool move_next() const { return table.move_next(); } + void reset() const { table.reset(); } + map_pair& element() + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tmap_pair& joint_probability_table::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return table.element(); + } + + const map_pair& element() const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tconst map_pair& joint_probability_table::element() const" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return table.element(); + } + + bool at_start() const { return table.at_start(); } + + bool current_element_valid() const { return table.current_element_valid(); } + + + template + void marginalize ( + const T& vars, + joint_probability_table& out + ) const + { + out.clear(); + double p; + reset(); + while (move_next()) + { + assignment a; + const assignment& asrc = element().key(); + p = element().value(); + + asrc.reset(); + while (asrc.move_next()) + { + if (vars.is_member(asrc.element().key())) + a.add(asrc.element().key(), asrc.element().value()); + } + + out.add_probability(a,p); + } + } + + void marginalize ( + const unsigned long var, + joint_probability_table& out + ) const + { + out.clear(); + double p; + reset(); + while (move_next()) + { + assignment a; + const assignment& asrc = element().key(); + p = element().value(); + + asrc.reset(); + while (asrc.move_next()) + { + if (var == asrc.element().key()) + a.add(asrc.element().key(), asrc.element().value()); + } + + out.add_probability(a,p); + } + } + + void normalize ( + ) + { + double sum = 0; + + reset(); + while (move_next()) + sum += element().value(); + + reset(); + while (move_next()) + element().value() /= sum; + } + + void swap ( + joint_probability_table& item + ) + { + table.swap(item.table); + } + + friend inline void serialize ( + const joint_probability_table& item, + std::ostream& out + ) + { + serialize(item.table, out); + } + + friend inline void deserialize ( + joint_probability_table& item, + std::istream& in + ) + { + deserialize(item.table, in); + } + + private: + + dlib::map::kernel_1b_c table; + }; + + inline void swap ( + joint_probability_table& a, + joint_probability_table& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + class conditional_probability_table : noncopyable + { + /*! + INITIAL VALUE + - table.size() == 0 + + CONVENTION + - if (table.is_in_domain(ps) && value < num_vals && table[ps](value) >= 0) then + - has_entry_for(value,ps) == true + - probability(value,ps) == table[ps](value) + - else + - has_entry_for(value,ps) == false + + - num_values() == num_vals + !*/ + public: + + conditional_probability_table() + { + clear(); + } + + void set_num_values ( + unsigned long num + ) + { + num_vals = num; + table.clear(); + } + + bool has_entry_for ( + unsigned long value, + const assignment& ps + ) const + { + if (table.is_in_domain(ps) && value < num_vals && table[ps](value) >= 0) + return true; + else + return false; + } + + unsigned long num_values ( + ) const { return num_vals; } + + void set_probability ( + unsigned long value, + const assignment& ps, + double p + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( value < num_values() && 0.0 <= p && p <= 1.0 , + "\tvoid conditional_probability_table::set_probability()" + << "\n\tinvalid arguments to set_probability" + << "\n\tvalue: " << value + << "\n\tnum_values(): " << num_values() + << "\n\tp: " << p + << "\n\tps: " << ps + << "\n\tthis: " << this + ); + + if (table.is_in_domain(ps)) + { + table[ps](value) = p; + } + else + { + matrix dist(num_vals); + set_all_elements(dist,-1); + dist(value) = p; + assignment temp(ps); + table.add(temp,dist); + } + } + + double probability( + unsigned long value, + const assignment& ps + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( value < num_values() && has_entry_for(value,ps) , + "\tvoid conditional_probability_table::probability()" + << "\n\tinvalid arguments to probability" + << "\n\tvalue: " << value + << "\n\tnum_values(): " << num_values() + << "\n\tps: " << ps + << "\n\tthis: " << this + ); + + return table[ps](value); + } + + void clear() + { + table.clear(); + num_vals = 0; + } + + void empty_table () + { + table.clear(); + } + + void swap ( + conditional_probability_table& item + ) + { + exchange(num_vals, item.num_vals); + table.swap(item.table); + } + + friend inline void serialize ( + const conditional_probability_table& item, + std::ostream& out + ) + { + serialize(item.table, out); + serialize(item.num_vals, out); + } + + friend inline void deserialize ( + conditional_probability_table& item, + std::istream& in + ) + { + deserialize(item.table, in); + deserialize(item.num_vals, in); + } + + private: + dlib::map >::kernel_1b_c table; + unsigned long num_vals; + }; + + inline void swap ( + conditional_probability_table& a, + conditional_probability_table& b + ) { a.swap(b); } + +// ------------------------------------------------------------------------ + + class bayes_node : noncopyable + { + public: + bayes_node () + { + is_instantiated = false; + value_ = 0; + } + + unsigned long value ( + ) const { return value_;} + + void set_value ( + unsigned long new_value + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( new_value < table().num_values(), + "\tvoid bayes_node::set_value(new_value)" + << "\n\tnew_value must be less than the number of possible values for this node" + << "\n\tnew_value: " << new_value + << "\n\ttable().num_values(): " << table().num_values() + << "\n\tthis: " << this + ); + + value_ = new_value; + } + + conditional_probability_table& table ( + ) { return table_; } + + const conditional_probability_table& table ( + ) const { return table_; } + + bool is_evidence ( + ) const { return is_instantiated; } + + void set_as_nonevidence ( + ) { is_instantiated = false; } + + void set_as_evidence ( + ) { is_instantiated = true; } + + void swap ( + bayes_node& item + ) + { + exchange(value_, item.value_); + exchange(is_instantiated, item.is_instantiated); + table_.swap(item.table_); + } + + friend inline void serialize ( + const bayes_node& item, + std::ostream& out + ) + { + serialize(item.value_, out); + serialize(item.is_instantiated, out); + serialize(item.table_, out); + } + + friend inline void deserialize ( + bayes_node& item, + std::istream& in + ) + { + deserialize(item.value_, in); + deserialize(item.is_instantiated, in); + deserialize(item.table_, in); + } + + private: + + unsigned long value_; + bool is_instantiated; + conditional_probability_table table_; + }; + + inline void swap ( + bayes_node& a, + bayes_node& b + ) { a.swap(b); } + +// ------------------------------------------------------------------------ + + namespace bayes_node_utils + { + + template + unsigned long node_num_values ( + const T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tvoid bayes_node_utils::node_num_values(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + return bn.node(n).data.table().num_values(); + } + + // ---------------------------------------------------------------------------------------- + + template + void set_node_value ( + T& bn, + unsigned long n, + unsigned long val + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes() && val < node_num_values(bn,n), + "\tvoid bayes_node_utils::set_node_value(bn, n, val)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tval: " << val + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + << "\n\tnode_num_values(bn,n): " << node_num_values(bn,n) + ); + + bn.node(n).data.set_value(val); + } + + // ---------------------------------------------------------------------------------------- + template + unsigned long node_value ( + const T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tunsigned long bayes_node_utils::node_value(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + return bn.node(n).data.value(); + } + // ---------------------------------------------------------------------------------------- + + template + bool node_is_evidence ( + const T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tbool bayes_node_utils::node_is_evidence(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + return bn.node(n).data.is_evidence(); + } + + // ---------------------------------------------------------------------------------------- + + template + void set_node_as_evidence ( + T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tvoid bayes_node_utils::set_node_as_evidence(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + bn.node(n).data.set_as_evidence(); + } + + // ---------------------------------------------------------------------------------------- + template + void set_node_as_nonevidence ( + T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tvoid bayes_node_utils::set_node_as_nonevidence(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + bn.node(n).data.set_as_nonevidence(); + } + + // ---------------------------------------------------------------------------------------- + + template + void set_node_num_values ( + T& bn, + unsigned long n, + unsigned long num + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tvoid bayes_node_utils::set_node_num_values(bn, n, num)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + bn.node(n).data.table().set_num_values(num); + } + + // ---------------------------------------------------------------------------------------- + + template + double node_probability ( + const T& bn, + unsigned long n, + unsigned long value, + const assignment& parents + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n), + "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tvalue: " << value + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + << "\n\tnode_num_values(bn,n): " << node_num_values(bn,n) + ); + + DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(), + "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tparents.size(): " << parents.size() + << "\n\tb.node(n).number_of_parents(): " << bn.node(n).number_of_parents() + ); + +#ifdef ENABLE_ASSERTS + parents.reset(); + while (parents.move_next()) + { + const unsigned long x = parents.element().key(); + DLIB_ASSERT( bn.has_edge(x, n), + "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + ); + DLIB_ASSERT( parents[x] < node_num_values(bn,x), + "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + << "\n\tparents[x]: " << parents[x] + << "\n\tnode_num_values(bn,x): " << node_num_values(bn,x) + ); + } +#endif + + return bn.node(n).data.table().probability(value, parents); + } + + // ---------------------------------------------------------------------------------------- + + template + void set_node_probability ( + T& bn, + unsigned long n, + unsigned long value, + const assignment& parents, + double p + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n), + "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tp: " << p + << "\n\tvalue: " << value + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + << "\n\tnode_num_values(bn,n): " << node_num_values(bn,n) + ); + + DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(), + "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tp: " << p + << "\n\tparents.size(): " << parents.size() + << "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents() + ); + + DLIB_ASSERT( 0.0 <= p && p <= 1.0, + "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tp: " << p + ); + +#ifdef ENABLE_ASSERTS + parents.reset(); + while (parents.move_next()) + { + const unsigned long x = parents.element().key(); + DLIB_ASSERT( bn.has_edge(x, n), + "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + ); + DLIB_ASSERT( parents[x] < node_num_values(bn,x), + "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + << "\n\tparents[x]: " << parents[x] + << "\n\tnode_num_values(bn,x): " << node_num_values(bn,x) + ); + } +#endif + + bn.node(n).data.table().set_probability(value,parents,p); + } + +// ---------------------------------------------------------------------------------------- + + template + const assignment node_first_parent_assignment ( + const T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tconst assignment bayes_node_utils::node_first_parent_assignment(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + ); + + assignment a; + const unsigned long num_parents = bn.node(n).number_of_parents(); + for (unsigned long i = 0; i < num_parents; ++i) + { + a.add(bn.node(n).parent(i).index(), 0); + } + return a; + } + +// ---------------------------------------------------------------------------------------- + + template + bool node_next_parent_assignment ( + const T& bn, + unsigned long n, + assignment& a + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + ); + + DLIB_ASSERT( a.size() == bn.node(n).number_of_parents(), + "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\ta.size(): " << a.size() + << "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents() + ); + +#ifdef ENABLE_ASSERTS + a.reset(); + while (a.move_next()) + { + const unsigned long x = a.element().key(); + DLIB_ASSERT( bn.has_edge(x, n), + "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + ); + DLIB_ASSERT( a[x] < node_num_values(bn,x), + "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + << "\n\ta[x]: " << a[x] + << "\n\tnode_num_values(bn,x): " << node_num_values(bn,x) + ); + } +#endif + + // basically this loop just adds 1 to the assignment but performs + // carries if necessary + for (unsigned long p = 0; p < a.size(); ++p) + { + const unsigned long pindex = bn.node(n).parent(p).index(); + a[pindex] += 1; + + // if we need to perform a carry + if (a[pindex] >= node_num_values(bn,pindex)) + { + a[pindex] = 0; + } + else + { + // no carry necessary so we are done + return true; + } + } + + // we got through the entire loop which means a carry propagated all the way out + // so there must not be any more valid assignments left + return false; + } + +// ---------------------------------------------------------------------------------------- + + template + bool node_cpt_filled_out ( + const T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tbool bayes_node_utils::node_cpt_filled_out(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + const unsigned long num_values = node_num_values(bn,n); + + + const conditional_probability_table& table = bn.node(n).data.table(); + + // now loop over all the possible parent assignments for this node + assignment a(node_first_parent_assignment(bn,n)); + do + { + double sum = 0; + // make sure that this assignment has an entry for all the values this node can take one + for (unsigned long value = 0; value < num_values; ++value) + { + if (table.has_entry_for(value,a) == false) + return false; + else + sum += table.probability(value,a); + } + + // check if the sum of probabilities equals 1 as it should + if (std::abs(sum-1.0) > 1e-5) + return false; + } while (node_next_parent_assignment(bn,n,a)); + + return true; + } + + } + +// ---------------------------------------------------------------------------------------- + + class bayesian_network_gibbs_sampler : noncopyable + { + public: + + bayesian_network_gibbs_sampler () + { + rnd.set_seed(cast_to_string(std::time(0))); + } + + + template < + typename T + > + void sample_graph ( + T& bn + ) + { + using namespace bayes_node_utils; + for (unsigned long n = 0; n < bn.number_of_nodes(); ++n) + { + if (node_is_evidence(bn, n)) + continue; + + samples.set_size(node_num_values(bn,n)); + // obtain the probability distribution for this node + for (long i = 0; i < samples.nc(); ++i) + { + set_node_value(bn, n, i); + samples(i) = node_probability(bn, n); + + for (unsigned long j = 0; j < bn.node(n).number_of_children(); ++j) + samples(i) *= node_probability(bn, bn.node(n).child(j).index()); + } + + //normalize samples + samples /= sum(samples); + + + // select a random point in the probability distribution + double prob = rnd.get_random_double(); + + // now find the point in the distribution this probability corresponds to + long j; + for (j = 0; j < samples.nc()-1; ++j) + { + if (prob <= samples(j)) + break; + else + prob -= samples(j); + } + + set_node_value(bn, n, j); + } + } + + + private: + + template < + typename T + > + double node_probability ( + const T& bn, + unsigned long n + ) + /*! + requires + - n < bn.number_of_nodes() + ensures + - computes the probability of node n having its current value given + the current values of its parents in the network bn + !*/ + { + v.clear(); + for (unsigned long i = 0; i < bn.node(n).number_of_parents(); ++i) + { + v.add(bn.node(n).parent(i).index(), bn.node(n).parent(i).data.value()); + } + return bn.node(n).data.table().probability(bn.node(n).data.value(), v); + } + + assignment v; + + dlib::rand rnd; + matrix samples; + }; + +// ---------------------------------------------------------------------------------------- + + namespace bayesian_network_join_tree_helpers + { + class bnjt + { + /*! + this object is the base class used in this pimpl idiom + !*/ + public: + virtual ~bnjt() {} + + virtual const matrix probability( + unsigned long idx + ) const = 0; + }; + + template + class bnjt_impl : public bnjt + { + /*! + This object is the implementation in the pimpl idiom + !*/ + + public: + + bnjt_impl ( + const T& bn, + const U& join_tree + ) + { + create_bayesian_network_join_tree(bn, join_tree, join_tree_values); + + cliques.resize(bn.number_of_nodes()); + + // figure out which cliques contain each node + for (unsigned long i = 0; i < cliques.size(); ++i) + { + // find the smallest clique that contains node with index i + unsigned long smallest_clique = 0; + unsigned long size = std::numeric_limits::max(); + + for (unsigned long n = 0; n < join_tree.number_of_nodes(); ++n) + { + if (join_tree.node(n).data.is_member(i) && join_tree.node(n).data.size() < size) + { + size = join_tree.node(n).data.size(); + smallest_clique = n; + } + } + + cliques[i] = smallest_clique; + } + } + + virtual const matrix probability( + unsigned long idx + ) const + { + join_tree_values.node(cliques[idx]).data.marginalize(idx, table); + table.normalize(); + var.clear(); + var.add(idx); + dist.set_size(table.size()); + + // read the probabilities out of the table and into the row matrix + for (unsigned long i = 0; i < table.size(); ++i) + { + var[idx] = i; + dist(i) = table.probability(var); + } + + return dist; + } + + private: + + graph< joint_probability_table, joint_probability_table >::kernel_1a_c join_tree_values; + array cliques; + mutable joint_probability_table table; + mutable assignment var; + mutable matrix dist; + + + // ---------------------------------------------------------------------------------------- + + template + bool set_contains_all_parents_of_node ( + const set_type& set, + const node_type& node + ) + { + for (unsigned long i = 0; i < node.number_of_parents(); ++i) + { + if (set.is_member(node.parent(i).index()) == false) + return false; + } + return true; + } + + // ---------------------------------------------------------------------------------------- + + template < + typename V + > + void pass_join_tree_message ( + const U& join_tree, + V& bn_join_tree , + unsigned long from, + unsigned long to + ) + { + using namespace bayes_node_utils; + const typename U::edge_type& e = edge(join_tree, from, to); + typename V::edge_type& old_s = edge(bn_join_tree, from, to); + + typedef typename V::edge_type joint_prob_table; + + joint_prob_table new_s; + bn_join_tree.node(from).data.marginalize(e, new_s); + + joint_probability_table temp(new_s); + // divide new_s by old_s and store the result in temp. + // if old_s is empty then that is the same as if it was all 1s + // so we don't have to do this if that is the case. + if (old_s.size() > 0) + { + temp.reset(); + old_s.reset(); + while (temp.move_next()) + { + old_s.move_next(); + if (old_s.element().value() != 0) + temp.element().value() /= old_s.element().value(); + } + } + + // now multiply temp by d and store the results in d + joint_probability_table& d = bn_join_tree.node(to).data; + d.reset(); + while (d.move_next()) + { + assignment a; + const assignment& asrc = d.element().key(); + asrc.reset(); + while (asrc.move_next()) + { + if (e.is_member(asrc.element().key())) + a.add(asrc.element().key(), asrc.element().value()); + } + + d.element().value() *= temp.probability(a); + + } + + // store new_s in old_s + new_s.swap(old_s); + + } + + // ---------------------------------------------------------------------------------------- + + template < + typename V + > + void create_bayesian_network_join_tree ( + const T& bn, + const U& join_tree, + V& bn_join_tree + ) + /*! + requires + - bn is a proper bayesian network + - join_tree is the join tree for that bayesian network + ensures + - bn_join_tree == the output of the join tree algorithm for bayesian network inference. + So each node in this graph contains a joint_probability_table for the clique + in the corresponding node in the join_tree graph. + !*/ + { + using namespace bayes_node_utils; + bn_join_tree.clear(); + copy_graph_structure(join_tree, bn_join_tree); + + // we need to keep track of which node is "in" each clique for the purposes of + // initializing the tables in each clique. So this vector will be used to do that + // and a value of join_tree.number_of_nodes() means that the node with + // that index is unassigned. + std::vector node_assigned_to(bn.number_of_nodes(),join_tree.number_of_nodes()); + + // populate evidence with all the evidence node indices and their values + dlib::map::kernel_1b_c evidence; + for (unsigned long i = 0; i < bn.number_of_nodes(); ++i) + { + if (node_is_evidence(bn, i)) + { + unsigned long idx = i; + unsigned long value = node_value(bn, i); + evidence.add(idx,value); + } + } + + + // initialize the bn join tree + for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i) + { + bool contains_evidence = false; + std::vector indices; + assignment value; + + // loop over all the nodes in this clique in the join tree. In this loop + // we are making an assignment with all the values of the nodes it represents set to 0 + join_tree.node(i).data.reset(); + while (join_tree.node(i).data.move_next()) + { + const unsigned long idx = join_tree.node(i).data.element(); + indices.push_back(idx); + value.add(idx); + + if (evidence.is_in_domain(join_tree.node(i).data.element())) + contains_evidence = true; + } + + // now loop over all possible combinations of values that the nodes this + // clique in the join tree can take on. We do this by counting by one through all + // legal values + bool more_assignments = true; + while (more_assignments) + { + bn_join_tree.node(i).data.set_probability(value,1); + + // account for any evidence + if (contains_evidence) + { + // loop over all the nodes in this cluster + for (unsigned long j = 0; j < indices.size(); ++j) + { + // if the current node is an evidence node + if (evidence.is_in_domain(indices[j])) + { + const unsigned long idx = indices[j]; + const unsigned long evidence_value = evidence[idx]; + if (value[idx] != evidence_value) + bn_join_tree.node(i).data.set_probability(value , 0); + } + } + } + + + // now check if any of the nodes in this cluster also have their parents in this cluster + join_tree.node(i).data.reset(); + while (join_tree.node(i).data.move_next()) + { + const unsigned long idx = join_tree.node(i).data.element(); + // if this clique contains all the parents of this node and also hasn't + // been assigned to another clique + if (set_contains_all_parents_of_node(join_tree.node(i).data, bn.node(idx)) && + (i == node_assigned_to[idx] || node_assigned_to[idx] == join_tree.number_of_nodes()) ) + { + // note that this node is now assigned to this clique + node_assigned_to[idx] = i; + // node idx has all its parents in the cluster + assignment parent_values; + for (unsigned long j = 0; j < bn.node(idx).number_of_parents(); ++j) + { + const unsigned long pidx = bn.node(idx).parent(j).index(); + parent_values.add(pidx, value[pidx]); + } + + double temp = bn_join_tree.node(i).data.probability(value); + bn_join_tree.node(i).data.set_probability(value, temp * node_probability(bn, idx, value[idx], parent_values)); + + } + } + + + // now advance the value variable to its next possible state if there is one + more_assignments = false; + value.reset(); + while (value.move_next()) + { + value.element().value() += 1; + // if overflow + if (value.element().value() == node_num_values(bn, value.element().key())) + { + value.element().value() = 0; + } + else + { + more_assignments = true; + break; + } + } + + } // end while (more_assignments) + } + + + + + // the tree is now initialized. Now all we need to do is perform the propagation and + // we are done + dlib::array::compare_1b_c> remaining_msg_to_send; + dlib::array::compare_1b_c> remaining_msg_to_receive; + remaining_msg_to_receive.resize(join_tree.number_of_nodes()); + remaining_msg_to_send.resize(join_tree.number_of_nodes()); + for (unsigned long i = 0; i < remaining_msg_to_receive.size(); ++i) + { + for (unsigned long j = 0; j < join_tree.node(i).number_of_neighbors(); ++j) + { + const unsigned long idx = join_tree.node(i).neighbor(j).index(); + unsigned long temp; + temp = idx; remaining_msg_to_receive[i].add(temp); + temp = idx; remaining_msg_to_send[i].add(temp); + } + } + + // now remaining_msg_to_receive[i] contains all the nodes that node i hasn't yet received + // a message from. + // we will consider node 0 to be the root node. + + + bool message_sent = true; + while (message_sent) + { + message_sent = false; + for (unsigned long i = 1; i < remaining_msg_to_send.size(); ++i) + { + // if node i hasn't sent any messages but has received all but one then send a message to the one + // node who hasn't sent i a message + if (remaining_msg_to_send[i].size() == join_tree.node(i).number_of_neighbors() && remaining_msg_to_receive[i].size() == 1) + { + unsigned long to; + // get the last remaining thing from this set + remaining_msg_to_receive[i].remove_any(to); + + // send the message + pass_join_tree_message(join_tree, bn_join_tree, i, to); + + // record that we sent this message + remaining_msg_to_send[i].destroy(to); + remaining_msg_to_receive[to].destroy(i); + + // put to back in since we still need to receive it + remaining_msg_to_receive[i].add(to); + message_sent = true; + } + else if (remaining_msg_to_receive[i].size() == 0 && remaining_msg_to_send[i].size() > 0) + { + unsigned long to; + remaining_msg_to_send[i].remove_any(to); + remaining_msg_to_receive[to].destroy(i); + pass_join_tree_message(join_tree, bn_join_tree, i, to); + message_sent = true; + } + } + + if (remaining_msg_to_receive[0].size() == 0) + { + // send a message to all of the root nodes neighbors unless we have already sent out he messages + while (remaining_msg_to_send[0].size() > 0) + { + unsigned long to; + remaining_msg_to_send[0].remove_any(to); + remaining_msg_to_receive[to].destroy(0); + pass_join_tree_message(join_tree, bn_join_tree, 0, to); + message_sent = true; + } + } + + + } + + } + + }; + } + + class bayesian_network_join_tree : noncopyable + { + /*! + use the pimpl idiom to push the template arguments from the class level to the + constructor level + !*/ + + public: + + template < + typename T, + typename U + > + bayesian_network_join_tree ( + const T& bn, + const U& join_tree + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( bn.number_of_nodes() > 0 , + "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" + << "\n\tYou have given an invalid bayesian network" + << "\n\tthis: " << this + ); + + DLIB_ASSERT( is_join_tree(bn, join_tree) == true , + "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" + << "\n\tYou have given an invalid join tree for the supplied bayesian network" + << "\n\tthis: " << this + ); + DLIB_ASSERT( graph_contains_length_one_cycle(bn) == false, + "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" + << "\n\tYou have given an invalid bayesian network" + << "\n\tthis: " << this + ); + DLIB_ASSERT( graph_is_connected(bn) == true, + "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" + << "\n\tYou have given an invalid bayesian network" + << "\n\tthis: " << this + ); + +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < bn.number_of_nodes(); ++i) + { + DLIB_ASSERT(bayes_node_utils::node_cpt_filled_out(bn,i) == true, + "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" + << "\n\tYou have given an invalid bayesian network. " + << "\n\tYou must finish filling out the conditional_probability_table of node " << i + << "\n\tthis: " << this + ); + } +#endif + + impl.reset(new bayesian_network_join_tree_helpers::bnjt_impl(bn, join_tree)); + num_nodes = bn.number_of_nodes(); + } + + const matrix probability( + unsigned long idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( idx < number_of_nodes() , + "\tconst matrix bayesian_network_join_tree::probability(idx)" + << "\n\tYou have specified an invalid node index" + << "\n\tidx: " << idx + << "\n\tnumber_of_nodes(): " << number_of_nodes() + << "\n\tthis: " << this + ); + + return impl->probability(idx); + } + + unsigned long number_of_nodes ( + ) const { return num_nodes; } + + void swap ( + bayesian_network_join_tree& item + ) + { + exchange(num_nodes, item.num_nodes); + impl.swap(item.impl); + } + + private: + + std::unique_ptr impl; + unsigned long num_nodes; + + }; + + inline void swap ( + bayesian_network_join_tree& a, + bayesian_network_join_tree& b + ) { a.swap(b); } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_BAYES_UTILs_ + diff --git a/dlib/bayes_utils/bayes_utils_abstract.h b/dlib/bayes_utils/bayes_utils_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..b19e6e1da1844ccc9cf09a6139866822d1971a0f --- /dev/null +++ b/dlib/bayes_utils/bayes_utils_abstract.h @@ -0,0 +1,1042 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BAYES_UTILs_ABSTRACT_ +#ifdef DLIB_BAYES_UTILs_ABSTRACT_ + +#include "../algs.h" +#include "../noncopyable.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/map_pair.h" +#include "../serialize.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class assignment : public enumerable > + { + /*! + INITIAL VALUE + - size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the entries in the assignment in + ascending order according to index values. (i.e. the elements are + enumerated in sorted order according to the value of their keys) + + WHAT THIS OBJECT REPRESENTS + This object models an assignment of random variables to particular values. + It is used with the joint_probability_table and conditional_probability_table + objects to represent assignments of various random variables to actual values. + + So for example, if you had a joint_probability_table that represented the + following table: + P(A = 0, B = 0) = 0.2 + P(A = 0, B = 1) = 0.3 + P(A = 1, B = 0) = 0.1 + P(A = 1, B = 1) = 0.4 + + Also lets define an enum so we have concrete index numbers for A and B + enum { A = 0, B = 1}; + + Then you could query the value of P(A=1, B=0) as follows: + assignment a; + a.set(A, 1); + a.set(B, 0); + // and now it is the case that: + table.probability(a) == 0.1 + a[A] == 1 + a[B] == 0 + + + Also note that when enumerating the elements of an assignment object + the key() refers to the index and the value() refers to the value at that + index. For example: + + // assume a is an assignment object + a.reset(); + while (a.move_next()) + { + // in this loop it is always the case that: + // a[a.element().key()] == a.element().value() + } + !*/ + + public: + + assignment( + ); + /*! + ensures + - this object is properly initialized + !*/ + + assignment( + const assignment& a + ); + /*! + ensures + - #*this is a copy of a + !*/ + + assignment& operator = ( + const assignment& rhs + ); + /*! + ensures + - #*this is a copy of rhs + - returns *this + !*/ + + void clear( + ); + /*! + ensures + - this object has been returned to its initial value + !*/ + + bool operator < ( + const assignment& item + ) const; + /*! + ensures + - The exact functioning of this operator is undefined. The only guarantee + is that it establishes a total ordering on all possible assignment objects. + In other words, this operator makes it so that you can use assignment + objects in the associative containers but otherwise isn't of any + particular use. + !*/ + + bool has_index ( + unsigned long idx + ) const; + /*! + ensures + - if (this assignment object has an entry for index idx) then + - returns true + - else + - returns false + !*/ + + void add ( + unsigned long idx, + unsigned long value = 0 + ); + /*! + requires + - has_index(idx) == false + ensures + - #has_index(idx) == true + - #(*this)[idx] == value + !*/ + + void remove ( + unsigned long idx + ); + /*! + requires + - has_index(idx) == true + ensures + - #has_index(idx) == false + !*/ + + unsigned long& operator[] ( + const long idx + ); + /*! + requires + - has_index(idx) == true + ensures + - returns a reference to the value associated with index idx + !*/ + + const unsigned long& operator[] ( + const long idx + ) const; + /*! + requires + - has_index(idx) == true + ensures + - returns a const reference to the value associated with index idx + !*/ + + void swap ( + assignment& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + inline void swap ( + assignment& a, + assignment& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + + std::ostream& operator << ( + std::ostream& out, + const assignment& a + ); + /*! + ensures + - writes a to the given output stream in the following format: + (index1:value1, index2:value2, ..., indexN:valueN) + !*/ + + void serialize ( + const assignment& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + assignment& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ------------------------------------------------------------------------ + + class joint_probability_table : public enumerable > + { + /*! + INITIAL VALUE + - size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the entries in the probability table + in no particular order but they will all be visited. + + WHAT THIS OBJECT REPRESENTS + This object models a joint probability table. That is, it models + the function p(X). So this object models the probability of a particular + set of variables (referred to as X). + !*/ + + public: + + joint_probability_table( + ); + /*! + ensures + - this object is properly initialized + !*/ + + joint_probability_table ( + const joint_probability_table& t + ); + /*! + ensures + - this object is a copy of t + !*/ + + void clear( + ); + /*! + ensures + - this object has its initial value + !*/ + + joint_probability_table& operator= ( + const joint_probability_table& rhs + ); + /*! + ensures + - this object is a copy of rhs + - returns a reference to *this + !*/ + + bool has_entry_for ( + const assignment& a + ) const; + /*! + ensures + - if (this joint_probability_table has an entry for p(X = a)) then + - returns true + - else + - returns false + !*/ + + void set_probability ( + const assignment& a, + double p + ); + /*! + requires + - 0 <= p <= 1 + ensures + - if (has_entry_for(a) == false) then + - #size() == size() + 1 + - #probability(a) == p + - #has_entry_for(a) == true + !*/ + + void add_probability ( + const assignment& a, + double p + ); + /*! + requires + - 0 <= p <= 1 + ensures + - if (has_entry_for(a) == false) then + - #size() == size() + 1 + - #probability(a) == p + - else + - #probability(a) == min(probability(a) + p, 1.0) + (i.e. does a saturating add) + - #has_entry_for(a) == true + !*/ + + const double probability ( + const assignment& a + ) const; + /*! + ensures + - returns the probability p(X == a) + !*/ + + template < + typename T + > + void marginalize ( + const T& vars, + joint_probability_table& output_table + ) const; + /*! + requires + - T is an implementation of set/set_kernel_abstract.h + ensures + - marginalizes *this by summing over all variables not in vars. The + result is stored in output_table. + !*/ + + void marginalize ( + const unsigned long var, + joint_probability_table& output_table + ) const; + /*! + ensures + - is identical to calling the above marginalize() function with a set + that contains only var. Or in other words, performs a marginalization + with just one variable var. So that output_table will contain a table giving + the marginal probability of var all by itself. + !*/ + + void normalize ( + ); + /*! + ensures + - let sum == the sum of all the probabilities in this table + - after normalize() has finished it will be the case that the sum of all + the entries in this table is 1.0. This is accomplished by dividing all + the entries by the sum described above. + !*/ + + void swap ( + joint_probability_table& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + inline void swap ( + joint_probability_table& a, + joint_probability_table& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + + void serialize ( + const joint_probability_table& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + joint_probability_table& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + class conditional_probability_table : noncopyable + { + /*! + INITIAL VALUE + - num_values() == 0 + - has_value_for(x, y) == false for all values of x and y + + WHAT THIS OBJECT REPRESENTS + This object models a conditional probability table. That is, it models + the function p( X | parents). So this object models the conditional + probability of a particular variable (referred to as X) given another set + of variables (referred to as parents). + !*/ + + public: + + conditional_probability_table( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void clear( + ); + /*! + ensures + - this object has its initial value + !*/ + + void empty_table ( + ); + /*! + ensures + - for all possible v and p: + - #has_entry_for(v,p) == false + (i.e. this function clears out the table when you call it but doesn't + change the value of num_values()) + !*/ + + void set_num_values ( + unsigned long num + ); + /*! + ensures + - #num_values() == num + - for all possible v and p: + - #has_entry_for(v,p) == false + (i.e. this function clears out the table when you call it) + !*/ + + unsigned long num_values ( + ) const; + /*! + ensures + - This object models the probability table p(X | parents). This + function returns the number of values X can take on. + !*/ + + bool has_entry_for ( + unsigned long value, + const assignment& ps + ) const; + /*! + ensures + - if (this conditional_probability_table has an entry for p(X = value, parents = ps)) then + - returns true + - else + - returns false + !*/ + + void set_probability ( + unsigned long value, + const assignment& ps, + double p + ); + /*! + requires + - value < num_values() + - 0 <= p <= 1 + ensures + - #probability(ps, value) == p + - #has_entry_for(value, ps) == true + !*/ + + double probability( + unsigned long value, + const assignment& ps + ) const; + /*! + requires + - value < num_values() + - has_entry_for(value, ps) == true + ensures + - returns the probability p( X = value | parents = ps). + !*/ + + void swap ( + conditional_probability_table& item + ); + /*! + ensures + - swaps *this and item + !*/ + }; + + inline void swap ( + conditional_probability_table& a, + conditional_probability_table& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + + void serialize ( + const conditional_probability_table& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + conditional_probability_table& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ------------------------------------------------------------------------ +// ------------------------------------------------------------------------ +// ------------------------------------------------------------------------ + + class bayes_node : noncopyable + { + /*! + INITIAL VALUE + - is_evidence() == false + - value() == 0 + - table().num_values() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a node in a bayesian network. It is + intended to be used inside the dlib::directed_graph object to + represent bayesian networks. + !*/ + + public: + bayes_node ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + unsigned long value ( + ) const; + /*! + ensures + - returns the current value of this node + !*/ + + void set_value ( + unsigned long new_value + ); + /*! + requires + - new_value < table().num_values() + ensures + - #value() == new_value + !*/ + + conditional_probability_table& table ( + ); + /*! + ensures + - returns a reference to the conditional_probability_table associated with this node + !*/ + + const conditional_probability_table& table ( + ) const; + /*! + ensures + - returns a const reference to the conditional_probability_table associated with this + node. + !*/ + + bool is_evidence ( + ) const; + /*! + ensures + - if (this is an evidence node) then + - returns true + - else + - returns false + !*/ + + void set_as_nonevidence ( + ); + /*! + ensures + - #is_evidence() == false + !*/ + + void set_as_evidence ( + ); + /*! + ensures + - #is_evidence() == true + !*/ + + void swap ( + bayes_node& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + inline void swap ( + bayes_node& a, + bayes_node& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + + void serialize ( + const bayes_node& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + bayes_node& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + /* + The following group of functions are convenience functions for manipulating + bayes_node objects while they are inside a directed_graph. These functions + also have additional requires clauses that, in debug mode, will protect you + from attempts to manipulate a bayesian network in an inappropriate way. + */ + + namespace bayes_node_utils + { + + template < + typename T + > + void set_node_value ( + T& bn, + unsigned long n, + unsigned long val + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + - val < node_num_values(bn, n) + ensures + - #bn.node(n).data.value() = val + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + unsigned long node_value ( + const T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - returns bn.node(n).data.value() + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + bool node_is_evidence ( + const T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - returns bn.node(n).data.is_evidence() + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + void set_node_as_evidence ( + T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - executes: bn.node(n).data.set_as_evidence() + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + void set_node_as_nonevidence ( + T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - executes: bn.node(n).data.set_as_nonevidence() + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + void set_node_num_values ( + T& bn, + unsigned long n, + unsigned long num + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - #bn.node(n).data.table().num_values() == num + (i.e. sets the number of different values this node can take) + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + unsigned long node_num_values ( + const T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - returns bn.node(n).data.table().num_values() + (i.e. returns the number of different values this node can take) + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + const double node_probability ( + const T& bn, + unsigned long n, + unsigned long value, + const assignment& parents + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + - value < node_num_values(bn,n) + - parents.size() == bn.node(n).number_of_parents() + - if (parents.has_index(x)) then + - bn.has_edge(x, n) + - parents[x] < node_num_values(bn,x) + ensures + - returns bn.node(n).data.table().probability(value, parents) + (i.e. returns the probability of node n having the given value when + its parents have the given assignment) + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + const double set_node_probability ( + const T& bn, + unsigned long n, + unsigned long value, + const assignment& parents, + double p + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + - value < node_num_values(bn,n) + - 0 <= p <= 1 + - parents.size() == bn.node(n).number_of_parents() + - if (parents.has_index(x)) then + - bn.has_edge(x, n) + - parents[x] < node_num_values(bn,x) + ensures + - #bn.node(n).data.table().probability(value, parents) == p + (i.e. sets the probability of node n having the given value when + its parents have the given assignment to the probability p) + !*/ + + // ------------------------------------------------------------------------------------ + + template + const assignment node_first_parent_assignment ( + const T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - returns an assignment A such that: + - A.size() == bn.node(n).number_of_parents() + - if (P is a parent of bn.node(n)) then + - A.has_index(P) + - A[P] == 0 + - I.e. this function returns an assignment that contains all + the parents of the given node. Also, all the values of each + parent in the assignment is set to zero. + !*/ + + // ------------------------------------------------------------------------------------ + + template + bool node_next_parent_assignment ( + const T& bn, + unsigned long n, + assignment& A + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + - A.size() == bn.node(n).number_of_parents() + - if (A.has_index(x)) then + - bn.has_edge(x, n) + - A[x] < node_num_values(bn,x) + ensures + - The behavior of this function is defined by the following code: + assignment a(node_first_parent_assignment(bn,n); + do { + // this loop loops over all possible parent assignments + // of the node bn.node(n). Each time through the loop variable a + // will be the next assignment. + } while (node_next_parent_assignment(bn,n,a)) + !*/ + + // ------------------------------------------------------------------------------------ + + template + bool node_cpt_filled_out ( + const T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - if (the conditional_probability_table bn.node(n).data.table() is + fully filled out for this node) then + - returns true + - This means that each parent assignment for the given node + along with all possible values of this node shows up in the + table. + - It also means that all the probabilities conditioned on the + same parent assignment sum to 1.0 + - else + - returns false + !*/ + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class bayesian_network_gibbs_sampler : noncopyable + { + /*! + INITIAL VALUE + This object has no state + + WHAT THIS OBJECT REPRESENTS + This object performs Markov Chain Monte Carlo sampling of a bayesian + network using the Gibbs sampling technique. + + Note that this object is limited to only bayesian networks that + don't contain deterministic nodes. That is, incorrect results may + be computed if this object is used when the bayesian network contains + any nodes that have a probability of 1 in their conditional probability + tables for any event. So don't use this object for networks with + deterministic nodes. + !*/ + public: + + bayesian_network_gibbs_sampler ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + template < + typename T + > + void sample_graph ( + T& bn + ) + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + ensures + - modifies randomly (via the Gibbs sampling technique) samples all the nodes + in the network and updates their values with the newly sampled values + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class bayesian_network_join_tree : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents an implementation of the join tree algorithm + for inference in bayesian networks. It doesn't have any mutable state. + To you use you just give it a directed_graph that contains a bayesian + network and a graph object that contains that networks corresponding + join tree. Then you may query this object to determine the probabilities + of any variables in the original bayesian network. + !*/ + + public: + + template < + typename bn_type, + typename join_tree_type + > + bayesian_network_join_tree ( + const bn_type& bn, + const join_tree_type& join_tree + ); + /*! + requires + - bn_type is an implementation of directed_graph/directed_graph_kernel_abstract.h + - bn_type::type == bayes_node + - join_tree_type is an implementation of graph/graph_kernel_abstract.h + - join_tree_type::type is an implementation of set/set_compare_abstract.h and + this set type contains unsigned long objects. + - join_tree_type::edge_type is an implementation of set/set_compare_abstract.h and + this set type contains unsigned long objects. + - is_join_tree(bn, join_tree) == true + - bn == a valid bayesian network with all its conditional probability tables + filled out + - for all valid n: + - node_cpt_filled_out(bn,n) == true + - graph_contains_length_one_cycle(bn) == false + - graph_is_connected(bn) == true + - bn.number_of_nodes() > 0 + ensures + - this object is properly initialized + !*/ + + unsigned long number_of_nodes ( + ) const; + /*! + ensures + - returns the number of nodes in the bayesian network that this + object was instantiated from. + !*/ + + const matrix probability( + unsigned long idx + ) const; + /*! + requires + - idx < number_of_nodes() + ensures + - returns the probability distribution for the node with index idx that was in the bayesian + network that *this was instantiated from. Let D represent this distribution, then: + - D.nc() == the number of values the node idx ranges over + - D.nr() == 1 + - D(i) == the probability of node idx taking on the value i + !*/ + + void swap ( + bayesian_network_join_tree& item + ); + /*! + ensures + - swaps *this with item + !*/ + + }; + + inline void swap ( + bayesian_network_join_tree& a, + bayesian_network_join_tree& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BAYES_UTILs_ABSTRACT_ + + diff --git a/dlib/bigint.h b/dlib/bigint.h new file mode 100644 index 0000000000000000000000000000000000000000..73496689ac9129e51a3f1ec0268cda4a342523c1 --- /dev/null +++ b/dlib/bigint.h @@ -0,0 +1,43 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINt_ +#define DLIB_BIGINt_ + +#include "bigint/bigint_kernel_1.h" +#include "bigint/bigint_kernel_2.h" +#include "bigint/bigint_kernel_c.h" + + + + +namespace dlib +{ + + + class bigint + { + bigint() {} + + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef bigint_kernel_1 + kernel_1a; + typedef bigint_kernel_c + kernel_1a_c; + + // kernel_2a + typedef bigint_kernel_2 + kernel_2a; + typedef bigint_kernel_c + kernel_2a_c; + + + }; +} + +#endif // DLIB_BIGINt_ + diff --git a/dlib/bigint/bigint_kernel_1.cpp b/dlib/bigint/bigint_kernel_1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..feef761c227996c1dd04e90087c1c04a86c65639 --- /dev/null +++ b/dlib/bigint/bigint_kernel_1.cpp @@ -0,0 +1,1720 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINT_KERNEL_1_CPp_ +#define DLIB_BIGINT_KERNEL_1_CPp_ +#include "bigint_kernel_1.h" + +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member/friend function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1:: + bigint_kernel_1 ( + ) : + slack(25), + data(new data_record(slack)) + {} + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1:: + bigint_kernel_1 ( + uint32 value + ) : + slack(25), + data(new data_record(slack)) + { + *(data->number) = static_cast(value&0xFFFF); + *(data->number+1) = static_cast((value>>16)&0xFFFF); + if (*(data->number+1) != 0) + data->digits_used = 2; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1:: + bigint_kernel_1 ( + const bigint_kernel_1& item + ) : + slack(25), + data(item.data) + { + data->references += 1; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1:: + ~bigint_kernel_1 ( + ) + { + if (data->references == 1) + { + delete data; + } + else + { + data->references -= 1; + } + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator+ ( + const bigint_kernel_1& rhs + ) const + { + data_record* temp = new data_record ( + std::max(rhs.data->digits_used,data->digits_used) + slack + ); + long_add(data,rhs.data,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator+= ( + const bigint_kernel_1& rhs + ) + { + // if there are other references to our data + if (data->references != 1) + { + data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); + data->references -= 1; + long_add(data,rhs.data,temp); + data = temp; + } + // if data is not big enough for the result + else if (data->size <= std::max(data->digits_used,rhs.data->digits_used)) + { + data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); + long_add(data,rhs.data,temp); + delete data; + data = temp; + } + // there is enough size and no references + else + { + long_add(data,rhs.data,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator- ( + const bigint_kernel_1& rhs + ) const + { + data_record* temp = new data_record ( + data->digits_used + slack + ); + long_sub(data,rhs.data,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator-= ( + const bigint_kernel_1& rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + long_sub(data,rhs.data,temp); + data = temp; + } + else + { + long_sub(data,rhs.data,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator* ( + const bigint_kernel_1& rhs + ) const + { + data_record* temp = new data_record ( + data->digits_used + rhs.data->digits_used + slack + ); + long_mul(data,rhs.data,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator*= ( + const bigint_kernel_1& rhs + ) + { + // create a data_record to store the result of the multiplication in + data_record* temp = new data_record(rhs.data->digits_used+data->digits_used+slack); + long_mul(data,rhs.data,temp); + + // if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + else + { + delete data; + } + data = temp; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator/ ( + const bigint_kernel_1& rhs + ) const + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + delete remainder; + + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator/= ( + const bigint_kernel_1& rhs + ) + { + + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + + // check if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + // if there are no references to data then it must be deleted + else + { + delete data; + } + data = temp; + delete remainder; + + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator% ( + const bigint_kernel_1& rhs + ) const + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + delete temp; + return bigint_kernel_1(remainder,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator%= ( + const bigint_kernel_1& rhs + ) + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + + // check if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + // if there are no references to data then it must be deleted + else + { + delete data; + } + data = remainder; + delete temp; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_1:: + operator < ( + const bigint_kernel_1& rhs + ) const + { + return is_less_than(data,rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_1:: + operator == ( + const bigint_kernel_1& rhs + ) const + { + return is_equal_to(data,rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator= ( + const bigint_kernel_1& rhs + ) + { + if (this == &rhs) + return *this; + + // if we have the only reference to our data then delete it + if (data->references == 1) + { + delete data; + data = rhs.data; + data->references += 1; + } + else + { + data->references -= 1; + data = rhs.data; + data->references += 1; + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + std::ostream& operator<< ( + std::ostream& out_, + const bigint_kernel_1& rhs + ) + { + std::ostream out(out_.rdbuf()); + + typedef bigint_kernel_1 bigint; + + bigint::data_record* temp = new bigint::data_record(*rhs.data,0); + + + + // get a char array big enough to hold the number in ascii format + char* str; + try { + str = new char[(rhs.data->digits_used)*5+10]; + } catch (...) { delete temp; throw; } + + char* str_start = str; + str += (rhs.data->digits_used)*5+9; + *str = 0; --str; + + + uint16 remainder; + rhs.short_div(temp,10000,temp,remainder); + + // pull the digits out of remainder + char a = remainder % 10 + '0'; + remainder /= 10; + char b = remainder % 10 + '0'; + remainder /= 10; + char c = remainder % 10 + '0'; + remainder /= 10; + char d = remainder % 10 + '0'; + remainder /= 10; + + + *str = a; --str; + *str = b; --str; + *str = c; --str; + *str = d; --str; + + + // keep looping until temp represents zero + while (temp->digits_used != 1 || *(temp->number) != 0) + { + rhs.short_div(temp,10000,temp,remainder); + + // pull the digits out of remainder + char a = remainder % 10 + '0'; + remainder /= 10; + char b = remainder % 10 + '0'; + remainder /= 10; + char c = remainder % 10 + '0'; + remainder /= 10; + char d = remainder % 10 + '0'; + remainder /= 10; + + *str = a; --str; + *str = b; --str; + *str = c; --str; + *str = d; --str; + } + + // throw away and extra leading zeros + ++str; + if (*str == '0') + ++str; + if (*str == '0') + ++str; + if (*str == '0') + ++str; + + + + + out << str; + delete [] str_start; + delete temp; + return out_; + + } + +// ---------------------------------------------------------------------------------------- + + std::istream& operator>> ( + std::istream& in_, + bigint_kernel_1& rhs + ) + { + std::istream in(in_.rdbuf()); + + // ignore any leading whitespaces + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\n') + { + in.get(); + } + + // if the first digit is not an integer then this is an error + if ( !(in.peek() >= '0' && in.peek() <= '9')) + { + in_.clear(std::ios::failbit); + return in_; + } + + int num_read; + bigint_kernel_1 temp; + do + { + + // try to get 4 chars from in + num_read = 1; + char a = 0; + char b = 0; + char c = 0; + char d = 0; + + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + a = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + b = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + c = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + d = in.get(); + } + + // merge the for digits into an uint16 + uint16 num = 0; + if (a != 0) + { + num = a - '0'; + } + if (b != 0) + { + num *= 10; + num += b - '0'; + } + if (c != 0) + { + num *= 10; + num += c - '0'; + } + if (d != 0) + { + num *= 10; + num += d - '0'; + } + + + if (num_read != 1) + { + // shift the digits in temp left by the number of new digits we just read + temp *= num_read; + // add in new digits + temp += num; + } + + } while (num_read == 10000); + + + rhs = temp; + return in_; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator+ ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (rhs.data->digits_used+rhs.slack); + + rhs.short_add(rhs.data,lhs,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator+ ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_add(lhs.data,rhs,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator+= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_add(data,rhs,temp); + data = temp; + } + // or if we need to enlarge data then do so + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + short_add(data,rhs,temp); + delete data; + data = temp; + } + // or if there is plenty of space and no references + else + { + short_add(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator- ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + *(temp->number) = lhs - *(rhs.data->number); + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator- ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_sub(lhs.data,rhs,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator-= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_sub(data,rhs,temp); + data = temp; + } + else + { + short_sub(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator* ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (rhs.data->digits_used+rhs.slack); + + rhs.short_mul(rhs.data,lhs,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator* ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_mul(lhs.data,rhs,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator*= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_mul(data,rhs,temp); + data = temp; + } + // or if we need to enlarge data + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + short_mul(data,rhs,temp); + delete data; + data = temp; + } + else + { + short_mul(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator/ ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + // if rhs might not be bigger than lhs + if (rhs.data->digits_used == 1) + { + *(temp->number) = lhs/ *(rhs.data->number); + } + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator/ ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + uint16 remainder; + lhs.short_div(lhs.data,rhs,temp,remainder); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator/= ( + uint16 rhs + ) + { + uint16 remainder; + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_div(data,rhs,temp,remainder); + data = temp; + } + else + { + short_div(data,rhs,data,remainder); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator% ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + typedef bigint_kernel_1 bigint; + // temp is zero by default + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + if (rhs.data->digits_used == 1) + { + // if rhs is just an uint16 inside then perform the modulus + *(temp->number) = lhs % *(rhs.data->number); + } + else + { + // if rhs is bigger than lhs then the answer is lhs + *(temp->number) = lhs; + } + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator% ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record(lhs.data->digits_used+lhs.slack); + + uint16 remainder; + + lhs.short_div(lhs.data,rhs,temp,remainder); + temp->digits_used = 1; + *(temp->number) = remainder; + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator%= ( + uint16 rhs + ) + { + uint16 remainder; + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_div(data,rhs,temp,remainder); + data = temp; + } + else + { + short_div(data,rhs,data,remainder); + } + + data->digits_used = 1; + *(data->number) = remainder; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bool operator < ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + return (rhs.data->digits_used > 1 || lhs < *(rhs.data->number) ); + } + +// ---------------------------------------------------------------------------------------- + + bool operator < ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + return (lhs.data->digits_used == 1 && *(lhs.data->number) < rhs); + } + +// ---------------------------------------------------------------------------------------- + + bool operator == ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + return (lhs.data->digits_used == 1 && *(lhs.data->number) == rhs); + } + +// ---------------------------------------------------------------------------------------- + + bool operator == ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + return (rhs.data->digits_used == 1 && *(rhs.data->number) == lhs); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator= ( + uint16 rhs + ) + { + // check if there are other references to our data + if (data->references != 1) + { + data->references -= 1; + try { + data = new data_record(slack); + } catch (...) { data->references += 1; throw; } + } + else + { + data->digits_used = 1; + } + + *(data->number) = rhs; + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator++ ( + ) + { + // if there are other references to this data then make a copy of it + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + increment(data,temp); + data = temp; + } + // or if we need to enlarge data then do so + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + increment(data,temp); + delete data; + data = temp; + } + else + { + increment(data,data); + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator++ ( + int + ) + { + data_record* temp; // this is the copy of temp we will return in the end + + data_record* temp2 = new data_record(data->digits_used+slack); + increment(data,temp2); + + temp = data; + data = temp2; + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator-- ( + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + decrement(data,temp); + data = temp; + } + else + { + decrement(data,data); + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator-- ( + int + ) + { + data_record* temp; // this is the copy of temp we will return in the end + + data_record* temp2 = new data_record(data->digits_used+slack); + decrement(data,temp2); + + temp = data; + data = temp2; + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + short_add ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + // put value into the carry part of temp + uint32 temp = value; + temp <<= 16; + + + const uint16* number = data->number; + const uint16* end = number + data->digits_used; // one past the end of number + uint16* r = result->number; + + while (number != end) + { + // add *number and the current carry + temp = *number + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number; + ++r; + } + + // if there is a final carry + if ((temp>>16) != 0) + { + result->digits_used = data->digits_used + 1; + // store the carry in the most significant digit of the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = data->digits_used; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + short_sub ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + + + const uint16* number = data->number; + const uint16* end = number + data->digits_used - 1; + uint16* r = result->number; + + uint32 temp = *number - value; + + // put the low word of temp into *data + *r = static_cast(temp & 0xFFFF); + + + while (number != end) + { + ++number; + ++r; + + // subtract the carry from *number + temp = *number - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + } + + // if we lost a digit in the subtraction + if (*r == 0) + { + if (data->digits_used == 1) + result->digits_used = 1; + else + result->digits_used = data->digits_used - 1; + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + short_mul ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + + uint32 temp = 0; + + + const uint16* number = data->number; + uint16* r = result->number; + const uint16* end = r + data->digits_used; + + + + while ( r != end) + { + + // multiply *data and value and add in the carry + temp = *number*(uint32)value + (temp>>16); + + // put the low word of temp into *data + *r = static_cast(temp & 0xFFFF); + + ++number; + ++r; + } + + // if there is a final carry + if ((temp>>16) != 0) + { + result->digits_used = data->digits_used + 1; + // put the final carry into the most significant digit of the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + short_div ( + const data_record* data, + uint16 value, + data_record* result, + uint16& rem + ) const + { + + uint16 remainder = 0; + uint32 temp; + + + + const uint16* number = data->number + data->digits_used - 1; + const uint16* end = number - data->digits_used; + uint16* r = result->number + data->digits_used - 1; + + + // if we are losing a digit in this division + if (*number < value) + { + if (data->digits_used == 1) + result->digits_used = 1; + else + result->digits_used = data->digits_used - 1; + } + else + { + result->digits_used = data->digits_used; + } + + + // perform the actual division + while (number != end) + { + + temp = *number + (((uint32)remainder)<<16); + + *r = static_cast(temp/value); + remainder = static_cast(temp%value); + + --number; + --r; + } + + rem = remainder; + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + long_add ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + // put value into the carry part of temp + uint32 temp=0; + + uint16* min_num; // the number with the least digits used + uint16* max_num; // the number with the most digits used + uint16* min_end; // one past the end of min_num + uint16* max_end; // one past the end of max_num + uint16* r = result->number; + + uint32 max_digits_used; + if (lhs->digits_used < rhs->digits_used) + { + max_digits_used = rhs->digits_used; + min_num = lhs->number; + max_num = rhs->number; + min_end = min_num + lhs->digits_used; + max_end = max_num + rhs->digits_used; + } + else + { + max_digits_used = lhs->digits_used; + min_num = rhs->number; + max_num = lhs->number; + min_end = min_num + rhs->digits_used; + max_end = max_num + lhs->digits_used; + } + + + + + while (min_num != min_end) + { + // add *min_num, *max_num and the current carry + temp = *min_num + *max_num + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++min_num; + ++max_num; + ++r; + } + + + while (max_num != max_end) + { + // add *max_num and the current carry + temp = *max_num + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++max_num; + ++r; + } + + // check if there was a final carry + if ((temp>>16) != 0) + { + result->digits_used = max_digits_used + 1; + // put the carry into the most significant digit in the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = max_digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + long_sub ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + + + const uint16* number1 = lhs->number; + const uint16* number2 = rhs->number; + const uint16* end = number2 + rhs->digits_used; + uint16* r = result->number; + + + + uint32 temp =0; + + + while (number2 != end) + { + + // subtract *number2 from *number1 and then subtract any carry + temp = *number1 - *number2 - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number1; + ++number2; + ++r; + } + + end = lhs->number + lhs->digits_used; + while (number1 != end) + { + + // subtract the carry from *number1 + temp = *number1 - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number1; + ++r; + } + + result->digits_used = lhs->digits_used; + // adjust the number of digits used appropriately + --r; + while (*r == 0 && result->digits_used > 1) + { + --r; + --result->digits_used; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + long_div ( + const data_record* lhs, + const data_record* rhs, + data_record* result, + data_record* remainder + ) const + { + // zero result + result->digits_used = 1; + *(result->number) = 0; + + uint16* a; + uint16* b; + uint16* end; + + // copy lhs into remainder + remainder->digits_used = lhs->digits_used; + a = remainder->number; + end = a + remainder->digits_used; + b = lhs->number; + while (a != end) + { + *a = *b; + ++a; + ++b; + } + + + // if rhs is bigger than lhs then result == 0 and remainder == lhs + // so then we can quit right now + if (is_less_than(lhs,rhs)) + { + return; + } + + + // make a temporary number + data_record temp(lhs->digits_used + slack); + + + // shift rhs left until it is one shift away from being larger than lhs and + // put the number of left shifts necessary into shifts + uint32 shifts; + shifts = (lhs->digits_used - rhs->digits_used) * 16; + + shift_left(rhs,&temp,shifts); + + + // while (lhs > temp) + while (is_less_than(&temp,lhs)) + { + shift_left(&temp,&temp,1); + ++shifts; + } + // make sure lhs isn't smaller than temp + while (is_less_than(lhs,&temp)) + { + shift_right(&temp,&temp); + --shifts; + } + + + + // we want to execute the loop shifts +1 times + ++shifts; + while (shifts != 0) + { + shift_left(result,result,1); + // if (temp <= remainder) + if (!is_less_than(remainder,&temp)) + { + long_sub(remainder,&temp,remainder); + + // increment result + uint16* r = result->number; + uint16* end = r + result->digits_used; + while (true) + { + ++(*r); + // if there was no carry then we are done + if (*r != 0) + break; + + ++r; + + // if we hit the end of r and there is still a carry then + // the next digit of r is 1 and there is one more digit used + if (r == end) + { + *r = 1; + ++(result->digits_used); + break; + } + } + } + shift_right(&temp,&temp); + --shifts; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + long_mul ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + // make result be zero + result->digits_used = 1; + *(result->number) = 0; + + + const data_record* aa; + const data_record* bb; + + if (lhs->digits_used < rhs->digits_used) + { + // make copies of lhs and rhs and give them an appropriate amount of + // extra memory so there won't be any overflows + aa = lhs; + bb = rhs; + } + else + { + // make copies of lhs and rhs and give them an appropriate amount of + // extra memory so there won't be any overflows + aa = rhs; + bb = lhs; + } + // this is where we actually copy lhs and rhs + data_record b(*bb,aa->digits_used+slack); // the larger(approximately) of lhs and rhs + + + uint32 shift_value = 0; + uint16* anum = aa->number; + uint16* end = anum + aa->digits_used; + while (anum != end ) + { + uint16 bit = 0x0001; + + for (int i = 0; i < 16; ++i) + { + // if the specified bit of a is 1 + if ((*anum & bit) != 0) + { + shift_left(&b,&b,shift_value); + shift_value = 0; + long_add(&b,result,result); + } + ++shift_value; + bit <<= 1; + } + + ++anum; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + shift_left ( + const data_record* data, + data_record* result, + uint32 shift_amount + ) const + { + uint32 offset = shift_amount/16; + shift_amount &= 0xf; // same as shift_amount %= 16; + + uint16* r = result->number + data->digits_used + offset; // result + uint16* end = data->number; + uint16* s = end + data->digits_used; // source + const uint32 temp = 16 - shift_amount; + + *r = (*(--s) >> temp); + // set the number of digits used in the result + // if the upper bits from *s were zero then don't count this first word + if (*r == 0) + { + result->digits_used = data->digits_used + offset; + } + else + { + result->digits_used = data->digits_used + offset + 1; + } + --r; + + while (s != end) + { + *r = ((*s << shift_amount) | ( *(s-1) >> temp)); + --r; + --s; + } + *r = *s << shift_amount; + + // now zero the rest of the result + end = result->number; + while (r != end) + *(--r) = 0; + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + shift_right ( + const data_record* data, + data_record* result + ) const + { + + uint16* r = result->number; // result + uint16* s = data->number; // source + uint16* end = s + data->digits_used - 1; + + while (s != end) + { + *r = (*s >> 1) | (*(s+1) << 15); + ++r; + ++s; + } + *r = *s >> 1; + + + // calculate the new number for digits_used + if (*r == 0) + { + if (data->digits_used != 1) + result->digits_used = data->digits_used - 1; + else + result->digits_used = 1; + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_1:: + is_less_than ( + const data_record* lhs, + const data_record* rhs + ) const + { + uint32 lhs_digits_used = lhs->digits_used; + uint32 rhs_digits_used = rhs->digits_used; + + // if lhs is definitely less than rhs + if (lhs_digits_used < rhs_digits_used ) + return true; + // if lhs is definitely greater than rhs + else if (lhs_digits_used > rhs_digits_used) + return false; + else + { + uint16* end = lhs->number; + uint16* l = end + lhs_digits_used; + uint16* r = rhs->number + rhs_digits_used; + + while (l != end) + { + --l; + --r; + if (*l < *r) + return true; + else if (*l > *r) + return false; + } + + // at this point we know that they are equal + return false; + } + + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_1:: + is_equal_to ( + const data_record* lhs, + const data_record* rhs + ) const + { + // if lhs and rhs are definitely not equal + if (lhs->digits_used != rhs->digits_used ) + { + return false; + } + else + { + uint16* l = lhs->number; + uint16* r = rhs->number; + uint16* end = l + lhs->digits_used; + + while (l != end) + { + if (*l != *r) + return false; + ++l; + ++r; + } + + // at this point we know that they are equal + return true; + } + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + increment ( + const data_record* source, + data_record* dest + ) const + { + uint16* s = source->number; + uint16* d = dest->number; + uint16* end = s + source->digits_used; + while (true) + { + *d = *s + 1; + + // if there was no carry then break out of the loop + if (*d != 0) + { + dest->digits_used = source->digits_used; + + // copy the rest of the digits over to d + ++d; ++s; + while (s != end) + { + *d = *s; + ++d; + ++s; + } + + break; + } + + + ++s; + + // if we have hit the end of s and there was a carry up to this point + // then just make the next digit 1 and add one to the digits used + if (s == end) + { + ++d; + dest->digits_used = source->digits_used + 1; + *d = 1; + break; + } + + ++d; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + decrement ( + const data_record* source, + data_record* dest + ) const + { + uint16* s = source->number; + uint16* d = dest->number; + uint16* end = s + source->digits_used; + while (true) + { + *d = *s - 1; + + // if there was no carry then break out of the loop + if (*d != 0xFFFF) + { + // if we lost a digit in the subtraction + if (*d == 0 && s+1 == end) + { + if (source->digits_used == 1) + dest->digits_used = 1; + else + dest->digits_used = source->digits_used - 1; + } + else + { + dest->digits_used = source->digits_used; + } + break; + } + else + { + ++d; + ++s; + } + + } + + // copy the rest of the digits over to d + ++d; + ++s; + while (s != end) + { + *d = *s; + ++d; + ++s; + } + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_BIGINT_KERNEL_1_CPp_ + diff --git a/dlib/bigint/bigint_kernel_1.h b/dlib/bigint/bigint_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..32463c3b26f017d9dabdada92290945251fe2434 --- /dev/null +++ b/dlib/bigint/bigint_kernel_1.h @@ -0,0 +1,543 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINT_KERNEl_1_ +#define DLIB_BIGINT_KERNEl_1_ + +#include "bigint_kernel_abstract.h" +#include "../algs.h" +#include "../serialize.h" +#include "../uintn.h" +#include + +namespace dlib +{ + + + class bigint_kernel_1 + { + /*! + INITIAL VALUE + slack == 25 + data->number[0] == 0 + data->size == slack + data->references == 1 + data->digits_used == 1 + + + CONVENTION + slack == the number of extra digits placed into the number when it is + created. the slack value should never be less than 1 + + data->number == pointer to an array of data->size uint16s. + data represents a string of base 65535 numbers with data[0] being + the least significant bit and data[data->digits_used-1] being the most + significant + + + NOTE: In the comments I will consider a word to be a 16 bit value + + + data->digits_used == the number of significant digits in the number. + data->digits_used tells us the number of used elements in the + data->number array so everything beyond data->number[data->digits_used-1] + is undefined + + data->references == the number of bigint_kernel_1 objects which refer + to this data_record + + + + !*/ + + + struct data_record + { + + + explicit data_record( + uint32 size_ + ) : + size(size_), + number(new uint16[size_]), + references(1), + digits_used(1) + {*number = 0;} + /*! + ensures + - initializes *this to represent zero + !*/ + + data_record( + const data_record& item, + uint32 additional_size + ) : + size(item.digits_used + additional_size), + number(new uint16[size]), + references(1), + digits_used(item.digits_used) + { + uint16* source = item.number; + uint16* dest = number; + uint16* end = source + digits_used; + while (source != end) + { + *dest = *source; + ++dest; + ++source; + } + } + /*! + ensures + - *this is a copy of item except with + size == item.digits_used + additional_size + !*/ + + ~data_record( + ) + { + delete [] number; + } + + + const uint32 size; + uint16* number; + uint32 references; + uint32 digits_used; + + private: + // no copy constructor + data_record ( data_record&); + }; + + + + // note that the second parameter is just there + // to resolve the ambiguity between this constructor and + // bigint_kernel_1(uint32) + explicit bigint_kernel_1 ( + data_record* data_, int + ): slack(25),data(data_) {} + /*! + ensures + - *this is initialized with data_ as its data member + !*/ + + + public: + + bigint_kernel_1 ( + ); + + bigint_kernel_1 ( + uint32 value + ); + + bigint_kernel_1 ( + const bigint_kernel_1& item + ); + + virtual ~bigint_kernel_1 ( + ); + + const bigint_kernel_1 operator+ ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator+= ( + const bigint_kernel_1& rhs + ); + + const bigint_kernel_1 operator- ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator-= ( + const bigint_kernel_1& rhs + ); + + const bigint_kernel_1 operator* ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator*= ( + const bigint_kernel_1& rhs + ); + + const bigint_kernel_1 operator/ ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator/= ( + const bigint_kernel_1& rhs + ); + + const bigint_kernel_1 operator% ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator%= ( + const bigint_kernel_1& rhs + ); + + bool operator < ( + const bigint_kernel_1& rhs + ) const; + + bool operator == ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator= ( + const bigint_kernel_1& rhs + ); + + friend std::ostream& operator<< ( + std::ostream& out, + const bigint_kernel_1& rhs + ); + + friend std::istream& operator>> ( + std::istream& in, + bigint_kernel_1& rhs + ); + + bigint_kernel_1& operator++ ( + ); + + const bigint_kernel_1 operator++ ( + int + ); + + bigint_kernel_1& operator-- ( + ); + + const bigint_kernel_1 operator-- ( + int + ); + + friend const bigint_kernel_1 operator+ ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend const bigint_kernel_1 operator+ ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + bigint_kernel_1& operator+= ( + uint16 rhs + ); + + friend const bigint_kernel_1 operator- ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend const bigint_kernel_1 operator- ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + bigint_kernel_1& operator-= ( + uint16 rhs + ); + + friend const bigint_kernel_1 operator* ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend const bigint_kernel_1 operator* ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + bigint_kernel_1& operator*= ( + uint16 rhs + ); + + friend const bigint_kernel_1 operator/ ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend const bigint_kernel_1 operator/ ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + bigint_kernel_1& operator/= ( + uint16 rhs + ); + + friend const bigint_kernel_1 operator% ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend const bigint_kernel_1 operator% ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + bigint_kernel_1& operator%= ( + uint16 rhs + ); + + friend bool operator < ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend bool operator < ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + friend bool operator == ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + friend bool operator == ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + bigint_kernel_1& operator= ( + uint16 rhs + ); + + + void swap ( + bigint_kernel_1& item + ) { data_record* temp = data; data = item.data; item.data = temp; } + + + private: + + void long_add ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - result->size >= max(lhs->digits_used,rhs->digits_used) + 1 + ensures + - result == lhs + rhs + !*/ + + void long_sub ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - lhs >= rhs + - result->size >= lhs->digits_used + ensures + - result == lhs - rhs + !*/ + + void long_div ( + const data_record* lhs, + const data_record* rhs, + data_record* result, + data_record* remainder + ) const; + /*! + requires + - rhs != 0 + - result->size >= lhs->digits_used + - remainder->size >= lhs->digits_used + - each parameter is unique (i.e. lhs != result, lhs != remainder, etc.) + ensures + - result == lhs / rhs + - remainder == lhs % rhs + !*/ + + void long_mul ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - result->size >= lhs->digits_used + rhs->digits_used + - result != lhs + - result != rhs + ensures + - result == lhs * rhs + !*/ + + void short_add ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - result->size >= data->size + 1 + ensures + - result == data + value + !*/ + + void short_sub ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - data >= value + - result->size >= data->digits_used + ensures + - result == data - value + !*/ + + void short_mul ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - result->size >= data->digits_used + 1 + ensures + - result == data * value + !*/ + + void short_div ( + const data_record* data, + uint16 value, + data_record* result, + uint16& remainder + ) const; + /*! + requires + - value != 0 + - result->size >= data->digits_used + ensures + - result = data*value + - remainder = data%value + !*/ + + void shift_left ( + const data_record* data, + data_record* result, + uint32 shift_amount + ) const; + /*! + requires + - result->size >= data->digits_used + shift_amount/8 + 1 + ensures + - result == data << shift_amount + !*/ + + void shift_right ( + const data_record* data, + data_record* result + ) const; + /*! + requires + - result->size >= data->digits_used + ensures + - result == data >> 1 + !*/ + + bool is_less_than ( + const data_record* lhs, + const data_record* rhs + ) const; + /*! + ensures + - returns true if lhs < rhs + - returns false otherwise + !*/ + + bool is_equal_to ( + const data_record* lhs, + const data_record* rhs + ) const; + /*! + ensures + - returns true if lhs == rhs + - returns false otherwise + !*/ + + void increment ( + const data_record* source, + data_record* dest + ) const; + /*! + requires + - dest->size >= source->digits_used + 1 + ensures + - dest = source + 1 + !*/ + + void decrement ( + const data_record* source, + data_record* dest + ) const; + /*! + requires + source != 0 + ensuers + dest = source - 1 + !*/ + + // member data + const uint32 slack; + data_record* data; + + + + }; + + inline void swap ( + bigint_kernel_1& a, + bigint_kernel_1& b + ) { a.swap(b); } + + inline void serialize ( + const bigint_kernel_1& item, + std::ostream& out + ) + { + std::ios::fmtflags oldflags = out.flags(); + out << item << ' '; + out.flags(oldflags); + if (!out) throw serialization_error("Error serializing object of type bigint_kernel_c"); + } + + inline void deserialize ( + bigint_kernel_1& item, + std::istream& in + ) + { + std::ios::fmtflags oldflags = in.flags(); + in >> item; + in.flags(oldflags); + if (in.get() != ' ') + { + item = 0; + throw serialization_error("Error deserializing object of type bigint_kernel_c"); + } + } + + inline bool operator> (const bigint_kernel_1& a, const bigint_kernel_1& b) { return b < a; } + inline bool operator!= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(a == b); } + inline bool operator<= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(b < a); } + inline bool operator>= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(a < b); } +} + +#ifdef NO_MAKEFILE +#include "bigint_kernel_1.cpp" +#endif + +#endif // DLIB_BIGINT_KERNEl_1_ + diff --git a/dlib/bigint/bigint_kernel_2.cpp b/dlib/bigint/bigint_kernel_2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..694b1ae1f752146f8539bb5f9e0735e2c5e701d3 --- /dev/null +++ b/dlib/bigint/bigint_kernel_2.cpp @@ -0,0 +1,1945 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINT_KERNEL_2_CPp_ +#define DLIB_BIGINT_KERNEL_2_CPp_ +#include "bigint_kernel_2.h" + +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member/friend function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2:: + bigint_kernel_2 ( + ) : + slack(25), + data(new data_record(slack)) + {} + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2:: + bigint_kernel_2 ( + uint32 value + ) : + slack(25), + data(new data_record(slack)) + { + *(data->number) = static_cast(value&0xFFFF); + *(data->number+1) = static_cast((value>>16)&0xFFFF); + if (*(data->number+1) != 0) + data->digits_used = 2; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2:: + bigint_kernel_2 ( + const bigint_kernel_2& item + ) : + slack(25), + data(item.data) + { + data->references += 1; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2:: + ~bigint_kernel_2 ( + ) + { + if (data->references == 1) + { + delete data; + } + else + { + data->references -= 1; + } + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator+ ( + const bigint_kernel_2& rhs + ) const + { + data_record* temp = new data_record ( + std::max(rhs.data->digits_used,data->digits_used) + slack + ); + long_add(data,rhs.data,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator+= ( + const bigint_kernel_2& rhs + ) + { + // if there are other references to our data + if (data->references != 1) + { + data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); + data->references -= 1; + long_add(data,rhs.data,temp); + data = temp; + } + // if data is not big enough for the result + else if (data->size <= std::max(data->digits_used,rhs.data->digits_used)) + { + data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); + long_add(data,rhs.data,temp); + delete data; + data = temp; + } + // there is enough size and no references + else + { + long_add(data,rhs.data,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator- ( + const bigint_kernel_2& rhs + ) const + { + data_record* temp = new data_record ( + data->digits_used + slack + ); + long_sub(data,rhs.data,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator-= ( + const bigint_kernel_2& rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + long_sub(data,rhs.data,temp); + data = temp; + } + else + { + long_sub(data,rhs.data,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator* ( + const bigint_kernel_2& rhs + ) const + { + data_record* temp = new data_record ( + data->digits_used + rhs.data->digits_used + slack + ); + long_mul(data,rhs.data,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator*= ( + const bigint_kernel_2& rhs + ) + { + // create a data_record to store the result of the multiplication in + data_record* temp = new data_record(rhs.data->digits_used+data->digits_used+slack); + long_mul(data,rhs.data,temp); + + // if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + else + { + delete data; + } + data = temp; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator/ ( + const bigint_kernel_2& rhs + ) const + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + delete remainder; + + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator/= ( + const bigint_kernel_2& rhs + ) + { + + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + + // check if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + // if there are no references to data then it must be deleted + else + { + delete data; + } + data = temp; + delete remainder; + + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator% ( + const bigint_kernel_2& rhs + ) const + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + delete temp; + return bigint_kernel_2(remainder,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator%= ( + const bigint_kernel_2& rhs + ) + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + + // check if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + // if there are no references to data then it must be deleted + else + { + delete data; + } + data = remainder; + delete temp; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_2:: + operator < ( + const bigint_kernel_2& rhs + ) const + { + return is_less_than(data,rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_2:: + operator == ( + const bigint_kernel_2& rhs + ) const + { + return is_equal_to(data,rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator= ( + const bigint_kernel_2& rhs + ) + { + if (this == &rhs) + return *this; + + // if we have the only reference to our data then delete it + if (data->references == 1) + { + delete data; + data = rhs.data; + data->references += 1; + } + else + { + data->references -= 1; + data = rhs.data; + data->references += 1; + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + std::ostream& operator<< ( + std::ostream& out_, + const bigint_kernel_2& rhs + ) + { + std::ostream out(out_.rdbuf()); + + typedef bigint_kernel_2 bigint; + + bigint::data_record* temp = new bigint::data_record(*rhs.data,0); + + + + // get a char array big enough to hold the number in ascii format + char* str; + try { + str = new char[(rhs.data->digits_used)*5+10]; + } catch (...) { delete temp; throw; } + + char* str_start = str; + str += (rhs.data->digits_used)*5+9; + *str = 0; --str; + + + uint16 remainder; + rhs.short_div(temp,10000,temp,remainder); + + // pull the digits out of remainder + char a = remainder % 10 + '0'; + remainder /= 10; + char b = remainder % 10 + '0'; + remainder /= 10; + char c = remainder % 10 + '0'; + remainder /= 10; + char d = remainder % 10 + '0'; + remainder /= 10; + + + *str = a; --str; + *str = b; --str; + *str = c; --str; + *str = d; --str; + + + // keep looping until temp represents zero + while (temp->digits_used != 1 || *(temp->number) != 0) + { + rhs.short_div(temp,10000,temp,remainder); + + // pull the digits out of remainder + char a = remainder % 10 + '0'; + remainder /= 10; + char b = remainder % 10 + '0'; + remainder /= 10; + char c = remainder % 10 + '0'; + remainder /= 10; + char d = remainder % 10 + '0'; + remainder /= 10; + + *str = a; --str; + *str = b; --str; + *str = c; --str; + *str = d; --str; + } + + // throw away and extra leading zeros + ++str; + if (*str == '0') + ++str; + if (*str == '0') + ++str; + if (*str == '0') + ++str; + + + + + out << str; + delete [] str_start; + delete temp; + return out_; + + } + +// ---------------------------------------------------------------------------------------- + + std::istream& operator>> ( + std::istream& in_, + bigint_kernel_2& rhs + ) + { + std::istream in(in_.rdbuf()); + + // ignore any leading whitespaces + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\n') + { + in.get(); + } + + // if the first digit is not an integer then this is an error + if ( !(in.peek() >= '0' && in.peek() <= '9')) + { + in_.clear(std::ios::failbit); + return in_; + } + + int num_read; + bigint_kernel_2 temp; + do + { + + // try to get 4 chars from in + num_read = 1; + char a = 0; + char b = 0; + char c = 0; + char d = 0; + + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + a = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + b = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + c = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + d = in.get(); + } + + // merge the for digits into an uint16 + uint16 num = 0; + if (a != 0) + { + num = a - '0'; + } + if (b != 0) + { + num *= 10; + num += b - '0'; + } + if (c != 0) + { + num *= 10; + num += c - '0'; + } + if (d != 0) + { + num *= 10; + num += d - '0'; + } + + + if (num_read != 1) + { + // shift the digits in temp left by the number of new digits we just read + temp *= num_read; + // add in new digits + temp += num; + } + + } while (num_read == 10000); + + + rhs = temp; + return in_; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator+ ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (rhs.data->digits_used+rhs.slack); + + rhs.short_add(rhs.data,lhs,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator+ ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_add(lhs.data,rhs,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator+= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_add(data,rhs,temp); + data = temp; + } + // or if we need to enlarge data then do so + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + short_add(data,rhs,temp); + delete data; + data = temp; + } + // or if there is plenty of space and no references + else + { + short_add(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator- ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + *(temp->number) = lhs - *(rhs.data->number); + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator- ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_sub(lhs.data,rhs,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator-= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_sub(data,rhs,temp); + data = temp; + } + else + { + short_sub(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator* ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (rhs.data->digits_used+rhs.slack); + + rhs.short_mul(rhs.data,lhs,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator* ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_mul(lhs.data,rhs,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator*= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_mul(data,rhs,temp); + data = temp; + } + // or if we need to enlarge data + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + short_mul(data,rhs,temp); + delete data; + data = temp; + } + else + { + short_mul(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator/ ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + // if rhs might not be bigger than lhs + if (rhs.data->digits_used == 1) + { + *(temp->number) = lhs/ *(rhs.data->number); + } + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator/ ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + uint16 remainder; + lhs.short_div(lhs.data,rhs,temp,remainder); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator/= ( + uint16 rhs + ) + { + uint16 remainder; + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_div(data,rhs,temp,remainder); + data = temp; + } + else + { + short_div(data,rhs,data,remainder); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator% ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + typedef bigint_kernel_2 bigint; + // temp is zero by default + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + if (rhs.data->digits_used == 1) + { + // if rhs is just an uint16 inside then perform the modulus + *(temp->number) = lhs % *(rhs.data->number); + } + else + { + // if rhs is bigger than lhs then the answer is lhs + *(temp->number) = lhs; + } + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator% ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record(lhs.data->digits_used+lhs.slack); + + uint16 remainder; + + lhs.short_div(lhs.data,rhs,temp,remainder); + temp->digits_used = 1; + *(temp->number) = remainder; + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator%= ( + uint16 rhs + ) + { + uint16 remainder; + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_div(data,rhs,temp,remainder); + data = temp; + } + else + { + short_div(data,rhs,data,remainder); + } + + data->digits_used = 1; + *(data->number) = remainder; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bool operator < ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + return (rhs.data->digits_used > 1 || lhs < *(rhs.data->number) ); + } + +// ---------------------------------------------------------------------------------------- + + bool operator < ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + return (lhs.data->digits_used == 1 && *(lhs.data->number) < rhs); + } + +// ---------------------------------------------------------------------------------------- + + bool operator == ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + return (lhs.data->digits_used == 1 && *(lhs.data->number) == rhs); + } + +// ---------------------------------------------------------------------------------------- + + bool operator == ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + return (rhs.data->digits_used == 1 && *(rhs.data->number) == lhs); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator= ( + uint16 rhs + ) + { + // check if there are other references to our data + if (data->references != 1) + { + data->references -= 1; + try { + data = new data_record(slack); + } catch (...) { data->references += 1; throw; } + } + else + { + data->digits_used = 1; + } + + *(data->number) = rhs; + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator++ ( + ) + { + // if there are other references to this data then make a copy of it + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + increment(data,temp); + data = temp; + } + // or if we need to enlarge data then do so + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + increment(data,temp); + delete data; + data = temp; + } + else + { + increment(data,data); + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator++ ( + int + ) + { + data_record* temp; // this is the copy of temp we will return in the end + + data_record* temp2 = new data_record(data->digits_used+slack); + increment(data,temp2); + + temp = data; + data = temp2; + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator-- ( + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + decrement(data,temp); + data = temp; + } + else + { + decrement(data,data); + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator-- ( + int + ) + { + data_record* temp; // this is the copy of temp we will return in the end + + data_record* temp2 = new data_record(data->digits_used+slack); + decrement(data,temp2); + + temp = data; + data = temp2; + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + short_add ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + // put value into the carry part of temp + uint32 temp = value; + temp <<= 16; + + + const uint16* number = data->number; + const uint16* end = number + data->digits_used; // one past the end of number + uint16* r = result->number; + + while (number != end) + { + // add *number and the current carry + temp = *number + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number; + ++r; + } + + // if there is a final carry + if ((temp>>16) != 0) + { + result->digits_used = data->digits_used + 1; + // store the carry in the most significant digit of the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = data->digits_used; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + short_sub ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + + + const uint16* number = data->number; + const uint16* end = number + data->digits_used - 1; + uint16* r = result->number; + + uint32 temp = *number - value; + + // put the low word of temp into *data + *r = static_cast(temp & 0xFFFF); + + + while (number != end) + { + ++number; + ++r; + + // subtract the carry from *number + temp = *number - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + } + + // if we lost a digit in the subtraction + if (*r == 0) + { + if (data->digits_used == 1) + result->digits_used = 1; + else + result->digits_used = data->digits_used - 1; + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + short_mul ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + + uint32 temp = 0; + + + const uint16* number = data->number; + uint16* r = result->number; + const uint16* end = r + data->digits_used; + + + + while ( r != end) + { + + // multiply *data and value and add in the carry + temp = *number*(uint32)value + (temp>>16); + + // put the low word of temp into *data + *r = static_cast(temp & 0xFFFF); + + ++number; + ++r; + } + + // if there is a final carry + if ((temp>>16) != 0) + { + result->digits_used = data->digits_used + 1; + // put the final carry into the most significant digit of the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + short_div ( + const data_record* data, + uint16 value, + data_record* result, + uint16& rem + ) const + { + + uint16 remainder = 0; + uint32 temp; + + + + const uint16* number = data->number + data->digits_used - 1; + const uint16* end = number - data->digits_used; + uint16* r = result->number + data->digits_used - 1; + + + // if we are losing a digit in this division + if (*number < value) + { + if (data->digits_used == 1) + result->digits_used = 1; + else + result->digits_used = data->digits_used - 1; + } + else + { + result->digits_used = data->digits_used; + } + + + // perform the actual division + while (number != end) + { + + temp = *number + (((uint32)remainder)<<16); + + *r = static_cast(temp/value); + remainder = static_cast(temp%value); + + --number; + --r; + } + + rem = remainder; + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + long_add ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + // put value into the carry part of temp + uint32 temp=0; + + uint16* min_num; // the number with the least digits used + uint16* max_num; // the number with the most digits used + uint16* min_end; // one past the end of min_num + uint16* max_end; // one past the end of max_num + uint16* r = result->number; + + uint32 max_digits_used; + if (lhs->digits_used < rhs->digits_used) + { + max_digits_used = rhs->digits_used; + min_num = lhs->number; + max_num = rhs->number; + min_end = min_num + lhs->digits_used; + max_end = max_num + rhs->digits_used; + } + else + { + max_digits_used = lhs->digits_used; + min_num = rhs->number; + max_num = lhs->number; + min_end = min_num + rhs->digits_used; + max_end = max_num + lhs->digits_used; + } + + + + + while (min_num != min_end) + { + // add *min_num, *max_num and the current carry + temp = *min_num + *max_num + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++min_num; + ++max_num; + ++r; + } + + + while (max_num != max_end) + { + // add *max_num and the current carry + temp = *max_num + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++max_num; + ++r; + } + + // check if there was a final carry + if ((temp>>16) != 0) + { + result->digits_used = max_digits_used + 1; + // put the carry into the most significant digit in the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = max_digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + long_sub ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + + + const uint16* number1 = lhs->number; + const uint16* number2 = rhs->number; + const uint16* end = number2 + rhs->digits_used; + uint16* r = result->number; + + + + uint32 temp =0; + + + while (number2 != end) + { + + // subtract *number2 from *number1 and then subtract any carry + temp = *number1 - *number2 - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number1; + ++number2; + ++r; + } + + end = lhs->number + lhs->digits_used; + while (number1 != end) + { + + // subtract the carry from *number1 + temp = *number1 - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number1; + ++r; + } + + result->digits_used = lhs->digits_used; + // adjust the number of digits used appropriately + --r; + while (*r == 0 && result->digits_used > 1) + { + --r; + --result->digits_used; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + long_div ( + const data_record* lhs, + const data_record* rhs, + data_record* result, + data_record* remainder + ) const + { + // zero result + result->digits_used = 1; + *(result->number) = 0; + + uint16* a; + uint16* b; + uint16* end; + + // copy lhs into remainder + remainder->digits_used = lhs->digits_used; + a = remainder->number; + end = a + remainder->digits_used; + b = lhs->number; + while (a != end) + { + *a = *b; + ++a; + ++b; + } + + + // if rhs is bigger than lhs then result == 0 and remainder == lhs + // so then we can quit right now + if (is_less_than(lhs,rhs)) + { + return; + } + + + // make a temporary number + data_record temp(lhs->digits_used + slack); + + + // shift rhs left until it is one shift away from being larger than lhs and + // put the number of left shifts necessary into shifts + uint32 shifts; + shifts = (lhs->digits_used - rhs->digits_used) * 16; + + shift_left(rhs,&temp,shifts); + + + // while (lhs > temp) + while (is_less_than(&temp,lhs)) + { + shift_left(&temp,&temp,1); + ++shifts; + } + // make sure lhs isn't smaller than temp + while (is_less_than(lhs,&temp)) + { + shift_right(&temp,&temp); + --shifts; + } + + + + // we want to execute the loop shifts +1 times + ++shifts; + while (shifts != 0) + { + shift_left(result,result,1); + // if (temp <= remainder) + if (!is_less_than(remainder,&temp)) + { + long_sub(remainder,&temp,remainder); + + // increment result + uint16* r = result->number; + uint16* end = r + result->digits_used; + while (true) + { + ++(*r); + // if there was no carry then we are done + if (*r != 0) + break; + + ++r; + + // if we hit the end of r and there is still a carry then + // the next digit of r is 1 and there is one more digit used + if (r == end) + { + *r = 1; + ++(result->digits_used); + break; + } + } + } + shift_right(&temp,&temp); + --shifts; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + long_mul ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + // if one of the numbers is small then use this simple but O(n^2) algorithm + if (std::min(lhs->digits_used, rhs->digits_used) < 10) + { + // make result be zero + result->digits_used = 1; + *(result->number) = 0; + + + const data_record* aa; + const data_record* bb; + + if (lhs->digits_used < rhs->digits_used) + { + // make copies of lhs and rhs and give them an appropriate amount of + // extra memory so there won't be any overflows + aa = lhs; + bb = rhs; + } + else + { + // make copies of lhs and rhs and give them an appropriate amount of + // extra memory so there won't be any overflows + aa = rhs; + bb = lhs; + } + + // copy the larger(approximately) of lhs and rhs into b + data_record b(*bb,aa->digits_used+slack); + + + uint32 shift_value = 0; + uint16* anum = aa->number; + uint16* end = anum + aa->digits_used; + while (anum != end ) + { + uint16 bit = 0x0001; + + for (int i = 0; i < 16; ++i) + { + // if the specified bit of a is 1 + if ((*anum & bit) != 0) + { + shift_left(&b,&b,shift_value); + shift_value = 0; + long_add(&b,result,result); + } + ++shift_value; + bit <<= 1; + } + + ++anum; + } + } + else // else if both lhs and rhs are large then use the more complex + // O(n*logn) algorithm + { + uint32 size = 1; + // make size a power of 2 + while (size < (lhs->digits_used + rhs->digits_used)*2) + { + size *= 2; + } + + // allocate some temporary space so we can do the FFT + ct* a = new ct[size]; + ct* b; try {b = new ct[size]; } catch (...) { delete [] a; throw; } + + // load lhs into the a array. We are breaking the input number into + // 8bit chunks for the purpose of using this fft algorithm. The reason + // for this is so that we have smaller numbers coming out of the final + // ifft. This helps avoid overflow. + for (uint32 i = 0; i < lhs->digits_used; ++i) + { + a[i*2] = ct((t)(lhs->number[i]&0xFF),0); + a[i*2+1] = ct((t)(lhs->number[i]>>8),0); + } + for (uint32 i = lhs->digits_used*2; i < size; ++i) + { + a[i] = 0; + } + + // load rhs into the b array + for (uint32 i = 0; i < rhs->digits_used; ++i) + { + b[i*2] = ct((t)(rhs->number[i]&0xFF),0); + b[i*2+1] = ct((t)(rhs->number[i]>>8),0); + } + for (uint32 i = rhs->digits_used*2; i < size; ++i) + { + b[i] = 0; + } + + // perform the forward fft of a and b + fft(a,size); + fft(b,size); + + const double l = 1.0/size; + + // do the pointwise multiply of a and b and also apply the scale + // factor in this loop too. + for (unsigned long i = 0; i < size; ++i) + { + a[i] = l*a[i]*b[i]; + } + + // Now compute the inverse fft of the pointwise multiplication of a and b. + // This is basically the result. We just have to take care of any carries + // that should happen. + ifft(a,size); + + // loop over the result and propagate any carries that need to take place. + // We will also be moving the resulting numbers into result->number at + // the same time. + uint64 carry = 0; + result->digits_used = 0; + int zeros = 0; + const uint32 len = lhs->digits_used + rhs->digits_used; + for (unsigned long i = 0; i < len; ++i) + { + uint64 num1 = static_cast(std::round(a[i*2].real())); + num1 += carry; + carry = 0; + if (num1 > 255) + { + carry = num1 >> 8; + num1 = (num1&0xFF); + } + + uint64 num2 = static_cast(std::round(a[i*2+1].real())); + num2 += carry; + carry = 0; + if (num2 > 255) + { + carry = num2 >> 8; + num2 = (num2&0xFF); + } + + // put the new number into its final place + num1 = (num2<<8) | num1; + result->number[i] = static_cast(num1); + + // keep track of the number of leading zeros + if (num1 == 0) + ++zeros; + else + zeros = 0; + ++(result->digits_used); + } + + // adjust digits_used so that it reflects the actual number + // of non-zero digits in our representation. + result->digits_used -= zeros; + + // if the result was zero then adjust the result accordingly + if (result->digits_used == 0) + { + // make result be zero + result->digits_used = 1; + *(result->number) = 0; + } + + // free all the temporary buffers + delete [] a; + delete [] b; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + shift_left ( + const data_record* data, + data_record* result, + uint32 shift_amount + ) const + { + uint32 offset = shift_amount/16; + shift_amount &= 0xf; // same as shift_amount %= 16; + + uint16* r = result->number + data->digits_used + offset; // result + uint16* end = data->number; + uint16* s = end + data->digits_used; // source + const uint32 temp = 16 - shift_amount; + + *r = (*(--s) >> temp); + // set the number of digits used in the result + // if the upper bits from *s were zero then don't count this first word + if (*r == 0) + { + result->digits_used = data->digits_used + offset; + } + else + { + result->digits_used = data->digits_used + offset + 1; + } + --r; + + while (s != end) + { + *r = ((*s << shift_amount) | ( *(s-1) >> temp)); + --r; + --s; + } + *r = *s << shift_amount; + + // now zero the rest of the result + end = result->number; + while (r != end) + *(--r) = 0; + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + shift_right ( + const data_record* data, + data_record* result + ) const + { + + uint16* r = result->number; // result + uint16* s = data->number; // source + uint16* end = s + data->digits_used - 1; + + while (s != end) + { + *r = (*s >> 1) | (*(s+1) << 15); + ++r; + ++s; + } + *r = *s >> 1; + + + // calculate the new number for digits_used + if (*r == 0) + { + if (data->digits_used != 1) + result->digits_used = data->digits_used - 1; + else + result->digits_used = 1; + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_2:: + is_less_than ( + const data_record* lhs, + const data_record* rhs + ) const + { + uint32 lhs_digits_used = lhs->digits_used; + uint32 rhs_digits_used = rhs->digits_used; + + // if lhs is definitely less than rhs + if (lhs_digits_used < rhs_digits_used ) + return true; + // if lhs is definitely greater than rhs + else if (lhs_digits_used > rhs_digits_used) + return false; + else + { + uint16* end = lhs->number; + uint16* l = end + lhs_digits_used; + uint16* r = rhs->number + rhs_digits_used; + + while (l != end) + { + --l; + --r; + if (*l < *r) + return true; + else if (*l > *r) + return false; + } + + // at this point we know that they are equal + return false; + } + + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_2:: + is_equal_to ( + const data_record* lhs, + const data_record* rhs + ) const + { + // if lhs and rhs are definitely not equal + if (lhs->digits_used != rhs->digits_used ) + { + return false; + } + else + { + uint16* l = lhs->number; + uint16* r = rhs->number; + uint16* end = l + lhs->digits_used; + + while (l != end) + { + if (*l != *r) + return false; + ++l; + ++r; + } + + // at this point we know that they are equal + return true; + } + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + increment ( + const data_record* source, + data_record* dest + ) const + { + uint16* s = source->number; + uint16* d = dest->number; + uint16* end = s + source->digits_used; + while (true) + { + *d = *s + 1; + + // if there was no carry then break out of the loop + if (*d != 0) + { + dest->digits_used = source->digits_used; + + // copy the rest of the digits over to d + ++d; ++s; + while (s != end) + { + *d = *s; + ++d; + ++s; + } + + break; + } + + + ++s; + + // if we have hit the end of s and there was a carry up to this point + // then just make the next digit 1 and add one to the digits used + if (s == end) + { + ++d; + dest->digits_used = source->digits_used + 1; + *d = 1; + break; + } + + ++d; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + decrement ( + const data_record* source, + data_record* dest + ) const + { + uint16* s = source->number; + uint16* d = dest->number; + uint16* end = s + source->digits_used; + while (true) + { + *d = *s - 1; + + // if there was no carry then break out of the loop + if (*d != 0xFFFF) + { + // if we lost a digit in the subtraction + if (*d == 0 && s+1 == end) + { + if (source->digits_used == 1) + dest->digits_used = 1; + else + dest->digits_used = source->digits_used - 1; + } + else + { + dest->digits_used = source->digits_used; + } + break; + } + else + { + ++d; + ++s; + } + + } + + // copy the rest of the digits over to d + ++d; + ++s; + while (s != end) + { + *d = *s; + ++d; + ++s; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + fft ( + ct* data, + unsigned long len + ) const + { + const t pi2 = -2.0*3.1415926535897932384626433832795028841971693993751; + + const unsigned long half = len/2; + + std::vector twiddle_factors; + twiddle_factors.resize(half); + + // compute the complex root of unity w + const t temp = pi2/len; + ct w = ct(std::cos(temp),std::sin(temp)); + + ct w_pow = 1; + + // compute the twiddle factors + for (std::vector::size_type j = 0; j < twiddle_factors.size(); ++j) + { + twiddle_factors[j] = w_pow; + w_pow *= w; + } + + ct a, b; + + // now compute the decimation in frequency. This first + // outer loop loops log2(len) number of times + unsigned long skip = 1; + for (unsigned long step = half; step != 0; step >>= 1) + { + // do blocks of butterflies in this loop + for (unsigned long j = 0; j < len; j += step*2) + { + // do step butterflies + for (unsigned long k = 0; k < step; ++k) + { + const unsigned long a_idx = j+k; + const unsigned long b_idx = j+k+step; + a = data[a_idx] + data[b_idx]; + b = (data[a_idx] - data[b_idx])*twiddle_factors[k*skip]; + data[a_idx] = a; + data[b_idx] = b; + } + } + skip *= 2; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + ifft( + ct* data, + unsigned long len + ) const + { + const t pi2 = 2.0*3.1415926535897932384626433832795028841971693993751; + + const unsigned long half = len/2; + + std::vector twiddle_factors; + twiddle_factors.resize(half); + + // compute the complex root of unity w + const t temp = pi2/len; + ct w = ct(std::cos(temp),std::sin(temp)); + + ct w_pow = 1; + + // compute the twiddle factors + for (std::vector::size_type j = 0; j < twiddle_factors.size(); ++j) + { + twiddle_factors[j] = w_pow; + w_pow *= w; + } + + ct a, b; + + // now compute the inverse decimation in frequency. This first + // outer loop loops log2(len) number of times + unsigned long skip = half; + for (unsigned long step = 1; step <= half; step <<= 1) + { + // do blocks of butterflies in this loop + for (unsigned long j = 0; j < len; j += step*2) + { + // do step butterflies + for (unsigned long k = 0; k < step; ++k) + { + const unsigned long a_idx = j+k; + const unsigned long b_idx = j+k+step; + data[b_idx] *= twiddle_factors[k*skip]; + a = data[a_idx] + data[b_idx]; + b = data[a_idx] - data[b_idx]; + data[a_idx] = a; + data[b_idx] = b; + } + } + skip /= 2; + } + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_BIGINT_KERNEL_2_CPp_ + diff --git a/dlib/bigint/bigint_kernel_2.h b/dlib/bigint/bigint_kernel_2.h new file mode 100644 index 0000000000000000000000000000000000000000..4c771cf372896aac8113a914f8d4627c6a6f5d98 --- /dev/null +++ b/dlib/bigint/bigint_kernel_2.h @@ -0,0 +1,569 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINT_KERNEl_2_ +#define DLIB_BIGINT_KERNEl_2_ + +#include "bigint_kernel_abstract.h" +#include "../algs.h" +#include "../serialize.h" +#include "../uintn.h" +#include +#include +#include +#include + +namespace dlib +{ + + class bigint_kernel_2 + { + /*! + INITIAL VALUE + slack == 25 + data->number[0] == 0 + data->size == slack + data->references == 1 + data->digits_used == 1 + + + CONVENTION + slack == the number of extra digits placed into the number when it is + created. the slack value should never be less than 1 + + data->number == pointer to an array of data->size uint16s. + data represents a string of base 65535 numbers with data[0] being + the least significant bit and data[data->digits_used-1] being the most + significant + + + NOTE: In the comments I will consider a word to be a 16 bit value + + + data->digits_used == the number of significant digits in the number. + data->digits_used tells us the number of used elements in the + data->number array so everything beyond data->number[data->digits_used-1] + is undefined + + data->references == the number of bigint_kernel_2 objects which refer + to this data_record + !*/ + + + struct data_record + { + + + explicit data_record( + uint32 size_ + ) : + size(size_), + number(new uint16[size_]), + references(1), + digits_used(1) + {*number = 0;} + /*! + ensures + - initializes *this to represent zero + !*/ + + data_record( + const data_record& item, + uint32 additional_size + ) : + size(item.digits_used + additional_size), + number(new uint16[size]), + references(1), + digits_used(item.digits_used) + { + uint16* source = item.number; + uint16* dest = number; + uint16* end = source + digits_used; + while (source != end) + { + *dest = *source; + ++dest; + ++source; + } + } + /*! + ensures + - *this is a copy of item except with + size == item.digits_used + additional_size + !*/ + + ~data_record( + ) + { + delete [] number; + } + + + const uint32 size; + uint16* number; + uint32 references; + uint32 digits_used; + + private: + // no copy constructor + data_record ( data_record&); + }; + + + // note that the second parameter is just there + // to resolve the ambiguity between this constructor and + // bigint_kernel_2(uint32) + explicit bigint_kernel_2 ( + data_record* data_, int + ): slack(25),data(data_) {} + /*! + ensures + - *this is initialized with data_ as its data member + !*/ + + public: + + bigint_kernel_2 ( + ); + + bigint_kernel_2 ( + uint32 value + ); + + bigint_kernel_2 ( + const bigint_kernel_2& item + ); + + virtual ~bigint_kernel_2 ( + ); + + const bigint_kernel_2 operator+ ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator+= ( + const bigint_kernel_2& rhs + ); + + const bigint_kernel_2 operator- ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator-= ( + const bigint_kernel_2& rhs + ); + + const bigint_kernel_2 operator* ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator*= ( + const bigint_kernel_2& rhs + ); + + const bigint_kernel_2 operator/ ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator/= ( + const bigint_kernel_2& rhs + ); + + const bigint_kernel_2 operator% ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator%= ( + const bigint_kernel_2& rhs + ); + + bool operator < ( + const bigint_kernel_2& rhs + ) const; + + bool operator == ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator= ( + const bigint_kernel_2& rhs + ); + + friend std::ostream& operator<< ( + std::ostream& out, + const bigint_kernel_2& rhs + ); + + friend std::istream& operator>> ( + std::istream& in, + bigint_kernel_2& rhs + ); + + bigint_kernel_2& operator++ ( + ); + + const bigint_kernel_2 operator++ ( + int + ); + + bigint_kernel_2& operator-- ( + ); + + const bigint_kernel_2 operator-- ( + int + ); + + friend const bigint_kernel_2 operator+ ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend const bigint_kernel_2 operator+ ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + bigint_kernel_2& operator+= ( + uint16 rhs + ); + + friend const bigint_kernel_2 operator- ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend const bigint_kernel_2 operator- ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + bigint_kernel_2& operator-= ( + uint16 rhs + ); + + friend const bigint_kernel_2 operator* ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend const bigint_kernel_2 operator* ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + bigint_kernel_2& operator*= ( + uint16 rhs + ); + + friend const bigint_kernel_2 operator/ ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend const bigint_kernel_2 operator/ ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + bigint_kernel_2& operator/= ( + uint16 rhs + ); + + friend const bigint_kernel_2 operator% ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend const bigint_kernel_2 operator% ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + bigint_kernel_2& operator%= ( + uint16 rhs + ); + + friend bool operator < ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend bool operator < ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + friend bool operator == ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + friend bool operator == ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + bigint_kernel_2& operator= ( + uint16 rhs + ); + + + void swap ( + bigint_kernel_2& item + ) { data_record* temp = data; data = item.data; item.data = temp; } + + + private: + + typedef double t; + typedef std::complex ct; + + void fft( + ct* data, + unsigned long len + ) const; + /*! + requires + - len == x^n for some integer n (i.e. len is a power of 2) + - len > 0 + ensures + - #data == the FT decimation in frequency of data + !*/ + + void ifft( + ct* data, + unsigned long len + ) const; + /*! + requires + - len == x^n for some integer n (i.e. len is a power of 2) + - len > 0 + ensures + - #data == the inverse decimation in frequency of data. + (i.e. the inverse of what fft(data,len,-1) does to data) + !*/ + + void long_add ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - result->size >= max(lhs->digits_used,rhs->digits_used) + 1 + ensures + - result == lhs + rhs + !*/ + + void long_sub ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - lhs >= rhs + - result->size >= lhs->digits_used + ensures + - result == lhs - rhs + !*/ + + void long_div ( + const data_record* lhs, + const data_record* rhs, + data_record* result, + data_record* remainder + ) const; + /*! + requires + - rhs != 0 + - result->size >= lhs->digits_used + - remainder->size >= lhs->digits_used + - each parameter is unique (i.e. lhs != result, lhs != remainder, etc.) + ensures + - result == lhs / rhs + - remainder == lhs % rhs + !*/ + + void long_mul ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - result->size >= lhs->digits_used + rhs->digits_used + - result != lhs + - result != rhs + ensures + - result == lhs * rhs + !*/ + + void short_add ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - result->size >= data->size + 1 + ensures + - result == data + value + !*/ + + void short_sub ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - data >= value + - result->size >= data->digits_used + ensures + - result == data - value + !*/ + + void short_mul ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - result->size >= data->digits_used + 1 + ensures + - result == data * value + !*/ + + void short_div ( + const data_record* data, + uint16 value, + data_record* result, + uint16& remainder + ) const; + /*! + requires + - value != 0 + - result->size >= data->digits_used + ensures + - result = data*value + - remainder = data%value + !*/ + + void shift_left ( + const data_record* data, + data_record* result, + uint32 shift_amount + ) const; + /*! + requires + - result->size >= data->digits_used + shift_amount/8 + 1 + ensures + - result == data << shift_amount + !*/ + + void shift_right ( + const data_record* data, + data_record* result + ) const; + /*! + requires + - result->size >= data->digits_used + ensures + - result == data >> 1 + !*/ + + bool is_less_than ( + const data_record* lhs, + const data_record* rhs + ) const; + /*! + ensures + - returns true if lhs < rhs + - returns false otherwise + !*/ + + bool is_equal_to ( + const data_record* lhs, + const data_record* rhs + ) const; + /*! + ensures + - returns true if lhs == rhs + - returns false otherwise + !*/ + + void increment ( + const data_record* source, + data_record* dest + ) const; + /*! + requires + - dest->size >= source->digits_used + 1 + ensures + - dest = source + 1 + !*/ + + void decrement ( + const data_record* source, + data_record* dest + ) const; + /*! + requires + source != 0 + ensuers + dest = source - 1 + !*/ + + // member data + const uint32 slack; + data_record* data; + + + + }; + + inline void swap ( + bigint_kernel_2& a, + bigint_kernel_2& b + ) { a.swap(b); } + + inline void serialize ( + const bigint_kernel_2& item, + std::ostream& out + ) + { + std::ios::fmtflags oldflags = out.flags(); + out << item << ' '; + out.flags(oldflags); + if (!out) throw serialization_error("Error serializing object of type bigint_kernel_c"); + } + + inline void deserialize ( + bigint_kernel_2& item, + std::istream& in + ) + { + std::ios::fmtflags oldflags = in.flags(); + in >> item; + in.flags(oldflags); + if (in.get() != ' ') + { + item = 0; + throw serialization_error("Error deserializing object of type bigint_kernel_c"); + } + } + + inline bool operator> (const bigint_kernel_2& a, const bigint_kernel_2& b) { return b < a; } + inline bool operator!= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(a == b); } + inline bool operator<= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(b < a); } + inline bool operator>= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(a < b); } + +} + +#ifdef NO_MAKEFILE +#include "bigint_kernel_2.cpp" +#endif + +#endif // DLIB_BIGINT_KERNEl_2_ + diff --git a/dlib/bigint/bigint_kernel_abstract.h b/dlib/bigint/bigint_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..644129c4be99b9bb82164aff1cf6870dc8674080 --- /dev/null +++ b/dlib/bigint/bigint_kernel_abstract.h @@ -0,0 +1,670 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BIGINT_KERNEl_ABSTRACT_ +#ifdef DLIB_BIGINT_KERNEl_ABSTRACT_ + +#include +#include "../algs.h" +#include "../serialize.h" +#include "../uintn.h" + +namespace dlib +{ + + class bigint + { + /*! + INITIAL VALUE + *this == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents an arbitrary precision unsigned integer + + the following operators are supported: + operator + + operator += + operator - + operator -= + operator * + operator *= + operator / + operator /= + operator % + operator %= + operator == + operator < + operator = + operator << (for writing to ostreams) + operator >> (for reading from istreams) + operator++ // pre increment + operator++(int) // post increment + operator-- // pre decrement + operator--(int) // post decrement + + + the other comparison operators(>, !=, <=, and >=) are + available and come from the templates in dlib::relational_operators + + THREAD SAFETY + bigint may be reference counted so it is very unthread safe. + use with care in a multithreaded program + + !*/ + + public: + + bigint ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + if this is thrown the bigint will be unusable but + will not leak memory + !*/ + + bigint ( + uint32 value + ); + /*! + requires + - value <= (2^32)-1 + ensures + - #*this is properly initialized + - #*this == value + throws + - std::bad_alloc + if this is thrown the bigint will be unusable but + will not leak memory + !*/ + + bigint ( + const bigint& item + ); + /*! + ensures + - #*this is properly initialized + - #*this == value + throws + - std::bad_alloc + if this is thrown the bigint will be unusable but + will not leak memory + !*/ + + virtual ~bigint ( + ); + /*! + ensures + - all resources associated with #*this have been released + !*/ + + const bigint operator+ ( + const bigint& rhs + ) const; + /*! + ensures + - returns the result of adding rhs to *this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator+= ( + const bigint& rhs + ); + /*! + ensures + - #*this == *this + rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator- ( + const bigint& rhs + ) const; + /*! + requires + - *this >= rhs + ensures + - returns the result of subtracting rhs from *this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator-= ( + const bigint& rhs + ); + /*! + requires + - *this >= rhs + ensures + - #*this == *this - rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator* ( + const bigint& rhs + ) const; + /*! + ensures + - returns the result of multiplying *this and rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator*= ( + const bigint& rhs + ); + /*! + ensures + - #*this == *this * rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator/ ( + const bigint& rhs + ) const; + /*! + requires + - rhs != 0 + ensures + - returns the result of dividing *this by rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator/= ( + const bigint& rhs + ); + /*! + requires + - rhs != 0 + ensures + - #*this == *this / rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator% ( + const bigint& rhs + ) const; + /*! + requires + - rhs != 0 + ensures + - returns the result of *this mod rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator%= ( + const bigint& rhs + ); + /*! + requires + - rhs != 0 + ensures + - #*this == *this % rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bool operator < ( + const bigint& rhs + ) const; + /*! + ensures + - returns true if *this is less than rhs + - returns false otherwise + !*/ + + bool operator == ( + const bigint& rhs + ) const; + /*! + ensures + - returns true if *this and rhs represent the same number + - returns false otherwise + !*/ + + bigint& operator= ( + const bigint& rhs + ); + /*! + ensures + - #*this == rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + + friend std::ostream& operator<< ( + std::ostream& out, + const bigint& rhs + ); + /*! + ensures + - the number in *this has been written to #out as a base ten number + throws + - std::bad_alloc + if this function throws then it has no effect (nothing + is written to out) + !*/ + + friend std::istream& operator>> ( + std::istream& in, + bigint& rhs + ); + /*! + ensures + - reads a number from in and puts it into #*this + - if (there is no positive base ten number on the input stream ) then + - #in.fail() == true + throws + - std::bad_alloc + if this function throws the value in rhs is undefined and some + characters may have been read from in. rhs is still usable though, + its value is just unknown. + !*/ + + + bigint& operator++ ( + ); + /*! + ensures + - #*this == *this + 1 + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator++ ( + int + ); + /*! + ensures + - #*this == *this + 1 + - returns *this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator-- ( + ); + /*! + requires + - *this != 0 + ensures + - #*this == *this - 1 + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator-- ( + int + ); + /*! + requires + - *this != 0 + ensures + - #*this == *this - 1 + - returns *this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + void swap ( + bigint& item + ); + /*! + ensures + - swaps *this and item + !*/ + + + // ------------------------------------------------------------------ + // ---- The following functions are identical to the above ----- + // ---- but take uint16 as one of their arguments. They --- + // ---- exist only to allow for a more efficient implementation --- + // ------------------------------------------------------------------ + + + friend const bigint operator+ ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - lhs <= 65535 + ensures + - returns the result of adding rhs to lhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator+ ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - returns the result of adding rhs to lhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator+= ( + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - #*this == *this + rhs + - returns #this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator- ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - lhs >= rhs + - lhs <= 65535 + ensures + - returns the result of subtracting rhs from lhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator- ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - lhs >= rhs + - rhs <= 65535 + ensures + - returns the result of subtracting rhs from lhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator-= ( + uint16 rhs + ); + /*! + requires + - *this >= rhs + - rhs <= 65535 + ensures + - #*this == *this - rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator* ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - lhs <= 65535 + ensures + - returns the result of multiplying lhs and rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator* ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - returns the result of multiplying lhs and rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator*= ( + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - #*this == *this * rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator/ ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - rhs != 0 + - lhs <= 65535 + ensures + - returns the result of dividing lhs by rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator/ ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs != 0 + - rhs <= 65535 + ensures + - returns the result of dividing lhs by rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator/= ( + uint16 rhs + ); + /*! + requires + - rhs != 0 + - rhs <= 65535 + ensures + - #*this == *this / rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator% ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - rhs != 0 + - lhs <= 65535 + ensures + - returns the result of lhs mod rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator% ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs != 0 + - rhs <= 65535 + ensures + - returns the result of lhs mod rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator%= ( + uint16 rhs + ); + /*! + requires + - rhs != 0 + - rhs <= 65535 + ensures + - #*this == *this % rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + + friend bool operator < ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - lhs <= 65535 + ensures + - returns true if lhs is less than rhs + - returns false otherwise + !*/ + + friend bool operator < ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - returns true if lhs is less than rhs + - returns false otherwise + !*/ + + friend bool operator == ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - returns true if lhs and rhs represent the same number + - returns false otherwise + !*/ + + friend bool operator == ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - lhs <= 65535 + ensures + - returns true if lhs and rhs represent the same number + - returns false otherwise + !*/ + + bigint& operator= ( + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - #*this == rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + }; + + inline void swap ( + bigint& a, + bigint& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + void serialize ( + const bigint& item, + std::istream& in + ); + /*! + provides serialization support + !*/ + + void deserialize ( + bigint& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + + inline bool operator> (const bigint& a, const bigint& b) { return b < a; } + inline bool operator!= (const bigint& a, const bigint& b) { return !(a == b); } + inline bool operator<= (const bigint& a, const bigint& b) { return !(b < a); } + inline bool operator>= (const bigint& a, const bigint& b) { return !(a < b); } +} + +#endif // DLIB_BIGINT_KERNEl_ABSTRACT_ + diff --git a/dlib/bigint/bigint_kernel_c.h b/dlib/bigint/bigint_kernel_c.h new file mode 100644 index 0000000000000000000000000000000000000000..47089730af5472af32ddc40f43a8ff9df3c8033d --- /dev/null +++ b/dlib/bigint/bigint_kernel_c.h @@ -0,0 +1,1140 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINT_KERNEl_C_ +#define DLIB_BIGINT_KERNEl_C_ + +#include "bigint_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + class bigint_kernel_c + { + bigint_base data; + + explicit bigint_kernel_c ( + const bigint_base& item + ) : data(item) {} + + public: + + + bigint_kernel_c ( + ); + + bigint_kernel_c ( + uint32 value + ); + + bigint_kernel_c ( + const bigint_kernel_c& item + ); + + ~bigint_kernel_c ( + ); + + const bigint_kernel_c operator+ ( + const bigint_kernel_c& rhs + ) const; + + bigint_kernel_c& operator+= ( + const bigint_kernel_c& rhs + ); + + const bigint_kernel_c operator- ( + const bigint_kernel_c& rhs + ) const; + bigint_kernel_c& operator-= ( + const bigint_kernel_c& rhs + ); + + const bigint_kernel_c operator* ( + const bigint_kernel_c& rhs + ) const; + + bigint_kernel_c& operator*= ( + const bigint_kernel_c& rhs + ); + + const bigint_kernel_c operator/ ( + const bigint_kernel_c& rhs + ) const; + + bigint_kernel_c& operator/= ( + const bigint_kernel_c& rhs + ); + + const bigint_kernel_c operator% ( + const bigint_kernel_c& rhs + ) const; + + bigint_kernel_c& operator%= ( + const bigint_kernel_c& rhs + ); + + bool operator < ( + const bigint_kernel_c& rhs + ) const; + + bool operator == ( + const bigint_kernel_c& rhs + ) const; + + bigint_kernel_c& operator= ( + const bigint_kernel_c& rhs + ); + + template + friend std::ostream& operator<< ( + std::ostream& out, + const bigint_kernel_c& rhs + ); + + template + friend std::istream& operator>> ( + std::istream& in, + bigint_kernel_c& rhs + ); + + bigint_kernel_c& operator++ ( + ); + + const bigint_kernel_c operator++ ( + int + ); + + bigint_kernel_c& operator-- ( + ); + + const bigint_kernel_c operator-- ( + int + ); + + template + friend const bigint_kernel_c operator+ ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend const bigint_kernel_c operator+ ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + bigint_kernel_c& operator+= ( + uint16 rhs + ); + + template + friend const bigint_kernel_c operator- ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend const bigint_kernel_c operator- ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + bigint_kernel_c& operator-= ( + uint16 rhs + ); + + template + friend const bigint_kernel_c operator* ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend const bigint_kernel_c operator* ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + bigint_kernel_c& operator*= ( + uint16 rhs + ); + + template + friend const bigint_kernel_c operator/ ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend const bigint_kernel_c operator/ ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + bigint_kernel_c& operator/= ( + uint16 rhs + ); + + template + friend const bigint_kernel_c operator% ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend const bigint_kernel_c operator% ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + bigint_kernel_c& operator%= ( + uint16 rhs + ); + + template + friend bool operator < ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend bool operator < ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + template + friend bool operator == ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + template + friend bool operator == ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + bigint_kernel_c& operator= ( + uint16 rhs + ); + + + void swap ( + bigint_kernel_c& item + ) { data.swap(item.data); } + + }; + + template < + typename bigint_base + > + void swap ( + bigint_kernel_c& a, + bigint_kernel_c& b + ) { a.swap(b); } + + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + inline void serialize ( + const bigint_kernel_c& item, + std::ostream& out + ) + { + std::ios::fmtflags oldflags = out.flags(); + out << item << ' '; + out.flags(oldflags); + if (!out) throw serialization_error("Error serializing object of type bigint_kernel_c"); + } + + template < + typename bigint_base + > + inline void deserialize ( + bigint_kernel_c& item, + std::istream& in + ) + { + std::ios::fmtflags oldflags = in.flags(); + in >> item; + in.flags(oldflags); + if (in.get() != ' ') + { + item = 0; + throw serialization_error("Error deserializing object of type bigint_kernel_c"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c:: + bigint_kernel_c ( + ) + {} + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c:: + bigint_kernel_c ( + uint32 value + ) : + data(value) + { + // make sure requires clause is not broken + DLIB_CASSERT( value <= 0xFFFFFFFF , + "\tbigint::bigint(uint16)" + << "\n\t value must be <= (2^32)-1" + << "\n\tthis: " << this + << "\n\tvalue: " << value + ); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c:: + bigint_kernel_c ( + const bigint_kernel_c& item + ) : + data(item.data) + {} + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c:: + ~bigint_kernel_c ( + ) + {} + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator+ ( + const bigint_kernel_c& rhs + ) const + { + return bigint_kernel_c(data + rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator+= ( + const bigint_kernel_c& rhs + ) + { + data += rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator- ( + const bigint_kernel_c& rhs + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( !(*this < rhs), + "\tconst bigint bigint::operator-(const bigint&)" + << "\n\t *this should not be less than rhs" + << "\n\tthis: " << this + << "\n\t*this: " << *this + << "\n\trhs: " << rhs + ); + + // call the real function + return bigint_kernel_c(data-rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator-= ( + const bigint_kernel_c& rhs + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(*this < rhs), + "\tbigint& bigint::operator-=(const bigint&)" + << "\n\t *this should not be less than rhs" + << "\n\tthis: " << this + << "\n\t*this: " << *this + << "\n\trhs: " << rhs + ); + + // call the real function + data -= rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator* ( + const bigint_kernel_c& rhs + ) const + { + return bigint_kernel_c(data * rhs.data ); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator*= ( + const bigint_kernel_c& rhs + ) + { + data *= rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator/ ( + const bigint_kernel_c& rhs + ) const + { + //make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0), + "\tconst bigint bigint::operator/(const bigint&)" + << "\n\t can't divide by zero" + << "\n\tthis: " << this + ); + + // call the real function + return bigint_kernel_c(data/rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator/= ( + const bigint_kernel_c& rhs + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0), + "\tbigint& bigint::operator/=(const bigint&)" + << "\n\t can't divide by zero" + << "\n\tthis: " << this + ); + + // call the real function + data /= rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator% ( + const bigint_kernel_c& rhs + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0), + "\tconst bigint bigint::operator%(const bigint&)" + << "\n\t can't divide by zero" + << "\n\tthis: " << this + ); + + // call the real function + return bigint_kernel_c(data%rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator%= ( + const bigint_kernel_c& rhs + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0), + "\tbigint& bigint::operator%=(const bigint&)" + << "\n\t can't divide by zero" + << "\n\tthis: " << this + ); + + // call the real function + data %= rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool bigint_kernel_c:: + operator < ( + const bigint_kernel_c& rhs + ) const + { + return data < rhs.data; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool bigint_kernel_c:: + operator == ( + const bigint_kernel_c& rhs + ) const + { + return data == rhs.data; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator= ( + const bigint_kernel_c& rhs + ) + { + data = rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + std::ostream& operator<< ( + std::ostream& out, + const bigint_kernel_c& rhs + ) + { + out << rhs.data; + return out; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + std::istream& operator>> ( + std::istream& in, + bigint_kernel_c& rhs + ) + { + in >> rhs.data; + return in; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator++ ( + ) + { + ++data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator++ ( + int + ) + { + return bigint_kernel_c(data++); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator-- ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(*this == 0), + "\tbigint& bigint::operator--()" + << "\n\t *this to subtract from *this it must not be zero to begin with" + << "\n\tthis: " << this + ); + + // call the real function + --data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator-- ( + int + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(*this == 0), + "\tconst bigint bigint::operator--(int)" + << "\n\t *this to subtract from *this it must not be zero to begin with" + << "\n\tthis: " << this + ); + + // call the real function + return bigint_kernel_c(data--); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator+ ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( lhs <= 65535, + "\tconst bigint operator+(uint16, const bigint&)" + << "\n\t lhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return bigint_kernel_c(static_cast(lhs)+rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator+ ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tconst bigint operator+(const bigint&, uint16)" + << "\n\t rhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return bigint_kernel_c(lhs.data+static_cast(rhs)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator+= ( + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tbigint& bigint::operator+=(uint16)" + << "\n\t rhs must be <= 65535" + << "\n\tthis: " << this + << "\n\t*this: " << *this + << "\n\trhs: " << rhs + ); + + data += rhs; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator- ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( !(static_cast(lhs) < rhs) && lhs <= 65535, + "\tconst bigint operator-(uint16,const bigint&)" + << "\n\t lhs must be greater than or equal to rhs and lhs <= 65535" + << "\n\tlhs: " << lhs + << "\n\trhs: " << rhs + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + ); + + // call the real function + return bigint_kernel_c(static_cast(lhs)-rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator- ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( !(lhs < static_cast(rhs)) && rhs <= 65535, + "\tconst bigint operator-(const bigint&,uint16)" + << "\n\t lhs must be greater than or equal to rhs and rhs <= 65535" + << "\n\tlhs: " << lhs + << "\n\trhs: " << rhs + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + ); + + // call the real function + return bigint_kernel_c(lhs.data-static_cast(rhs)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator-= ( + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( !(*this < static_cast(rhs)) && rhs <= 65535, + "\tbigint& bigint::operator-=(uint16)" + << "\n\t *this must not be less than rhs and rhs <= 65535" + << "\n\tthis: " << this + << "\n\t*this: " << *this + << "\n\trhs: " << rhs + ); + + // call the real function + data -= static_cast(rhs); + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator* ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( lhs <= 65535, + "\tconst bigint operator*(uint16, const bigint&)" + << "\n\t lhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return bigint_kernel_c(lhs*rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator* ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tconst bigint operator*(const bigint&, uint16)" + << "\n\t rhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return bigint_kernel_c(lhs.data*rhs); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator*= ( + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\t bigint bigint::operator*=(uint16)" + << "\n\t rhs must be <= 65535" + << "\n\tthis: " << this + << "\n\t*this: " << *this + << "\n\trhs: " << rhs + ); + + data *= static_cast(rhs); + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator/ ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && lhs <= 65535, + "\tconst bigint operator/(uint16,const bigint&)" + << "\n\t you can't divide by zero and lhs <= 65535" + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + << "\n\tlhs: " << lhs + ); + + // call the real function + return bigint_kernel_c(lhs/rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator/ ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && rhs <= 65535, + "\tconst bigint operator/(const bigint&,uint16)" + << "\n\t you can't divide by zero and rhs <= 65535" + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + << "\n\trhs: " << rhs + ); + + // call the real function + return bigint_kernel_c(lhs.data/static_cast(rhs)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator/= ( + uint16 rhs + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && static_cast(rhs) <= 65535, + "\tbigint& bigint::operator/=(uint16)" + << "\n\t you can't divide by zero and rhs must be <= 65535" + << "\n\tthis: " << this + << "\n\trhs: " << rhs + ); + + // call the real function + data /= rhs; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator% ( + uint16 lhs, + const bigint_kernel_c& rhs + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && static_cast(lhs) <= 65535, + "\tconst bigint operator%(uint16,const bigint&)" + << "\n\t you can't divide by zero and lhs must be <= 65535" + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + << "\n\tlhs: " << lhs + ); + + // call the real function + return bigint_kernel_c(lhs%rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator% ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && rhs <= 65535, + "\tconst bigint operator%(const bigint&,uint16)" + << "\n\t you can't divide by zero and rhs must be <= 65535" + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + << "\n\trhs: " << rhs + ); + + // call the real function + return bigint_kernel_c(lhs.data%static_cast(rhs)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator%= ( + uint16 r + ) + { + + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && rhs <= 65535, + "\tbigint& bigint::operator%=(uint16)" + << "\n\t you can't divide by zero and rhs must be <= 65535" + << "\n\tthis: " << this + << "\n\trhs: " << rhs + ); + + // call the real function + data %= rhs; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool operator < ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( lhs <= 65535, + "\tbool operator<(uint16, const bigint&)" + << "\n\t lhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return static_cast(lhs) < rhs.data; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool operator < ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tbool operator<(const bigint&, uint16)" + << "\n\t rhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return lhs.data < static_cast(rhs); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool operator == ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tbool operator==(const bigint&, uint16)" + << "\n\t rhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return lhs.data == static_cast(rhs); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool operator == ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( lhs <= 65535, + "\tbool operator==(uint16, const bigint&)" + << "\n\t lhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return static_cast(lhs) == rhs.data; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator= ( + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tbigint bigint::operator=(uint16)" + << "\n\t rhs must be <= 65535" + << "\n\t*this: " << *this + << "\n\tthis: " << this + << "\n\tlhs: " << rhs + ); + + data = static_cast(rhs); + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < typename bigint_base > + inline bool operator> (const bigint_kernel_c& a, const bigint_kernel_c& b) { return b < a; } + template < typename bigint_base > + inline bool operator!= (const bigint_kernel_c& a, const bigint_kernel_c& b) { return !(a == b); } + template < typename bigint_base > + inline bool operator<= (const bigint_kernel_c& a, const bigint_kernel_c& b) { return !(b < a); } + template < typename bigint_base > + inline bool operator>= (const bigint_kernel_c& a, const bigint_kernel_c& b) { return !(a < b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BIGINT_KERNEl_C_ + diff --git a/dlib/binary_search_tree.h b/dlib/binary_search_tree.h new file mode 100644 index 0000000000000000000000000000000000000000..5273e8ce9230544f31d29c913f4b75518474a525 --- /dev/null +++ b/dlib/binary_search_tree.h @@ -0,0 +1,50 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREe_ +#define DLIB_BINARY_SEARCH_TREe_ + + +#include "binary_search_tree/binary_search_tree_kernel_1.h" +#include "binary_search_tree/binary_search_tree_kernel_2.h" +#include "binary_search_tree/binary_search_tree_kernel_c.h" + + +#include "algs.h" +#include + + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class binary_search_tree + { + binary_search_tree() {} + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef binary_search_tree_kernel_1 + kernel_1a; + typedef binary_search_tree_kernel_c + kernel_1a_c; + + + // kernel_2a + typedef binary_search_tree_kernel_2 + kernel_2a; + typedef binary_search_tree_kernel_c + kernel_2a_c; + + }; +} + +#endif // DLIB_BINARY_SEARCH_TREe_ + diff --git a/dlib/binary_search_tree/binary_search_tree_kernel_1.h b/dlib/binary_search_tree/binary_search_tree_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..f0a0cfbd820b93431dd092d5d4667ffe306509ed --- /dev/null +++ b/dlib/binary_search_tree/binary_search_tree_kernel_1.h @@ -0,0 +1,2064 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_1_ +#define DLIB_BINARY_SEARCH_TREE_KERNEl_1_ + +#include "binary_search_tree_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/map_pair.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include +#include + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager, + typename compare = std::less + > + class binary_search_tree_kernel_1 : public enumerable >, + public asc_pair_remover + { + + /*! + INITIAL VALUE + tree_size == 0 + tree_root == 0 + tree_height == 0 + at_start_ == true + current_element == 0 + stack == array of 50 node pointers + stack_pos == 0 + + + CONVENTION + tree_size == size() + tree_height == height() + + stack[stack_pos-1] == pop() + + current_element_valid() == (current_element != 0) + if (current_element_valid()) then + element() == current_element->d and current_element->r + at_start_ == at_start() + if (current_element != 0 && current_element != tree_root) then + stack[stack_pos-1] == the parent of the node pointed to by current_element + + if (tree_size != 0) + tree_root == pointer to the root node of the binary search tree + else + tree_root == 0 + + + for all nodes: + { + left points to the left subtree or 0 if there is no left subtree and + right points to the right subtree or 0 if there is no right subtree and + all elements in a left subtree are <= the root and + all elements in a right subtree are >= the root and + d is the item in the domain of *this contained in the node + r is the item in the range of *this contained in the node + balance: + balance == 0 if both subtrees have the same height + balance == -1 if the left subtree has a height that is greater + than the height of the right subtree by 1 + balance == 1 if the right subtree has a height that is greater + than the height of the left subtree by 1 + for all trees: + the height of the left and right subtrees differ by at most one + } + + !*/ + + class node + { + public: + node* left; + node* right; + domain d; + range r; + signed char balance; + }; + + class mpair : public map_pair + { + public: + const domain* d; + range* r; + + const domain& key( + ) const { return *d; } + + const range& value( + ) const { return *r; } + + range& value( + ) { return *r; } + }; + + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + binary_search_tree_kernel_1( + ) : + tree_size(0), + tree_root(0), + current_element(0), + tree_height(0), + at_start_(true), + stack_pos(0), + stack(ppool.allocate_array(50)) + { + } + + virtual ~binary_search_tree_kernel_1( + ); + + inline void clear( + ); + + inline short height ( + ) const; + + inline unsigned long count ( + const domain& item + ) const; + + inline void add ( + domain& d, + range& r + ); + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + void destroy ( + const domain& item + ); + + inline const range* operator[] ( + const domain& item + ) const; + + inline range* operator[] ( + const domain& item + ); + + inline void swap ( + binary_search_tree_kernel_1& item + ); + + // function from the asc_pair_remover interface + void remove_any ( + domain& d, + range& r + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + const map_pair& element ( + ) const; + + map_pair& element ( + ); + + bool move_next ( + ) const; + + void remove_last_in_order ( + domain& d, + range& r + ); + + void remove_current_element ( + domain& d, + range& r + ); + + void position_enumerator ( + const domain& d + ) const; + + private: + + + inline void rotate_left ( + node*& t + ); + /*! + requires + - t->balance == 2 + - t->right->balance == 0 or 1 + - t == reference to the pointer in t's parent node that points to t + ensures + - #t is still a binary search tree + - #t->balance is between 1 and -1 + - #t now has a height smaller by 1 if #t->balance == 0 + !*/ + + inline void rotate_right ( + node*& t + ); + /*! + requires + - t->balance == -2 + - t->left->balance == 0 or -1 + - t == reference to the pointer in t's parent node that points to t + ensures + - #t is still a binary search tree + - #t->balance is between 1 and -1 + - #t now has a height smaller by 1 if #t->balance == 0 + + !*/ + + inline void double_rotate_right ( + node*& t + ); + /*! + requires + - t->balance == -2 + - t->left->balance == 1 + - t == reference to the pointer in t's parent node that points to t + ensures + - #t is still a binary search tree + - #t now has a balance of 0 + - #t now has a height smaller by 1 + !*/ + + inline void double_rotate_left ( + node*& t + ); + /*! + requires + - t->balance == 2 + - t->right->balance == -1 + - t == reference to the pointer in t's parent node that points to t + ensures + - #t is still a binary search tree + - #t now has a balance of 0 + - #t now has a height smaller by 1 + !*/ + + bool remove_biggest_element_in_tree ( + node*& t, + domain& d, + range& r + ); + /*! + requires + - t != 0 (i.e. there must be something in the tree to remove) + - t == reference to the pointer in t's parent node that points to t + ensures + - the biggest node in t has been removed + - the biggest node domain element in t has been put into #d + - the biggest node range element in t has been put into #r + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + bool remove_least_element_in_tree ( + node*& t, + domain& d, + range& r + ); + /*! + requires + - t != 0 (i.e. there must be something in the tree to remove) + - t == reference to the pointer in t's parent node that points to t + ensures + - the least node in t has been removed + - the least node domain element in t has been put into #d + - the least node range element in t has been put into #r + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + bool add_to_tree ( + node*& t, + domain& d, + range& r + ); + /*! + requires + - t == reference to the pointer in t's parent node that points to t + ensures + - the mapping (d --> r) has been added to #t + - #d and #r have initial values for their types + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has grown by one + !*/ + + bool remove_from_tree ( + node*& t, + const domain& d, + domain& d_copy, + range& r + ); + /*! + requires + - return_reference(t,d) != 0 + - t == reference to the pointer in t's parent node that points to t + ensures + - #d_copy is equivalent to d + - an element in t equivalent to d has been removed and swapped + into #d_copy and its associated range object has been + swapped into #r + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + bool remove_from_tree ( + node*& t, + const domain& item + ); + /*! + requires + - return_reference(t,item) != 0 + - t == reference to the pointer in t's parent node that points to t + ensures + - an element in t equivalent to item has been removed + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + const range* return_reference ( + const node* t, + const domain& d + ) const; + /*! + ensures + - if (there is a domain element equivalent to d in t) then + - returns a pointer to the element in the range equivalent to d + - else + - returns 0 + !*/ + + range* return_reference ( + node* t, + const domain& d + ); + /*! + ensures + - if (there is a domain element equivalent to d in t) then + - returns a pointer to the element in the range equivalent to d + - else + - returns 0 + !*/ + + + inline bool keep_node_balanced ( + node*& t + ); + /*! + requires + - t != 0 + - t == reference to the pointer in t's parent node that points to t + ensures + - if (t->balance is < 1 or > 1) then + - keep_node_balanced() will ensure that #t->balance == 0, -1, or 1 + - #t is still a binary search tree + - returns true if it made the tree one height shorter + - returns false if it didn't change the height + !*/ + + + unsigned long get_count ( + const domain& item, + node* tree_root + ) const; + /*! + requires + - tree_root == the root of a binary search tree or 0 + ensures + - if (tree_root == 0) then + - returns 0 + - else + - returns the number of elements in tree_root that are + equivalent to item + !*/ + + + void delete_tree ( + node* t + ); + /*! + requires + - t != 0 + ensures + - deallocates the node pointed to by t and all of t's left and right children + !*/ + + + void push ( + node* n + ) const { stack[stack_pos] = n; ++stack_pos; } + /*! + ensures + - pushes n onto the stack + !*/ + + + node* pop ( + ) const { --stack_pos; return stack[stack_pos]; } + /*! + ensures + - pops the top of the stack and returns it + !*/ + + + + bool fix_stack ( + node* t, + unsigned char depth = 0 + ); + /*! + requires + - current_element != 0 + - depth == 0 + - t == tree_root + ensures + - makes the stack contain the correct set of parent pointers. + also adjusts stack_pos so it is correct. + - #t is still a binary search tree + !*/ + + bool remove_current_element_from_tree ( + node*& t, + domain& d, + range& r, + unsigned long cur_stack_pos = 1 + ); + /*! + requires + - t == tree_root + - cur_stack_pos == 1 + - current_element != 0 + ensures + - removes the data in the node given by current_element and swaps it into + #d and #r. + - #t is still a binary search tree + - the enumerator is advances on to the next element but its stack is + potentially corrupted. so you must call fix_stack(tree_root) to fix + it. + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + + // data members + + mutable mpair p; + unsigned long tree_size; + node* tree_root; + mutable node* current_element; + typename mem_manager::template rebind::other pool; + typename mem_manager::template rebind::other ppool; + short tree_height; + mutable bool at_start_; + mutable unsigned char stack_pos; + mutable node** stack; + compare comp; + + // restricted functions + binary_search_tree_kernel_1(binary_search_tree_kernel_1&); + binary_search_tree_kernel_1& operator=(binary_search_tree_kernel_1&); + + + }; + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + inline void swap ( + binary_search_tree_kernel_1& a, + binary_search_tree_kernel_1& b + ) { a.swap(b); } + + + + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void deserialize ( + binary_search_tree_kernel_1& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + domain d; + range r; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(d,in); + deserialize(r,in); + item.add(d,r); + } + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type binary_search_tree_kernel_1"); + } + } + + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + binary_search_tree_kernel_1:: + ~binary_search_tree_kernel_1 ( + ) + { + ppool.deallocate_array(stack); + if (tree_size != 0) + { + delete_tree(tree_root); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + clear ( + ) + { + if (tree_size > 0) + { + delete_tree(tree_root); + tree_root = 0; + tree_size = 0; + tree_height = 0; + } + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + size_t binary_search_tree_kernel_1:: + size ( + ) const + { + return tree_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + short binary_search_tree_kernel_1:: + height ( + ) const + { + return tree_height; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + unsigned long binary_search_tree_kernel_1:: + count ( + const domain& item + ) const + { + return get_count(item,tree_root); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + add ( + domain& d, + range& r + ) + { + tree_height += add_to_tree(tree_root,d,r); + ++tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + remove ( + const domain& d, + domain& d_copy, + range& r + ) + { + tree_height -= remove_from_tree(tree_root,d,d_copy,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + destroy ( + const domain& item + ) + { + tree_height -= remove_from_tree(tree_root,item); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + remove_any ( + domain& d, + range& r + ) + { + tree_height -= remove_least_element_in_tree(tree_root,d,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + range* binary_search_tree_kernel_1:: + operator[] ( + const domain& item + ) + { + return return_reference(tree_root,item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const range* binary_search_tree_kernel_1:: + operator[] ( + const domain& item + ) const + { + return return_reference(tree_root,item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + swap ( + binary_search_tree_kernel_1& item + ) + { + pool.swap(item.pool); + ppool.swap(item.ppool); + exchange(p,item.p); + exchange(stack,item.stack); + exchange(stack_pos,item.stack_pos); + exchange(comp,item.comp); + + + node* tree_root_temp = item.tree_root; + unsigned long tree_size_temp = item.tree_size; + short tree_height_temp = item.tree_height; + node* current_element_temp = item.current_element; + bool at_start_temp = item.at_start_; + + item.tree_root = tree_root; + item.tree_size = tree_size; + item.tree_height = tree_height; + item.current_element = current_element; + item.at_start_ = at_start_; + + tree_root = tree_root_temp; + tree_size = tree_size_temp; + tree_height = tree_height_temp; + current_element = current_element_temp; + at_start_ = at_start_temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + remove_last_in_order ( + domain& d, + range& r + ) + { + tree_height -= remove_biggest_element_in_tree(tree_root,d,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + remove_current_element ( + domain& d, + range& r + ) + { + tree_height -= remove_current_element_from_tree(tree_root,d,r); + --tree_size; + + // fix the enumerator stack if we need to + if (current_element) + fix_stack(tree_root); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + position_enumerator ( + const domain& d + ) const + { + // clear the enumerator state and make sure the stack is empty + reset(); + at_start_ = false; + node* t = tree_root; + bool went_left = false; + while (t != 0) + { + if ( comp(d , t->d) ) + { + push(t); + // if item is on the left then look in left + t = t->left; + went_left = true; + } + else if (comp(t->d , d)) + { + push(t); + // if item is on the right then look in right + t = t->right; + went_left = false; + } + else + { + current_element = t; + return; + } + } + + // if we didn't find any matches but there might be something after the + // d in this tree. + if (stack_pos > 0) + { + current_element = pop(); + // if we went left from this node then this node is the next + // biggest. + if (went_left) + { + return; + } + else + { + move_next(); + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + at_start ( + ) const + { + return at_start_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + reset ( + ) const + { + at_start_ = true; + current_element = 0; + stack_pos = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + current_element_valid ( + ) const + { + return (current_element != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const map_pair& binary_search_tree_kernel_1:: + element ( + ) const + { + p.d = &(current_element->d); + p.r = &(current_element->r); + return p; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + map_pair& binary_search_tree_kernel_1:: + element ( + ) + { + p.d = &(current_element->d); + p.r = &(current_element->r); + return p; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + move_next ( + ) const + { + // if we haven't started iterating yet + if (at_start_) + { + at_start_ = false; + if (tree_size == 0) + { + return false; + } + else + { + // find the first element in the tree + current_element = tree_root; + node* temp = current_element->left; + while (temp != 0) + { + push(current_element); + current_element = temp; + temp = current_element->left; + } + return true; + } + } + else + { + if (current_element == 0) + { + return false; + } + else + { + node* temp; + bool went_up; // true if we went up the tree from a child node to parent + bool from_left = false; // true if we went up and were coming from a left child node + // find the next element in the tree + if (current_element->right != 0) + { + // go right and down + temp = current_element; + push(current_element); + current_element = temp->right; + went_up = false; + } + else + { + // go up to the parent if we can + if (current_element == tree_root) + { + // in this case we have iterated over all the element of the tree + current_element = 0; + return false; + } + went_up = true; + node* parent = pop(); + + + from_left = (parent->left == current_element); + // go up to parent + current_element = parent; + } + + + while (true) + { + if (went_up) + { + if (from_left) + { + // in this case we have found the next node + break; + } + else + { + if (current_element == tree_root) + { + // in this case we have iterated over all the elements + // in the tree + current_element = 0; + return false; + } + // we should go up + node* parent = pop(); + from_left = (parent->left == current_element); + current_element = parent; + } + } + else + { + // we just went down to a child node + if (current_element->left != 0) + { + // go left + went_up = false; + temp = current_element; + push(current_element); + current_element = temp->left; + } + else + { + // if there is no left child then we have found the next node + break; + } + } + } + + return true; + } + } + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + delete_tree ( + node* t + ) + { + if (t->left != 0) + delete_tree(t->left); + if (t->right != 0) + delete_tree(t->right); + pool.deallocate(t); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + rotate_left ( + node*& t + ) + { + + // set the new balance numbers + if (t->right->balance == 1) + { + t->balance = 0; + t->right->balance = 0; + } + else + { + t->balance = 1; + t->right->balance = -1; + } + + // perform the rotation + node* temp = t->right; + t->right = temp->left; + temp->left = t; + t = temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + rotate_right ( + node*& t + ) + { + // set the new balance numbers + if (t->left->balance == -1) + { + t->balance = 0; + t->left->balance = 0; + } + else + { + t->balance = -1; + t->left->balance = 1; + } + + // preform the rotation + node* temp = t->left; + t->left = temp->right; + temp->right = t; + t = temp; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + double_rotate_right ( + node*& t + ) + { + + node* temp = t; + t = t->left->right; + + temp->left->right = t->left; + t->left = temp->left; + + temp->left = t->right; + t->right = temp; + + if (t->balance < 0) + { + t->left->balance = 0; + t->right->balance = 1; + } + else if (t->balance > 0) + { + t->left->balance = -1; + t->right->balance = 0; + } + else + { + t->left->balance = 0; + t->right->balance = 0; + } + t->balance = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + double_rotate_left ( + node*& t + ) + { + node* temp = t; + t = t->right->left; + + temp->right->left = t->right; + t->right = temp->right; + + temp->right = t->left; + t->left = temp; + + if (t->balance < 0) + { + t->left->balance = 0; + t->right->balance = 1; + } + else if (t->balance > 0) + { + t->left->balance = -1; + t->right->balance = 0; + } + else + { + t->left->balance = 0; + t->right->balance = 0; + } + + t->balance = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + remove_biggest_element_in_tree ( + node*& t, + domain& d, + range& r + ) + { + // make a reference to the current node so we don't have to dereference a + // pointer a bunch of times + node& tree = *t; + + // if the right tree is an empty tree + if ( tree.right == 0) + { + // swap nodes domain and range elements into d and r + exchange(d,tree.d); + exchange(r,tree.r); + + // plug hole left by removing this node + t = tree.left; + + // delete the node that was just removed + pool.deallocate(&tree); + + // return that the height of this part of the tree has decreased + return true; + } + else + { + + // keep going right + + // if remove made the tree one height shorter + if ( remove_biggest_element_in_tree(tree.right,d,r) ) + { + // if this caused the current tree to strink then report that + if ( tree.balance == 1) + { + --tree.balance; + return true; + } + else + { + --tree.balance; + return keep_node_balanced(t); + } + } + + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + remove_least_element_in_tree ( + node*& t, + domain& d, + range& r + ) + { + // make a reference to the current node so we don't have to dereference a + // pointer a bunch of times + node& tree = *t; + + // if the left tree is an empty tree + if ( tree.left == 0) + { + // swap nodes domain and range elements into d and r + exchange(d,tree.d); + exchange(r,tree.r); + + // plug hole left by removing this node + t = tree.right; + + // delete the node that was just removed + pool.deallocate(&tree); + + // return that the height of this part of the tree has decreased + return true; + } + else + { + + // keep going left + + // if remove made the tree one height shorter + if ( remove_least_element_in_tree(tree.left,d,r) ) + { + // if this caused the current tree to strink then report that + if ( tree.balance == -1) + { + ++tree.balance; + return true; + } + else + { + ++tree.balance; + return keep_node_balanced(t); + } + } + + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + add_to_tree ( + node*& t, + domain& d, + range& r + ) + { + + // if found place to add + if (t == 0) + { + // create a node to add new item into + t = pool.allocate(); + + // make a reference to the current node so we don't have to dereference a + // pointer a bunch of times + node& tree = *t; + + + // set left and right pointers to NULL to indicate that there are no + // left or right subtrees + tree.left = 0; + tree.right = 0; + tree.balance = 0; + + // put d and r into t + exchange(tree.d,d); + exchange(tree.r,r); + + // indicate that the height of this tree has increased + return true; + } + else // keep looking for a place to add the new item + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + signed char old_balance = tree.balance; + + // add the new item to whatever subtree it should go into + if (comp( d , tree.d) ) + tree.balance -= add_to_tree(tree.left,d,r); + else + tree.balance += add_to_tree(tree.right,d,r); + + + // if the tree was balanced to start with + if (old_balance == 0) + { + // if its not balanced anymore then it grew in height + if (tree.balance != 0) + return true; + else + return false; + } + else + { + // if the tree is now balanced then it didn't grow + if (tree.balance == 0) + { + return false; + } + else + { + // if the tree needs to be balanced + if (tree.balance != old_balance) + { + return !keep_node_balanced(t); + } + // if there has been no change in the heights + else + { + return false; + } + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + fix_stack ( + node* t, + unsigned char depth + ) + { + // if we found the node we were looking for + if (t == current_element) + { + stack_pos = depth; + return true; + } + else if (t == 0) + { + return false; + } + + if (!( comp(t->d , current_element->d))) + { + // go left + if (fix_stack(t->left,depth+1)) + { + stack[depth] = t; + return true; + } + } + if (!(comp(current_element->d , t->d))) + { + // go right + if (fix_stack(t->right,depth+1)) + { + stack[depth] = t; + return true; + } + } + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + remove_current_element_from_tree ( + node*& t, + domain& d, + range& r, + unsigned long cur_stack_pos + ) + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + // if we found the node we were looking for + if (t == current_element) + { + + // swap nodes domain and range elements into d_copy and r + exchange(d,tree.d); + exchange(r,tree.r); + + // if there is no left node + if (tree.left == 0) + { + // move the enumerator on to the next element before we mess with the + // tree + move_next(); + + // plug hole left by removing this node and free memory + t = tree.right; // plug hole with right subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height has changed + return true; + } + // if there is no right node + else if (tree.right == 0) + { + // move the enumerator on to the next element before we mess with the + // tree + move_next(); + + // plug hole left by removing this node and free memory + t = tree.left; // plug hole with left subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height of this tree has changed + return true; + } + // if there are both a left and right sub node + else + { + + // in this case the next current element is going to get swapped back + // into this t node. + current_element = t; + + // get an element that can replace the one being removed and do this + // if it made the right subtree shrink by one + if (remove_least_element_in_tree(tree.right,tree.d,tree.r)) + { + // adjust the tree height + --tree.balance; + + // if the height of the current tree has dropped by one + if (tree.balance == 0) + { + return true; + } + else + { + return keep_node_balanced(t); + } + } + // else this remove did not effect the height of this tree + else + { + return false; + } + + } + + } + else if ( (cur_stack_pos < stack_pos && stack[cur_stack_pos] == tree.left) || + tree.left == current_element ) + { + // go left + if (tree.balance == -1) + { + int balance = tree.balance; + balance += remove_current_element_from_tree(tree.left,d,r,cur_stack_pos+1); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance += remove_current_element_from_tree(tree.left,d,r,cur_stack_pos+1); + tree.balance = balance; + return keep_node_balanced(t); + } + } + else if ( (cur_stack_pos < stack_pos && stack[cur_stack_pos] == tree.right) || + tree.right == current_element ) + { + // go right + if (tree.balance == 1) + { + int balance = tree.balance; + balance -= remove_current_element_from_tree(tree.right,d,r,cur_stack_pos+1); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance -= remove_current_element_from_tree(tree.right,d,r,cur_stack_pos+1); + tree.balance = balance; + return keep_node_balanced(t); + } + } + + // this return should never happen but do it anyway to suppress compiler warnings + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + remove_from_tree ( + node*& t, + const domain& d, + domain& d_copy, + range& r + ) + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + // if item is on the left + if (comp(d , tree.d)) + { + // if the left side of the tree has the greatest height + if (tree.balance == -1) + { + int balance = tree.balance; + balance += remove_from_tree(tree.left,d,d_copy,r); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance += remove_from_tree(tree.left,d,d_copy,r); + tree.balance = balance; + return keep_node_balanced(t); + } + + } + // if item is on the right + else if (comp(tree.d , d)) + { + + // if the right side of the tree has the greatest height + if (tree.balance == 1) + { + int balance = tree.balance; + balance -= remove_from_tree(tree.right,d,d_copy,r); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance -= remove_from_tree(tree.right,d,d_copy,r); + tree.balance = balance; + return keep_node_balanced(t); + } + } + // if item is found + else + { + + // swap nodes domain and range elements into d_copy and r + exchange(d_copy,tree.d); + exchange(r,tree.r); + + // if there is no left node + if (tree.left == 0) + { + + // plug hole left by removing this node and free memory + t = tree.right; // plug hole with right subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height has changed + return true; + } + // if there is no right node + else if (tree.right == 0) + { + + // plug hole left by removing this node and free memory + t = tree.left; // plug hole with left subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height of this tree has changed + return true; + } + // if there are both a left and right sub node + else + { + + // get an element that can replace the one being removed and do this + // if it made the right subtree shrink by one + if (remove_least_element_in_tree(tree.right,tree.d,tree.r)) + { + // adjust the tree height + --tree.balance; + + // if the height of the current tree has dropped by one + if (tree.balance == 0) + { + return true; + } + else + { + return keep_node_balanced(t); + } + } + // else this remove did not effect the height of this tree + else + { + return false; + } + + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + remove_from_tree ( + node*& t, + const domain& d + ) + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + // if item is on the left + if (comp(d , tree.d)) + { + // if the left side of the tree has the greatest height + if (tree.balance == -1) + { + int balance = tree.balance; + balance += remove_from_tree(tree.left,d); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance += remove_from_tree(tree.left,d); + tree.balance = balance; + return keep_node_balanced(t); + } + + } + // if item is on the right + else if (comp(tree.d , d)) + { + + // if the right side of the tree has the greatest height + if (tree.balance == 1) + { + int balance = tree.balance; + balance -= remove_from_tree(tree.right,d); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance -= remove_from_tree(tree.right,d); + tree.balance = balance; + return keep_node_balanced(t); + } + } + // if item is found + else + { + + // if there is no left node + if (tree.left == 0) + { + + // plug hole left by removing this node and free memory + t = tree.right; // plug hole with right subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height has changed + return true; + } + // if there is no right node + else if (tree.right == 0) + { + + // plug hole left by removing this node and free memory + t = tree.left; // plug hole with left subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height of this tree has changed + return true; + } + // if there are both a left and right sub node + else + { + + // get an element that can replace the one being removed and do this + // if it made the right subtree shrink by one + if (remove_least_element_in_tree(tree.right,tree.d,tree.r)) + { + // adjust the tree height + --tree.balance; + + // if the height of the current tree has dropped by one + if (tree.balance == 0) + { + return true; + } + else + { + return keep_node_balanced(t); + } + } + // else this remove did not effect the height of this tree + else + { + return false; + } + + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + range* binary_search_tree_kernel_1:: + return_reference ( + node* t, + const domain& d + ) + { + while (t != 0) + { + + if ( comp(d , t->d )) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // if it's found then return a reference to it + return &(t->r); + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const range* binary_search_tree_kernel_1:: + return_reference ( + const node* t, + const domain& d + ) const + { + while (t != 0) + { + + if ( comp(d , t->d) ) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // if it's found then return a reference to it + return &(t->r); + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + keep_node_balanced ( + node*& t + ) + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + // if tree does not need to be balanced then return false + if (tree.balance == 0) + return false; + + + // if tree needs to be rotated left + if (tree.balance == 2) + { + if (tree.right->balance >= 0) + rotate_left(t); + else + double_rotate_left(t); + } + // else if the tree needs to be rotated right + else if (tree.balance == -2) + { + if (tree.left->balance <= 0) + rotate_right(t); + else + double_rotate_right(t); + } + + + if (t->balance == 0) + return true; + else + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + unsigned long binary_search_tree_kernel_1:: + get_count ( + const domain& d, + node* tree_root + ) const + { + if (tree_root != 0) + { + if (comp(d , tree_root->d)) + { + // go left + return get_count(d,tree_root->left); + } + else if (comp(tree_root->d , d)) + { + // go right + return get_count(d,tree_root->right); + } + else + { + // go left and right to look for more matches + return get_count(d,tree_root->left) + + get_count(d,tree_root->right) + + 1; + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_1_ + diff --git a/dlib/binary_search_tree/binary_search_tree_kernel_2.h b/dlib/binary_search_tree/binary_search_tree_kernel_2.h new file mode 100644 index 0000000000000000000000000000000000000000..83866f27188769e9fc0aff1e7a9a91cd906e17e1 --- /dev/null +++ b/dlib/binary_search_tree/binary_search_tree_kernel_2.h @@ -0,0 +1,1897 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_2_ +#define DLIB_BINARY_SEARCH_TREE_KERNEl_2_ + +#include "binary_search_tree_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/map_pair.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager, + typename compare = std::less + > + class binary_search_tree_kernel_2 : public enumerable >, + public asc_pair_remover + { + + /*! + INITIAL VALUE + NIL == pointer to a node that represents a leaf + tree_size == 0 + tree_root == NIL + at_start == true + current_element == 0 + + + CONVENTION + current_element_valid() == (current_element != 0) + if (current_element_valid()) then + element() == current_element->d and current_element->r + at_start_ == at_start() + + + tree_size == size() + + NIL == pointer to a node that represents a leaf + + if (tree_size != 0) + tree_root == pointer to the root node of the binary search tree + else + tree_root == NIL + + tree_root->color == black + Every leaf is black and all leafs are the NIL node. + The number of black nodes in any path from the root to a leaf is the + same. + + for all nodes: + { + - left points to the left subtree or NIL if there is no left subtree + - right points to the right subtree or NIL if there is no right + subtree + - parent points to the parent node or NIL if the node is the root + - ordering of nodes is determined by comparing each node's d member + - all elements in a left subtree are <= the node + - all elements in a right subtree are >= the node + - color == red or black + - if (color == red) + - the node's children are black + } + + !*/ + + class node + { + public: + node* left; + node* right; + node* parent; + domain d; + range r; + char color; + }; + + class mpair : public map_pair + { + public: + const domain* d; + range* r; + + const domain& key( + ) const { return *d; } + + const range& value( + ) const { return *r; } + + range& value( + ) { return *r; } + }; + + + const static char red = 0; + const static char black = 1; + + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + binary_search_tree_kernel_2( + ) : + NIL(pool.allocate()), + tree_size(0), + tree_root(NIL), + current_element(0), + at_start_(true) + { + NIL->color = black; + NIL->left = 0; + NIL->right = 0; + NIL->parent = 0; + } + + virtual ~binary_search_tree_kernel_2( + ); + + inline void clear( + ); + + inline short height ( + ) const; + + inline unsigned long count ( + const domain& d + ) const; + + inline void add ( + domain& d, + range& r + ); + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + void destroy ( + const domain& d + ); + + void remove_any ( + domain& d, + range& r + ); + + inline const range* operator[] ( + const domain& item + ) const; + + inline range* operator[] ( + const domain& item + ); + + inline void swap ( + binary_search_tree_kernel_2& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + const map_pair& element ( + ) const; + + map_pair& element ( + ); + + bool move_next ( + ) const; + + void remove_last_in_order ( + domain& d, + range& r + ); + + void remove_current_element ( + domain& d, + range& r + ); + + void position_enumerator ( + const domain& d + ) const; + + private: + + inline void rotate_left ( + node* t + ); + /*! + requires + - t != NIL + - t->right != NIL + ensures + - performs a left rotation around t and its right child + !*/ + + inline void rotate_right ( + node* t + ); + /*! + requires + - t != NIL + - t->left != NIL + ensures + - performs a right rotation around t and its left child + !*/ + + inline void double_rotate_right ( + node* t + ); + /*! + requires + - t != NIL + - t->left != NIL + - t->left->right != NIL + - double_rotate_right() is only called in fix_after_add() + ensures + - performs a left rotation around t->left + - then performs a right rotation around t + !*/ + + inline void double_rotate_left ( + node* t + ); + /*! + requires + - t != NIL + - t->right != NIL + - t->right->left != NIL + - double_rotate_left() is only called in fix_after_add() + ensures + - performs a right rotation around t->right + - then performs a left rotation around t + !*/ + + void remove_biggest_element_in_tree ( + node* t, + domain& d, + range& r + ); + /*! + requires + - t != NIL (i.e. there must be something in the tree to remove) + ensures + - the biggest node in t has been removed + - the biggest node element in t has been put into #d and #r + - #t is still a binary search tree + !*/ + + bool remove_least_element_in_tree ( + node* t, + domain& d, + range& r + ); + /*! + requires + - t != NIL (i.e. there must be something in the tree to remove) + ensures + - the least node in t has been removed + - the least node element in t has been put into #d and #r + - #t is still a binary search tree + - if (the node that was removed was the one pointed to by current_element) then + - returns true + - else + - returns false + !*/ + + void add_to_tree ( + node* t, + domain& d, + range& r + ); + /*! + requires + - t != NIL + ensures + - d and r are now in #t + - there is a mapping from d to r in #t + - #d and #r have initial values for their types + - #t is still a binary search tree + !*/ + + void remove_from_tree ( + node* t, + const domain& d, + domain& d_copy, + range& r + ); + /*! + requires + - return_reference(t,d) != 0 + ensures + - #d_copy is equivalent to d + - the first element in t equivalent to d that is encountered when searching down the tree + from t has been removed and swapped into #d_copy. Also, the associated range element + has been removed and swapped into #r. + - if (the node that got removed wasn't current_element) then + - adjusts the current_element pointer if the data in the node that it points to gets moved. + - else + - the value of current_element is now invalid + - #t is still a binary search tree + !*/ + + void remove_from_tree ( + node* t, + const domain& d + ); + /*! + requires + - return_reference(t,d) != 0 + ensures + - an element in t equivalent to d has been removed + - #t is still a binary search tree + !*/ + + const range* return_reference ( + const node* t, + const domain& d + ) const; + /*! + ensures + - if (there is a domain element equivalent to d in t) then + - returns a pointer to the element in the range equivalent to d + - else + - returns 0 + !*/ + + range* return_reference ( + node* t, + const domain& d + ); + /*! + ensures + - if (there is a domain element equivalent to d in t) then + - returns a pointer to the element in the range equivalent to d + - else + - returns 0 + !*/ + + void fix_after_add ( + node* t + ); + /*! + requires + - t == pointer to the node just added + - t->color == red + - t->parent != NIL (t must not be the root) + - fix_after_add() is only called after a new node has been added + to t + ensures + - fixes any deviations from the CONVENTION caused by adding a node + !*/ + + void fix_after_remove ( + node* t + ); + /*! + requires + - t == pointer to the only child of the node that was spliced out + - fix_after_remove() is only called after a node has been removed + from t + - the color of the spliced out node was black + ensures + - fixes any deviations from the CONVENTION causes by removing a node + !*/ + + + short tree_height ( + node* t + ) const; + /*! + ensures + - returns the number of nodes in the longest path from the root of the + tree to a leaf + !*/ + + void delete_tree ( + node* t + ); + /*! + requires + - t == root of binary search tree + - t != NIL + ensures + - deletes all nodes in t except for NIL + !*/ + + unsigned long get_count ( + const domain& item, + node* tree_root + ) const; + /*! + requires + - tree_root == the root of a binary search tree or NIL + ensures + - if (tree_root == NIL) then + - returns 0 + - else + - returns the number of elements in tree_root that are + equivalent to item + !*/ + + + + // data members + typename mem_manager::template rebind::other pool; + node* NIL; + unsigned long tree_size; + node* tree_root; + mutable node* current_element; + mutable bool at_start_; + mutable mpair p; + compare comp; + + + + // restricted functions + binary_search_tree_kernel_2(binary_search_tree_kernel_2&); + binary_search_tree_kernel_2& operator=(binary_search_tree_kernel_2&); + + + }; + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + inline void swap ( + binary_search_tree_kernel_2& a, + binary_search_tree_kernel_2& b + ) { a.swap(b); } + + + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void deserialize ( + binary_search_tree_kernel_2& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + domain d; + range r; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(d,in); + deserialize(r,in); + item.add(d,r); + } + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type binary_search_tree_kernel_2"); + } + } + + + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + binary_search_tree_kernel_2:: + ~binary_search_tree_kernel_2 ( + ) + { + if (tree_root != NIL) + delete_tree(tree_root); + pool.deallocate(NIL); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + clear ( + ) + { + if (tree_size > 0) + { + delete_tree(tree_root); + tree_root = NIL; + tree_size = 0; + } + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + size_t binary_search_tree_kernel_2:: + size ( + ) const + { + return tree_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + short binary_search_tree_kernel_2:: + height ( + ) const + { + return tree_height(tree_root); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + unsigned long binary_search_tree_kernel_2:: + count ( + const domain& item + ) const + { + return get_count(item,tree_root); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + add ( + domain& d, + range& r + ) + { + if (tree_size == 0) + { + tree_root = pool.allocate(); + tree_root->color = black; + tree_root->left = NIL; + tree_root->right = NIL; + tree_root->parent = NIL; + exchange(tree_root->d,d); + exchange(tree_root->r,r); + } + else + { + add_to_tree(tree_root,d,r); + } + ++tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove ( + const domain& d, + domain& d_copy, + range& r + ) + { + remove_from_tree(tree_root,d,d_copy,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + destroy ( + const domain& item + ) + { + remove_from_tree(tree_root,item); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_any ( + domain& d, + range& r + ) + { + remove_least_element_in_tree(tree_root,d,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + range* binary_search_tree_kernel_2:: + operator[] ( + const domain& d + ) + { + return return_reference(tree_root,d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const range* binary_search_tree_kernel_2:: + operator[] ( + const domain& d + ) const + { + return return_reference(tree_root,d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + swap ( + binary_search_tree_kernel_2& item + ) + { + pool.swap(item.pool); + + exchange(p,item.p); + exchange(comp,item.comp); + + node* tree_root_temp = item.tree_root; + unsigned long tree_size_temp = item.tree_size; + node* const NIL_temp = item.NIL; + node* current_element_temp = item.current_element; + bool at_start_temp = item.at_start_; + + item.tree_root = tree_root; + item.tree_size = tree_size; + item.NIL = NIL; + item.current_element = current_element; + item.at_start_ = at_start_; + + tree_root = tree_root_temp; + tree_size = tree_size_temp; + NIL = NIL_temp; + current_element = current_element_temp; + at_start_ = at_start_temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_last_in_order ( + domain& d, + range& r + ) + { + remove_biggest_element_in_tree(tree_root,d,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_current_element ( + domain& d, + range& r + ) + { + node* t = current_element; + move_next(); + remove_from_tree(t,t->d,d,r); + --tree_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + position_enumerator ( + const domain& d + ) const + { + // clear the enumerator state and make sure the stack is empty + reset(); + at_start_ = false; + node* t = tree_root; + node* parent = NIL; + bool went_left = false; + while (t != NIL) + { + if ( comp(d , t->d )) + { + // if item is on the left then look in left + parent = t; + t = t->left; + went_left = true; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + parent = t; + t = t->right; + went_left = false; + } + else + { + current_element = t; + return; + } + } + + // if we didn't find any matches but there might be something after the + // d in this tree. + if (parent != NIL) + { + current_element = parent; + // if we went left from this node then this node is the next + // biggest. + if (went_left) + { + return; + } + else + { + move_next(); + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_2:: + at_start ( + ) const + { + return at_start_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + reset ( + ) const + { + at_start_ = true; + current_element = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_2:: + current_element_valid ( + ) const + { + return (current_element != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const map_pair& binary_search_tree_kernel_2:: + element ( + ) const + { + p.d = &(current_element->d); + p.r = &(current_element->r); + return p; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + map_pair& binary_search_tree_kernel_2:: + element ( + ) + { + p.d = &(current_element->d); + p.r = &(current_element->r); + return p; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_2:: + move_next ( + ) const + { + // if we haven't started iterating yet + if (at_start_) + { + at_start_ = false; + if (tree_size == 0) + { + return false; + } + else + { + // find the first element in the tree + current_element = tree_root; + node* temp = current_element->left; + while (temp != NIL) + { + current_element = temp; + temp = current_element->left; + } + return true; + } + } + else + { + if (current_element == 0) + { + return false; + } + else + { + bool went_up; // true if we went up the tree from a child node to parent + bool from_left = false; // true if we went up and were coming from a left child node + // find the next element in the tree + if (current_element->right != NIL) + { + // go right and down + current_element = current_element->right; + went_up = false; + } + else + { + went_up = true; + node* parent = current_element->parent; + if (parent == NIL) + { + // in this case we have iterated over all the element of the tree + current_element = 0; + return false; + } + + from_left = (parent->left == current_element); + // go up to parent + current_element = parent; + } + + + while (true) + { + if (went_up) + { + if (from_left) + { + // in this case we have found the next node + break; + } + else + { + // we should go up + node* parent = current_element->parent; + from_left = (parent->left == current_element); + current_element = parent; + if (current_element == NIL) + { + // in this case we have iterated over all the elements + // in the tree + current_element = 0; + return false; + } + } + } + else + { + // we just went down to a child node + if (current_element->left != NIL) + { + // go left + went_up = false; + current_element = current_element->left; + } + else + { + // if there is no left child then we have found the next node + break; + } + } + } + + return true; + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + delete_tree ( + node* t + ) + { + if (t->left != NIL) + delete_tree(t->left); + if (t->right != NIL) + delete_tree(t->right); + pool.deallocate(t); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + rotate_left ( + node* t + ) + { + + // perform the rotation + node* temp = t->right; + t->right = temp->left; + if (temp->left != NIL) + temp->left->parent = t; + temp->left = t; + temp->parent = t->parent; + + + if (t == tree_root) + tree_root = temp; + else + { + // if t was on the left + if (t->parent->left == t) + t->parent->left = temp; + else + t->parent->right = temp; + } + + t->parent = temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + rotate_right ( + node* t + ) + { + // perform the rotation + node* temp = t->left; + t->left = temp->right; + if (temp->right != NIL) + temp->right->parent = t; + temp->right = t; + temp->parent = t->parent; + + if (t == tree_root) + tree_root = temp; + else + { + // if t is a left child + if (t->parent->left == t) + t->parent->left = temp; + else + t->parent->right = temp; + } + + t->parent = temp; + } + + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + double_rotate_right ( + node* t + ) + { + + // preform the rotation + node& temp = *(t->left->right); + t->left = temp.right; + temp.right->parent = t; + temp.left->parent = temp.parent; + temp.parent->right = temp.left; + temp.parent->parent = &temp; + temp.right = t; + temp.left = temp.parent; + temp.parent = t->parent; + + + if (tree_root == t) + tree_root = &temp; + else + { + // t is a left child + if (t->parent->left == t) + t->parent->left = &temp; + else + t->parent->right = &temp; + } + t->parent = &temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + double_rotate_left ( + node* t + ) + { + + + // preform the rotation + node& temp = *(t->right->left); + t->right = temp.left; + temp.left->parent = t; + temp.right->parent = temp.parent; + temp.parent->left = temp.right; + temp.parent->parent = &temp; + temp.left = t; + temp.right = temp.parent; + temp.parent = t->parent; + + + if (tree_root == t) + tree_root = &temp; + else + { + // t is a left child + if (t->parent->left == t) + t->parent->left = &temp; + else + t->parent->right = &temp; + } + t->parent = &temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_biggest_element_in_tree ( + node* t, + domain& d, + range& r + ) + { + + node* next = t->right; + node* child; // the child node of the one we will slice out + + if (next == NIL) + { + // need to determine if t is a right or left child + if (t->parent->right == t) + child = t->parent->right = t->left; + else + child = t->parent->left = t->left; + + // update tree_root if necessary + if (t == tree_root) + tree_root = child; + } + else + { + // find the least node + do + { + t = next; + next = next->right; + } while (next != NIL); + // t is a right child + child = t->parent->right = t->left; + + } + + // swap the item from this node into d and r + exchange(d,t->d); + exchange(r,t->r); + + // plug hole right by removing this node + child->parent = t->parent; + + // keep the red-black properties true + if (t->color == black) + fix_after_remove(child); + + // free the memory for this removed node + pool.deallocate(t); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_2:: + remove_least_element_in_tree ( + node* t, + domain& d, + range& r + ) + { + + node* next = t->left; + node* child; // the child node of the one we will slice out + + if (next == NIL) + { + // need to determine if t is a left or right child + if (t->parent->left == t) + child = t->parent->left = t->right; + else + child = t->parent->right = t->right; + + // update tree_root if necessary + if (t == tree_root) + tree_root = child; + } + else + { + // find the least node + do + { + t = next; + next = next->left; + } while (next != NIL); + // t is a left child + child = t->parent->left = t->right; + + } + + // swap the item from this node into d and r + exchange(d,t->d); + exchange(r,t->r); + + // plug hole left by removing this node + child->parent = t->parent; + + // keep the red-black properties true + if (t->color == black) + fix_after_remove(child); + + bool rvalue = (t == current_element); + // free the memory for this removed node + pool.deallocate(t); + return rvalue; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + add_to_tree ( + node* t, + domain& d, + range& r + ) + { + // parent of the current node + node* parent; + + // find a place to add node + while (true) + { + parent = t; + // if item should be put on the left then go left + if (comp(d , t->d)) + { + t = t->left; + if (t == NIL) + { + t = parent->left = pool.allocate(); + break; + } + } + // if item should be put on the right then go right + else + { + t = t->right; + if (t == NIL) + { + t = parent->right = pool.allocate(); + break; + } + } + } + + // t is now the node where we will add item and + // parent is the parent of t + + t->parent = parent; + t->left = NIL; + t->right = NIL; + t->color = red; + exchange(t->d,d); + exchange(t->r,r); + + + // keep the red-black properties true + fix_after_add(t); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_from_tree ( + node* t, + const domain& d, + domain& d_copy, + range& r + ) + { + while (true) + { + if ( comp(d , t->d) ) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // found the node we want to remove + + // swap out the item into d_copy and r + exchange(d_copy,t->d); + exchange(r,t->r); + + if (t->left == NIL) + { + // if there is no left subtree + + node* parent = t->parent; + + // plug hole with right subtree + + + // if t is on the left + if (parent->left == t) + parent->left = t->right; + else + parent->right = t->right; + t->right->parent = parent; + + // update tree_root if necessary + if (t == tree_root) + tree_root = t->right; + + if (t->color == black) + fix_after_remove(t->right); + + // delete old node + pool.deallocate(t); + } + else if (t->right == NIL) + { + // if there is no right subtree + + node* parent = t->parent; + + // plug hole with left subtree + if (parent->left == t) + parent->left = t->left; + else + parent->right = t->left; + t->left->parent = parent; + + // update tree_root if necessary + if (t == tree_root) + tree_root = t->left; + + if (t->color == black) + fix_after_remove(t->left); + + // delete old node + pool.deallocate(t); + } + else + { + // if there is both a left and right subtree + // get an element to fill this node now that its been swapped into + // item_copy + if (remove_least_element_in_tree(t->right,t->d,t->r)) + { + // the node removed was the one pointed to by current_element so we + // need to update it so that it points to the right spot. + current_element = t; + } + } + + // quit loop + break; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_from_tree ( + node* t, + const domain& d + ) + { + while (true) + { + if ( comp(d , t->d) ) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // found the node we want to remove + + + if (t->left == NIL) + { + // if there is no left subtree + + node* parent = t->parent; + + // plug hole with right subtree + + + if (parent->left == t) + parent->left = t->right; + else + parent->right = t->right; + t->right->parent = parent; + + // update tree_root if necessary + if (t == tree_root) + tree_root = t->right; + + if (t->color == black) + fix_after_remove(t->right); + + // delete old node + pool.deallocate(t); + } + else if (t->right == NIL) + { + // if there is no right subtree + + node* parent = t->parent; + + // plug hole with left subtree + if (parent->left == t) + parent->left = t->left; + else + parent->right = t->left; + t->left->parent = parent; + + // update tree_root if necessary + if (t == tree_root) + tree_root = t->left; + + if (t->color == black) + fix_after_remove(t->left); + + // delete old node + pool.deallocate(t); + } + else + { + // if there is both a left and right subtree + // get an element to fill this node now that its been swapped into + // item_copy + remove_least_element_in_tree(t->right,t->d,t->r); + + } + + // quit loop + break; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + range* binary_search_tree_kernel_2:: + return_reference ( + node* t, + const domain& d + ) + { + while (t != NIL) + { + if ( comp(d , t->d )) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // if it's found then return a reference to it + return &(t->r); + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const range* binary_search_tree_kernel_2:: + return_reference ( + const node* t, + const domain& d + ) const + { + while (t != NIL) + { + if ( comp(d , t->d) ) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // if it's found then return a reference to it + return &(t->r); + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + fix_after_add ( + node* t + ) + { + + while (t->parent->color == red) + { + node& grandparent = *(t->parent->parent); + + // if both t's parent and its sibling are red + if (grandparent.left->color == grandparent.right->color) + { + grandparent.color = red; + grandparent.left->color = black; + grandparent.right->color = black; + t = &grandparent; + } + else + { + // if t is a left child + if (t == t->parent->left) + { + // if t's parent is a left child + if (t->parent == grandparent.left) + { + grandparent.color = red; + grandparent.left->color = black; + rotate_right(&grandparent); + } + // if t's parent is a right child + else + { + t->color = black; + grandparent.color = red; + double_rotate_left(&grandparent); + } + } + // if t is a right child + else + { + // if t's parent is a left child + if (t->parent == grandparent.left) + { + t->color = black; + grandparent.color = red; + double_rotate_right(&grandparent); + } + // if t's parent is a right child + else + { + grandparent.color = red; + grandparent.right->color = black; + rotate_left(&grandparent); + } + } + break; + } + } + tree_root->color = black; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + fix_after_remove ( + node* t + ) + { + + while (t != tree_root && t->color == black) + { + if (t->parent->left == t) + { + node* sibling = t->parent->right; + if (sibling->color == red) + { + sibling->color = black; + t->parent->color = red; + rotate_left(t->parent); + sibling = t->parent->right; + } + + if (sibling->left->color == black && sibling->right->color == black) + { + sibling->color = red; + t = t->parent; + } + else + { + if (sibling->right->color == black) + { + sibling->left->color = black; + sibling->color = red; + rotate_right(sibling); + sibling = t->parent->right; + } + + sibling->color = t->parent->color; + t->parent->color = black; + sibling->right->color = black; + rotate_left(t->parent); + t = tree_root; + + } + + + } + else + { + + node* sibling = t->parent->left; + if (sibling->color == red) + { + sibling->color = black; + t->parent->color = red; + rotate_right(t->parent); + sibling = t->parent->left; + } + + if (sibling->left->color == black && sibling->right->color == black) + { + sibling->color = red; + t = t->parent; + } + else + { + if (sibling->left->color == black) + { + sibling->right->color = black; + sibling->color = red; + rotate_left(sibling); + sibling = t->parent->left; + } + + sibling->color = t->parent->color; + t->parent->color = black; + sibling->left->color = black; + rotate_right(t->parent); + t = tree_root; + + } + + + } + + } + t->color = black; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + short binary_search_tree_kernel_2:: + tree_height ( + node* t + ) const + { + if (t == NIL) + return 0; + + short height1 = tree_height(t->left); + short height2 = tree_height(t->right); + if (height1 > height2) + return height1 + 1; + else + return height2 + 1; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + unsigned long binary_search_tree_kernel_2:: + get_count ( + const domain& d, + node* tree_root + ) const + { + if (tree_root != NIL) + { + if (comp(d , tree_root->d)) + { + // go left + return get_count(d,tree_root->left); + } + else if (comp(tree_root->d , d)) + { + // go right + return get_count(d,tree_root->right); + } + else + { + // go left and right to look for more matches + return get_count(d,tree_root->left) + + get_count(d,tree_root->right) + + 1; + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_2_ + diff --git a/dlib/binary_search_tree/binary_search_tree_kernel_abstract.h b/dlib/binary_search_tree/binary_search_tree_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..2abfe7e3955315cffbd90ef3ed897580ebfa73e4 --- /dev/null +++ b/dlib/binary_search_tree/binary_search_tree_kernel_abstract.h @@ -0,0 +1,311 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BINARY_SEARCH_TREE_KERNEl_ABSTRACT_ +#ifdef DLIB_BINARY_SEARCH_TREE_KERNEl_ABSTRACT_ + +#include "../interfaces/map_pair.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include "../algs.h" +#include + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class binary_search_tree : public enumerable >, + public asc_pair_remover + { + + /*! + REQUIREMENTS ON domain + domain must be comparable by compare where compare is a functor compatible with std::less and + domain is swappable by a global swap() and + domain must have a default constructor + + REQUIREMENTS ON range + range is swappable by a global swap() and + range must have a default constructor + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + + POINTERS AND REFERENCES TO INTERNAL DATA + swap(), count(), height(), and operator[] functions + do not invalidate pointers or references to internal data. + + position_enumerator() invalidates pointers or references to + data returned by element() and only by element() (i.e. pointers and + references returned by operator[] are still valid). + + All other functions have no such guarantees. + + INITIAL VALUE + size() == 0 + height() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the domain (and each associated + range element) elements in ascending order according to the compare functor. + (i.e. the elements are enumerated in sorted order) + + WHAT THIS OBJECT REPRESENTS + this object represents a data dictionary that is built on top of some + kind of binary search tree. It maps objects of type domain to objects + of type range. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + + NOTE: + definition of equivalent: + a is equivalent to b if + a < b == false and + b < a == false + !*/ + + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + binary_search_tree( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + !*/ + + virtual ~binary_search_tree( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + short height ( + ) const; + /*! + ensures + - returns the number of elements in the longest path from the root + of the tree to a leaf + !*/ + + unsigned long count ( + const domain& d + ) const; + /*! + ensures + - returns the number of elements in the domain of *this that are + equivalent to d + !*/ + + void add ( + domain& d, + range& r + ); + /*! + requires + - &d != &r (i.e. d and r cannot be the same variable) + ensures + - adds a mapping between d and r to *this + - if (count(d) == 0) then + - #*(*this)[d] == r + - else + - #(*this)[d] != 0 + - #d and #r have initial values for their types + - #count(d) == count(d) + 1 + - #at_start() == true + - #size() == size() + 1 + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + if add() throws then it has no effect + !*/ + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + /*! + requires + - (*this)[d] != 0 + - &d != &r (i.e. d and r cannot be the same variable) + - &d != &d_copy (i.e. d and d_copy cannot be the same variable) + - &r != &d_copy (i.e. r and d_copy cannot be the same variable) + ensures + - some element in the domain of *this that is equivalent to d has + been removed and swapped into #d_copy. Additionally, its + associated range element has been removed and swapped into #r. + - #count(d) == count(d) - 1 + - #size() == size() - 1 + - #at_start() == true + !*/ + + void destroy ( + const domain& d + ); + /*! + requires + - (*this)[d] != 0 + ensures + - an element in the domain of *this equivalent to d has been removed. + The element in the range of *this associated with d has also been + removed. + - #count(d) == count(d) - 1 + - #size() == size() - 1 + - #at_start() == true + !*/ + + void remove_last_in_order ( + domain& d, + range& r + ); + /*! + requires + - size() > 0 + ensures + - the last/biggest (according to the compare functor) element in the domain of *this has + been removed and swapped into #d. The element in the range of *this + associated with #d has also been removed and swapped into #r. + - #count(#d) == count(#d) - 1 + - #size() == size() - 1 + - #at_start() == true + !*/ + + void remove_current_element ( + domain& d, + range& r + ); + /*! + requires + - current_element_valid() == true + ensures + - the current element given by element() has been removed and swapped into d and r. + - #d == element().key() + - #r == element().value() + - #count(#d) == count(#d) - 1 + - #size() == size() - 1 + - moves the enumerator to the next element. If element() was the last + element in enumeration order then #current_element_valid() == false + and #at_start() == false. + !*/ + + void position_enumerator ( + const domain& d + ) const; + /*! + ensures + - #at_start() == false + - if (count(d) > 0) then + - #element().key() == d + - else if (there are any items in the domain of *this that are bigger than + d according to the compare functor) then + - #element().key() == the smallest item in the domain of *this that is + bigger than d according to the compare functor. + - else + - #current_element_valid() == false + !*/ + + const range* operator[] ( + const domain& d + ) const; + /*! + ensures + - if (there is an element in the domain equivalent to d) then + - returns a pointer to an element in the range of *this that + is associated with an element in the domain of *this + equivalent to d. + - else + - returns 0 + !*/ + + range* operator[] ( + const domain& d + ); + /*! + ensures + - if (there is an element in the domain equivalent to d) then + - returns a pointer to an element in the range of *this that + is associated with an element in the domain of *this + equivalent to d. + - else + - returns 0 + !*/ + + void swap ( + binary_search_tree& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + binary_search_tree(binary_search_tree&); + binary_search_tree& operator=(binary_search_tree&); + + }; + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + inline void swap ( + binary_search_tree& a, + binary_search_tree& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void deserialize ( + binary_search_tree& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ +} + +#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_ABSTRACT_ + diff --git a/dlib/binary_search_tree/binary_search_tree_kernel_c.h b/dlib/binary_search_tree/binary_search_tree_kernel_c.h new file mode 100644 index 0000000000000000000000000000000000000000..0dc15396193f699b484244697d23b75adb4aa425 --- /dev/null +++ b/dlib/binary_search_tree/binary_search_tree_kernel_c.h @@ -0,0 +1,235 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_C_ +#define DLIB_BINARY_SEARCH_TREE_KERNEl_C_ + +#include "../interfaces/map_pair.h" +#include "binary_search_tree_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + + template < + typename bst_base + > + class binary_search_tree_kernel_c : public bst_base + { + typedef typename bst_base::domain_type domain; + typedef typename bst_base::range_type range; + + public: + + binary_search_tree_kernel_c () {} + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + void destroy ( + const domain& d + ); + + void add ( + domain& d, + range& r + ); + + void remove_any ( + domain& d, + range& r + ); + + const map_pair& element( + ) const + { + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst map_pair& binary_search_tree::element() const" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return bst_base::element(); + } + + map_pair& element( + ) + { + DLIB_CASSERT(this->current_element_valid() == true, + "\tmap_pair& binary_search_tree::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return bst_base::element(); + } + + void remove_last_in_order ( + domain& d, + range& r + ); + + void remove_current_element ( + domain& d, + range& r + ); + + + }; + + + template < + typename bst_base + > + inline void swap ( + binary_search_tree_kernel_c& a, + binary_search_tree_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + add ( + domain& d, + range& r + ) + { + DLIB_CASSERT( static_cast(&d) != static_cast(&r), + "\tvoid binary_search_tree::add" + << "\n\tyou can't call add() and give the same object to both parameters." + << "\n\tthis: " << this + << "\n\t&d: " << &d + << "\n\t&r: " << &r + << "\n\tsize(): " << this->size() + ); + + bst_base::add(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + destroy ( + const domain& d + ) + { + DLIB_CASSERT(this->operator[](d) != 0, + "\tvoid binary_search_tree::destroy" + << "\n\tthe element must be in the tree for it to be removed" + << "\n\tthis: " << this + << "\n\t&d: " << &d + ); + + bst_base::destroy(d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + remove ( + const domain& d, + domain& d_copy, + range& r + ) + { + DLIB_CASSERT(this->operator[](d) != 0 && + (static_cast(&d) != static_cast(&d_copy)) && + (static_cast(&d) != static_cast(&r)) && + (static_cast(&r) != static_cast(&d_copy)), + "\tvoid binary_search_tree::remove" + << "\n\tthe element must be in the tree for it to be removed" + << "\n\tthis: " << this + << "\n\t&d: " << &d + << "\n\t&d_copy: " << &d_copy + << "\n\t&r: " << &r + ); + + bst_base::remove(d,d_copy,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + remove_any( + domain& d, + range& r + ) + { + DLIB_CASSERT(this->size() != 0 && + (static_cast(&d) != static_cast(&r)), + "\tvoid binary_search_tree::remove_any" + << "\n\ttree must not be empty if something is going to be removed" + << "\n\tthis: " << this + << "\n\t&d: " << &d + << "\n\t&r: " << &r + ); + + bst_base::remove_any(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + remove_last_in_order ( + domain& d, + range& r + ) + { + DLIB_CASSERT(this->size() > 0, + "\tvoid binary_search_tree::remove_last_in_order()" + << "\n\tyou can't remove an element if it doesn't exist" + << "\n\tthis: " << this + ); + + bst_base::remove_last_in_order(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + remove_current_element ( + domain& d, + range& r + ) + { + DLIB_CASSERT(this->current_element_valid() == true, + "\tvoid binary_search_tree::remove_current_element()" + << "\n\tyou can't remove the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + bst_base::remove_current_element(d,r); + } + + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_C_ + diff --git a/dlib/bit_stream.h b/dlib/bit_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..8885f35157e41d5e9c19c9204215a3763ac848e0 --- /dev/null +++ b/dlib/bit_stream.h @@ -0,0 +1,42 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAm_ +#define DLIB_BIT_STREAm_ + +#include "bit_stream/bit_stream_kernel_1.h" +#include "bit_stream/bit_stream_kernel_c.h" + +#include "bit_stream/bit_stream_multi_1.h" +#include "bit_stream/bit_stream_multi_c.h" + +namespace dlib +{ + + + class bit_stream + { + bit_stream() {} + public: + + //----------- kernels --------------- + + // kernel_1a + typedef bit_stream_kernel_1 + kernel_1a; + typedef bit_stream_kernel_c + kernel_1a_c; + + //---------- extensions ------------ + + + // multi_1 extend kernel_1a + typedef bit_stream_multi_1 + multi_1a; + typedef bit_stream_multi_c > + multi_1a_c; + + }; +} + +#endif // DLIB_BIT_STREAm_ + diff --git a/dlib/bit_stream/bit_stream_kernel_1.cpp b/dlib/bit_stream/bit_stream_kernel_1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f49db14d59bbd4fa3da223aabf2b67f2bc34490e --- /dev/null +++ b/dlib/bit_stream/bit_stream_kernel_1.cpp @@ -0,0 +1,200 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAM_KERNEL_1_CPp_ +#define DLIB_BIT_STREAM_KERNEL_1_CPp_ + + +#include "bit_stream_kernel_1.h" +#include "../algs.h" + +#include + +namespace dlib +{ + + inline void swap ( + bit_stream_kernel_1& a, + bit_stream_kernel_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + clear ( + ) + { + if (write_mode) + { + write_mode = false; + + // flush output buffer + if (buffer_size > 0) + { + buffer <<= 8 - buffer_size; + osp->write(reinterpret_cast(&buffer),1); + } + } + else + read_mode = false; + + } + +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + set_input_stream ( + std::istream& is + ) + { + isp = &is; + read_mode = true; + + buffer_size = 0; + } + +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + set_output_stream ( + std::ostream& os + ) + { + osp = &os; + write_mode = true; + + buffer_size = 0; + } + +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + close ( + ) + { + if (write_mode) + { + write_mode = false; + + // flush output buffer + if (buffer_size > 0) + { + buffer <<= 8 - buffer_size; + osp->write(reinterpret_cast(&buffer),1); + } + } + else + read_mode = false; + } + +// ---------------------------------------------------------------------------------------- + + bool bit_stream_kernel_1:: + is_in_write_mode ( + ) const + { + return write_mode; + } + +// ---------------------------------------------------------------------------------------- + + bool bit_stream_kernel_1:: + is_in_read_mode ( + ) const + { + return read_mode; + } + +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + write ( + int bit + ) + { + // flush buffer if necessary + if (buffer_size == 8) + { + buffer <<= 8 - buffer_size; + if (osp->rdbuf()->sputn(reinterpret_cast(&buffer),1) == 0) + { + throw std::ios_base::failure("error occurred in the bit_stream object"); + } + + buffer_size = 0; + } + + ++buffer_size; + buffer <<= 1; + buffer += static_cast(bit); + } + +// ---------------------------------------------------------------------------------------- + + bool bit_stream_kernel_1:: + read ( + int& bit + ) + { + // get new byte if necessary + if (buffer_size == 0) + { + if (isp->rdbuf()->sgetn(reinterpret_cast(&buffer), 1) == 0) + { + // if we didn't read anything then return false + return false; + } + + buffer_size = 8; + } + + // put the most significant bit from buffer into bit + bit = static_cast(buffer >> 7); + + // shift out the bit that was just read + buffer <<= 1; + --buffer_size; + + return true; + } + +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + swap ( + bit_stream_kernel_1& item + ) + { + + std::istream* isp_temp = item.isp; + std::ostream* osp_temp = item.osp; + bool write_mode_temp = item.write_mode; + bool read_mode_temp = item.read_mode; + unsigned char buffer_temp = item.buffer; + unsigned short buffer_size_temp = item.buffer_size; + + item.isp = isp; + item.osp = osp; + item.write_mode = write_mode; + item.read_mode = read_mode; + item.buffer = buffer; + item.buffer_size = buffer_size; + + + isp = isp_temp; + osp = osp_temp; + write_mode = write_mode_temp; + read_mode = read_mode_temp; + buffer = buffer_temp; + buffer_size = buffer_size_temp; + + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_BIT_STREAM_KERNEL_1_CPp_ + diff --git a/dlib/bit_stream/bit_stream_kernel_1.h b/dlib/bit_stream/bit_stream_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..801e93e0a27facbc6a00b59c9241452a4486d62a --- /dev/null +++ b/dlib/bit_stream/bit_stream_kernel_1.h @@ -0,0 +1,120 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAM_KERNEl_1_ +#define DLIB_BIT_STREAM_KERNEl_1_ + +#include "bit_stream_kernel_abstract.h" +#include + +namespace dlib +{ + + class bit_stream_kernel_1 + { + + /*! + INITIAL VALUE + write_mode == false + read_mode == false + + CONVENTION + write_mode == is_in_write_mode() + read_mode == is_in_read_mode() + + if (write_mode) + { + osp == pointer to an ostream object + buffer == the low order bits of buffer are the bits to be + written + buffer_size == the number of low order bits in buffer that are + bits that should be written + the lowest order bit is the last bit entered by the user + } + + if (read_mode) + { + isp == pointer to an istream object + buffer == the high order bits of buffer are the bits + waiting to be read by the user + buffer_size == the number of high order bits in buffer that + are bits that are waiting to be read + the highest order bit is the next bit to give to the user + } + !*/ + + + public: + + bit_stream_kernel_1 ( + ) : + write_mode(false), + read_mode(false) + {} + + virtual ~bit_stream_kernel_1 ( + ) + {} + + void clear ( + ); + + void set_input_stream ( + std::istream& is + ); + + void set_output_stream ( + std::ostream& os + ); + + void close ( + ); + + inline bool is_in_write_mode ( + ) const; + + inline bool is_in_read_mode ( + ) const; + + inline void write ( + int bit + ); + + bool read ( + int& bit + ); + + void swap ( + bit_stream_kernel_1& item + ); + + private: + + // member data + std::istream* isp; + std::ostream* osp; + bool write_mode; + bool read_mode; + unsigned char buffer; + unsigned short buffer_size; + + // restricted functions + bit_stream_kernel_1(bit_stream_kernel_1&); // copy constructor + bit_stream_kernel_1& operator=(bit_stream_kernel_1&); // assignment operator + + }; + + inline void swap ( + bit_stream_kernel_1& a, + bit_stream_kernel_1& b + ); + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "bit_stream_kernel_1.cpp" +#endif + +#endif // DLIB_BIT_STREAM_KERNEl_1_ + diff --git a/dlib/bit_stream/bit_stream_kernel_abstract.h b/dlib/bit_stream/bit_stream_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..00c2ae3b94d2594170d14d48c8390a261c51c8fd --- /dev/null +++ b/dlib/bit_stream/bit_stream_kernel_abstract.h @@ -0,0 +1,185 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BIT_STREAM_KERNEl_ABSTRACT_ +#ifdef DLIB_BIT_STREAM_KERNEl_ABSTRACT_ + +#include + +namespace dlib +{ + + class bit_stream + { + + /*! + INITIAL VALUE + is_in_write_mode() == false + is_in_read_mode() == false + + WHAT THIS OBJECT REPRESENTS + this object is a middle man between a user and the iostream classes. + it allows single bits to be read/written easily to/from + the iostream classes + + BUFFERING: + This object will only read/write single bytes at a time from/to the + iostream objects. Any buffered bits still in the bit_stream object + when it is closed or destructed are lost if it is in read mode. If + it is in write mode then any remaining bits are guaranteed to be + written to the output stream by the time it is closed or destructed. + !*/ + + + public: + + bit_stream ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~bit_stream ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear ( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + + void set_input_stream ( + std::istream& is + ); + /*! + requires + - is_in_write_mode() == false + - is_in_read_mode() == false + - is is ready to give input + ensures + - #is_in_write_mode() == false + - #is_in_read_mode() == true + - #*this will now be reading from is + throws + - std::bad_alloc + !*/ + + void set_output_stream ( + std::ostream& os + ); + /*! + requires + - is_in_write_mode() == false + - is_in_read_mode() == false + - os is ready to take output + ensures + - #is_in_write_mode() == true + - #is_in_read_mode() == false + - #*this will now write to os + throws + - std::bad_alloc + !*/ + + + + void close ( + ); + /*! + requires + - is_in_write_mode() == true || is_in_read_mode() == true + ensures + - #is_in_write_mode() == false + - #is_in_read_mode() == false + !*/ + + bool is_in_write_mode ( + ) const; + /*! + ensures + - returns true if *this is associated with an output stream object + - returns false otherwise + !*/ + + bool is_in_read_mode ( + ) const; + /*! + ensures + - returns true if *this is associated with an input stream object + - returns false otherwise + !*/ + + void write ( + int bit + ); + /*! + requires + - is_in_write_mode() == true + - bit == 0 || bit == 1 + ensures + - bit will be written to the ostream object associated with *this + throws + - std::ios_base::failure + if (there was a problem writing to the output stream) then + this exception will be thrown. #*this will be unusable until + clear() is called and succeeds + - any other exception + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + bool read ( + int& bit + ); + /*! + requires + - is_in_read_mode() == true + ensures + - the next bit has been read and placed into #bit + - returns true if the read was successful, else false + (ex. false if EOF has been reached) + throws + - any exception + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + void swap ( + bit_stream& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + bit_stream(bit_stream&); // copy constructor + bit_stream& operator=(bit_stream&); // assignment operator + + }; + + inline void swap ( + bit_stream& a, + bit_stream& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_BIT_STREAM_KERNEl_ABSTRACT_ + diff --git a/dlib/bit_stream/bit_stream_kernel_c.h b/dlib/bit_stream/bit_stream_kernel_c.h new file mode 100644 index 0000000000000000000000000000000000000000..1d52bff200f4aae532856cbe84e8f01208e2f7e1 --- /dev/null +++ b/dlib/bit_stream/bit_stream_kernel_c.h @@ -0,0 +1,172 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAM_KERNEl_C_ +#define DLIB_BIT_STREAM_KERNEl_C_ + +#include "bit_stream_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include + +namespace dlib +{ + + template < + typename bit_stream_base // implements bit_stream/bit_stream_kernel_abstract.h + > + class bit_stream_kernel_c : public bit_stream_base + { + public: + + + void set_input_stream ( + std::istream& is + ); + + void set_output_stream ( + std::ostream& os + ); + + void close ( + ); + + void write ( + int bit + ); + + bool read ( + int& bit + ); + + }; + + template < + typename bit_stream_base + > + inline void swap ( + bit_stream_kernel_c& a, + bit_stream_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_kernel_c:: + set_input_stream ( + std::istream& is + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_write_mode() == false ) && ( this->is_in_read_mode() == false ), + "\tvoid bit_stream::set_intput_stream" + << "\n\tbit_stream must not be in write or read mode" + << "\n\tthis: " << this + ); + + // call the real function + bit_stream_base::set_input_stream(is); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_kernel_c:: + set_output_stream ( + std::ostream& os + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_write_mode() == false ) && ( this->is_in_read_mode() == false ), + "\tvoid bit_stream::set_output_stream" + << "\n\tbit_stream must not be in write or read mode" + << "\n\tthis: " << this + ); + + // call the real function + bit_stream_base::set_output_stream(os); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_kernel_c:: + close ( + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_write_mode() == true ) || ( this->is_in_read_mode() == true ), + "\tvoid bit_stream::close" + << "\n\tyou can't close a bit_stream that isn't open" + << "\n\tthis: " << this + ); + + // call the real function + bit_stream_base::close(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_kernel_c:: + write ( + int bit + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_write_mode() == true ) && ( bit == 0 || bit == 1 ), + "\tvoid bit_stream::write" + << "\n\tthe bit stream bust be in write mode and bit must be either 1 or 0" + << "\n\tis_in_write_mode() == " << this->is_in_write_mode() + << "\n\tbit == " << bit + << "\n\tthis: " << this + ); + + // call the real function + bit_stream_base::write(bit); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + bool bit_stream_kernel_c:: + read ( + int& bit + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_read_mode() == true ), + "\tbool bit_stream::read" + << "\n\tyou can't read from a bit_stream that isn't in read mode" + << "\n\tthis: " << this + ); + + // call the real function + return bit_stream_base::read(bit); + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BIT_STREAM_KERNEl_C_ + diff --git a/dlib/bit_stream/bit_stream_multi_1.h b/dlib/bit_stream/bit_stream_multi_1.h new file mode 100644 index 0000000000000000000000000000000000000000..bf1cc0357d86606416ce105c5cbf82c3f2f32c82 --- /dev/null +++ b/dlib/bit_stream/bit_stream_multi_1.h @@ -0,0 +1,103 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAM_MULTi_1_ +#define DLIB_BIT_STREAM_MULTi_1_ + +#include "bit_stream_multi_abstract.h" + +namespace dlib +{ + template < + typename bit_stream_base + > + class bit_stream_multi_1 : public bit_stream_base + { + + public: + + void multi_write ( + unsigned long data, + int num_to_write + ); + + int multi_read ( + unsigned long& data, + int num_to_read + ); + + }; + + template < + typename bit_stream_base + > + inline void swap ( + bit_stream_multi_1& a, + bit_stream_multi_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_multi_1:: + multi_write ( + unsigned long data, + int num_to_write + ) + { + // move the first bit into the most significant position + data <<= 32 - num_to_write; + + for (int i = 0; i < num_to_write; ++i) + { + // write the first bit from data + this->write(static_cast(data >> 31)); + + // shift the next bit into position + data <<= 1; + + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + int bit_stream_multi_1:: + multi_read ( + unsigned long& data, + int num_to_read + ) + { + int bit, i; + data = 0; + for (i = 0; i < num_to_read; ++i) + { + + // get a bit + if (this->read(bit) == false) + break; + + // shift data to make room for this new bit + data <<= 1; + + // put bit into the least significant position in data + data += static_cast(bit); + + } + + return i; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BIT_STREAM_MULTi_1_ + diff --git a/dlib/bit_stream/bit_stream_multi_abstract.h b/dlib/bit_stream/bit_stream_multi_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..061af94f49e15a239303b57e40bedae7329912f0 --- /dev/null +++ b/dlib/bit_stream/bit_stream_multi_abstract.h @@ -0,0 +1,77 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BIT_STREAM_MULTi_ABSTRACT_ +#ifdef DLIB_BIT_STREAM_MULTi_ABSTRACT_ + +#include "bit_stream_kernel_abstract.h" + +namespace dlib +{ + template < + typename bit_stream_base + > + class bit_stream_multi : public bit_stream_base + { + + /*! + REQUIREMENTS ON BIT_STREAM_BASE + it is an implementation of bit_stream/bit_stream_kernel_abstract.h + + + WHAT THIS EXTENSION DOES FOR BIT_STREAM + this gives a bit_stream object the ability to read/write multible bits + at a time + !*/ + + + public: + + void multi_write ( + unsigned long data, + int num_to_write + ); + /*! + requires + - is_in_write_mode() == true + - 0 <= num_to_write <= 32 + ensures + - num_to_write low order bits from data will be written to the ostream + - object associated with *this + example: if data is 10010 then the bits will be written in the + order 1,0,0,1,0 + !*/ + + + int multi_read ( + unsigned long& data, + int num_to_read + ); + /*! + requires + - is_in_read_mode() == true + - 0 <= num_to_read <= 32 + ensures + - tries to read num_to_read bits into the low order end of #data + example: if the incoming bits were 10010 then data would end + up with 10010 as its low order bits + - all of the bits in #data not filled in by multi_read() are zero + - returns the number of bits actually read into #data + !*/ + + }; + + template < + typename bit_stream_base + > + inline void swap ( + bit_stream_multi& a, + bit_stream_multi& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_BIT_STREAM_MULTi_ABSTRACT_ + diff --git a/dlib/bit_stream/bit_stream_multi_c.h b/dlib/bit_stream/bit_stream_multi_c.h new file mode 100644 index 0000000000000000000000000000000000000000..de80c63280a9168d4152bd6b9ccbabe9bc3ab1da --- /dev/null +++ b/dlib/bit_stream/bit_stream_multi_c.h @@ -0,0 +1,101 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAM_MULTi_C_ +#define DLIB_BIT_STREAM_MULTi_C_ + +#include "bit_stream_multi_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + template < + typename bit_stream_base // implements bit_stream/bit_stream_multi_abstract.h + > + class bit_stream_multi_c : public bit_stream_base + { + public: + + void multi_write ( + unsigned long data, + int num_to_write + ); + + int multi_read ( + unsigned long& data, + int num_to_read + ); + + }; + + template < + typename bit_stream_base + > + inline void swap ( + bit_stream_multi_c& a, + bit_stream_multi_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_multi_c:: + multi_write ( + unsigned long data, + int num_to_write + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( (this->is_in_write_mode() == true) && (num_to_write >= 0 && num_to_write <=32), + "\tvoid bit_stream::write" + << "\n\tthe bit stream bust be in write mode and" + << "\n\tnum_to_write must be between 0 and 32 inclusive" + << "\n\tnum_to_write == " << num_to_write + << "\n\tis_in_write_mode() == " << this->is_in_write_mode() + << "\n\tthis: " << this + ); + + // call the real function + bit_stream_base::multi_write(data,num_to_write); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + int bit_stream_multi_c:: + multi_read ( + unsigned long& data, + int num_to_read + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_read_mode() == true && ( num_to_read >= 0 && num_to_read <=32 ) ), + "\tvoid bit_stream::read" + << "\n\tyou can't read from a bit_stream that isn't in read mode and" + << "\n\tnum_to_read must be between 0 and 32 inclusive" + << "\n\tnum_to_read == " << num_to_read + << "\n\tis_in_read_mode() == " << this->is_in_read_mode() + << "\n\tthis: " << this + ); + + // call the real function + return bit_stream_base::multi_read(data,num_to_read); + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BIT_STREAM_MULTi_C_ + diff --git a/dlib/bits/c++config.h b/dlib/bits/c++config.h new file mode 100644 index 0000000000000000000000000000000000000000..6139ba8238c1db00fe4f10178abcb881715a2ef9 --- /dev/null +++ b/dlib/bits/c++config.h @@ -0,0 +1 @@ +#include "../dlib_include_path_tutorial.txt" diff --git a/dlib/bound_function_pointer.h b/dlib/bound_function_pointer.h new file mode 100644 index 0000000000000000000000000000000000000000..a482919c6e7db5aebbafa0d4846d6c55218a1810 --- /dev/null +++ b/dlib/bound_function_pointer.h @@ -0,0 +1,10 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BOUND_FUNCTION_POINTEr_ +#define DLIB_BOUND_FUNCTION_POINTEr_ + +#include "bound_function_pointer/bound_function_pointer_kernel_1.h" + +#endif // DLIB_BOUND_FUNCTION_POINTEr_ + + diff --git a/dlib/bound_function_pointer/bound_function_pointer_kernel_1.h b/dlib/bound_function_pointer/bound_function_pointer_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..2984fbc93cf881682940dbd02ccd24e478000fbf --- /dev/null +++ b/dlib/bound_function_pointer/bound_function_pointer_kernel_1.h @@ -0,0 +1,774 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_ +#define DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_ + +#include "../algs.h" +#include "../member_function_pointer.h" +#include "bound_function_pointer_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace bfp1_helpers + { + template struct strip { typedef T type; }; + template struct strip { typedef T type; }; + + // ------------------------------------------------------------------------------------ + + class bound_function_helper_base_base + { + public: + virtual ~bound_function_helper_base_base(){} + virtual void call() const = 0; + virtual bool is_set() const = 0; + virtual void clone(void* ptr) const = 0; + }; + + // ------------------------------------------------------------------------------------ + + template + class bound_function_helper_base : public bound_function_helper_base_base + { + public: + bound_function_helper_base():arg1(0), arg2(0), arg3(0), arg4(0) {} + + typename strip::type* arg1; + typename strip::type* arg2; + typename strip::type* arg3; + typename strip::type* arg4; + + + member_function_pointer mfp; + }; + + // ---------------- + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + (*fp)(*this->arg1, *this->arg2, *this->arg3, *this->arg4); + } + + typename strip::type* fp; + }; + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + if (this->mfp) this->mfp(*this->arg1, *this->arg2, *this->arg3, *this->arg4); + else if (fp) fp(*this->arg1, *this->arg2, *this->arg3, *this->arg4); + } + + void (*fp)(T1, T2, T3, T4); + }; + + // ---------------- + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + (*fp)(); + } + + typename strip::type* fp; + }; + + template <> + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + if (this->mfp) this->mfp(); + else if (fp) fp(); + } + + void (*fp)(); + }; + + // ---------------- + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + (*fp)(*this->arg1); + } + + typename strip::type* fp; + }; + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + if (this->mfp) this->mfp(*this->arg1); + else if (fp) fp(*this->arg1); + } + + void (*fp)(T1); + }; + + // ---------------- + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + (*fp)(*this->arg1, *this->arg2); + } + + typename strip::type* fp; + }; + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + if (this->mfp) this->mfp(*this->arg1, *this->arg2); + else if (fp) fp(*this->arg1, *this->arg2); + } + + void (*fp)(T1, T2); + }; + + // ---------------- + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + (*fp)(*this->arg1, *this->arg2, *this->arg3); + } + + typename strip::type* fp; + }; + + template + class bound_function_helper : public bound_function_helper_base + { + public: + + void call() const + { + if (this->mfp) this->mfp(*this->arg1, *this->arg2, *this->arg3); + else if (fp) fp(*this->arg1, *this->arg2, *this->arg3); + } + + void (*fp)(T1, T2, T3); + }; + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + template + class bound_function_helper_T : public T + { + public: + bound_function_helper_T(){ this->fp = 0;} + + bool is_set() const + { + return this->fp != 0 || this->mfp.is_set(); + } + + template + void safe_clone(stack_based_memory_block& buf) + { + // This is here just to validate the assumption that our block of memory we have made + // in bf_memory is the right size to store the data for this object. If you + // get a compiler error on this line then email me :) + COMPILE_TIME_ASSERT(sizeof(bound_function_helper_T) <= mem_size); + clone(buf.get()); + } + + void clone (void* ptr) const + { + bound_function_helper_T* p = new(ptr) bound_function_helper_T(); + p->arg1 = this->arg1; + p->arg2 = this->arg2; + p->arg3 = this->arg3; + p->arg4 = this->arg4; + p->fp = this->fp; + p->mfp = this->mfp; + } + }; + + } + +// ---------------------------------------------------------------------------------------- + + class bound_function_pointer + { + typedef bfp1_helpers::bound_function_helper_T > bf_null_type; + + public: + + // These typedefs are here for backwards compatibility with previous versions of + // dlib. + typedef bound_function_pointer kernel_1a; + typedef bound_function_pointer kernel_1a_c; + + + bound_function_pointer ( + ) { bf_null_type().safe_clone(bf_memory); } + + bound_function_pointer ( + const bound_function_pointer& item + ) { item.bf()->clone(bf_memory.get()); } + + ~bound_function_pointer() + { destroy_bf_memory(); } + + bound_function_pointer& operator= ( + const bound_function_pointer& item + ) { bound_function_pointer(item).swap(*this); return *this; } + + void clear ( + ) { bound_function_pointer().swap(*this); } + + bool is_set ( + ) const + { + return bf()->is_set(); + } + + void swap ( + bound_function_pointer& item + ) + { + // make a temp copy of item + bound_function_pointer temp(item); + + // destory the stuff in item + item.destroy_bf_memory(); + // copy *this into item + bf()->clone(item.bf_memory.get()); + + // destory the stuff in this + destroy_bf_memory(); + // copy temp into *this + temp.bf()->clone(bf_memory.get()); + } + + void operator() ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_set() == true , + "\tvoid bound_function_pointer::operator()" + << "\n\tYou must call set() before you can use this function" + << "\n\tthis: " << this + ); + + bf()->call(); + } + + private: + struct dummy{ void nonnull() {}}; + typedef void (dummy::*safe_bool)(); + + public: + operator safe_bool () const { return is_set() ? &dummy::nonnull : 0; } + bool operator!() const { return !is_set(); } + + // ------------------------------------------- + // set function object overloads + // ------------------------------------------- + + template + void set ( + F& function_object + ) + { + COMPILE_TIME_ASSERT(std::is_function::value == false); + COMPILE_TIME_ASSERT(std::is_pointer::value == false); + + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.fp = &function_object; + + temp.safe_clone(bf_memory); + } + + template + void set ( + F& function_object, + A1& arg1 + ) + { + COMPILE_TIME_ASSERT(std::is_function::value == false); + COMPILE_TIME_ASSERT(std::is_pointer::value == false); + + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.fp = &function_object; + + temp.safe_clone(bf_memory); + } + + template + void set ( + F& function_object, + A1& arg1, + A2& arg2 + ) + { + COMPILE_TIME_ASSERT(std::is_function::value == false); + COMPILE_TIME_ASSERT(std::is_pointer::value == false); + + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.fp = &function_object; + + temp.safe_clone(bf_memory); + } + + template + void set ( + F& function_object, + A1& arg1, + A2& arg2, + A3& arg3 + ) + { + COMPILE_TIME_ASSERT(std::is_function::value == false); + COMPILE_TIME_ASSERT(std::is_pointer::value == false); + + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.fp = &function_object; + + temp.safe_clone(bf_memory); + } + + template + void set ( + F& function_object, + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ) + { + COMPILE_TIME_ASSERT(std::is_function::value == false); + COMPILE_TIME_ASSERT(std::is_pointer::value == false); + + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.arg4 = &arg4; + temp.fp = &function_object; + + temp.safe_clone(bf_memory); + } + + // ------------------------------------------- + // set mfp overloads + // ------------------------------------------- + + template + void set ( + T& object, + void (T::*funct)() + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + template + void set ( + const T& object, + void (T::*funct)()const + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + // ------------------------------------------- + + template + void set ( + T& object, + void (T::*funct)(T1), + A1& arg1 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + template + void set ( + const T& object, + void (T::*funct)(T1)const, + A1& arg1 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + // ---------------- + + template + void set ( + T& object, + void (T::*funct)(T1, T2), + A1& arg1, + A2& arg2 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + template + void set ( + const T& object, + void (T::*funct)(T1, T2)const, + A1& arg1, + A2& arg2 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + // ---------------- + + template + void set ( + T& object, + void (T::*funct)(T1, T2, T3), + A1& arg1, + A2& arg2, + A3& arg3 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + template + void set ( + const T& object, + void (T::*funct)(T1, T2, T3)const, + A1& arg1, + A2& arg2, + A3& arg3 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + // ---------------- + + template + void set ( + T& object, + void (T::*funct)(T1, T2, T3, T4), + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.arg4 = &arg4; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + template + void set ( + const T& object, + void (T::*funct)(T1, T2, T3, T4)const, + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.arg4 = &arg4; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + // ------------------------------------------- + // set fp overloads + // ------------------------------------------- + + void set ( + void (*funct)() + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.fp = funct; + + temp.safe_clone(bf_memory); + } + + template + void set ( + void (*funct)(T1), + A1& arg1 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.fp = funct; + + temp.safe_clone(bf_memory); + } + + template + void set ( + void (*funct)(T1, T2), + A1& arg1, + A2& arg2 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.fp = funct; + + temp.safe_clone(bf_memory); + } + + template + void set ( + void (*funct)(T1, T2, T3), + A1& arg1, + A2& arg2, + A3& arg3 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.fp = funct; + + temp.safe_clone(bf_memory); + } + + template + void set ( + void (*funct)(T1, T2, T3, T4), + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.arg4 = &arg4; + temp.fp = funct; + + temp.safe_clone(bf_memory); + } + + // ------------------------------------------- + + private: + + stack_based_memory_block bf_memory; + + void destroy_bf_memory ( + ) + { + // Honestly, this probably doesn't even do anything but I'm putting + // it here just for good measure. + bf()->~bound_function_helper_base_base(); + } + + bfp1_helpers::bound_function_helper_base_base* bf () + { return static_cast(bf_memory.get()); } + + const bfp1_helpers::bound_function_helper_base_base* bf () const + { return static_cast(bf_memory.get()); } + + }; + +// ---------------------------------------------------------------------------------------- + + inline void swap ( + bound_function_pointer& a, + bound_function_pointer& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_ + diff --git a/dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h b/dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..b5356d6e0015601d2ee5791771221ccb64442a21 --- /dev/null +++ b/dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h @@ -0,0 +1,456 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_ +#ifdef DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_ + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class bound_function_pointer + { + /*! + INITIAL VALUE + is_set() == false + + WHAT THIS OBJECT REPRESENTS + This object represents a function with all its arguments bound to + specific objects. For example: + + void test(int& var) { var = var+1; } + + bound_function_pointer funct; + + int a = 4; + funct.set(test,a); // bind the variable a to the first argument of the test() function + + // at this point a == 4 + funct(); + // after funct() is called a == 5 + !*/ + + public: + + bound_function_pointer ( + ); + /*! + ensures + - #*this is properly initialized + !*/ + + bound_function_pointer( + const bound_function_pointer& item + ); + /*! + ensures + - *this == item + !*/ + + ~bound_function_pointer ( + ); + /*! + ensures + - any resources associated with *this have been released + !*/ + + bound_function_pointer& operator=( + const bound_function_pointer& item + ); + /*! + ensures + - *this == item + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + !*/ + + bool is_set ( + ) const; + /*! + ensures + - if (this->set() has been called) then + - returns true + - else + - returns false + !*/ + + operator some_undefined_pointer_type ( + ) const; + /*! + ensures + - if (is_set()) then + - returns a non 0 value + - else + - returns a 0 value + !*/ + + bool operator! ( + ) const; + /*! + ensures + - returns !is_set() + !*/ + + void operator () ( + ) const; + /*! + requires + - is_set() == true + ensures + - calls the bound function on the object(s) specified by the last + call to this->set() + throws + - any exception thrown by the function specified by + the previous call to this->set(). + If any of these exceptions are thrown then the call to this + function will have no effect on *this. + !*/ + + void swap ( + bound_function_pointer& item + ); + /*! + ensures + - swaps *this and item + !*/ + + // ---------------------- + + template + void set ( + F& function_object + ); + /*! + requires + - function_object() is a valid expression + ensures + - #is_set() == true + - calls to this->operator() will call function_object() + (This seems pointless but it is a useful base case) + !*/ + + template < typename T> + void set ( + T& object, + void (T::*funct)() + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)() + !*/ + + template < typename T> + void set ( + const T& object, + void (T::*funct)()const + ); + /*! + requires + - funct == a valid bound function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)() + !*/ + + void set ( + void (*funct)() + ); + /*! + requires + - funct == a valid function pointer + ensures + - #is_set() == true + - calls to this->operator() will call funct() + !*/ + + // ---------------------- + + template + void set ( + F& function_object, + A1& arg1 + ); + /*! + requires + - function_object(arg1) is a valid expression + ensures + - #is_set() == true + - calls to this->operator() will call function_object(arg1) + !*/ + + template < typename T, typename T1, typename A1 > + void set ( + T& object, + void (T::*funct)(T1), + A1& arg1 + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1) + !*/ + + template < typename T, typename T1, typename A1 > + void set ( + const T& object, + void (T::*funct)(T1)const, + A1& arg1 + ); + /*! + requires + - funct == a valid bound function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1) + !*/ + + template + void set ( + void (*funct)(T1), + A1& arg1 + ); + /*! + requires + - funct == a valid function pointer + ensures + - #is_set() == true + - calls to this->operator() will call funct(arg1) + !*/ + + // ---------------------- + template + void set ( + F& function_object, + A1& arg1, + A2& arg2 + ); + /*! + requires + - function_object(arg1,arg2) is a valid expression + ensures + - #is_set() == true + - calls to this->operator() will call function_object(arg1,arg2) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2> + void set ( + T& object, + void (T::*funct)(T1,T2), + A1& arg1, + A2& arg2 + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2> + void set ( + const T& object, + void (T::*funct)(T1,T2)const, + A1& arg1, + A2& arg2 + ); + /*! + requires + - funct == a valid bound function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2) + !*/ + + template + void set ( + void (*funct)(T1,T2), + A1& arg1, + A2& arg2 + ); + /*! + requires + - funct == a valid function pointer + ensures + - #is_set() == true + - calls to this->operator() will call funct(arg1,arg2) + !*/ + + // ---------------------- + + template + void set ( + F& function_object, + A1& arg1, + A2& arg2, + A3& arg3 + ); + /*! + requires + - function_object(arg1,arg2,arg3) is a valid expression + ensures + - #is_set() == true + - calls to this->operator() will call function_object(arg1,arg2,arg3) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2, + typename T3, typename A3> + void set ( + T& object, + void (T::*funct)(T1,T2,T3), + A1& arg1, + A2& arg2, + A3& arg3 + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2, + typename T3, typename A3> + void set ( + const T& object, + void (T::*funct)(T1,T2,T3)const, + A1& arg1, + A2& arg2, + A3& arg3 + ); + /*! + requires + - funct == a valid bound function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3) + !*/ + + template + void set ( + void (*funct)(T1,T2,T3), + A1& arg1, + A2& arg2, + A3& arg3 + ); + /*! + requires + - funct == a valid function pointer + ensures + - #is_set() == true + - calls to this->operator() will call funct(arg1,arg2,arg3) + !*/ + + // ---------------------- + + template + void set ( + F& function_object, + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ); + /*! + requires + - function_object(arg1,arg2,arg3,arg4) is a valid expression + ensures + - #is_set() == true + - calls to this->operator() will call function_object(arg1,arg2,arg3,arg4) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2, + typename T3, typename A3, + typename T4, typename A4> + void set ( + T& object, + void (T::*funct)(T1,T2,T3,T4), + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3,arg4) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2, + typename T3, typename A3, + typename T4, typename A4> + void set ( + const T& object, + void (T::*funct)(T1,T2,T3,T4)const, + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ); + /*! + requires + - funct == a valid bound function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3,arg4) + !*/ + + template + void set ( + void (*funct)(T1,T2,T3,T4), + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ); + /*! + requires + - funct == a valid function pointer + ensures + - #is_set() == true + - calls to this->operator() will call funct(arg1,arg2,arg3,arg4) + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + inline void swap ( + bound_function_pointer& a, + bound_function_pointer& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_ + diff --git a/dlib/bridge.h b/dlib/bridge.h new file mode 100644 index 0000000000000000000000000000000000000000..4b633c4053c11fca7f223b6a153c9fb1dc8daa6d --- /dev/null +++ b/dlib/bridge.h @@ -0,0 +1,17 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifdef DLIB_ALL_SOURCE_END +#include "dlib_basic_cpp_build_tutorial.txt" +#endif + + +#ifndef DLIB_BRIdGE_ +#define DLIB_BRIdGE_ + + +#include "bridge/bridge.h" + +#endif // DLIB_BRIdGE_ + + diff --git a/dlib/bridge/bridge.h b/dlib/bridge/bridge.h new file mode 100644 index 0000000000000000000000000000000000000000..93e23995aff6172041134866b14d038993da91ec --- /dev/null +++ b/dlib/bridge/bridge.h @@ -0,0 +1,669 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BRIDGe_Hh_ +#define DLIB_BRIDGe_Hh_ + +#include +#include +#include + +#include "bridge_abstract.h" +#include "../pipe.h" +#include "../threads.h" +#include "../serialize.h" +#include "../sockets.h" +#include "../sockstreambuf.h" +#include "../logger.h" +#include "../algs.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct connect_to_ip_and_port + { + connect_to_ip_and_port ( + const std::string& ip_, + unsigned short port_ + ): ip(ip_), port(port_) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_ip_address(ip) && port != 0, + "\t connect_to_ip_and_port()" + << "\n\t Invalid inputs were given to this function" + << "\n\t ip: " << ip + << "\n\t port: " << port + << "\n\t this: " << this + ); + } + + private: + friend class bridge; + const std::string ip; + const unsigned short port; + }; + + inline connect_to_ip_and_port connect_to ( + const network_address& addr + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(addr.port != 0, + "\t connect_to_ip_and_port()" + << "\n\t The TCP port to connect to can't be 0." + << "\n\t addr.port: " << addr.port + ); + + if (is_ip_address(addr.host_address)) + { + return connect_to_ip_and_port(addr.host_address, addr.port); + } + else + { + std::string ip; + if(hostname_to_ip(addr.host_address,ip)) + throw socket_error(ERESOLVE,"unable to resolve '" + addr.host_address + "' in connect_to()"); + + return connect_to_ip_and_port(ip, addr.port); + } + } + + struct listen_on_port + { + listen_on_port( + unsigned short port_ + ) : port(port_) + { + // make sure requires clause is not broken + DLIB_ASSERT( port != 0, + "\t listen_on_port()" + << "\n\t Invalid inputs were given to this function" + << "\n\t port: " << port + << "\n\t this: " << this + ); + } + + private: + friend class bridge; + const unsigned short port; + }; + + template + struct bridge_transmit_decoration + { + bridge_transmit_decoration ( + pipe_type& p_ + ) : p(p_) {} + + private: + friend class bridge; + pipe_type& p; + }; + + template + bridge_transmit_decoration transmit ( pipe_type& p) { return bridge_transmit_decoration(p); } + + template + struct bridge_receive_decoration + { + bridge_receive_decoration ( + pipe_type& p_ + ) : p(p_) {} + + private: + friend class bridge; + pipe_type& p; + }; + + template + bridge_receive_decoration receive ( pipe_type& p) { return bridge_receive_decoration(p); } + +// ---------------------------------------------------------------------------------------- + + struct bridge_status + { + bridge_status() : is_connected(false), foreign_port(0){} + + bool is_connected; + unsigned short foreign_port; + std::string foreign_ip; + }; + + inline void serialize ( const bridge_status& , std::ostream& ) + { + throw serialization_error("It is illegal to serialize bridge_status objects."); + } + + inline void deserialize ( bridge_status& , std::istream& ) + { + throw serialization_error("It is illegal to serialize bridge_status objects."); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl_brns + { + class impl_bridge_base + { + public: + + virtual ~impl_bridge_base() {} + + virtual bridge_status get_bridge_status ( + ) const = 0; + }; + + template < + typename transmit_pipe_type, + typename receive_pipe_type + > + class impl_bridge : public impl_bridge_base, private noncopyable, private multithreaded_object + { + /*! + CONVENTION + - if (list) then + - this object is supposed to be listening on the list object for incoming + connections when not connected. + - else + - this object is supposed to be attempting to connect to ip:port when + not connected. + + - get_bridge_status() == current_bs + !*/ + public: + + impl_bridge ( + unsigned short listen_port, + transmit_pipe_type* transmit_pipe_, + receive_pipe_type* receive_pipe_ + ) : + s(m), + receive_thread_active(false), + transmit_thread_active(false), + port(0), + transmit_pipe(transmit_pipe_), + receive_pipe(receive_pipe_), + dlog("dlib.bridge"), + keepalive_code(0), + message_code(1) + { + int status = create_listener(list, listen_port); + if (status == PORTINUSE) + { + std::ostringstream sout; + sout << "Error, the port " << listen_port << " is already in use."; + throw socket_error(EPORT_IN_USE, sout.str()); + } + else if (status == OTHER_ERROR) + { + throw socket_error("Unable to create listening socket for an unknown reason."); + } + + register_thread(*this, &impl_bridge::transmit_thread); + register_thread(*this, &impl_bridge::receive_thread); + register_thread(*this, &impl_bridge::connect_thread); + + start(); + } + + impl_bridge ( + const std::string ip_, + unsigned short port_, + transmit_pipe_type* transmit_pipe_, + receive_pipe_type* receive_pipe_ + ) : + s(m), + receive_thread_active(false), + transmit_thread_active(false), + port(port_), + ip(ip_), + transmit_pipe(transmit_pipe_), + receive_pipe(receive_pipe_), + dlog("dlib.bridge"), + keepalive_code(0), + message_code(1) + { + register_thread(*this, &impl_bridge::transmit_thread); + register_thread(*this, &impl_bridge::receive_thread); + register_thread(*this, &impl_bridge::connect_thread); + + start(); + } + + ~impl_bridge() + { + // tell the threads to terminate + stop(); + + // save current pipe enabled status so we can restore it to however + // it was before this destructor ran. + bool transmit_enabled = true; + bool receive_enabled = true; + + // make any calls blocked on a pipe return immediately. + if (transmit_pipe) + { + transmit_enabled = transmit_pipe->is_dequeue_enabled(); + transmit_pipe->disable_dequeue(); + } + if (receive_pipe) + { + receive_enabled = receive_pipe->is_enqueue_enabled(); + receive_pipe->disable_enqueue(); + } + + { + auto_mutex lock(m); + s.broadcast(); + // Shutdown the connection if we have one. This will cause + // all blocked I/O calls to return an error. + if (con) + con->shutdown(); + } + + // wait for all the threads to terminate. + wait(); + + if (transmit_pipe && transmit_enabled) + transmit_pipe->enable_dequeue(); + if (receive_pipe && receive_enabled) + receive_pipe->enable_enqueue(); + } + + bridge_status get_bridge_status ( + ) const + { + auto_mutex lock(current_bs_mutex); + return current_bs; + } + + private: + + + template + std::enable_if_t::value> enqueue_bridge_status ( + pipe_type* p, + const bridge_status& status + ) + { + if (p) + { + typename pipe_type::type temp(status); + p->enqueue(temp); + } + } + + template + std::enable_if_t::value> enqueue_bridge_status ( + pipe_type* , + const bridge_status& + ) + { + } + + void connect_thread ( + ) + { + while (!should_stop()) + { + auto_mutex lock(m); + int status = OTHER_ERROR; + if (list) + { + do + { + status = list->accept(con, 1000); + } while (status == TIMEOUT && !should_stop()); + } + else + { + status = create_connection(con, port, ip); + } + + if (should_stop()) + break; + + if (status != 0) + { + // The last connection attempt failed. So pause for a little bit before making another attempt. + s.wait_or_timeout(2000); + continue; + } + + dlog << LINFO << "Established new connection to " << con->get_foreign_ip() << ":" << con->get_foreign_port() << "."; + + bridge_status temp_bs; + { auto_mutex lock(current_bs_mutex); + current_bs.is_connected = true; + current_bs.foreign_port = con->get_foreign_port(); + current_bs.foreign_ip = con->get_foreign_ip(); + temp_bs = current_bs; + } + enqueue_bridge_status(receive_pipe, temp_bs); + + + receive_thread_active = true; + transmit_thread_active = true; + + s.broadcast(); + + // Wait for the transmit and receive threads to end before we continue. + // This way we don't invalidate the con pointer while it is in use. + while (receive_thread_active || transmit_thread_active) + s.wait(); + + + dlog << LINFO << "Closed connection to " << con->get_foreign_ip() << ":" << con->get_foreign_port() << "."; + { auto_mutex lock(current_bs_mutex); + current_bs.is_connected = false; + current_bs.foreign_port = con->get_foreign_port(); + current_bs.foreign_ip = con->get_foreign_ip(); + temp_bs = current_bs; + } + enqueue_bridge_status(receive_pipe, temp_bs); + } + + } + + + void receive_thread ( + ) + { + while (true) + { + // wait until we have a connection + { auto_mutex lock(m); + while (!receive_thread_active && !should_stop()) + { + s.wait(); + } + + if (should_stop()) + break; + } + + + + try + { + if (receive_pipe) + { + sockstreambuf buf(con); + std::istream in(&buf); + typename receive_pipe_type::type item; + // This isn't necessary but doing it avoids a warning about + // item being uninitialized sometimes. + assign_zero_if_built_in_scalar_type(item); + + while (in.peek() != EOF) + { + unsigned char code; + in.read((char*)&code, sizeof(code)); + if (code == message_code) + { + deserialize(item, in); + receive_pipe->enqueue(item); + } + } + } + else + { + // Since we don't have a receive pipe to put messages into we will + // just read the bytes from the connection and ignore them. + char buf[1000]; + while (con->read(buf, sizeof(buf)) > 0) ; + } + } + catch (std::bad_alloc& ) + { + dlog << LERROR << "std::bad_alloc thrown while deserializing message from " + << con->get_foreign_ip() << ":" << con->get_foreign_port(); + } + catch (dlib::serialization_error& e) + { + dlog << LERROR << "dlib::serialization_error thrown while deserializing message from " + << con->get_foreign_ip() << ":" << con->get_foreign_port() + << ".\nThe exception error message is: \n" << e.what(); + } + catch (std::exception& e) + { + dlog << LERROR << "std::exception thrown while deserializing message from " + << con->get_foreign_ip() << ":" << con->get_foreign_port() + << ".\nThe exception error message is: \n" << e.what(); + } + + + + + con->shutdown(); + auto_mutex lock(m); + receive_thread_active = false; + s.broadcast(); + } + + auto_mutex lock(m); + receive_thread_active = false; + s.broadcast(); + } + + void transmit_thread ( + ) + { + while (true) + { + // wait until we have a connection + { auto_mutex lock(m); + while (!transmit_thread_active && !should_stop()) + { + s.wait(); + } + + if (should_stop()) + break; + } + + + + try + { + sockstreambuf buf(con); + std::ostream out(&buf); + typename transmit_pipe_type::type item; + // This isn't necessary but doing it avoids a warning about + // item being uninitialized sometimes. + assign_zero_if_built_in_scalar_type(item); + + + while (out) + { + bool dequeue_timed_out = false; + if (transmit_pipe ) + { + if (transmit_pipe->dequeue_or_timeout(item,1000)) + { + out.write((char*)&message_code, sizeof(message_code)); + serialize(item, out); + if (transmit_pipe->size() == 0) + out.flush(); + + continue; + } + + dequeue_timed_out = (transmit_pipe->is_enabled() && transmit_pipe->is_dequeue_enabled()); + } + + // Pause for about a second. Note that we use a wait_or_timeout() call rather + // than sleep() here because we want to wake up immediately if this object is + // being destructed rather than hang for a second. + if (!dequeue_timed_out) + { + auto_mutex lock(m); + if (should_stop()) + break; + + s.wait_or_timeout(1000); + } + // Just send the keepalive byte periodically so we can + // tell if the connection is alive. + out.write((char*)&keepalive_code, sizeof(keepalive_code)); + out.flush(); + } + } + catch (std::bad_alloc& ) + { + dlog << LERROR << "std::bad_alloc thrown while serializing message to " + << con->get_foreign_ip() << ":" << con->get_foreign_port(); + } + catch (dlib::serialization_error& e) + { + dlog << LERROR << "dlib::serialization_error thrown while serializing message to " + << con->get_foreign_ip() << ":" << con->get_foreign_port() + << ".\nThe exception error message is: \n" << e.what(); + } + catch (std::exception& e) + { + dlog << LERROR << "std::exception thrown while serializing message to " + << con->get_foreign_ip() << ":" << con->get_foreign_port() + << ".\nThe exception error message is: \n" << e.what(); + } + + + + + con->shutdown(); + auto_mutex lock(m); + transmit_thread_active = false; + s.broadcast(); + } + + auto_mutex lock(m); + transmit_thread_active = false; + s.broadcast(); + } + + mutex m; + signaler s; + bool receive_thread_active; + bool transmit_thread_active; + std::unique_ptr con; + std::unique_ptr list; + const unsigned short port; + const std::string ip; + transmit_pipe_type* const transmit_pipe; + receive_pipe_type* const receive_pipe; + logger dlog; + const unsigned char keepalive_code; + const unsigned char message_code; + + mutex current_bs_mutex; + bridge_status current_bs; + }; + } + + +// ---------------------------------------------------------------------------------------- + + class bridge : noncopyable + { + public: + + bridge () {} + + template < typename T, typename U, typename V > + bridge ( + T network_parameters, + U pipe1, + V pipe2 + ) { reconfigure(network_parameters,pipe1,pipe2); } + + template < typename T, typename U> + bridge ( + T network_parameters, + U pipe + ) { reconfigure(network_parameters,pipe); } + + + void clear ( + ) + { + pimpl.reset(); + } + + template < typename T, typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_transmit_decoration transmit_pipe, + bridge_receive_decoration receive_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } + + template < typename T, typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_receive_decoration receive_pipe, + bridge_transmit_decoration transmit_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } + + template < typename T > + void reconfigure ( + listen_on_port network_parameters, + bridge_transmit_decoration transmit_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, &transmit_pipe.p, 0)); } + + template < typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_receive_decoration receive_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, 0, &receive_pipe.p)); } + + + + + template < typename T, typename R > + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_transmit_decoration transmit_pipe, + bridge_receive_decoration receive_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } + + template < typename T, typename R > + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_receive_decoration receive_pipe, + bridge_transmit_decoration transmit_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } + + template < typename R > + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_receive_decoration receive_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, 0, &receive_pipe.p)); } + + template < typename T > + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_transmit_decoration transmit_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, &transmit_pipe.p, 0)); } + + + bridge_status get_bridge_status ( + ) const + { + if (pimpl) + return pimpl->get_bridge_status(); + else + return bridge_status(); + } + + private: + + std::unique_ptr pimpl; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BRIDGe_Hh_ + diff --git a/dlib/bridge/bridge_abstract.h b/dlib/bridge/bridge_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..76ed21153af9119a82d265b40cc4926046362c94 --- /dev/null +++ b/dlib/bridge/bridge_abstract.h @@ -0,0 +1,347 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BRIDGe_ABSTRACT_ +#ifdef DLIB_BRIDGe_ABSTRACT_ + +#include +#include "../pipe/pipe_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct connect_to_ip_and_port + { + connect_to_ip_and_port ( + const std::string& ip, + unsigned short port + ); + /*! + requires + - is_ip_address(ip) == true + - port != 0 + ensures + - this object will represent a request to make a TCP connection + to the given IP address and port number. + !*/ + }; + + connect_to_ip_and_port connect_to ( + const network_address& addr + ); + /*! + requires + - addr.port != 0 + ensures + - converts the given network_address object into a connect_to_ip_and_port + object. + !*/ + + struct listen_on_port + { + listen_on_port( + unsigned short port + ); + /*! + requires + - port != 0 + ensures + - this object will represent a request to listen on the given + port number for incoming TCP connections. + !*/ + }; + + template < + typename pipe_type + > + bridge_transmit_decoration transmit ( + pipe_type& p + ); + /*! + requires + - pipe_type is some kind of dlib::pipe object + - the objects in the pipe must be serializable + ensures + - Adds a type decoration to the given pipe, marking it as a transmit pipe, and + then returns it. + !*/ + + template < + typename pipe_type + > + bridge_receive_decoration receive ( + pipe_type& p + ); + /*! + requires + - pipe_type is some kind of dlib::pipe object + - the objects in the pipe must be serializable + ensures + - Adds a type decoration to the given pipe, marking it as a receive pipe, and + then returns it. + !*/ + +// ---------------------------------------------------------------------------------------- + + struct bridge_status + { + /*! + WHAT THIS OBJECT REPRESENTS + This simple struct represents the state of a bridge object. A + bridge is either connected or not. If it is connected then it + is connected to a foreign host with an IP address and port number + as indicated by this object. + !*/ + + bridge_status( + ); + /*! + ensures + - #is_connected == false + - #foreign_port == 0 + - #foreign_ip == "" + !*/ + + bool is_connected; + unsigned short foreign_port; + std::string foreign_ip; + }; + +// ---------------------------------------------------------------------------------------- + + class bridge : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for bridging a dlib::pipe object between + two network connected applications. + + + Note also that this object contains a dlib::logger object + which will log various events taking place inside a bridge. + If you want to see these log messages then enable the logger + named "dlib.bridge". + + + BRIDGE PROTOCOL DETAILS + The bridge object creates a single TCP connection between + two applications. Whenever it sends an object from a pipe + over a TCP connection it sends a byte with the value 1 followed + immediately by the serialized copy of the object from the pipe. + The serialization is performed by calling the global serialize() + function. + + Additionally, a bridge object will periodically send bytes with + a value of 0 to ensure the TCP connection remains alive. These + are just read and ignored. + !*/ + + public: + + bridge ( + ); + /*! + ensures + - this object is properly initialized + - #get_bridge_status().is_connected == false + !*/ + + template + bridge ( + T network_parameters, + U pipe1, + V pipe2 + ); + /*! + requires + - T is of type connect_to_ip_and_port or listen_on_port + - U and V are of type bridge_transmit_decoration or bridge_receive_decoration, + however, U and V must be of different types (i.e. one is a receive type and + another a transmit type). + ensures + - this object is properly initialized + - performs: reconfigure(network_parameters, pipe1, pipe2) + (i.e. using this constructor is identical to using the default constructor + and then calling reconfigure()) + !*/ + + template + bridge ( + T network_parameters, + U pipe + ); + /*! + requires + - T is of type connect_to_ip_and_port or listen_on_port + - U is of type bridge_transmit_decoration or bridge_receive_decoration. + ensures + - this object is properly initialized + - performs: reconfigure(network_parameters, pipe) + (i.e. using this constructor is identical to using the default constructor + and then calling reconfigure()) + !*/ + + ~bridge ( + ); + /*! + ensures + - blocks until all resources associated with this object have been destroyed. + !*/ + + void clear ( + ); + /*! + ensures + - returns this object to its default constructed state. That is, it will + be inactive, neither maintaining a connection nor attempting to acquire one. + - Any active connections or listening sockets will be closed. + !*/ + + bridge_status get_bridge_status ( + ) const; + /*! + ensures + - returns the current status of this bridge object. In particular, returns + an object BS such that: + - BS.is_connected == true if and only if the bridge has an active TCP + connection to another computer. + - if (BS.is_connected) then + - BS.foreign_ip == the IP address of the remote host we are connected to. + - BS.foreign_port == the port number on the remote host we are connected to. + - else if (the bridge has previously been connected to a remote host but hasn't been + reconfigured or cleared since) then + - BS.foreign_ip == the IP address of the remote host we were connected to. + - BS.foreign_port == the port number on the remote host we were connected to. + - else + - BS.foreign_ip == "" + - BS.foreign_port == 0 + !*/ + + + + template < typename T, typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_transmit_decoration transmit_pipe, + bridge_receive_decoration receive_pipe + ); + /*! + ensures + - This object will begin listening on the port specified by network_parameters + for incoming TCP connections. Any previous bridge state is cleared out. + - Onces a connection is established we will: + - Stop accepting new connections. + - Begin dequeuing objects from the transmit pipe and serializing them over + the TCP connection. + - Begin deserializing objects from the TCP connection and enqueueing them + onto the receive pipe. + - if (the current TCP connection is lost) then + - This object goes back to listening for a new connection. + - if (the receive pipe can contain bridge_status objects) then + - Whenever the bridge's status changes the updated bridge_status will be + enqueued onto the receive pipe unless the change was a TCP disconnect + resulting from a user calling reconfigure(), clear(), or destructing this + bridge. The status contents are defined by get_bridge_status(). + throws + - socket_error + This exception is thrown if we are unable to open the listening socket. + !*/ + template < typename T, typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_receive_decoration receive_pipe, + bridge_transmit_decoration transmit_pipe + ); + /*! + ensures + - performs reconfigure(network_parameters, transmit_pipe, receive_pipe) + !*/ + template < typename T > + void reconfigure ( + listen_on_port network_parameters, + bridge_transmit_decoration transmit_pipe + ); + /*! + ensures + - This function is identical to the above two reconfigure() functions + except that there is no receive pipe. + !*/ + template < typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_receive_decoration receive_pipe + ); + /*! + ensures + - This function is identical to the above three reconfigure() functions + except that there is no transmit pipe. + !*/ + + + + template + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_transmit_decoration transmit_pipe, + bridge_receive_decoration receive_pipe + ); + /*! + ensures + - This object will begin making TCP connection attempts to the IP address and port + specified by network_parameters. Any previous bridge state is cleared out. + - Onces a connection is established we will: + - Stop attempting new connections. + - Begin dequeuing objects from the transmit pipe and serializing them over + the TCP connection. + - Begin deserializing objects from the TCP connection and enqueueing them + onto the receive pipe. + - if (the current TCP connection is lost) then + - This object goes back to attempting to make a TCP connection with the + IP address and port specified by network_parameters. + - if (the receive pipe can contain bridge_status objects) then + - Whenever the bridge's status changes the updated bridge_status will be + enqueued onto the receive pipe unless the change was a TCP disconnect + resulting from a user calling reconfigure(), clear(), or destructing this + bridge. The status contents are defined by get_bridge_status(). + !*/ + template + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_receive_decoration receive_pipe, + bridge_transmit_decoration transmit_pipe + ); + /*! + ensures + - performs reconfigure(network_parameters, transmit_pipe, receive_pipe) + !*/ + template + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_transmit_decoration transmit_pipe + ); + /*! + ensures + - This function is identical to the above two reconfigure() functions + except that there is no receive pipe. + !*/ + template + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_receive_decoration receive_pipe + ); + /*! + ensures + - This function is identical to the above three reconfigure() functions + except that there is no transmit pipe. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BRIDGe_ABSTRACT_ + + diff --git a/dlib/bsp.h b/dlib/bsp.h new file mode 100644 index 0000000000000000000000000000000000000000..899b6a40517ae4b79318f231afd0a3eb70f5d1aa --- /dev/null +++ b/dlib/bsp.h @@ -0,0 +1,12 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BSPh_ +#define DLIB_BSPh_ + + +#include "bsp/bsp.h" + +#endif // DLIB_BSPh_ + + + diff --git a/dlib/bsp/bsp.cpp b/dlib/bsp/bsp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..32e23519e7ed0193932a8ce6d2c5d6ed1cfa1372 --- /dev/null +++ b/dlib/bsp/bsp.cpp @@ -0,0 +1,496 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BSP_CPph_ +#define DLIB_BSP_CPph_ + +#include "bsp.h" +#include +#include + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + + namespace impl1 + { + + void connect_all ( + map_id_to_con& cons, + const std::vector& hosts, + unsigned long node_id + ) + { + cons.clear(); + for (unsigned long i = 0; i < hosts.size(); ++i) + { + std::unique_ptr con(new bsp_con(hosts[i])); + dlib::serialize(node_id, con->stream); // tell the other end our node_id + unsigned long id = i+1; + cons.add(id, con); + } + } + + void connect_all_hostinfo ( + map_id_to_con& cons, + const std::vector& hosts, + unsigned long node_id, + std::string& error_string + ) + { + cons.clear(); + for (unsigned long i = 0; i < hosts.size(); ++i) + { + try + { + std::unique_ptr con(new bsp_con(hosts[i].addr)); + dlib::serialize(node_id, con->stream); // tell the other end our node_id + con->stream.flush(); + unsigned long id = hosts[i].node_id; + cons.add(id, con); + } + catch (std::exception&) + { + std::ostringstream sout; + sout << "Could not connect to " << hosts[i].addr; + error_string = sout.str(); + break; + } + } + } + + + void send_out_connection_orders ( + map_id_to_con& cons, + const std::vector& hosts + ) + { + // tell everyone their node ids + cons.reset(); + while (cons.move_next()) + { + dlib::serialize(cons.element().key(), cons.element().value()->stream); + } + + // now tell them who to connect to + std::vector targets; + for (unsigned long i = 0; i < hosts.size(); ++i) + { + hostinfo info(hosts[i], i+1); + + dlib::serialize(targets, cons[info.node_id]->stream); + targets.push_back(info); + + // let the other host know how many incoming connections to expect + const unsigned long num = hosts.size()-targets.size(); + dlib::serialize(num, cons[info.node_id]->stream); + cons[info.node_id]->stream.flush(); + } + } + + // ------------------------------------------------------------------------------------ + + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace impl2 + { + // These control bytes are sent before each message between nodes. Note that many + // of these are only sent between the control node (node 0) and the other nodes. + // This is because the controller node is responsible for handling the + // synchronization that needs to happen when all nodes block on calls to + // receive_data() + // at the same time. + + // denotes a normal content message. + const static char MESSAGE_HEADER = 0; + + // sent to the controller node when someone receives a message via receive_data(). + const static char GOT_MESSAGE = 1; + + // sent to the controller node when someone sends a message via send(). + const static char SENT_MESSAGE = 2; + + // sent to the controller node when someone enters a call to receive_data() + const static char IN_WAITING_STATE = 3; + + // broadcast when a node terminates itself. + const static char NODE_TERMINATE = 5; + + // broadcast by the controller node when it determines that all nodes are blocked + // on calls to receive_data() and there aren't any messages in flight. This is also + // what makes us go to the next epoch. + const static char SEE_ALL_IN_WAITING_STATE = 6; + + // This isn't ever transmitted between nodes. It is used internally to indicate + // that an error occurred. + const static char READ_ERROR = 7; + + // ------------------------------------------------------------------------------------ + + void read_thread ( + impl1::bsp_con* con, + unsigned long node_id, + unsigned long sender_id, + impl1::thread_safe_message_queue& msg_buffer + ) + { + try + { + while(true) + { + impl1::msg_data msg; + deserialize(msg.msg_type, con->stream); + msg.sender_id = sender_id; + + if (msg.msg_type == MESSAGE_HEADER) + { + msg.data.reset(new std::vector); + deserialize(msg.epoch, con->stream); + deserialize(*msg.data, con->stream); + } + + msg_buffer.push_and_consume(msg); + + if (msg.msg_type == NODE_TERMINATE) + break; + } + } + catch (std::exception& e) + { + impl1::msg_data msg; + msg.data.reset(new std::vector); + vectorstream sout(*msg.data); + sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n"; + sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl; + sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl; + sout << " Receiving processing node id: " << node_id << std::endl; + sout << " Error message in the exception: " << e.what() << std::endl; + + msg.sender_id = sender_id; + msg.msg_type = READ_ERROR; + + msg_buffer.push_and_consume(msg); + } + catch (...) + { + impl1::msg_data msg; + msg.data.reset(new std::vector); + vectorstream sout(*msg.data); + sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n"; + sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl; + sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl; + sout << " Receiving processing node id: " << node_id << std::endl; + + msg.sender_id = sender_id; + msg.msg_type = READ_ERROR; + + msg_buffer.push_and_consume(msg); + } + } + + // ------------------------------------------------------------------------------------ + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// IMPLEMENTATION OF bsp_context OBJECT MEMBERS +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void bsp_context:: + close_all_connections_gracefully( + ) + { + if (node_id() != 0) + { + _cons.reset(); + while (_cons.move_next()) + { + // tell the other end that we are intentionally dropping the connection + serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream); + _cons.element().value()->stream.flush(); + } + } + + impl1::msg_data msg; + // now wait for all the other nodes to terminate + while (num_terminated_nodes < _cons.size() ) + { + if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0) + { + num_waiting_nodes = 0; + broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE); + ++current_epoch; + } + + if (!msg_buffer.pop(msg)) + throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); + + if (msg.msg_type == impl2::NODE_TERMINATE) + { + ++num_terminated_nodes; + _cons[msg.sender_id]->terminated = true; + } + else if (msg.msg_type == impl2::READ_ERROR) + { + throw dlib::socket_error(msg.data_to_string()); + } + else if (msg.msg_type == impl2::MESSAGE_HEADER) + { + throw dlib::socket_error("A BSP node received a message after it has terminated."); + } + else if (msg.msg_type == impl2::GOT_MESSAGE) + { + --num_waiting_nodes; + --outstanding_messages; + } + else if (msg.msg_type == impl2::SENT_MESSAGE) + { + ++outstanding_messages; + } + else if (msg.msg_type == impl2::IN_WAITING_STATE) + { + ++num_waiting_nodes; + } + } + + if (node_id() == 0) + { + _cons.reset(); + while (_cons.move_next()) + { + // tell the other end that we are intentionally dropping the connection + serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream); + _cons.element().value()->stream.flush(); + } + + if (outstanding_messages != 0) + { + std::ostringstream sout; + sout << "A BSP job was allowed to terminate before all sent messages have been received.\n"; + sout << "There are at least " << outstanding_messages << " messages still in flight. Make sure all sent messages\n"; + sout << "have a corresponding call to receive()."; + throw dlib::socket_error(sout.str()); + } + } + } + +// ---------------------------------------------------------------------------------------- + + bsp_context:: + ~bsp_context() + { + _cons.reset(); + while (_cons.move_next()) + { + _cons.element().value()->con->shutdown(); + } + + msg_buffer.disable(); + + // this will wait for all the threads to terminate + threads.clear(); + } + +// ---------------------------------------------------------------------------------------- + + bsp_context:: + bsp_context( + unsigned long node_id_, + impl1::map_id_to_con& cons_ + ) : + outstanding_messages(0), + num_waiting_nodes(0), + num_terminated_nodes(0), + current_epoch(1), + _cons(cons_), + _node_id(node_id_) + { + // spawn a bunch of read threads, one for each connection + _cons.reset(); + while (_cons.move_next()) + { + std::unique_ptr ptr(new thread_function(&impl2::read_thread, + _cons.element().value().get(), + _node_id, + _cons.element().key(), + ref(msg_buffer))); + threads.push_back(ptr); + } + + } + +// ---------------------------------------------------------------------------------------- + + bool bsp_context:: + receive_data ( + std::shared_ptr >& item, + unsigned long& sending_node_id + ) + { + notify_control_node(impl2::IN_WAITING_STATE); + + while (true) + { + // If there aren't any nodes left to give us messages then return right now. + // We need to check the msg_buffer size to make sure there aren't any + // unprocessed message there. Recall that this can happen because status + // messages always jump to the front of the message buffer. So we might have + // learned about the node terminations before processing their messages for us. + if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0) + { + return false; + } + + // if all running nodes are currently blocking forever on receive_data() + if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size()) + { + num_waiting_nodes = 0; + broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE); + + // Note that the reason we have this epoch counter is so we can tell if a + // sent message is from before or after one of these "all nodes waiting" + // synchronization events. If we didn't have the epoch count we would have + // a race condition where one node gets the SEE_ALL_IN_WAITING_STATE + // message before others and then sends out a message to another node + // before that node got the SEE_ALL_IN_WAITING_STATE message. Then that + // node would think the normal message came before SEE_ALL_IN_WAITING_STATE + // which would be bad. + ++current_epoch; + return false; + } + + impl1::msg_data data; + if (!msg_buffer.pop(data, current_epoch)) + throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); + + + switch(data.msg_type) + { + case impl2::MESSAGE_HEADER: { + item = data.data; + sending_node_id = data.sender_id; + notify_control_node(impl2::GOT_MESSAGE); + return true; + } break; + + case impl2::IN_WAITING_STATE: { + ++num_waiting_nodes; + } break; + + case impl2::GOT_MESSAGE: { + --outstanding_messages; + --num_waiting_nodes; + } break; + + case impl2::SENT_MESSAGE: { + ++outstanding_messages; + } break; + + case impl2::NODE_TERMINATE: { + ++num_terminated_nodes; + _cons[data.sender_id]->terminated = true; + } break; + + case impl2::SEE_ALL_IN_WAITING_STATE: { + ++current_epoch; + return false; + } break; + + case impl2::READ_ERROR: { + throw dlib::socket_error(data.data_to_string()); + } break; + + default: { + throw dlib::socket_error("Unknown message received by dlib::bsp_context"); + } break; + } // end switch() + } // end while (true) + } + +// ---------------------------------------------------------------------------------------- + + void bsp_context:: + notify_control_node ( + char val + ) + { + if (node_id() == 0) + { + using namespace impl2; + switch(val) + { + case SENT_MESSAGE: { + ++outstanding_messages; + } break; + + case GOT_MESSAGE: { + --outstanding_messages; + } break; + + case IN_WAITING_STATE: { + // nothing to do in this case + } break; + + default: + DLIB_CASSERT(false,"This should never happen"); + } + } + else + { + serialize(val, _cons[0]->stream); + _cons[0]->stream.flush(); + } + } + +// ---------------------------------------------------------------------------------------- + + void bsp_context:: + broadcast_byte ( + char val + ) + { + for (unsigned long i = 0; i < number_of_nodes(); ++i) + { + // don't send to yourself or to terminated nodes + if (i == node_id() || _cons[i]->terminated) + continue; + + serialize(val, _cons[i]->stream); + _cons[i]->stream.flush(); + } + } + +// ---------------------------------------------------------------------------------------- + + void bsp_context:: + send_data( + const std::vector& item, + unsigned long target_node_id + ) + { + using namespace impl2; + if (_cons[target_node_id]->terminated) + throw socket_error("Attempt to send a message to a node that has terminated."); + + serialize(MESSAGE_HEADER, _cons[target_node_id]->stream); + serialize(current_epoch, _cons[target_node_id]->stream); + serialize(item, _cons[target_node_id]->stream); + _cons[target_node_id]->stream.flush(); + + notify_control_node(SENT_MESSAGE); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BSP_CPph_ + diff --git a/dlib/bsp/bsp.h b/dlib/bsp/bsp.h new file mode 100644 index 0000000000000000000000000000000000000000..f0732c15380ca413850625887b3c8c21708261a4 --- /dev/null +++ b/dlib/bsp/bsp.h @@ -0,0 +1,1043 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BsP_Hh_ +#define DLIB_BsP_Hh_ + +#include "bsp_abstract.h" + +#include +#include +#include + +#include "../sockets.h" +#include "../array.h" +#include "../sockstreambuf.h" +#include "../string.h" +#include "../serialize.h" +#include "../map.h" +#include "../ref.h" +#include "../vectorstream.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl1 + { + inline void null_notify( + unsigned short + ) {} + + struct bsp_con + { + bsp_con( + const network_address& dest + ) : + con(connect(dest)), + buf(con), + stream(&buf), + terminated(false) + { + con->disable_nagle(); + } + + bsp_con( + std::unique_ptr& conptr + ) : + buf(conptr), + stream(&buf), + terminated(false) + { + // make sure we own the connection + conptr.swap(con); + + con->disable_nagle(); + } + + std::unique_ptr con; + sockstreambuf buf; + std::iostream stream; + bool terminated; + }; + + typedef dlib::map >::kernel_1a_c map_id_to_con; + + void connect_all ( + map_id_to_con& cons, + const std::vector& hosts, + unsigned long node_id + ); + /*! + ensures + - creates connections to all the given hosts and stores them into cons + !*/ + + void send_out_connection_orders ( + map_id_to_con& cons, + const std::vector& hosts + ); + + // ------------------------------------------------------------------------------------ + + struct hostinfo + { + hostinfo() {} + hostinfo ( + const network_address& addr_, + unsigned long node_id_ + ) : + addr(addr_), + node_id(node_id_) + { + } + + network_address addr; + unsigned long node_id; + }; + + inline void serialize ( + const hostinfo& item, + std::ostream& out + ) + { + dlib::serialize(item.addr, out); + dlib::serialize(item.node_id, out); + } + + inline void deserialize ( + hostinfo& item, + std::istream& in + ) + { + dlib::deserialize(item.addr, in); + dlib::deserialize(item.node_id, in); + } + + // ------------------------------------------------------------------------------------ + + void connect_all_hostinfo ( + map_id_to_con& cons, + const std::vector& hosts, + unsigned long node_id, + std::string& error_string + ); + + // ------------------------------------------------------------------------------------ + + template < + typename port_notify_function_type + > + void listen_and_connect_all( + unsigned long& node_id, + map_id_to_con& cons, + unsigned short port, + port_notify_function_type port_notify_function + ) + { + cons.clear(); + std::unique_ptr list; + const int status = create_listener(list, port); + if (status == PORTINUSE) + { + throw socket_error("Unable to create listening port " + cast_to_string(port) + + ". The port is already in use"); + } + else if (status != 0) + { + throw socket_error("Unable to create listening port " + cast_to_string(port) ); + } + + port_notify_function(list->get_listening_port()); + + std::unique_ptr con; + if (list->accept(con)) + { + throw socket_error("Error occurred while accepting new connection"); + } + + std::unique_ptr temp(new bsp_con(con)); + + unsigned long remote_node_id; + dlib::deserialize(remote_node_id, temp->stream); + dlib::deserialize(node_id, temp->stream); + std::vector targets; + dlib::deserialize(targets, temp->stream); + unsigned long num_incoming_connections; + dlib::deserialize(num_incoming_connections, temp->stream); + + cons.add(remote_node_id,temp); + + // make a thread that will connect to all the targets + map_id_to_con cons2; + std::string error_string; + thread_function thread(connect_all_hostinfo, dlib::ref(cons2), dlib::ref(targets), node_id, dlib::ref(error_string)); + if (error_string.size() != 0) + throw socket_error(error_string); + + // accept any incoming connections + for (unsigned long i = 0; i < num_incoming_connections; ++i) + { + // If it takes more than 10 seconds for the other nodes to connect to us + // then something has gone horribly wrong and it almost certainly will + // never connect at all. So just give up if that happens. + const unsigned long timeout_milliseconds = 10000; + if (list->accept(con, timeout_milliseconds)) + { + throw socket_error("Error occurred while accepting new connection"); + } + + temp.reset(new bsp_con(con)); + + dlib::deserialize(remote_node_id, temp->stream); + cons.add(remote_node_id,temp); + } + + + // put all the connections created by the thread into cons + thread.wait(); + while (cons2.size() > 0) + { + unsigned long id; + std::unique_ptr temp; + cons2.remove_any(id,temp); + cons.add(id,temp); + } + } + + // ------------------------------------------------------------------------------------ + + struct msg_data + { + std::shared_ptr > data; + unsigned long sender_id; + char msg_type; + dlib::uint64 epoch; + + msg_data() : sender_id(0xFFFFFFFF), msg_type(-1), epoch(0) {} + + std::string data_to_string() const + { + if (data && data->size() != 0) + return std::string(&(*data)[0], data->size()); + else + return ""; + } + }; + + // ------------------------------------------------------------------------------------ + + class thread_safe_message_queue : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple message queue for msg_data objects. Note that it + has the special property that, while messages will generally leave + the queue in the order they are inserted, any message with a smaller + epoch value will always be popped out first. But for all messages + with equal epoch values the queue functions as a normal FIFO queue. + !*/ + private: + struct msg_wrap + { + msg_wrap( + const msg_data& data_, + const dlib::uint64& sequence_number_ + ) : data(data_), sequence_number(sequence_number_) {} + + msg_wrap() : sequence_number(0){} + + msg_data data; + dlib::uint64 sequence_number; + + // Make it so that when msg_wrap objects are in a std::priority_queue, + // messages with a smaller epoch number always come first. Then, within an + // epoch, messages are ordered by their sequence number (so smaller first + // there as well). + bool operator<(const msg_wrap& item) const + { + if (data.epoch < item.data.epoch) + { + return false; + } + else if (data.epoch > item.data.epoch) + { + return true; + } + else + { + if (sequence_number < item.sequence_number) + return false; + else + return true; + } + } + }; + + public: + thread_safe_message_queue() : sig(class_mutex),disabled(false),next_seq_num(1) {} + + ~thread_safe_message_queue() + { + disable(); + } + + void disable() + { + auto_mutex lock(class_mutex); + disabled = true; + sig.broadcast(); + } + + unsigned long size() const + { + auto_mutex lock(class_mutex); + return data.size(); + } + + void push_and_consume( msg_data& item) + { + auto_mutex lock(class_mutex); + data.push(msg_wrap(item, next_seq_num++)); + // do this here so that we don't have to worry about different threads touching the shared_ptr. + item.data.reset(); + sig.signal(); + } + + bool pop ( + msg_data& item + ) + /*! + ensures + - if (this function returns true) then + - #item == the next thing from the queue + - else + - this object is disabled + !*/ + { + auto_mutex lock(class_mutex); + while (data.size() == 0 && !disabled) + sig.wait(); + + if (disabled) + return false; + + item = data.top().data; + data.pop(); + + return true; + } + + bool pop ( + msg_data& item, + const dlib::uint64& max_epoch + ) + /*! + ensures + - if (this function returns true) then + - #item == the next thing from the queue that has an epoch <= max_epoch + - else + - this object is disabled + !*/ + { + auto_mutex lock(class_mutex); + while ((data.size() == 0 || data.top().data.epoch > max_epoch) && !disabled) + sig.wait(); + + if (disabled) + return false; + + item = data.top().data; + data.pop(); + + return true; + } + + private: + std::priority_queue data; + dlib::mutex class_mutex; + dlib::signaler sig; + bool disabled; + dlib::uint64 next_seq_num; + }; + + + } + +// ---------------------------------------------------------------------------------------- + + class bsp_context : noncopyable + { + + public: + + template + void send( + const T& item, + unsigned long target_node_id + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(target_node_id < number_of_nodes() && + target_node_id != node_id(), + "\t void bsp_context::send()" + << "\n\t Invalid arguments were given to this function." + << "\n\t target_node_id: " << target_node_id + << "\n\t node_id(): " << node_id() + << "\n\t number_of_nodes(): " << number_of_nodes() + << "\n\t this: " << this + ); + + std::vector buf; + vectorstream sout(buf); + serialize(item, sout); + send_data(buf, target_node_id); + } + + template + void broadcast ( + const T& item + ) + { + std::vector buf; + vectorstream sout(buf); + serialize(item, sout); + for (unsigned long i = 0; i < number_of_nodes(); ++i) + { + // Don't send to yourself. + if (i == node_id()) + continue; + + send_data(buf, i); + } + } + + unsigned long node_id ( + ) const { return _node_id; } + + unsigned long number_of_nodes ( + ) const { return _cons.size()+1; } + + void receive ( + ) + { + unsigned long id; + std::shared_ptr > temp; + if (receive_data(temp,id)) + throw dlib::socket_error("Call to bsp_context::receive() got an unexpected message."); + } + + template + void receive ( + T& item + ) + { + if(!try_receive(item)) + throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked."); + } + + template + bool try_receive ( + T& item + ) + { + unsigned long sending_node_id; + return try_receive(item, sending_node_id); + } + + template + void receive ( + T& item, + unsigned long& sending_node_id + ) + { + if(!try_receive(item, sending_node_id)) + throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked."); + } + + template + bool try_receive ( + T& item, + unsigned long& sending_node_id + ) + { + std::shared_ptr > temp; + if (receive_data(temp, sending_node_id)) + { + vectorstream sin(*temp); + deserialize(item, sin); + if (sin.peek() != EOF) + throw serialization_error("deserialize() did not consume all bytes produced by serialize(). " + "This probably means you are calling a receive method with a different type " + "of object than the one which was sent."); + return true; + } + else + { + return false; + } + } + + ~bsp_context(); + + private: + + bsp_context(); + + bsp_context( + unsigned long node_id_, + impl1::map_id_to_con& cons_ + ); + + void close_all_connections_gracefully(); + /*! + ensures + - closes all the connections to other nodes and lets them know that + we are terminating normally rather than as the result of some kind + of error. + !*/ + + bool receive_data ( + std::shared_ptr >& item, + unsigned long& sending_node_id + ); + + + void notify_control_node ( + char val + ); + + void broadcast_byte ( + char val + ); + + void send_data( + const std::vector& item, + unsigned long target_node_id + ); + /*! + requires + - target_node_id < number_of_nodes() + - target_node_id != node_id() + ensures + - sends a copy of item to the node with the given id. + !*/ + + + + + unsigned long outstanding_messages; + unsigned long num_waiting_nodes; + unsigned long num_terminated_nodes; + dlib::uint64 current_epoch; + + impl1::thread_safe_message_queue msg_buffer; + + impl1::map_id_to_con& _cons; + const unsigned long _node_id; + array > threads; + + // ----------------------------------- + + template < + typename funct_type + > + friend void bsp_connect ( + const std::vector& hosts, + funct_type funct + ); + + template < + typename funct_type, + typename ARG1 + > + friend void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1 + ); + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + friend void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + friend void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + friend void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + + // ----------------------------------- + + template < + typename port_notify_function_type, + typename funct_type + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1 + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + + // ----------------------------------- + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3,arg4); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3, arg4); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3,arg4); + obj.close_all_connections_gracefully(); + } +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "bsp.cpp" +#endif + +#endif // DLIB_BsP_Hh_ + diff --git a/dlib/bsp/bsp_abstract.h b/dlib/bsp/bsp_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..b87f3a0c3478ecb6c25f651e99d528148b37c4ee --- /dev/null +++ b/dlib/bsp/bsp_abstract.h @@ -0,0 +1,912 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BsP_ABSTRACT_Hh_ +#ifdef DLIB_BsP_ABSTRACT_Hh_ + +#include "../noncopyable.h" +#include "../sockets/sockets_extensions_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class bsp_context : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a tool used to implement algorithms using the Bulk Synchronous + Parallel (BSP) computing model. A BSP algorithm is composed of a number of + processing nodes, each executing in parallel. The general flow of + execution in each processing node is the following: + 1. Do work locally on some data. + 2. Send some messages to other nodes. + 3. Receive messages from other nodes. + 4. Go to step 1 or terminate if complete. + + To do this, each processing node needs an API used to send and receive + messages. This API is implemented by the bsp_connect object which provides + these services to a BSP node. + + Note that BSP processing nodes are spawned using the bsp_connect() and + bsp_listen() routines defined at the bottom of this file. For example, to + start a BSP algorithm consisting of N processing nodes, you would make N-1 + calls to bsp_listen() and one call to bsp_connect(). The call to + bsp_connect() then initiates the computation on all nodes. + + Finally, note that there is no explicit barrier synchronization function + you call at the end of step 3. Instead, you can simply call a method such + as try_receive() until it returns false. That is, the bsp_context's + receive methods incorporate a barrier synchronization that happens once all + the BSP nodes are blocked on receive calls and there are no more messages + in flight. + + + THREAD SAFETY + This object is not thread-safe. In particular, you should only ever have + one thread that works with an instance of this object. This means that, + for example, you should not spawn sub-threads from within a BSP processing + node and have them invoke methods on this object. Instead, you should only + invoke this object's methods from within the BSP processing node's main + thread (i.e. the thread that executes the user supplied function funct()). + !*/ + + public: + + template + void send( + const T& item, + unsigned long target_node_id + ); + /*! + requires + - item is serializable + - target_node_id < number_of_nodes() + - target_node_id != node_id() + ensures + - sends a copy of item to the node with the given id. + throws + - dlib::socket_error: + This exception is thrown if there is an error which prevents us from + delivering the message to the given node. One way this might happen is + if the target node has already terminated its execution or has lost + network connectivity. + !*/ + + template + void broadcast ( + const T& item + ); + /*! + ensures + - item is serializable + - sends a copy of item to all other processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents us from + delivering a message to one of the other nodes. This might happen, for + example, if one of the nodes has terminated its execution or has lost + network connectivity. + !*/ + + unsigned long node_id ( + ) const; + /*! + ensures + - Returns the id of the current processing node. That is, + returns a number N such that: + - N < number_of_nodes() + - N == the node id of the processing node that called node_id(). This + is a number that uniquely identifies the processing node. + !*/ + + unsigned long number_of_nodes ( + ) const; + /*! + ensures + - returns the number of processing nodes participating in the BSP + computation. + !*/ + + template + bool try_receive ( + T& item + ); + /*! + requires + - item is serializable + ensures + - if (this function returns true) then + - #item == the next message which was sent to the calling processing + node. + - else + - The following must have been true for this function to return false: + - All other nodes were blocked on calls to receive(), + try_receive(), or have terminated. + - There were not any messages in flight between any nodes. + - That is, if all the nodes had continued to block on receive + methods then they all would have blocked forever. Therefore, + this function only returns false once there are no more messages + to process by any node and there is no possibility of more being + generated until control is returned to the callers of receive + methods. + - When one BSP node's receive method returns because of the above + conditions then all of them will also return. That is, it is NOT the + case that just a subset of BSP nodes unblock. Moreover, they all + unblock at the same time. + throws + - dlib::socket_error: + This exception is thrown if some error occurs which prevents us from + communicating with other processing nodes. + - dlib::serialization_error or any exception thrown by the global + deserialize(T) routine: + This is thrown if there is a problem in deserialize(). This might + happen if the message sent doesn't match the type T expected by + try_receive(). + !*/ + + template + void receive ( + T& item + ); + /*! + requires + - item is serializable + ensures + - #item == the next message which was sent to the calling processing + node. + - This function is just a wrapper around try_receive() that throws an + exception if a message is not received (i.e. if try_receive() returns + false). + throws + - dlib::socket_error: + This exception is thrown if some error occurs which prevents us from + communicating with other processing nodes or if there was not a message + to receive. + - dlib::serialization_error or any exception thrown by the global + deserialize(T) routine: + This is thrown if there is a problem in deserialize(). This might + happen if the message sent doesn't match the type T expected by + receive(). + !*/ + + template + bool try_receive ( + T& item, + unsigned long& sending_node_id + ); + /*! + requires + - item is serializable + ensures + - if (this function returns true) then + - #item == the next message which was sent to the calling processing + node. + - #sending_node_id == the node id of the node that sent this message. + - #sending_node_id < number_of_nodes() + - else + - The following must have been true for this function to return false: + - All other nodes were blocked on calls to receive(), + try_receive(), or have terminated. + - There were not any messages in flight between any nodes. + - That is, if all the nodes had continued to block on receive + methods then they all would have blocked forever. Therefore, + this function only returns false once there are no more messages + to process by any node and there is no possibility of more being + generated until control is returned to the callers of receive + methods. + - When one BSP node's receive method returns because of the above + conditions then all of them will also return. That is, it is NOT the + case that just a subset of BSP nodes unblock. Moreover, they all + unblock at the same time. + throws + - dlib::socket_error: + This exception is thrown if some error occurs which prevents us from + communicating with other processing nodes. + - dlib::serialization_error or any exception thrown by the global + deserialize(T) routine: + This is thrown if there is a problem in deserialize(). This might + happen if the message sent doesn't match the type T expected by + try_receive(). + !*/ + + template + void receive ( + T& item, + unsigned long& sending_node_id + ); + /*! + requires + - item is serializable + ensures + - #item == the next message which was sent to the calling processing node. + - #sending_node_id == the node id of the node that sent this message. + - #sending_node_id < number_of_nodes() + - This function is just a wrapper around try_receive() that throws an + exception if a message is not received (i.e. if try_receive() returns + false). + throws + - dlib::socket_error: + This exception is thrown if some error occurs which prevents us from + communicating with other processing nodes or if there was not a message + to receive. + - dlib::serialization_error or any exception thrown by the global + deserialize(T) routine: + This is thrown if there is a problem in deserialize(). This might + happen if the message sent doesn't match the type T expected by + receive(). + !*/ + + void receive ( + ); + /*! + ensures + - Waits for the following to all be true: + - All other nodes were blocked on calls to receive(), try_receive(), or + have terminated. + - There are not any messages in flight between any nodes. + - That is, if all the nodes had continued to block on receive methods + then they all would have blocked forever. Therefore, this function + only returns once there are no more messages to process by any node + and there is no possibility of more being generated until control is + returned to the callers of receive methods. + - When one BSP node's receive method returns because of the above + conditions then all of them will also return. That is, it is NOT the + case that just a subset of BSP nodes unblock. Moreover, they all unblock + at the same time. + throws + - dlib::socket_error: + This exception is thrown if some error occurs which prevents us from + communicating with other processing nodes or if a message is received + before this function would otherwise return. + + !*/ + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. + - The processing node with a node ID of 0 will run locally on the machine + calling bsp_connect(). In particular, this node will execute funct(CONTEXT), + which is expected to carry out this node's portion of the BSP computation. + - The other processing nodes are executed on the hosts indicated by the input + argument. In particular, this function interprets hosts as a list addresses + identifying machines running the bsp_listen() or bsp_listen_dynamic_port() + routines. + - This call to bsp_connect() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. + - The processing node with a node ID of 0 will run locally on the machine + calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1), + which is expected to carry out this node's portion of the BSP computation. + - The other processing nodes are executed on the hosts indicated by the input + argument. In particular, this function interprets hosts as a list addresses + identifying machines running the bsp_listen() or bsp_listen_dynamic_port() + routines. + - This call to bsp_connect() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. + - The processing node with a node ID of 0 will run locally on the machine + calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2), + which is expected to carry out this node's portion of the BSP computation. + - The other processing nodes are executed on the hosts indicated by the input + argument. In particular, this function interprets hosts as a list addresses + identifying machines running the bsp_listen() or bsp_listen_dynamic_port() + routines. + - This call to bsp_connect() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. + - The processing node with a node ID of 0 will run locally on the machine + calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2,arg3), + which is expected to carry out this node's portion of the BSP computation. + - The other processing nodes are executed on the hosts indicated by the input + argument. In particular, this function interprets hosts as a list addresses + identifying machines running the bsp_listen() or bsp_listen_dynamic_port() + routines. + - This call to bsp_connect() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. + - The processing node with a node ID of 0 will run locally on the machine + calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2,arg3,arg4), + which is expected to carry out this node's portion of the BSP computation. + - The other processing nodes are executed on the hosts indicated by the input + argument. In particular, this function interprets hosts as a list addresses + identifying machines running the bsp_listen() or bsp_listen_dynamic_port() + routines. + - This call to bsp_connect() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct + ); + /*! + requires + - listening_port != 0 + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT) will be executed and it will + then be able to participate in the BSP computation as one of the processing + nodes. + - This function will listen on TCP port listening_port for a connection from + bsp_connect(). Once the connection is established, it will close the + listening port so it is free for use by other applications. The connection + and BSP computation will continue uninterrupted. + - This call to bsp_listen() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1 + ); + /*! + requires + - listening_port != 0 + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1) will be executed and it will + then be able to participate in the BSP computation as one of the processing + nodes. + - This function will listen on TCP port listening_port for a connection from + bsp_connect(). Once the connection is established, it will close the + listening port so it is free for use by other applications. The connection + and BSP computation will continue uninterrupted. + - This call to bsp_listen() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + /*! + requires + - listening_port != 0 + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2) will be executed and + it will then be able to participate in the BSP computation as one of the + processing nodes. + - This function will listen on TCP port listening_port for a connection from + bsp_connect(). Once the connection is established, it will close the + listening port so it is free for use by other applications. The connection + and BSP computation will continue uninterrupted. + - This call to bsp_listen() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + /*! + requires + - listening_port != 0 + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2,arg3) will be + executed and it will then be able to participate in the BSP computation as + one of the processing nodes. + - This function will listen on TCP port listening_port for a connection from + bsp_connect(). Once the connection is established, it will close the + listening port so it is free for use by other applications. The connection + and BSP computation will continue uninterrupted. + - This call to bsp_listen() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + /*! + requires + - listening_port != 0 + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2,arg3,arg4) will be + executed and it will then be able to participate in the BSP computation as + one of the processing nodes. + - This function will listen on TCP port listening_port for a connection from + bsp_connect(). Once the connection is established, it will close the + listening port so it is free for use by other applications. The connection + and BSP computation will continue uninterrupted. + - This call to bsp_listen() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT) must be a valid expression + (i.e. funct must be a function or function object) + - port_notify_function((unsigned short) 1234) must be a valid expression + (i.e. port_notify_function() must be a function or function object taking an + unsigned short) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT) will be executed and it will + then be able to participate in the BSP computation as one of the processing + nodes. + - if (listening_port != 0) then + - This function will listen on TCP port listening_port for a connection + from bsp_connect(). + - else + - An available TCP port number is automatically selected and this function + will listen on it for a connection from bsp_connect(). + - Once a listening port is opened, port_notify_function() is called with the + port number used. This provides a mechanism to find out what listening port + has been used if it is automatically selected. It also allows you to find + out when the routine has begun listening for an incoming connection from + bsp_connect(). + - Once a connection is established, we will close the listening port so it is + free for use by other applications. The connection and BSP computation will + continue uninterrupted. + - This call to bsp_listen_dynamic_port() blocks until the BSP computation has + completed on all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1) must be a valid expression + (i.e. funct must be a function or function object) + - port_notify_function((unsigned short) 1234) must be a valid expression + (i.e. port_notify_function() must be a function or function object taking an + unsigned short) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1) will be executed and it + will then be able to participate in the BSP computation as one of the + processing nodes. + - if (listening_port != 0) then + - This function will listen on TCP port listening_port for a connection + from bsp_connect(). + - else + - An available TCP port number is automatically selected and this function + will listen on it for a connection from bsp_connect(). + - Once a listening port is opened, port_notify_function() is called with the + port number used. This provides a mechanism to find out what listening port + has been used if it is automatically selected. It also allows you to find + out when the routine has begun listening for an incoming connection from + bsp_connect(). + - Once a connection is established, we will close the listening port so it is + free for use by other applications. The connection and BSP computation will + continue uninterrupted. + - This call to bsp_listen_dynamic_port() blocks until the BSP computation has + completed on all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2) must be a valid expression + (i.e. funct must be a function or function object) + - port_notify_function((unsigned short) 1234) must be a valid expression + (i.e. port_notify_function() must be a function or function object taking an + unsigned short) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2) will be executed and + it will then be able to participate in the BSP computation as one of the + processing nodes. + - if (listening_port != 0) then + - This function will listen on TCP port listening_port for a connection + from bsp_connect(). + - else + - An available TCP port number is automatically selected and this function + will listen on it for a connection from bsp_connect(). + - Once a listening port is opened, port_notify_function() is called with the + port number used. This provides a mechanism to find out what listening port + has been used if it is automatically selected. It also allows you to find + out when the routine has begun listening for an incoming connection from + bsp_connect(). + - Once a connection is established, we will close the listening port so it is + free for use by other applications. The connection and BSP computation will + continue uninterrupted. + - This call to bsp_listen_dynamic_port() blocks until the BSP computation has + completed on all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3) must be a valid expression + (i.e. funct must be a function or function object) + - port_notify_function((unsigned short) 1234) must be a valid expression + (i.e. port_notify_function() must be a function or function object taking an + unsigned short) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2,arg3) will be + executed and it will then be able to participate in the BSP computation as + one of the processing nodes. + - if (listening_port != 0) then + - This function will listen on TCP port listening_port for a connection + from bsp_connect(). + - else + - An available TCP port number is automatically selected and this function + will listen on it for a connection from bsp_connect(). + - Once a listening port is opened, port_notify_function() is called with the + port number used. This provides a mechanism to find out what listening port + has been used if it is automatically selected. It also allows you to find + out when the routine has begun listening for an incoming connection from + bsp_connect(). + - Once a connection is established, we will close the listening port so it is + free for use by other applications. The connection and BSP computation will + continue uninterrupted. + - This call to bsp_listen_dynamic_port() blocks until the BSP computation has + completed on all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression + (i.e. funct must be a function or function object) + - port_notify_function((unsigned short) 1234) must be a valid expression + (i.e. port_notify_function() must be a function or function object taking an + unsigned short) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2,arg3,arg4) will be + executed and it will then be able to participate in the BSP computation as + one of the processing nodes. + - if (listening_port != 0) then + - This function will listen on TCP port listening_port for a connection + from bsp_connect(). + - else + - An available TCP port number is automatically selected and this function + will listen on it for a connection from bsp_connect(). + - Once a listening port is opened, port_notify_function() is called with the + port number used. This provides a mechanism to find out what listening port + has been used if it is automatically selected. It also allows you to find + out when the routine has begun listening for an incoming connection from + bsp_connect(). + - Once a connection is established, we will close the listening port so it is + free for use by other applications. The connection and BSP computation will + continue uninterrupted. + - This call to bsp_listen_dynamic_port() blocks until the BSP computation has + completed on all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BsP_ABSTRACT_Hh_ + diff --git a/dlib/byte_orderer.h b/dlib/byte_orderer.h new file mode 100644 index 0000000000000000000000000000000000000000..bc8f6108da769eac99aea17cc51b48aa98bdaf51 --- /dev/null +++ b/dlib/byte_orderer.h @@ -0,0 +1,10 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BYTE_ORDEREr_ +#define DLIB_BYTE_ORDEREr_ + + +#include "byte_orderer/byte_orderer_kernel_1.h" + +#endif // DLIB_BYTE_ORDEREr_ + diff --git a/dlib/byte_orderer/byte_orderer_kernel_1.h b/dlib/byte_orderer/byte_orderer_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..9f8e8342f4ecab9078db51cf0903e2143adef9cc --- /dev/null +++ b/dlib/byte_orderer/byte_orderer_kernel_1.h @@ -0,0 +1,176 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BYTE_ORDEREr_KERNEL_1_ +#define DLIB_BYTE_ORDEREr_KERNEL_1_ + +#include "byte_orderer_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + + class byte_orderer + { + /*! + INITIAL VALUE + - if (this machine is little endian) then + - little_endian == true + - else + - little_endian == false + + CONVENTION + - host_is_big_endian() == !little_endian + - host_is_little_endian() == little_endian + + - if (this machine is little endian) then + - little_endian == true + - else + - little_endian == false + + + !*/ + + + public: + + // this is here for backwards compatibility with older versions of dlib. + typedef byte_orderer kernel_1a; + + byte_orderer ( + ) + { + // This will probably never be false but if it is then it means chars are not 8bits + // on this system. Which is a problem for this object. + COMPILE_TIME_ASSERT(sizeof(short) >= 2); + + unsigned long temp = 1; + unsigned char* ptr = reinterpret_cast(&temp); + if (*ptr == 1) + little_endian = true; + else + little_endian = false; + } + + virtual ~byte_orderer ( + ){} + + bool host_is_big_endian ( + ) const { return !little_endian; } + + bool host_is_little_endian ( + ) const { return little_endian; } + + template < + typename T + > + inline void host_to_network ( + T& item + ) const + { if (little_endian) flip(item); } + + template < + typename T + > + inline void network_to_host ( + T& item + ) const { if (little_endian) flip(item); } + + template < + typename T + > + void host_to_big ( + T& item + ) const { if (little_endian) flip(item); } + + template < + typename T + > + void big_to_host ( + T& item + ) const { if (little_endian) flip(item); } + + template < + typename T + > + void host_to_little ( + T& item + ) const { if (!little_endian) flip(item); } + + template < + typename T + > + void little_to_host ( + T& item + ) const { if (!little_endian) flip(item); } + + + private: + + template < + typename T, + size_t size + > + inline void flip ( + T (&array)[size] + ) const + /*! + ensures + - flips the bytes in every element of this array + !*/ + { + for (size_t i = 0; i < size; ++i) + { + flip(array[i]); + } + } + + template < + typename T + > + inline void flip ( + T& item + ) const + /*! + ensures + - reverses the byte ordering in item + !*/ + { + DLIB_ASSERT_HAS_STANDARD_LAYOUT(T); + + T value; + + // If you are getting this as an error then you are probably using + // this object wrong. If you think you aren't then send me (Davis) an + // email and I'll either set you straight or change/remove this check so + // your stuff works :) + COMPILE_TIME_ASSERT(sizeof(T) <= sizeof(long double)); + + // If you are getting a compile error on this line then it means T is + // a pointer type. It doesn't make any sense to byte swap pointers + // since they have no meaning outside the context of their own process. + // So you probably just forgot to dereference that pointer before passing + // it to this function :) + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + + const size_t size = sizeof(T); + unsigned char* const ptr = reinterpret_cast(&item); + unsigned char* const ptr_temp = reinterpret_cast(&value); + for (size_t i = 0; i < size; ++i) + ptr_temp[size-i-1] = ptr[i]; + + item = value; + } + + bool little_endian; + }; + + // make flip not do anything at all for chars + template <> inline void byte_orderer::flip ( char& ) const {} + template <> inline void byte_orderer::flip ( unsigned char& ) const {} + template <> inline void byte_orderer::flip ( signed char& ) const {} +} + +#endif // DLIB_BYTE_ORDEREr_KERNEL_1_ + diff --git a/dlib/byte_orderer/byte_orderer_kernel_abstract.h b/dlib/byte_orderer/byte_orderer_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..f7ea151035f95cfb7b2114f015568e0501458904 --- /dev/null +++ b/dlib/byte_orderer/byte_orderer_kernel_abstract.h @@ -0,0 +1,149 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BYTE_ORDEREr_ABSTRACT_ +#ifdef DLIB_BYTE_ORDEREr_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + + class byte_orderer + { + /*! + INITIAL VALUE + This object has no state. + + WHAT THIS OBJECT REPRESENTS + This object simply provides a mechanism to convert data from a + host machine's own byte ordering to big or little endian and to + also do the reverse. + + It also provides a pair of functions to convert to/from network byte + order where network byte order is big endian byte order. This pair of + functions does the exact same thing as the host_to_big() and big_to_host() + functions and is provided simply so that client code can use the most + self documenting name appropriate. + + Also note that this object is capable of correctly flipping the contents + of arrays when the arrays are declared on the stack. e.g. You can + say things like: + int array[10]; + bo.host_to_network(array); + !*/ + + public: + + byte_orderer ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~byte_orderer ( + ); + /*! + ensures + - any resources associated with *this have been released + !*/ + + bool host_is_big_endian ( + ) const; + /*! + ensures + - if (the host computer is a big endian machine) then + - returns true + - else + - returns false + !*/ + + bool host_is_little_endian ( + ) const; + /*! + ensures + - if (the host computer is a little endian machine) then + - returns true + - else + - returns false + !*/ + + template < + typename T + > + void host_to_network ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from host byte order + to network byte order. + !*/ + + template < + typename T + > + void network_to_host ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from network byte order + to host byte order. + !*/ + + template < + typename T + > + void host_to_big ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from host byte order + to big endian byte order. + !*/ + + template < + typename T + > + void big_to_host ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from big endian byte order + to host byte order. + !*/ + + template < + typename T + > + void host_to_little ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from host byte order + to little endian byte order. + !*/ + + template < + typename T + > + void little_to_host ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from little endian byte order + to host byte order. + !*/ + + }; +} + +#endif // DLIB_BYTE_ORDEREr_ABSTRACT_ + diff --git a/dlib/cassert b/dlib/cassert new file mode 100644 index 0000000000000000000000000000000000000000..eb0e59e4176001d73c2070b3dd7cb2a6f9677eef --- /dev/null +++ b/dlib/cassert @@ -0,0 +1 @@ +#include "dlib_include_path_tutorial.txt" diff --git a/dlib/clustering.h b/dlib/clustering.h new file mode 100644 index 0000000000000000000000000000000000000000..3cbd6cfd431171ae275561813abbc84d73ad2937 --- /dev/null +++ b/dlib/clustering.h @@ -0,0 +1,13 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CLuSTERING_ +#define DLIB_CLuSTERING_ + +#include "clustering/modularity_clustering.h" +#include "clustering/chinese_whispers.h" +#include "clustering/spectral_cluster.h" +#include "clustering/bottom_up_cluster.h" +#include "svm/kkmeans.h" + +#endif // DLIB_CLuSTERING_ + diff --git a/dlib/clustering/bottom_up_cluster.h b/dlib/clustering/bottom_up_cluster.h new file mode 100644 index 0000000000000000000000000000000000000000..f80b651087dcda0d3d6e050208142e84c7fd1bc1 --- /dev/null +++ b/dlib/clustering/bottom_up_cluster.h @@ -0,0 +1,253 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BOTTOM_uP_CLUSTER_Hh_ +#define DLIB_BOTTOM_uP_CLUSTER_Hh_ + +#include +#include + +#include "bottom_up_cluster_abstract.h" +#include "../algs.h" +#include "../matrix.h" +#include "../disjoint_subsets.h" +#include "../graph_utils.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace buc_impl + { + inline void merge_sets ( + matrix& dists, + unsigned long dest, + unsigned long src + ) + { + for (long r = 0; r < dists.nr(); ++r) + dists(dest,r) = dists(r,dest) = std::max(dists(r,dest), dists(r,src)); + } + + struct compare_dist + { + bool operator() ( + const sample_pair& a, + const sample_pair& b + ) const + { + return a.distance() > b.distance(); + } + }; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + unsigned long bottom_up_cluster ( + const matrix_exp& dists_, + std::vector& labels, + unsigned long min_num_clusters, + double max_dist = std::numeric_limits::infinity() + ) + { + matrix dists = matrix_cast(dists_); + // make sure requires clause is not broken + DLIB_CASSERT(dists.nr() == dists.nc() && min_num_clusters > 0, + "\t unsigned long bottom_up_cluster()" + << "\n\t Invalid inputs were given to this function." + << "\n\t dists.nr(): " << dists.nr() + << "\n\t dists.nc(): " << dists.nc() + << "\n\t min_num_clusters: " << min_num_clusters + ); + + using namespace buc_impl; + + labels.resize(dists.nr()); + disjoint_subsets sets; + sets.set_size(dists.nr()); + if (labels.size() == 0) + return 0; + + // push all the edges in the graph into a priority queue so the best edges to merge + // come first. + std::priority_queue, compare_dist> que; + for (long r = 0; r < dists.nr(); ++r) + for (long c = r+1; c < dists.nc(); ++c) + que.push(sample_pair(r,c,dists(r,c))); + + // Now start merging nodes. + for (unsigned long iter = min_num_clusters; iter < sets.size(); ++iter) + { + // find the next best thing to merge. + double best_dist = que.top().distance(); + unsigned long a = sets.find_set(que.top().index1()); + unsigned long b = sets.find_set(que.top().index2()); + que.pop(); + // we have been merging and modifying the distances, so make sure this distance + // is still valid and these guys haven't been merged already. + while(a == b || best_dist < dists(a,b)) + { + // Haven't merged it yet, so put it back in with updated distance for + // reconsideration later. + if (a != b) + que.push(sample_pair(a, b, dists(a, b))); + + best_dist = que.top().distance(); + a = sets.find_set(que.top().index1()); + b = sets.find_set(que.top().index2()); + que.pop(); + } + + + // now merge these sets if the best distance is small enough + if (best_dist > max_dist) + break; + unsigned long news = sets.merge_sets(a,b); + unsigned long olds = (news==a)?b:a; + merge_sets(dists, news, olds); + } + + // figure out which cluster each element is in. Also make sure the labels are + // contiguous. + std::map relabel; + for (unsigned long r = 0; r < labels.size(); ++r) + { + unsigned long l = sets.find_set(r); + // relabel to make contiguous + if (relabel.count(l) == 0) + { + unsigned long next = relabel.size(); + relabel[l] = next; + } + labels[r] = relabel[l]; + } + + + return relabel.size(); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + struct snl_range + { + snl_range() = default; + snl_range(double val) : lower(val), upper(val) {} + snl_range(double l, double u) : lower(l), upper(u) { DLIB_ASSERT(lower <= upper)} + + double lower = 0; + double upper = 0; + + double width() const { return upper-lower; } + bool operator<(const snl_range& item) const { return lower < item.lower; } + }; + + inline snl_range merge(const snl_range& a, const snl_range& b) + { + return snl_range(std::min(a.lower, b.lower), std::max(a.upper, b.upper)); + } + + inline double distance (const snl_range& a, const snl_range& b) + { + return std::max(a.lower,b.lower) - std::min(a.upper,b.upper); + } + + inline std::ostream& operator<< (std::ostream& out, const snl_range& item ) + { + out << "["< segment_number_line ( + const std::vector& x, + const double max_range_width + ) + { + DLIB_CASSERT(max_range_width >= 0); + + // create initial ranges, one for each value in x. So initially, all the ranges have + // width of 0. + std::vector ranges; + for (auto v : x) + ranges.push_back(v); + std::sort(ranges.begin(), ranges.end()); + + std::vector greedy_final_ranges; + if (ranges.size() == 0) + return greedy_final_ranges; + // We will try two different clustering strategies. One that does a simple greedy left + // to right sweep and another that does a bottom up agglomerative clustering. This + // first loop runs the greedy left to right sweep. Then at the end of this routine we + // will return the results that produced the tightest clustering. + greedy_final_ranges.push_back(ranges[0]); + for (size_t i = 1; i < ranges.size(); ++i) + { + auto m = merge(greedy_final_ranges.back(), ranges[i]); + if (m.width() <= max_range_width) + greedy_final_ranges.back() = m; + else + greedy_final_ranges.push_back(ranges[i]); + } + + + // Here we do the bottom up clustering. So compute the edges connecting our ranges. + // We will simply say there are edges between ranges if and only if they are + // immediately adjacent on the number line. + std::vector edges; + for (size_t i = 1; i < ranges.size(); ++i) + edges.push_back(sample_pair(i-1,i, distance(ranges[i-1],ranges[i]))); + std::sort(edges.begin(), edges.end(), order_by_distance); + + disjoint_subsets sets; + sets.set_size(ranges.size()); + + // Now start merging nodes. + for (auto edge : edges) + { + // find the next best thing to merge. + unsigned long a = sets.find_set(edge.index1()); + unsigned long b = sets.find_set(edge.index2()); + + // merge it if it doesn't result in an interval that's too big. + auto m = merge(ranges[a], ranges[b]); + if (m.width() <= max_range_width) + { + unsigned long news = sets.merge_sets(a,b); + ranges[news] = m; + } + } + + // Now create a list of the final ranges. We will do this by keeping track of which + // range we already added to final_ranges. + std::vector final_ranges; + std::vector already_output(ranges.size(), false); + for (unsigned long i = 0; i < sets.size(); ++i) + { + auto s = sets.find_set(i); + if (!already_output[s]) + { + final_ranges.push_back(ranges[s]); + already_output[s] = true; + } + } + + // only use the greedy clusters if they found a clustering with fewer clusters. + // Otherwise, the bottom up clustering probably produced a more sensible clustering. + if (final_ranges.size() <= greedy_final_ranges.size()) + return final_ranges; + else + return greedy_final_ranges; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BOTTOM_uP_CLUSTER_Hh_ + diff --git a/dlib/clustering/bottom_up_cluster_abstract.h b/dlib/clustering/bottom_up_cluster_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..72d362c12574ec3e1ee457232cea481551b600a5 --- /dev/null +++ b/dlib/clustering/bottom_up_cluster_abstract.h @@ -0,0 +1,136 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_ +#ifdef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_ + +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + unsigned long bottom_up_cluster ( + const matrix_exp& dists, + std::vector& labels, + unsigned long min_num_clusters, + double max_dist = std::numeric_limits::infinity() + ); + /*! + requires + - dists.nr() == dists.nc() + - min_num_clusters > 0 + - dists == trans(dists) + (l.e. dists should be symmetric) + ensures + - Runs a bottom up agglomerative clustering algorithm. + - Interprets dists as a matrix that gives the distances between dists.nr() + items. In particular, we take dists(i,j) to be the distance between the ith + and jth element of some set. This function clusters the elements of this set + into at least min_num_clusters (or dists.nr() if there aren't enough + elements). Additionally, within each cluster, the maximum pairwise distance + between any two cluster elements is <= max_dist. + - returns the number of clusters found. + - #labels.size() == dists.nr() + - for all valid i: + - #labels[i] == the cluster ID of the node with index i (i.e. the node + corresponding to the distances dists(i,*)). + - 0 <= #labels[i] < the number of clusters found + (i.e. cluster IDs are assigned contiguously and start at 0) + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + struct snl_range + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents an interval on the real number line. It is used + to store the outputs of the segment_number_line() routine defined below. + !*/ + + snl_range( + ); + /*! + ensures + - #lower == 0 + - #upper == 0 + !*/ + + snl_range( + double val + ); + /*! + ensures + - #lower == val + - #upper == val + !*/ + + snl_range( + double l, + double u + ); + /*! + requires + - l <= u + ensures + - #lower == l + - #upper == u + !*/ + + double lower; + double upper; + + double width( + ) const { return upper-lower; } + /*! + ensures + - returns the width of this interval on the number line. + !*/ + + bool operator<(const snl_range& item) const { return lower < item.lower; } + /*! + ensures + - provides a total ordering of snl_range objects assuming they are + non-overlapping. + !*/ + }; + + std::ostream& operator<< (std::ostream& out, const snl_range& item ); + /*! + ensures + - prints item to out in the form [lower,upper]. + !*/ + +// ---------------------------------------------------------------------------------------- + + std::vector segment_number_line ( + const std::vector& x, + const double max_range_width + ); + /*! + requires + - max_range_width >= 0 + ensures + - Finds a clustering of the values in x and returns the ranges that define the + clustering. This routine uses a combination of bottom up clustering and a + simple greedy scan to try and find the most compact set of ranges that + contain all the values in x. + - This routine has approximately linear runtime. + - Every value in x will be contained inside one of the returned snl_range + objects; + - All returned snl_range object's will have a width() <= max_range_width and + will also be non-overlapping. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_ + + diff --git a/dlib/clustering/chinese_whispers.h b/dlib/clustering/chinese_whispers.h new file mode 100644 index 0000000000000000000000000000000000000000..332cce1a08c046d1e4d5f59d02c06392ddcd6884 --- /dev/null +++ b/dlib/clustering/chinese_whispers.h @@ -0,0 +1,135 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CHINESE_WHISPErS_Hh_ +#define DLIB_CHINESE_WHISPErS_Hh_ + +#include "chinese_whispers_abstract.h" +#include +#include "../rand.h" +#include "../graph_utils/edge_list_graphs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + inline unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations, + dlib::rand& rnd + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_ordered_by_index(edges), + "\t unsigned long chinese_whispers()" + << "\n\t Invalid inputs were given to this function" + ); + + labels.clear(); + if (edges.size() == 0) + return 0; + + std::vector > neighbors; + find_neighbor_ranges(edges, neighbors); + + // Initialize the labels, each node gets a different label. + labels.resize(neighbors.size()); + for (unsigned long i = 0; i < labels.size(); ++i) + labels[i] = i; + + + for (unsigned long iter = 0; iter < neighbors.size()*num_iterations; ++iter) + { + // Pick a random node. + const unsigned long idx = rnd.get_random_64bit_number()%neighbors.size(); + + // Count how many times each label happens amongst our neighbors. + std::map labels_to_counts; + const unsigned long end = neighbors[idx].second; + for (unsigned long i = neighbors[idx].first; i != end; ++i) + { + labels_to_counts[labels[edges[i].index2()]] += edges[i].distance(); + } + + // find the most common label + std::map::iterator i; + double best_score = -std::numeric_limits::infinity(); + unsigned long best_label = labels[idx]; + for (i = labels_to_counts.begin(); i != labels_to_counts.end(); ++i) + { + if (i->second > best_score) + { + best_score = i->second; + best_label = i->first; + } + } + + labels[idx] = best_label; + } + + + // Remap the labels into a contiguous range. First we find the + // mapping. + std::map label_remap; + for (unsigned long i = 0; i < labels.size(); ++i) + { + const unsigned long next_id = label_remap.size(); + if (label_remap.count(labels[i]) == 0) + label_remap[labels[i]] = next_id; + } + // now apply the mapping to all the labels. + for (unsigned long i = 0; i < labels.size(); ++i) + { + labels[i] = label_remap[labels[i]]; + } + + return label_remap.size(); + } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations, + dlib::rand& rnd + ) + { + std::vector oedges; + convert_unordered_to_ordered(edges, oedges); + std::sort(oedges.begin(), oedges.end(), &order_by_index); + + return chinese_whispers(oedges, labels, num_iterations, rnd); + } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations = 100 + ) + { + dlib::rand rnd; + return chinese_whispers(edges, labels, num_iterations, rnd); + } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations = 100 + ) + { + dlib::rand rnd; + return chinese_whispers(edges, labels, num_iterations, rnd); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CHINESE_WHISPErS_Hh_ + diff --git a/dlib/clustering/chinese_whispers_abstract.h b/dlib/clustering/chinese_whispers_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..7a184c6f9449a38a6ac7140c325e5f3264f80551 --- /dev/null +++ b/dlib/clustering/chinese_whispers_abstract.h @@ -0,0 +1,97 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_ +#ifdef DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_ + +#include +#include "../rand.h" +#include "../graph_utils/ordered_sample_pair_abstract.h" +#include "../graph_utils/sample_pair_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations, + dlib::rand& rnd + ); + /*! + requires + - is_ordered_by_index(edges) == true + ensures + - This function implements the graph clustering algorithm described in the + paper: Chinese Whispers - an Efficient Graph Clustering Algorithm and its + Application to Natural Language Processing Problems by Chris Biemann. + - Interprets edges as a directed graph. That is, it contains the edges on the + said graph and the ordered_sample_pair::distance() values define the edge + weights (larger values indicating a stronger edge connection between the + nodes). If an edge has a distance() value of infinity then it is considered + a "must link" edge. + - returns the number of clusters found. + - #labels.size() == max_index_plus_one(edges) + - for all valid i: + - #labels[i] == the cluster ID of the node with index i in the graph. + - 0 <= #labels[i] < the number of clusters found + (i.e. cluster IDs are assigned contiguously and start at 0) + - Duplicate edges are interpreted as if there had been just one edge with a + distance value equal to the sum of all the duplicate edge's distance values. + - The algorithm performs exactly num_iterations passes over the graph before + terminating. + !*/ + +// ---------------------------------------------------------------------------------------- + + unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations, + dlib::rand& rnd + ); + /*! + ensures + - This function is identical to the above chinese_whispers() routine except + that it operates on a vector of sample_pair objects instead of + ordered_sample_pairs. Therefore, this is simply a convenience routine. In + particular, it is implemented by transforming the given edges into + ordered_sample_pairs and then calling the chinese_whispers() routine defined + above. + !*/ + +// ---------------------------------------------------------------------------------------- + + unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations = 100 + ); + /*! + requires + - is_ordered_by_index(edges) == true + ensures + - performs: return chinese_whispers(edges, labels, num_iterations, rnd) + where rnd is a default initialized dlib::rand object. + !*/ + +// ---------------------------------------------------------------------------------------- + + unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations = 100 + ); + /*! + ensures + - performs: return chinese_whispers(edges, labels, num_iterations, rnd) + where rnd is a default initialized dlib::rand object. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_ + diff --git a/dlib/clustering/modularity_clustering.h b/dlib/clustering/modularity_clustering.h new file mode 100644 index 0000000000000000000000000000000000000000..8b8a0b0a58b75b307c5a4f158cac4605ef0c299d --- /dev/null +++ b/dlib/clustering/modularity_clustering.h @@ -0,0 +1,515 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MODULARITY_ClUSTERING__H__ +#define DLIB_MODULARITY_ClUSTERING__H__ + +#include "modularity_clustering_abstract.h" +#include "../sparse_vector.h" +#include "../graph_utils/edge_list_graphs.h" +#include "../matrix.h" +#include "../rand.h" + +namespace dlib +{ + +// ----------------------------------------------------------------------------------------- + + namespace impl + { + inline double newman_cluster_split ( + dlib::rand& rnd, + const std::vector& edges, + const matrix& node_degrees, // k from the Newman paper + const matrix& Bdiag, // diag(B) from the Newman paper + const double& edge_sum, // m from the Newman paper + matrix& labels, + const double eps, + const unsigned long max_iterations + ) + /*! + requires + - node_degrees.size() == max_index_plus_one(edges) + - Bdiag.size() == max_index_plus_one(edges) + - edges must be sorted according to order_by_index() + ensures + - This routine splits a graph into two subgraphs using the Newman + clustering method. + - returns the modularity obtained when the graph is split according + to the contents of #labels. + - #labels.size() == node_degrees.size() + - for all valid i: #labels(i) == -1 or +1 + - if (this function returns 0) then + - all the labels are equal, i.e. the graph is not split. + !*/ + { + // Scale epsilon so that it is relative to the expected value of an element of a + // unit vector of length node_degrees.size(). + const double power_iter_eps = eps * std::sqrt(1.0/node_degrees.size()); + + // Make a random unit vector and put in labels. + labels.set_size(node_degrees.size()); + for (long i = 0; i < labels.size(); ++i) + labels(i) = rnd.get_random_gaussian(); + labels /= length(labels); + + matrix Bv, Bv_unit; + + // Do the power iteration for a while. + double eig = -1; + double offset = 0; + while (eig < 0) + { + + // any number larger than power_iter_eps + double iteration_change = power_iter_eps*2+1; + for (unsigned long i = 0; i < max_iterations && iteration_change > power_iter_eps; ++i) + { + sparse_matrix_vector_multiply(edges, labels, Bv); + Bv -= dot(node_degrees, labels)/(2*edge_sum) * node_degrees; + + if (offset != 0) + { + Bv -= offset*labels; + } + + + const double len = length(Bv); + if (len != 0) + { + Bv_unit = Bv/len; + iteration_change = max(abs(labels-Bv_unit)); + labels.swap(Bv_unit); + } + else + { + // Had a bad time, pick another random vector and try it with the + // power iteration. + for (long i = 0; i < labels.size(); ++i) + labels(i) = rnd.get_random_gaussian(); + } + } + + eig = dot(Bv,labels); + // we will repeat this loop if the largest eigenvalue is negative + offset = eig; + } + + + for (long i = 0; i < labels.size(); ++i) + { + if (labels(i) > 0) + labels(i) = 1; + else + labels(i) = -1; + } + + + // compute B*labels, store result in Bv. + sparse_matrix_vector_multiply(edges, labels, Bv); + Bv -= dot(node_degrees, labels)/(2*edge_sum) * node_degrees; + + // Do some label refinement. In this step we swap labels if it + // improves the modularity score. + bool flipped_label = true; + while(flipped_label) + { + flipped_label = false; + unsigned long idx = 0; + for (long i = 0; i < labels.size(); ++i) + { + const double val = -2*labels(i); + const double increase = 4*Bdiag(i) + 2*val*Bv(i); + + // if there is an increase in modularity for swapping this label + if (increase > 0) + { + labels(i) *= -1; + while (idx < edges.size() && edges[idx].index1() == (unsigned long)i) + { + const long j = edges[idx].index2(); + Bv(j) += val*edges[idx].distance(); + ++idx; + } + + Bv -= (val*node_degrees(i)/(2*edge_sum))*node_degrees; + + flipped_label = true; + } + else + { + while (idx < edges.size() && edges[idx].index1() == (unsigned long)i) + { + ++idx; + } + } + } + } + + + const double modularity = dot(Bv, labels)/(4*edge_sum); + + return modularity; + } + + // ------------------------------------------------------------------------------------- + + inline unsigned long newman_cluster_helper ( + dlib::rand& rnd, + const std::vector& edges, + const matrix& node_degrees, // k from the Newman paper + const matrix& Bdiag, // diag(B) from the Newman paper + const double& edge_sum, // m from the Newman paper + std::vector& labels, + double modularity_threshold, + const double eps, + const unsigned long max_iterations + ) + /*! + ensures + - returns the number of clusters the data was split into + !*/ + { + matrix l; + const double modularity = newman_cluster_split(rnd,edges,node_degrees,Bdiag,edge_sum,l,eps,max_iterations); + + + // We need to collapse the node index values down to contiguous values. So + // we use the following two vectors to contain the mappings from input index + // values to their corresponding index values in each split. + std::vector left_idx_map(node_degrees.size()); + std::vector right_idx_map(node_degrees.size()); + + // figure out how many nodes went into each side of the split. + unsigned long num_left_split = 0; + unsigned long num_right_split = 0; + for (long i = 0; i < l.size(); ++i) + { + if (l(i) > 0) + { + left_idx_map[i] = num_left_split; + ++num_left_split; + } + else + { + right_idx_map[i] = num_right_split; + ++num_right_split; + } + } + + // do a recursive split if it will improve the modularity. + if (modularity > modularity_threshold && num_left_split > 0 && num_right_split > 0) + { + + // split the node_degrees and Bdiag matrices into left and right split parts + matrix left_node_degrees(num_left_split); + matrix right_node_degrees(num_right_split); + matrix left_Bdiag(num_left_split); + matrix right_Bdiag(num_right_split); + for (long i = 0; i < l.size(); ++i) + { + if (l(i) > 0) + { + left_node_degrees(left_idx_map[i]) = node_degrees(i); + left_Bdiag(left_idx_map[i]) = Bdiag(i); + } + else + { + right_node_degrees(right_idx_map[i]) = node_degrees(i); + right_Bdiag(right_idx_map[i]) = Bdiag(i); + } + } + + + // put the edges from one side of the split into split_edges + std::vector split_edges; + modularity_threshold = 0; + for (unsigned long k = 0; k < edges.size(); ++k) + { + const unsigned long i = edges[k].index1(); + const unsigned long j = edges[k].index2(); + const double d = edges[k].distance(); + if (l(i) > 0 && l(j) > 0) + { + split_edges.push_back(ordered_sample_pair(left_idx_map[i], left_idx_map[j], d)); + modularity_threshold += d; + } + } + modularity_threshold -= sum(left_node_degrees*sum(left_node_degrees))/(2*edge_sum); + modularity_threshold /= 4*edge_sum; + + unsigned long num_left_clusters; + std::vector left_labels; + num_left_clusters = newman_cluster_helper(rnd,split_edges,left_node_degrees,left_Bdiag, + edge_sum,left_labels,modularity_threshold, + eps, max_iterations); + + // now load the other side into split_edges and cluster it as well + split_edges.clear(); + modularity_threshold = 0; + for (unsigned long k = 0; k < edges.size(); ++k) + { + const unsigned long i = edges[k].index1(); + const unsigned long j = edges[k].index2(); + const double d = edges[k].distance(); + if (l(i) < 0 && l(j) < 0) + { + split_edges.push_back(ordered_sample_pair(right_idx_map[i], right_idx_map[j], d)); + modularity_threshold += d; + } + } + modularity_threshold -= sum(right_node_degrees*sum(right_node_degrees))/(2*edge_sum); + modularity_threshold /= 4*edge_sum; + + unsigned long num_right_clusters; + std::vector right_labels; + num_right_clusters = newman_cluster_helper(rnd,split_edges,right_node_degrees,right_Bdiag, + edge_sum,right_labels,modularity_threshold, + eps, max_iterations); + + // Now merge the labels from the two splits. + labels.resize(node_degrees.size()); + for (unsigned long i = 0; i < labels.size(); ++i) + { + // if this node was in the left split + if (l(i) > 0) + { + labels[i] = left_labels[left_idx_map[i]]; + } + else // if this node was in the right split + { + labels[i] = right_labels[right_idx_map[i]] + num_left_clusters; + } + } + + + return num_left_clusters + num_right_clusters; + } + else + { + labels.assign(node_degrees.size(),0); + return 1; + } + + } + } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long newman_cluster ( + const std::vector& edges, + std::vector& labels, + const double eps = 1e-4, + const unsigned long max_iterations = 2000 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_ordered_by_index(edges), + "\t unsigned long newman_cluster()" + << "\n\t Invalid inputs were given to this function" + ); + + labels.clear(); + if (edges.size() == 0) + return 0; + + const unsigned long num_nodes = max_index_plus_one(edges); + + // compute the node_degrees vector, edge_sum value, and diag(B). + matrix node_degrees(num_nodes); + matrix Bdiag(num_nodes); + Bdiag = 0; + double edge_sum = 0; + node_degrees = 0; + for (unsigned long i = 0; i < edges.size(); ++i) + { + node_degrees(edges[i].index1()) += edges[i].distance(); + edge_sum += edges[i].distance(); + if (edges[i].index1() == edges[i].index2()) + Bdiag(edges[i].index1()) += edges[i].distance(); + } + edge_sum /= 2; + Bdiag -= squared(node_degrees)/(2*edge_sum); + + + dlib::rand rnd; + return impl::newman_cluster_helper(rnd,edges,node_degrees,Bdiag,edge_sum,labels,0,eps,max_iterations); + } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long newman_cluster ( + const std::vector& edges, + std::vector& labels, + const double eps = 1e-4, + const unsigned long max_iterations = 2000 + ) + { + std::vector oedges; + convert_unordered_to_ordered(edges, oedges); + std::sort(oedges.begin(), oedges.end(), &order_by_index); + + return newman_cluster(oedges, labels, eps, max_iterations); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline std::vector remap_labels ( + const std::vector& labels, + unsigned long& num_labels + ) + /*! + ensures + - This function takes labels and produces a mapping which maps elements of + labels into the most compact range in [0, max] as possible. In particular, + there won't be any unused integers in the mapped range. + - #num_labels == the number of distinct values in labels. + - returns a vector V such that: + - V.size() == labels.size() + - max(mat(V))+1 == num_labels. + - for all valid i,j: + - if (labels[i] == labels[j]) then + - V[i] == V[j] + - else + - V[i] != V[j] + !*/ + { + std::map temp; + for (unsigned long i = 0; i < labels.size(); ++i) + { + if (temp.count(labels[i]) == 0) + { + const unsigned long next = temp.size(); + temp[labels[i]] = next; + } + } + + num_labels = temp.size(); + + std::vector result(labels.size()); + for (unsigned long i = 0; i < labels.size(); ++i) + { + result[i] = temp[labels[i]]; + } + return result; + } + } + +// ---------------------------------------------------------------------------------------- + + inline double modularity ( + const std::vector& edges, + const std::vector& labels + ) + { + const unsigned long num_nodes = max_index_plus_one(edges); + // make sure requires clause is not broken + DLIB_ASSERT(labels.size() == num_nodes, + "\t double modularity()" + << "\n\t Invalid inputs were given to this function" + ); + + unsigned long num_labels; + const std::vector& labels_ = dlib::impl::remap_labels(labels,num_labels); + + std::vector cluster_sums(num_labels,0); + std::vector k(num_nodes,0); + + double Q = 0; + double m = 0; + for (unsigned long i = 0; i < edges.size(); ++i) + { + const unsigned long n1 = edges[i].index1(); + const unsigned long n2 = edges[i].index2(); + k[n1] += edges[i].distance(); + if (n1 != n2) + k[n2] += edges[i].distance(); + + if (n1 != n2) + m += edges[i].distance(); + else + m += edges[i].distance()/2; + + if (labels_[n1] == labels_[n2]) + { + if (n1 != n2) + Q += 2*edges[i].distance(); + else + Q += edges[i].distance(); + } + } + + if (m == 0) + return 0; + + for (unsigned long i = 0; i < labels_.size(); ++i) + { + cluster_sums[labels_[i]] += k[i]; + } + + for (unsigned long i = 0; i < labels_.size(); ++i) + { + Q -= k[i]*cluster_sums[labels_[i]]/(2*m); + } + + return 1.0/(2*m)*Q; + } + +// ---------------------------------------------------------------------------------------- + + inline double modularity ( + const std::vector& edges, + const std::vector& labels + ) + { + const unsigned long num_nodes = max_index_plus_one(edges); + // make sure requires clause is not broken + DLIB_ASSERT(labels.size() == num_nodes, + "\t double modularity()" + << "\n\t Invalid inputs were given to this function" + ); + + + unsigned long num_labels; + const std::vector& labels_ = dlib::impl::remap_labels(labels,num_labels); + + std::vector cluster_sums(num_labels,0); + std::vector k(num_nodes,0); + + double Q = 0; + double m = 0; + for (unsigned long i = 0; i < edges.size(); ++i) + { + const unsigned long n1 = edges[i].index1(); + const unsigned long n2 = edges[i].index2(); + k[n1] += edges[i].distance(); + m += edges[i].distance(); + if (labels_[n1] == labels_[n2]) + { + Q += edges[i].distance(); + } + } + + if (m == 0) + return 0; + + for (unsigned long i = 0; i < labels_.size(); ++i) + { + cluster_sums[labels_[i]] += k[i]; + } + + for (unsigned long i = 0; i < labels_.size(); ++i) + { + Q -= k[i]*cluster_sums[labels_[i]]/m; + } + + return 1.0/m*Q; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MODULARITY_ClUSTERING__H__ + diff --git a/dlib/clustering/modularity_clustering_abstract.h b/dlib/clustering/modularity_clustering_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..c1e7c20c41304230cf992b1326e1f5364a41aacf --- /dev/null +++ b/dlib/clustering/modularity_clustering_abstract.h @@ -0,0 +1,125 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_ +#ifdef DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_ + +#include +#include "../graph_utils/ordered_sample_pair_abstract.h" +#include "../graph_utils/sample_pair_abstract.h" + +namespace dlib +{ + +// ----------------------------------------------------------------------------------------- + + double modularity ( + const std::vector& edges, + const std::vector& labels + ); + /*! + requires + - labels.size() == max_index_plus_one(edges) + - for all valid i: + - 0 <= edges[i].distance() < std::numeric_limits::infinity() + ensures + - Interprets edges as an undirected graph. That is, it contains the edges on + the said graph and the sample_pair::distance() values define the edge weights + (larger values indicating a stronger edge connection between the nodes). + - This function returns the modularity value obtained when the given input + graph is broken into subgraphs according to the contents of labels. In + particular, we say that two nodes with indices i and j are in the same + subgraph or community if and only if labels[i] == labels[j]. + - Duplicate edges are interpreted as if there had been just one edge with a + distance value equal to the sum of all the duplicate edge's distance values. + - See the paper Modularity and community structure in networks by M. E. J. Newman + for a detailed definition. + !*/ + +// ---------------------------------------------------------------------------------------- + + double modularity ( + const std::vector& edges, + const std::vector& labels + ); + /*! + requires + - labels.size() == max_index_plus_one(edges) + - for all valid i: + - 0 <= edges[i].distance() < std::numeric_limits::infinity() + ensures + - Interprets edges as a directed graph. That is, it contains the edges on the + said graph and the ordered_sample_pair::distance() values define the edge + weights (larger values indicating a stronger edge connection between the + nodes). Note that, generally, modularity is only really defined for + undirected graphs. Therefore, the "directed graph" given to this function + should have symmetric edges between all nodes. The reason this function is + provided at all is because sometimes a vector of ordered_sample_pair objects + is a useful representation of an undirected graph. + - This function returns the modularity value obtained when the given input + graph is broken into subgraphs according to the contents of labels. In + particular, we say that two nodes with indices i and j are in the same + subgraph or community if and only if labels[i] == labels[j]. + - Duplicate edges are interpreted as if there had been just one edge with a + distance value equal to the sum of all the duplicate edge's distance values. + - See the paper Modularity and community structure in networks by M. E. J. Newman + for a detailed definition. + !*/ + +// ---------------------------------------------------------------------------------------- + + unsigned long newman_cluster ( + const std::vector& edges, + std::vector& labels, + const double eps = 1e-4, + const unsigned long max_iterations = 2000 + ); + /*! + requires + - is_ordered_by_index(edges) == true + - for all valid i: + - 0 <= edges[i].distance() < std::numeric_limits::infinity() + ensures + - This function performs the clustering algorithm described in the paper + Modularity and community structure in networks by M. E. J. Newman. + - This function interprets edges as a graph and attempts to find the labeling + that maximizes modularity(edges, #labels). + - returns the number of clusters found. + - #labels.size() == max_index_plus_one(edges) + - for all valid i: + - #labels[i] == the cluster ID of the node with index i in the graph. + - 0 <= #labels[i] < the number of clusters found + (i.e. cluster IDs are assigned contiguously and start at 0) + - The main computation of the algorithm is involved in finding an eigenvector + of a certain matrix. To do this, we use the power iteration. In particular, + each time we try to find an eigenvector we will let the power iteration loop + at most max_iterations times or until it reaches an accuracy of eps. + Whichever comes first. + !*/ + +// ---------------------------------------------------------------------------------------- + + unsigned long newman_cluster ( + const std::vector& edges, + std::vector& labels, + const double eps = 1e-4, + const unsigned long max_iterations = 2000 + ); + /*! + requires + - for all valid i: + - 0 <= edges[i].distance() < std::numeric_limits::infinity() + ensures + - This function is identical to the above newman_cluster() routine except that + it operates on a vector of sample_pair objects instead of + ordered_sample_pairs. Therefore, this is simply a convenience routine. In + particular, it is implemented by transforming the given edges into + ordered_sample_pairs and then calling the newman_cluster() routine defined + above. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_ + diff --git a/dlib/clustering/spectral_cluster.h b/dlib/clustering/spectral_cluster.h new file mode 100644 index 0000000000000000000000000000000000000000..2cac9870fc86f0af667a05bca5cadc1bccd51843 --- /dev/null +++ b/dlib/clustering/spectral_cluster.h @@ -0,0 +1,80 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SPECTRAL_CLUSTEr_H_ +#define DLIB_SPECTRAL_CLUSTEr_H_ + +#include "spectral_cluster_abstract.h" +#include +#include "../matrix.h" +#include "../svm/kkmeans.h" + +namespace dlib +{ + template < + typename kernel_type, + typename vector_type + > + std::vector spectral_cluster ( + const kernel_type& k, + const vector_type& samples, + const unsigned long num_clusters + ) + { + DLIB_CASSERT(num_clusters > 0, + "\t std::vector spectral_cluster(k,samples,num_clusters)" + << "\n\t num_clusters can't be 0." + ); + + if (num_clusters == 1) + { + // nothing to do, just assign everything to the 0 cluster. + return std::vector(samples.size(), 0); + } + + // compute the similarity matrix. + matrix K(samples.size(), samples.size()); + for (long r = 0; r < K.nr(); ++r) + for (long c = r+1; c < K.nc(); ++c) + K(r,c) = K(c,r) = (double)k(samples[r], samples[c]); + for (long r = 0; r < K.nr(); ++r) + K(r,r) = 0; + + matrix D(K.nr()); + for (long r = 0; r < K.nr(); ++r) + D(r) = sum(rowm(K,r)); + D = sqrt(reciprocal(D)); + K = diagm(D)*K*diagm(D); + matrix u,w,v; + // Use the normal SVD routine unless the matrix is really big, then use the fast + // approximate version. + if (K.nr() < 1000) + svd3(K,u,w,v); + else + svd_fast(K,u,w,v, num_clusters+100, 5); + // Pick out the eigenvectors associated with the largest eigenvalues. + rsort_columns(v,w); + v = colm(v, range(0,num_clusters-1)); + // Now build the normalized spectral vectors, one for each input vector. + std::vector > spec_samps, centers; + for (long r = 0; r < v.nr(); ++r) + { + spec_samps.push_back(trans(rowm(v,r))); + const double len = length(spec_samps.back()); + if (len != 0) + spec_samps.back() /= len; + } + // Finally do the K-means clustering + pick_initial_centers(num_clusters, centers, spec_samps); + find_clusters_using_kmeans(spec_samps, centers); + // And then compute the cluster assignments based on the output of K-means. + std::vector assignments; + for (unsigned long i = 0; i < spec_samps.size(); ++i) + assignments.push_back(nearest_center(centers, spec_samps[i])); + + return assignments; + } + +} + +#endif // DLIB_SPECTRAL_CLUSTEr_H_ + diff --git a/dlib/clustering/spectral_cluster_abstract.h b/dlib/clustering/spectral_cluster_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..880ad80afb10b2b0775491119a39c66e128be6d9 --- /dev/null +++ b/dlib/clustering/spectral_cluster_abstract.h @@ -0,0 +1,43 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_ +#ifdef DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_ + +#include + +namespace dlib +{ + template < + typename kernel_type, + typename vector_type + > + std::vector spectral_cluster ( + const kernel_type& k, + const vector_type& samples, + const unsigned long num_clusters + ); + /*! + requires + - samples must be something with an interface compatible with std::vector. + - The following expression must evaluate to a double or float: + k(samples[i], samples[j]) + - num_clusters > 0 + ensures + - Performs the spectral clustering algorithm described in the paper: + On spectral clustering: Analysis and an algorithm by Ng, Jordan, and Weiss. + and returns the results. + - This function clusters the input data samples into num_clusters clusters and + returns a vector that indicates which cluster each sample falls into. In + particular, we return an array A such that: + - A.size() == samples.size() + - A[i] == the cluster assignment of samples[i]. + - for all valid i: 0 <= A[i] < num_clusters + - The "similarity" of samples[i] with samples[j] is given by + k(samples[i],samples[j]). This means that k() should output a number >= 0 + and the number should be larger for samples that are more similar. + !*/ +} + +#endif // DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_ + + diff --git a/dlib/cmake b/dlib/cmake new file mode 100644 index 0000000000000000000000000000000000000000..224ba3a49160d962a1e1dfa65f1ed85af1bb2cde --- /dev/null +++ b/dlib/cmake @@ -0,0 +1,5 @@ + +cmake_minimum_required(VERSION 3.8.0) + +add_subdirectory(${CMAKE_CURRENT_LIST_DIR} dlib_build) + diff --git a/dlib/cmake_utils/check_if_avx_instructions_executable_on_host.cmake b/dlib/cmake_utils/check_if_avx_instructions_executable_on_host.cmake new file mode 100644 index 0000000000000000000000000000000000000000..4f2cfef933f379d8b7038f1016ed02b7daadca56 --- /dev/null +++ b/dlib/cmake_utils/check_if_avx_instructions_executable_on_host.cmake @@ -0,0 +1,19 @@ +# This script checks if your compiler and host processor can generate and then run programs with AVX instructions. + +cmake_minimum_required(VERSION 3.8.0) + +# Don't rerun this script if its already been executed. +if (DEFINED AVX_IS_AVAILABLE_ON_HOST) + return() +endif() + +# Set to false unless we find out otherwise in the code below. +set(AVX_IS_AVAILABLE_ON_HOST 0) + +try_compile(test_for_avx_worked ${PROJECT_BINARY_DIR}/avx_test_build ${CMAKE_CURRENT_LIST_DIR}/test_for_avx + avx_test) + +if(test_for_avx_worked) + message (STATUS "AVX instructions can be executed by the host processor.") + set(AVX_IS_AVAILABLE_ON_HOST 1) +endif() diff --git a/dlib/cmake_utils/check_if_neon_available.cmake b/dlib/cmake_utils/check_if_neon_available.cmake new file mode 100644 index 0000000000000000000000000000000000000000..895c810b74d71e288078e90a52068342965a951c --- /dev/null +++ b/dlib/cmake_utils/check_if_neon_available.cmake @@ -0,0 +1,20 @@ +# This script checks if __ARM_NEON__ is defined for your compiler + +cmake_minimum_required(VERSION 3.8.0) + +# Don't rerun this script if its already been executed. +if (DEFINED ARM_NEON_IS_AVAILABLE) + return() +endif() + +# Set to false unless we find out otherwise in the code below. +set(ARM_NEON_IS_AVAILABLE 0) + +# test if __ARM_NEON__ is defined +try_compile(test_for_neon_worked ${PROJECT_BINARY_DIR}/neon_test_build ${CMAKE_CURRENT_LIST_DIR}/test_for_neon + neon_test) + +if(test_for_neon_worked) + message (STATUS "__ARM_NEON__ defined.") + set(ARM_NEON_IS_AVAILABLE 1) +endif() diff --git a/dlib/cmake_utils/check_if_sse4_instructions_executable_on_host.cmake b/dlib/cmake_utils/check_if_sse4_instructions_executable_on_host.cmake new file mode 100644 index 0000000000000000000000000000000000000000..c47560997002166e6f965b7f13c8d2b0b7e9d611 --- /dev/null +++ b/dlib/cmake_utils/check_if_sse4_instructions_executable_on_host.cmake @@ -0,0 +1,19 @@ +# This script checks if your compiler and host processor can generate and then run programs with SSE4 instructions. + +cmake_minimum_required(VERSION 3.8.0) + +# Don't rerun this script if its already been executed. +if (DEFINED SSE4_IS_AVAILABLE_ON_HOST) + return() +endif() + +# Set to false unless we find out otherwise in the code below. +set(SSE4_IS_AVAILABLE_ON_HOST 0) + +try_compile(test_for_sse4_worked ${PROJECT_BINARY_DIR}/sse4_test_build ${CMAKE_CURRENT_LIST_DIR}/test_for_sse4 + sse4_test) + +if(test_for_sse4_worked) + message (STATUS "SSE4 instructions can be executed by the host processor.") + set(SSE4_IS_AVAILABLE_ON_HOST 1) +endif() diff --git a/dlib/cmake_utils/dlib.pc.in b/dlib/cmake_utils/dlib.pc.in new file mode 100644 index 0000000000000000000000000000000000000000..906011033605aba89cba300c81e75066713c6ec6 --- /dev/null +++ b/dlib/cmake_utils/dlib.pc.in @@ -0,0 +1,8 @@ +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: @PROJECT_NAME@ +Description: Numerical and networking C++ library +Version: @VERSION@ +Libs: -L${libdir} -ldlib @pkg_config_dlib_needed_libraries@ +Cflags: -I${includedir} @pkg_config_dlib_needed_includes@ diff --git a/dlib/cmake_utils/dlibConfig.cmake.in b/dlib/cmake_utils/dlibConfig.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..2667a2e718bc9bb8dcb824616a49137c05dcf16a --- /dev/null +++ b/dlib/cmake_utils/dlibConfig.cmake.in @@ -0,0 +1,53 @@ +# =================================================================================== +# The dlib CMake configuration file +# +# ** File generated automatically, do not modify ** +# +# Usage from an external project: +# In your CMakeLists.txt, add these lines: +# +# find_package(dlib REQUIRED) +# target_link_libraries(MY_TARGET_NAME dlib::dlib) +# +# =================================================================================== + + + + +# Our library dependencies (contains definitions for IMPORTED targets) +if(NOT TARGET dlib-shared AND NOT dlib_BINARY_DIR) + # Compute paths + get_filename_component(dlib_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) + include("${dlib_CMAKE_DIR}/dlib.cmake") + # Check if Threads::Threads target is required and find it if necessary + get_target_property(dlib_deps_threads_check dlib::dlib INTERFACE_LINK_LIBRARIES) + list(FIND dlib_deps_threads_check "Threads::Threads" dlib_deps_threads_idx) + if (${dlib_deps_threads_idx} GREATER -1) + if (NOT TARGET Threads) + find_package(Threads REQUIRED) + endif() + endif() + unset(dlib_deps_threads_idx) + unset(dlib_deps_threads_check) +endif() + +set(dlib_LIBRARIES dlib::dlib) +set(dlib_LIBS dlib::dlib) +set(dlib_INCLUDE_DIRS "@CMAKE_INSTALL_FULL_INCLUDEDIR@" "@dlib_needed_includes@") + +mark_as_advanced(dlib_LIBRARIES) +mark_as_advanced(dlib_LIBS) +mark_as_advanced(dlib_INCLUDE_DIRS) + +# Mark these variables above as deprecated. +function(__deprecated_var var access) + if(access STREQUAL "READ_ACCESS") + message(WARNING "The variable '${var}' is deprecated! Instead, simply use target_link_libraries(your_app dlib::dlib). See http://dlib.net/examples/CMakeLists.txt.html for an example.") + endif() +endfunction() +variable_watch(dlib_LIBRARIES __deprecated_var) +variable_watch(dlib_LIBS __deprecated_var) +variable_watch(dlib_INCLUDE_DIRS __deprecated_var) + + + diff --git a/dlib/cmake_utils/find_blas.cmake b/dlib/cmake_utils/find_blas.cmake new file mode 100644 index 0000000000000000000000000000000000000000..9a56d7c22ec236d955172c352dd45ee8ac1f1e50 --- /dev/null +++ b/dlib/cmake_utils/find_blas.cmake @@ -0,0 +1,480 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# +# +# This cmake file tries to find installed BLAS and LAPACK libraries. +# It looks for an installed copy of the Intel MKL library first and then +# attempts to find some other BLAS and LAPACK libraries if you don't have +# the Intel MKL. +# +# blas_found - True if BLAS is available +# lapack_found - True if LAPACK is available +# found_intel_mkl - True if the Intel MKL library is available +# found_intel_mkl_headers - True if Intel MKL headers are available +# blas_libraries - link against these to use BLAS library +# lapack_libraries - link against these to use LAPACK library +# mkl_libraries - link against these to use the MKL library +# mkl_include_dir - add to the include path to use the MKL library +# openmp_libraries - Set to Intel's OpenMP library if and only if we +# find the MKL. + +# setting this makes CMake allow normal looking if else statements +SET(CMAKE_ALLOW_LOOSE_LOOP_CONSTRUCTS true) + +SET(blas_found 0) +SET(lapack_found 0) +SET(found_intel_mkl 0) +SET(found_intel_mkl_headers 0) +SET(lapack_with_underscore 0) +SET(lapack_without_underscore 0) + +message(STATUS "Searching for BLAS and LAPACK") +INCLUDE(CheckFunctionExists) + +if (UNIX OR MINGW) + message(STATUS "Searching for BLAS and LAPACK") + + if (BUILDING_MATLAB_MEX_FILE) + # # This commented out stuff would link directly to MATLAB's built in + # BLAS and LAPACK. But it's better to not link to anything and do a + #find_library(MATLAB_BLAS_LIBRARY mwblas PATHS ${MATLAB_LIB_FOLDERS} ) + #find_library(MATLAB_LAPACK_LIBRARY mwlapack PATHS ${MATLAB_LIB_FOLDERS} ) + #if (MATLAB_BLAS_LIBRARY AND MATLAB_LAPACK_LIBRARY) + # add_subdirectory(external/cblas) + # set(blas_libraries ${MATLAB_BLAS_LIBRARY} cblas ) + # set(lapack_libraries ${MATLAB_LAPACK_LIBRARY} ) + # set(blas_found 1) + # set(lapack_found 1) + # message(STATUS "Found MATLAB's BLAS and LAPACK libraries") + #endif() + + # We need cblas since MATLAB doesn't provide cblas symbols. + add_subdirectory(external/cblas) + set(blas_libraries cblas ) + set(blas_found 1) + set(lapack_found 1) + message(STATUS "Will link with MATLAB's BLAS and LAPACK at runtime (hopefully!)") + + + ## Don't try to link to anything other than MATLAB's own internal blas + ## and lapack libraries because doing so generally upsets MATLAB. So + ## we just end here no matter what. + return() + endif() + + + # First, search for libraries via pkg-config, which is the cleanest path + find_package(PkgConfig) + pkg_check_modules(BLAS_REFERENCE cblas) + pkg_check_modules(LAPACK_REFERENCE lapack) + # Make sure the cblas found by pkgconfig actually has cblas symbols. + SET(CMAKE_REQUIRED_LIBRARIES "${BLAS_REFERENCE_LDFLAGS}") + CHECK_FUNCTION_EXISTS(cblas_ddot PKGCFG_HAVE_CBLAS) + if (BLAS_REFERENCE_FOUND AND LAPACK_REFERENCE_FOUND AND PKGCFG_HAVE_CBLAS) + set(blas_libraries "${BLAS_REFERENCE_LDFLAGS}") + set(lapack_libraries "${LAPACK_REFERENCE_LDFLAGS}") + set(blas_found 1) + set(lapack_found 1) + set(REQUIRES_LIBS "${REQUIRES_LIBS} cblas lapack") + message(STATUS "Found BLAS and LAPACK via pkg-config") + return() + endif() + + include(CheckTypeSize) + check_type_size( "void*" SIZE_OF_VOID_PTR) + + if (SIZE_OF_VOID_PTR EQUAL 8) + set( mkl_search_path + /opt/intel/oneapi/mkl/latest/lib/intel64 + /opt/intel/mkl/*/lib/em64t + /opt/intel/mkl/lib/intel64 + /opt/intel/lib/intel64 + /opt/intel/mkl/lib + /opt/intel/tbb/*/lib/em64t/gcc4.7 + /opt/intel/tbb/lib/intel64/gcc4.7 + /opt/intel/tbb/lib/gcc4.7 + ) + + find_library(mkl_intel mkl_intel_lp64 ${mkl_search_path}) + mark_as_advanced(mkl_intel) + else() + set( mkl_search_path + /opt/intel/oneapi/mkl/latest/lib/ia32 + /opt/intel/mkl/*/lib/32 + /opt/intel/mkl/lib/ia32 + /opt/intel/lib/ia32 + /opt/intel/tbb/*/lib/32/gcc4.7 + /opt/intel/tbb/lib/ia32/gcc4.7 + ) + + find_library(mkl_intel mkl_intel ${mkl_search_path}) + mark_as_advanced(mkl_intel) + endif() + + include(CheckLibraryExists) + + # Get mkl_include_dir + set(mkl_include_search_path + /opt/intel/oneapi/mkl/latest/include + /opt/intel/mkl/include + /opt/intel/include + ) + find_path(mkl_include_dir mkl_version.h ${mkl_include_search_path}) + mark_as_advanced(mkl_include_dir) + + if(NOT DLIB_USE_MKL_SEQUENTIAL AND NOT DLIB_USE_MKL_WITH_TBB) + # Search for the needed libraries from the MKL. We will try to link against the mkl_rt + # file first since this way avoids linking bugs in some cases. + find_library(mkl_rt mkl_rt ${mkl_search_path}) + find_library(openmp_libraries iomp5 ${mkl_search_path}) + mark_as_advanced(mkl_rt openmp_libraries) + # if we found the MKL + if (mkl_rt) + set(mkl_libraries ${mkl_rt} ) + set(blas_libraries ${mkl_rt} ) + set(lapack_libraries ${mkl_rt} ) + set(blas_found 1) + set(lapack_found 1) + set(found_intel_mkl 1) + message(STATUS "Found Intel MKL BLAS/LAPACK library") + endif() + endif() + + + if (NOT found_intel_mkl) + # Search for the needed libraries from the MKL. This time try looking for a different + # set of MKL files and try to link against those. + find_library(mkl_core mkl_core ${mkl_search_path}) + set(mkl_libs ${mkl_intel} ${mkl_core}) + mark_as_advanced(mkl_libs mkl_intel mkl_core) + + if (DLIB_USE_MKL_WITH_TBB) + find_library(mkl_tbb_thread mkl_tbb_thread ${mkl_search_path}) + find_library(mkl_tbb tbb ${mkl_search_path}) + mark_as_advanced(mkl_tbb_thread mkl_tbb) + list(APPEND mkl_libs ${mkl_tbb_thread} ${mkl_tbb}) + elseif (DLIB_USE_MKL_SEQUENTIAL) + find_library(mkl_sequential mkl_sequential ${mkl_search_path}) + mark_as_advanced(mkl_sequential) + list(APPEND mkl_libs ${mkl_sequential}) + else() + find_library(mkl_thread mkl_intel_thread ${mkl_search_path}) + find_library(mkl_iomp iomp5 ${mkl_search_path}) + find_library(mkl_pthread pthread ${mkl_search_path}) + mark_as_advanced(mkl_thread mkl_iomp mkl_pthread) + list(APPEND mkl_libs ${mkl_thread} ${mkl_iomp} ${mkl_pthread}) + endif() + + # If we found the MKL + if (mkl_intel AND mkl_core AND ((mkl_tbb_thread AND mkl_tbb) OR (mkl_thread AND mkl_iomp AND mkl_pthread) OR mkl_sequential)) + set(mkl_libraries ${mkl_libs}) + set(blas_libraries ${mkl_libs}) + set(lapack_libraries ${mkl_libs}) + set(blas_found 1) + set(lapack_found 1) + set(found_intel_mkl 1) + message(STATUS "Found Intel MKL BLAS/LAPACK library") + endif() + endif() + + if (found_intel_mkl AND mkl_include_dir) + set(found_intel_mkl_headers 1) + endif() + + # try to find some other LAPACK libraries if we didn't find the MKL + set(extra_paths + /usr/lib64 + /usr/lib64/atlas-sse3 + /usr/lib64/atlas-sse2 + /usr/lib64/atlas + /usr/lib + /usr/lib/atlas-sse3 + /usr/lib/atlas-sse2 + /usr/lib/atlas + /usr/lib/openblas-base + /opt/OpenBLAS/lib + $ENV{OPENBLAS_HOME}/lib + ) + + if (NOT blas_found) + find_library(cblas_lib NAMES openblasp openblas PATHS ${extra_paths}) + if (cblas_lib) + set(blas_libraries ${cblas_lib}) + set(blas_found 1) + message(STATUS "Found OpenBLAS library") + set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries}) + # If you compiled OpenBLAS with LAPACK in it then it should have the + # sgetrf_single function in it. So if we find that function in + # OpenBLAS then just use OpenBLAS's LAPACK. + CHECK_FUNCTION_EXISTS(sgetrf_single OPENBLAS_HAS_LAPACK) + if (OPENBLAS_HAS_LAPACK) + message(STATUS "Using OpenBLAS's built in LAPACK") + # set(lapack_libraries gfortran) + set(lapack_found 1) + endif() + endif() + mark_as_advanced( cblas_lib) + endif() + + + if (NOT lapack_found) + find_library(lapack_lib NAMES lapack lapack-3 PATHS ${extra_paths}) + if (lapack_lib) + set(lapack_libraries ${lapack_lib}) + set(lapack_found 1) + message(STATUS "Found LAPACK library") + endif() + mark_as_advanced( lapack_lib) + endif() + + + # try to find some other BLAS libraries if we didn't find the MKL + + if (NOT blas_found) + find_library(atlas_lib atlas PATHS ${extra_paths}) + find_library(cblas_lib cblas PATHS ${extra_paths}) + if (atlas_lib AND cblas_lib) + set(blas_libraries ${atlas_lib} ${cblas_lib}) + set(blas_found 1) + message(STATUS "Found ATLAS BLAS library") + endif() + mark_as_advanced( atlas_lib cblas_lib) + endif() + + # CentOS 7 atlas + if (NOT blas_found) + find_library(tatlas_lib tatlas PATHS ${extra_paths}) + find_library(satlas_lib satlas PATHS ${extra_paths}) + if (tatlas_lib AND satlas_lib ) + set(blas_libraries ${tatlas_lib} ${satlas_lib}) + set(blas_found 1) + message(STATUS "Found ATLAS BLAS library") + endif() + mark_as_advanced( tatlas_lib satlas_lib) + endif() + + + if (NOT blas_found) + find_library(cblas_lib cblas PATHS ${extra_paths}) + if (cblas_lib) + set(blas_libraries ${cblas_lib}) + set(blas_found 1) + message(STATUS "Found CBLAS library") + endif() + mark_as_advanced( cblas_lib) + endif() + + + if (NOT blas_found) + find_library(generic_blas blas PATHS ${extra_paths}) + if (generic_blas) + set(blas_libraries ${generic_blas}) + set(blas_found 1) + message(STATUS "Found BLAS library") + endif() + mark_as_advanced( generic_blas) + endif() + + + + + # Make sure we really found a CBLAS library. That is, it needs to expose + # the proper cblas link symbols. So here we test if one of them is present + # and assume everything is good if it is. Note that we don't do this check if + # we found the Intel MKL since for some reason CHECK_FUNCTION_EXISTS doesn't work + # with it. But it's fine since the MKL should always have cblas. + if (blas_found AND NOT found_intel_mkl) + set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries}) + CHECK_FUNCTION_EXISTS(cblas_ddot FOUND_BLAS_HAS_CBLAS) + if (NOT FOUND_BLAS_HAS_CBLAS) + message(STATUS "BLAS library does not have cblas symbols, so dlib will not use BLAS or LAPACK") + set(blas_found 0) + set(lapack_found 0) + endif() + endif() + + + +elseif(WIN32 AND NOT MINGW) + message(STATUS "Searching for BLAS and LAPACK") + + include(CheckTypeSize) + check_type_size( "void*" SIZE_OF_VOID_PTR) + if (SIZE_OF_VOID_PTR EQUAL 8) + set( mkl_search_path + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/mkl/lib/intel64" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/tbb/lib/intel64/vc14" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/compiler/lib/intel64" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/mkl/lib/intel64" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/tbb/lib/intel64/vc14" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/tbb/lib/intel64/vc_mt" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/compiler/lib/intel64" + "C:/Program Files (x86)/Intel/Composer XE/mkl/lib/intel64" + "C:/Program Files (x86)/Intel/Composer XE/tbb/lib/intel64/vc14" + "C:/Program Files (x86)/Intel/Composer XE/compiler/lib/intel64" + "C:/Program Files/Intel/Composer XE/mkl/lib/intel64" + "C:/Program Files/Intel/Composer XE/tbb/lib/intel64/vc14" + "C:/Program Files/Intel/Composer XE/compiler/lib/intel64" + "C:/Program Files (x86)/Intel/oneAPI/mkl/*/lib/intel64" + "C:/Program Files (x86)/Intel/oneAPI/compiler/*/windows/compiler/lib/intel64_win" + ) + set (mkl_redist_path + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/redist/intel64/compiler" + "C:/Program Files (x86)/Intel/oneAPI/compiler/*/windows/redist/intel64_win/compiler" + ) + find_library(mkl_intel mkl_intel_lp64 ${mkl_search_path}) + else() + set( mkl_search_path + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/mkl/lib/ia32" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/tbb/lib/ia32/vc14" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/compiler/lib/ia32" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/mkl/lib/ia32" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/tbb/lib/ia32/vc14" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/tbb/lib/ia32/vc_mt" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/compiler/lib/ia32" + "C:/Program Files (x86)/Intel/Composer XE/mkl/lib/ia32" + "C:/Program Files (x86)/Intel/Composer XE/tbb/lib/ia32/vc14" + "C:/Program Files (x86)/Intel/Composer XE/compiler/lib/ia32" + "C:/Program Files/Intel/Composer XE/mkl/lib/ia32" + "C:/Program Files/Intel/Composer XE/tbb/lib/ia32/vc14" + "C:/Program Files/Intel/Composer XE/compiler/lib/ia32" + "C:/Program Files (x86)/Intel/oneAPI/mkl/*/lib/ia32" + "C:/Program Files (x86)/Intel/oneAPI/compiler/*/windows/compiler/lib/ia32_win" + + ) + set (mkl_redist_path + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/redist/ia32/compiler" + "C:/Program Files (x86)/Intel/oneAPI/compiler/*/windows/redist/ia32_win/compiler" + ) + find_library(mkl_intel mkl_intel_c ${mkl_search_path}) + endif() + + + # Get mkl_include_dir + set(mkl_include_search_path + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/mkl/include" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/compiler/include" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/mkl/include" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/compiler/include" + "C:/Program Files (x86)/Intel/Composer XE/mkl/include" + "C:/Program Files (x86)/Intel/Composer XE/compiler/include" + "C:/Program Files/Intel/Composer XE/mkl/include" + "C:/Program Files/Intel/Composer XE/compiler/include" + "C:/Program Files (x86)/Intel/oneAPI/mkl/*/include" + ) + find_path(mkl_include_dir mkl_version.h ${mkl_include_search_path}) + mark_as_advanced(mkl_include_dir) + + # Search for the needed libraries from the MKL. + find_library(mkl_core mkl_core ${mkl_search_path}) + set(mkl_libs ${mkl_intel} ${mkl_core}) + mark_as_advanced(mkl_libs mkl_intel mkl_core) + if (DLIB_USE_MKL_WITH_TBB) + find_library(mkl_tbb_thread mkl_tbb_thread ${mkl_search_path}) + find_library(mkl_tbb tbb ${mkl_search_path}) + mark_as_advanced(mkl_tbb_thread mkl_tbb) + list(APPEND mkl_libs ${mkl_tbb_thread} ${mkl_tbb}) + elseif (DLIB_USE_MKL_SEQUENTIAL) + find_library(mkl_sequential mkl_sequential ${mkl_search_path}) + mark_as_advanced(mkl_sequential) + list(APPEND mkl_libs ${mkl_sequential}) + else() + find_library(mkl_thread mkl_intel_thread ${mkl_search_path}) + mark_as_advanced(mkl_thread) + if (mkl_thread) + find_library(mkl_iomp libiomp5md ${mkl_search_path}) + mark_as_advanced(mkl_iomp) + list(APPEND mkl_libs ${mkl_thread} ${mkl_iomp}) + + # See if we can find the dll that goes with this, so we can copy it to + # the output folder, since a very large number of windows users don't + # understand that they need to add the Intel MKL's folders to their + # PATH to use the Intel MKL. They then complain on the dlib forums. + # Copying the Intel MKL dlls to the output directory removes the need + # to add the Intel MKL to the PATH. + find_file(mkl_iomp_dll "libiomp5md.dll" ${mkl_redist_path}) + if (mkl_iomp_dll) + message(STATUS "FOUND libiomp5md.dll: ${mkl_iomp_dll}") + endif() + endif() + endif() + + # If we found the MKL + if (mkl_intel AND mkl_core AND ((mkl_tbb_thread AND mkl_tbb) OR mkl_sequential OR (mkl_thread AND mkl_iomp))) + set(blas_libraries ${mkl_libs}) + set(lapack_libraries ${mkl_libs}) + set(blas_found 1) + set(lapack_found 1) + set(found_intel_mkl 1) + message(STATUS "Found Intel MKL BLAS/LAPACK library") + + # Make sure the version of the Intel MKL we found is compatible with + # the compiler we are using. One way to do this check is to see if we can + # link to it right now. + set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries}) + CHECK_FUNCTION_EXISTS(cblas_ddot MKL_HAS_CBLAS) + if (NOT MKL_HAS_CBLAS) + message("BLAS library does not have cblas symbols, so dlib will not use BLAS or LAPACK") + set(blas_found 0) + set(lapack_found 0) + endif() + endif() + + if (found_intel_mkl AND mkl_include_dir) + set(found_intel_mkl_headers 1) + endif() + +endif() + + +# When all else fails use CMake's built in functions to find BLAS and LAPACK +if (NOT blas_found) + find_package(BLAS QUIET) + if (${BLAS_FOUND}) + set(blas_libraries ${BLAS_LIBRARIES}) + set(blas_found 1) + if (NOT lapack_found) + find_package(LAPACK QUIET) + if (${LAPACK_FOUND}) + set(lapack_libraries ${LAPACK_LIBRARIES}) + set(lapack_found 1) + endif() + endif() + endif() +endif() + + +# If using lapack, determine whether to mangle functions +if (lapack_found) + include(CheckFortranFunctionExists) + set(CMAKE_REQUIRED_LIBRARIES ${lapack_libraries}) + + check_function_exists("sgesv" LAPACK_FOUND_C_UNMANGLED) + check_function_exists("sgesv_" LAPACK_FOUND_C_MANGLED) + if (CMAKE_Fortran_COMPILER_LOADED) + check_fortran_function_exists("sgesv" LAPACK_FOUND_FORTRAN_UNMANGLED) + check_fortran_function_exists("sgesv_" LAPACK_FOUND_FORTRAN_MANGLED) + endif () + if (LAPACK_FOUND_C_MANGLED OR LAPACK_FOUND_FORTRAN_MANGLED) + set(lapack_with_underscore 1) + elseif (LAPACK_FOUND_C_UNMANGLED OR LAPACK_FOUND_FORTRAN_UNMANGLED) + set(lapack_without_underscore 1) + endif () +endif() + + +if (UNIX OR MINGW) + if (NOT blas_found) + message(" *****************************************************************************") + message(" *** No BLAS library found so using dlib's built in BLAS. However, if you ***") + message(" *** install an optimized BLAS such as OpenBLAS or the Intel MKL your code ***") + message(" *** will run faster. On Ubuntu you can install OpenBLAS by executing: ***") + message(" *** sudo apt-get install libopenblas-dev liblapack-dev ***") + message(" *** Or you can easily install OpenBLAS from source by downloading the ***") + message(" *** source tar file from http://www.openblas.net, extracting it, and ***") + message(" *** running: ***") + message(" *** make; sudo make install ***") + message(" *****************************************************************************") + endif() +endif() diff --git a/dlib/cmake_utils/find_ffmpeg.cmake b/dlib/cmake_utils/find_ffmpeg.cmake new file mode 100644 index 0000000000000000000000000000000000000000..b05955387b3292d5dbb42f1eba81ad3f3c251f60 --- /dev/null +++ b/dlib/cmake_utils/find_ffmpeg.cmake @@ -0,0 +1,29 @@ +cmake_minimum_required(VERSION 3.8.0) + +message(STATUS "Searching for FFMPEG/LIBAV") +find_package(PkgConfig) +if (PkgConfig_FOUND) + pkg_check_modules(FFMPEG IMPORTED_TARGET + libavdevice + libavfilter + libavformat + libavcodec + libswresample + libswscale + libavutil + ) + if (FFMPEG_FOUND) + message(STATUS "Found FFMPEG/LIBAV via pkg-config in `${FFMPEG_LIBRARY_DIRS}`") + else() + message(" *****************************************************************************") + message(" *** No FFMPEG/LIBAV libraries found. ***") + message(" *** On Ubuntu you can install them by executing ***") + message(" *** sudo apt install libavdevice-dev libavfilter-dev libavformat-dev ***") + message(" *** sudo apt install libavcodec-dev libswresample-dev libswscale-dev ***") + message(" *** sudo apt install libavutil-dev ***") + message(" *****************************************************************************") + endif() +else() + message(STATUS "PkgConfig could not be found, FFMPEG won't be available") + set(FFMPEG_FOUND 0) +endif() diff --git a/dlib/cmake_utils/find_libjpeg.cmake b/dlib/cmake_utils/find_libjpeg.cmake new file mode 100644 index 0000000000000000000000000000000000000000..028217b0752ce52e6f75d262b384efa689432f41 --- /dev/null +++ b/dlib/cmake_utils/find_libjpeg.cmake @@ -0,0 +1,38 @@ +#This script just runs CMake's built in JPEG finding tool. But it also checks that the +#copy of libjpeg that cmake finds actually builds and links. + +cmake_minimum_required(VERSION 3.8.0) + +if (BUILDING_PYTHON_IN_MSVC) + # Never use any system copy of libjpeg when building python in visual studio + set(JPEG_FOUND 0) + return() +endif() + +# Don't rerun this script if its already been executed. +if (DEFINED JPEG_FOUND) + return() +endif() + +find_package(JPEG QUIET) + +if(JPEG_FOUND) + set(JPEG_TEST_CMAKE_FLAGS + "-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" + "-DCMAKE_INCLUDE_PATH=${CMAKE_INCLUDE_PATH}" + "-DCMAKE_LIBRARY_PATH=${CMAKE_LIBRARY_PATH}") + + try_compile(test_for_libjpeg_worked + ${PROJECT_BINARY_DIR}/test_for_libjpeg_build + ${CMAKE_CURRENT_LIST_DIR}/test_for_libjpeg + test_if_libjpeg_is_broken + CMAKE_FLAGS "${JPEG_TEST_CMAKE_FLAGS}") + + message (STATUS "Found system copy of libjpeg: ${JPEG_LIBRARY}") + if(NOT test_for_libjpeg_worked) + set(JPEG_FOUND 0) + message (STATUS "System copy of libjpeg is broken or too old. Will build our own libjpeg and use that instead.") + endif() +endif() + + diff --git a/dlib/cmake_utils/find_libpng.cmake b/dlib/cmake_utils/find_libpng.cmake new file mode 100644 index 0000000000000000000000000000000000000000..676073922675681a606245d7326b4c4660ffba8d --- /dev/null +++ b/dlib/cmake_utils/find_libpng.cmake @@ -0,0 +1,37 @@ +#This script just runs CMake's built in PNG finding tool. But it also checks that the +#copy of libpng that cmake finds actually builds and links. + +cmake_minimum_required(VERSION 3.8.0) + +if (BUILDING_PYTHON_IN_MSVC) + # Never use any system copy of libpng when building python in visual studio + set(PNG_FOUND 0) + return() +endif() + +# Don't rerun this script if its already been executed. +if (DEFINED PNG_FOUND) + return() +endif() + +find_package(PNG QUIET) + +if(PNG_FOUND) + set(PNG_TEST_CMAKE_FLAGS + "-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" + "-DCMAKE_INCLUDE_PATH=${CMAKE_INCLUDE_PATH}" + "-DCMAKE_LIBRARY_PATH=${CMAKE_LIBRARY_PATH}") + + try_compile(test_for_libpng_worked + ${PROJECT_BINARY_DIR}/test_for_libpng_build + ${CMAKE_CURRENT_LIST_DIR}/test_for_libpng + test_if_libpng_is_broken + CMAKE_FLAGS "${PNG_TEST_CMAKE_FLAGS}") + + message (STATUS "Found system copy of libpng: ${PNG_LIBRARIES}") + if(NOT test_for_libpng_worked) + set(PNG_FOUND 0) + message (STATUS "System copy of libpng is broken. Will build our own libpng and use that instead.") + endif() +endif() + diff --git a/dlib/cmake_utils/find_libwebp.cmake b/dlib/cmake_utils/find_libwebp.cmake new file mode 100644 index 0000000000000000000000000000000000000000..022528592634856e2780694fe13a9c3a802798ec --- /dev/null +++ b/dlib/cmake_utils/find_libwebp.cmake @@ -0,0 +1,52 @@ +#============================================================================= +# Find WebP library +# From OpenCV +#============================================================================= +# Find the native WebP headers and libraries. +# +# WEBP_INCLUDE_DIRS - where to find webp/decode.h, etc. +# WEBP_LIBRARIES - List of libraries when using webp. +# WEBP_FOUND - True if webp is found. +#============================================================================= + +# Look for the header file. + +unset(WEBP_FOUND) + +find_path(WEBP_INCLUDE_DIR NAMES webp/decode.h) + +if(NOT WEBP_INCLUDE_DIR) + unset(WEBP_FOUND) +else() + mark_as_advanced(WEBP_INCLUDE_DIR) + + # Look for the library. + find_library(WEBP_LIBRARY NAMES webp) + mark_as_advanced(WEBP_LIBRARY) + + # handle the QUIETLY and REQUIRED arguments and set WEBP_FOUND to TRUE if + # all listed variables are TRUE + include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) + find_package_handle_standard_args(WebP DEFAULT_MSG WEBP_LIBRARY WEBP_INCLUDE_DIR) + + set(WEBP_LIBRARIES ${WEBP_LIBRARY}) + set(WEBP_INCLUDE_DIRS ${WEBP_INCLUDE_DIR}) +endif() + +if(WEBP_FOUND) + set(WEBP_TEST_CMAKE_FLAGS + "-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" + "-DCMAKE_INCLUDE_PATH=${CMAKE_INCLUDE_PATH}" + "-DCMAKE_LIBRARY_PATH=${CMAKE_LIBRARY_PATH}") + + try_compile(test_for_libwebp_worked + ${PROJECT_BINARY_DIR}/test_for_libwebp_build + ${CMAKE_CURRENT_LIST_DIR}/test_for_libwebp + test_if_libwebp_is_broken + CMAKE_FLAGS "${WEBP_TEST_CMAKE_FLAGS}") + + if(NOT test_for_libwebp_worked) + set(WEBP_FOUND 0) + message (STATUS "System copy of libwebp is either too old or broken. Will disable WebP support.") + endif() +endif() diff --git a/dlib/cmake_utils/release_build_by_default b/dlib/cmake_utils/release_build_by_default new file mode 100644 index 0000000000000000000000000000000000000000..1b0e958317c7c07af958f6c9e54ae84f682eb47f --- /dev/null +++ b/dlib/cmake_utils/release_build_by_default @@ -0,0 +1,9 @@ + +#set default build type to Release +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release" CACHE STRING + "Choose the type of build, options are: Debug Release + RelWithDebInfo MinSizeRel." FORCE) +endif() + + diff --git a/dlib/cmake_utils/set_compiler_specific_options.cmake b/dlib/cmake_utils/set_compiler_specific_options.cmake new file mode 100644 index 0000000000000000000000000000000000000000..8093ca6d30330444c698b092b08e723c7cbb2ae4 --- /dev/null +++ b/dlib/cmake_utils/set_compiler_specific_options.cmake @@ -0,0 +1,120 @@ + +cmake_minimum_required(VERSION 3.8.0) + + +# Check if we are being built as part of a pybind11 module. +if (COMMAND pybind11_add_module) + # For python users, enable SSE4 and AVX if they have these instructions. + include(${CMAKE_CURRENT_LIST_DIR}/check_if_sse4_instructions_executable_on_host.cmake) + if (SSE4_IS_AVAILABLE_ON_HOST) + set(USE_SSE4_INSTRUCTIONS ON CACHE BOOL "Compile your program with SSE4 instructions") + endif() + include(${CMAKE_CURRENT_LIST_DIR}/check_if_avx_instructions_executable_on_host.cmake) + if (AVX_IS_AVAILABLE_ON_HOST) + set(USE_AVX_INSTRUCTIONS ON CACHE BOOL "Compile your program with AVX instructions") + endif() + include(${CMAKE_CURRENT_LIST_DIR}/check_if_neon_available.cmake) + if (ARM_NEON_IS_AVAILABLE) + set(USE_NEON_INSTRUCTIONS ON CACHE BOOL "Compile your program with ARM-NEON instructions") + endif() +endif() + + + + +set(gcc_like_compilers GNU Clang Intel) +set(intel_archs x86_64 i386 i686 AMD64 amd64 x86) + + +# Setup some options to allow a user to enable SSE and AVX instruction use. +if ((";${gcc_like_compilers};" MATCHES ";${CMAKE_CXX_COMPILER_ID};") AND + (";${intel_archs};" MATCHES ";${CMAKE_SYSTEM_PROCESSOR};") AND NOT USE_AUTO_VECTOR) + option(USE_SSE2_INSTRUCTIONS "Compile your program with SSE2 instructions" OFF) + option(USE_SSE4_INSTRUCTIONS "Compile your program with SSE4 instructions" OFF) + option(USE_AVX_INSTRUCTIONS "Compile your program with AVX instructions" OFF) + if(USE_AVX_INSTRUCTIONS) + list(APPEND active_compile_opts -mavx) + message(STATUS "Enabling AVX instructions") + elseif (USE_SSE4_INSTRUCTIONS) + list(APPEND active_compile_opts -msse4) + message(STATUS "Enabling SSE4 instructions") + elseif(USE_SSE2_INSTRUCTIONS) + list(APPEND active_compile_opts -msse2) + message(STATUS "Enabling SSE2 instructions") + endif() +elseif (MSVC OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") # else if using Visual Studio + # Use SSE2 by default when using Visual Studio. + option(USE_SSE2_INSTRUCTIONS "Compile your program with SSE2 instructions" ON) + option(USE_SSE4_INSTRUCTIONS "Compile your program with SSE4 instructions" OFF) + option(USE_AVX_INSTRUCTIONS "Compile your program with AVX instructions" OFF) + + include(CheckTypeSize) + check_type_size( "void*" SIZE_OF_VOID_PTR) + if(USE_AVX_INSTRUCTIONS) + list(APPEND active_compile_opts /arch:AVX) + message(STATUS "Enabling AVX instructions") + elseif (USE_SSE4_INSTRUCTIONS) + # Visual studio doesn't have an /arch:SSE2 flag when building in 64 bit modes. + # So only give it when we are doing a 32 bit build. + if (SIZE_OF_VOID_PTR EQUAL 4) + list(APPEND active_compile_opts /arch:SSE2) + endif() + message(STATUS "Enabling SSE4 instructions") + list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE2") + list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE3") + list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE41") + elseif(USE_SSE2_INSTRUCTIONS) + # Visual studio doesn't have an /arch:SSE2 flag when building in 64 bit modes. + # So only give it when we are doing a 32 bit build. + if (SIZE_OF_VOID_PTR EQUAL 4) + list(APPEND active_compile_opts /arch:SSE2) + endif() + message(STATUS "Enabling SSE2 instructions") + list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE2") + endif() + +elseif((";${gcc_like_compilers};" MATCHES ";${CMAKE_CXX_COMPILER_ID};") AND + ("${CMAKE_SYSTEM_PROCESSOR}" MATCHES "^arm")) + option(USE_NEON_INSTRUCTIONS "Compile your program with ARM-NEON instructions" OFF) + if(USE_NEON_INSTRUCTIONS) + list(APPEND active_compile_opts -mfpu=neon) + message(STATUS "Enabling ARM-NEON instructions") + endif() +endif() + + + + +if (CMAKE_COMPILER_IS_GNUCXX) + # By default, g++ won't warn or error if you forget to return a value in a + # function which requires you to do so. This option makes it give a warning + # for doing this. + list(APPEND active_compile_opts "-Wreturn-type") +endif() + +if ("Clang" MATCHES ${CMAKE_CXX_COMPILER_ID} AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0.0) + # Clang 6 had a default template recursion depth of 256. This was changed to 1024 in Clang 7. + # It must be increased on Clang 6 and below to ensure that the dnn examples don't error out. + list(APPEND active_compile_opts "-ftemplate-depth=500") +endif() + +if (MSVC) + # By default Visual Studio does not support .obj files with more than 65k sections. + # However, code generated by file_to_code_ex and code using DNN module can have + # them. So this flag enables > 65k sections, but produces .obj files + # that will not be readable by VS 2005. + list(APPEND active_compile_opts "/bigobj") + + # Build dlib with all cores. Don't propagate the setting to client programs + # though since they might compile large translation units that use too much + # RAM. + list(APPEND active_compile_opts_private "/MP") + + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 3.3) + # Clang can compile all Dlib's code at Windows platform. Tested with Clang 5 + list(APPEND active_compile_opts -Xclang) + list(APPEND active_compile_opts -fcxx-exceptions) + endif() +endif() + + diff --git a/dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake b/dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake new file mode 100644 index 0000000000000000000000000000000000000000..80122d8018fdd9f7d2de2b5656522e176df4e1d1 --- /dev/null +++ b/dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake @@ -0,0 +1,19 @@ + +# Including this cmake script into your cmake project will cause visual studio +# to build your project against the static C runtime. + +cmake_minimum_required(VERSION 3.8.0) + +if (MSVC OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") + option (DLIB_FORCE_MSVC_STATIC_RUNTIME "use static runtime" ON) + if (DLIB_FORCE_MSVC_STATIC_RUNTIME) + foreach(flag_var + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) + if(${flag_var} MATCHES "/MD") + string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") + endif() + endforeach(flag_var) + endif () +endif() + diff --git a/dlib/cmake_utils/test_for_avx/CMakeLists.txt b/dlib/cmake_utils/test_for_avx/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..10890153cf680fa580a92aef1b14ec8a9d9f295f --- /dev/null +++ b/dlib/cmake_utils/test_for_avx/CMakeLists.txt @@ -0,0 +1,23 @@ + +cmake_minimum_required(VERSION 3.8.0) +project(avx_test) + +set(USE_AVX_INSTRUCTIONS ON CACHE BOOL "Use AVX instructions") + +# Pull this in since it sets the AVX compile options by putting that kind of stuff into the active_compile_opts list. +include(../set_compiler_specific_options.cmake) + + +try_run(run_result compile_result ${PROJECT_BINARY_DIR}/avx_test_try_run_build ${CMAKE_CURRENT_LIST_DIR}/avx_test.cpp + COMPILE_DEFINITIONS ${active_compile_opts}) + +message(STATUS "run_result = ${run_result}") +message(STATUS "compile_result = ${compile_result}") + +if ("${run_result}" EQUAL 0 AND compile_result) + message(STATUS "Ran AVX test program successfully, you have AVX available.") +else() + message(STATUS "Unable to run AVX test program, you don't seem to have AVX instructions available.") + # make this build fail so that calling try_compile statements will error in this case. + add_library(make_this_build_fail ${CMAKE_CURRENT_LIST_DIR}/this_file_doesnt_compile.cpp) +endif() diff --git a/dlib/cmake_utils/test_for_avx/avx_test.cpp b/dlib/cmake_utils/test_for_avx/avx_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..98e9b0b99000a012a88cbd8f3672d6025a1ba5b9 --- /dev/null +++ b/dlib/cmake_utils/test_for_avx/avx_test.cpp @@ -0,0 +1,13 @@ + +#include + +int main() +{ + __m256 x; + x = _mm256_set1_ps(1.23); + x = _mm256_add_ps(x,x); + return 0; +} + +// ------------------------------------------------------------------------------------ + diff --git a/dlib/cmake_utils/test_for_avx/this_file_doesnt_compile.cpp b/dlib/cmake_utils/test_for_avx/this_file_doesnt_compile.cpp new file mode 100644 index 0000000000000000000000000000000000000000..83f89c0c753db34dbcf82babecfa61a5c88f0cd3 --- /dev/null +++ b/dlib/cmake_utils/test_for_avx/this_file_doesnt_compile.cpp @@ -0,0 +1,3 @@ + +#error "This file doesn't compile!" + diff --git a/dlib/cmake_utils/test_for_cuda/CMakeLists.txt b/dlib/cmake_utils/test_for_cuda/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b6cd7e7b02e2f2d98123c24287d2e192e023a82 --- /dev/null +++ b/dlib/cmake_utils/test_for_cuda/CMakeLists.txt @@ -0,0 +1,14 @@ + +cmake_minimum_required(VERSION 3.8.0) +project(cuda_test) + +include_directories(../../cuda) +add_definitions(-DDLIB_USE_CUDA) + +# Override the FindCUDA.cmake setting to avoid duplication of host flags if using a toolchain: +option(CUDA_PROPAGATE_HOST_FLAGS "Propage C/CXX_FLAGS and friends to the host compiler via -Xcompile" OFF) +find_package(CUDA 7.5 REQUIRED) +set(CUDA_HOST_COMPILATION_CPP ON) +list(APPEND CUDA_NVCC_FLAGS "-arch=sm_50;-std=c++14;-D__STRICT_ANSI__;-D_MWAITXINTRIN_H_INCLUDED;-D_FORCE_INLINES") + +cuda_add_library(cuda_test STATIC cuda_test.cu ) diff --git a/dlib/cmake_utils/test_for_cuda/cuda_test.cu b/dlib/cmake_utils/test_for_cuda/cuda_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..fb1ffe0dadcd2da7eb22b91148ed7b9b8839fa92 --- /dev/null +++ b/dlib/cmake_utils/test_for_cuda/cuda_test.cu @@ -0,0 +1,21 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "cuda_utils.h" +#include "cuda_dlib.h" + + +// ------------------------------------------------------------------------------------ + +__global__ void cuda_add_arrays(const float* a, const float* b, float* out, size_t n) +{ + out[0] += a[0]+b[0]; +} + +void add_arrays() +{ + cuda_add_arrays<<<512,512>>>(0,0,0,0); +} + +// ------------------------------------------------------------------------------------ + diff --git a/dlib/cmake_utils/test_for_cudnn/CMakeLists.txt b/dlib/cmake_utils/test_for_cudnn/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..57ac21f1eaff61c1db41aee8b2f951498db4a91d --- /dev/null +++ b/dlib/cmake_utils/test_for_cudnn/CMakeLists.txt @@ -0,0 +1,18 @@ + +cmake_minimum_required(VERSION 3.8.0) +project(cudnn_test) + +# Override the FindCUDA.cmake setting to avoid duplication of host flags if using a toolchain: +option(CUDA_PROPAGATE_HOST_FLAGS "Propage C/CXX_FLAGS and friends to the host compiler via -Xcompile" OFF) +find_package(CUDA 7.5 REQUIRED) +set(CUDA_HOST_COMPILATION_CPP ON) +list(APPEND CUDA_NVCC_FLAGS "-arch=sm_50;-std=c++14;-D__STRICT_ANSI__") +add_definitions(-DDLIB_USE_CUDA) + +include(find_cudnn.txt) + +if (cudnn_include AND cudnn) + include_directories(${cudnn_include}) + cuda_add_library(cudnn_test STATIC ../../cuda/cudnn_dlibapi.cpp ${cudnn} ) + target_compile_features(cudnn_test PUBLIC cxx_std_14) +endif() diff --git a/dlib/cmake_utils/test_for_cudnn/find_cudnn.txt b/dlib/cmake_utils/test_for_cudnn/find_cudnn.txt new file mode 100644 index 0000000000000000000000000000000000000000..dd5f14e3ff1e7f760485c2893d296174a1888ee2 --- /dev/null +++ b/dlib/cmake_utils/test_for_cudnn/find_cudnn.txt @@ -0,0 +1,24 @@ + +message(STATUS "Looking for cuDNN install...") +# Look for cudnn, we will look in the same place as other CUDA +# libraries and also a few other places as well. +find_path(cudnn_include cudnn.h + HINTS ${CUDA_INCLUDE_DIRS} ENV CUDNN_INCLUDE_DIR ENV CUDNN_HOME + PATHS /usr/local ENV CPATH + PATH_SUFFIXES include + ) +get_filename_component(cudnn_hint_path "${CUDA_CUBLAS_LIBRARIES}" PATH) +find_library(cudnn cudnn + HINTS ${cudnn_hint_path} ENV CUDNN_LIBRARY_DIR ENV CUDNN_HOME + PATHS /usr/local /usr/local/cuda ENV LD_LIBRARY_PATH + PATH_SUFFIXES lib64 lib x64 + ) +mark_as_advanced(cudnn cudnn_include) + +if (cudnn AND cudnn_include) + message(STATUS "Found cuDNN: " ${cudnn}) +else() + message(STATUS "*** cuDNN V5.0 OR GREATER NOT FOUND. ***") + message(STATUS "*** Dlib requires cuDNN V5.0 OR GREATER. Since cuDNN is not found DLIB WILL NOT USE CUDA. ***") + message(STATUS "*** If you have cuDNN then set CMAKE_PREFIX_PATH to include cuDNN's folder. ***") +endif() diff --git a/dlib/cmake_utils/test_for_libjpeg/CMakeLists.txt b/dlib/cmake_utils/test_for_libjpeg/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a3c6910dbef12c8ceb947f5c5310ed0522f6ced7 --- /dev/null +++ b/dlib/cmake_utils/test_for_libjpeg/CMakeLists.txt @@ -0,0 +1,11 @@ + +cmake_minimum_required(VERSION 3.8.0) +project(test_if_libjpeg_is_broken) + +find_package(JPEG) + +include_directories(${JPEG_INCLUDE_DIR}) +add_executable(libjpeg_test libjpeg_test.cpp) +target_link_libraries(libjpeg_test ${JPEG_LIBRARY}) + + diff --git a/dlib/cmake_utils/test_for_libjpeg/libjpeg_test.cpp b/dlib/cmake_utils/test_for_libjpeg/libjpeg_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3811d0949b5907034b19bfe4692e88aaa484b60a --- /dev/null +++ b/dlib/cmake_utils/test_for_libjpeg/libjpeg_test.cpp @@ -0,0 +1,68 @@ +// Copyright (C) 2019 Davis E. King (davis@dlib.net), Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include +#include +#include +#include + +struct jpeg_loader_error_mgr +{ + jpeg_error_mgr pub; + jmp_buf setjmp_buffer; +}; + +void jpeg_loader_error_exit (j_common_ptr cinfo) +{ + jpeg_loader_error_mgr* myerr = (jpeg_loader_error_mgr*) cinfo->err; + + longjmp(myerr->setjmp_buffer, 1); +} + +// This code doesn't really make a lot of sense. It's just calling all the libjpeg functions to make +// sure they can be compiled and linked. +int main() +{ + std::cerr << "This program is just for build system testing. Don't actually run it." << std::endl; + abort(); + + FILE *fp = fopen("whatever.jpg", "rb" ); + + jpeg_decompress_struct cinfo; + jpeg_loader_error_mgr jerr; + + cinfo.err = jpeg_std_error(&jerr.pub); + + jerr.pub.error_exit = jpeg_loader_error_exit; + + setjmp(jerr.setjmp_buffer); + + jpeg_create_decompress(&cinfo); + + jpeg_stdio_src(&cinfo, fp); + if (false) { + unsigned char imgbuffer[1234]; + jpeg_mem_src(&cinfo, imgbuffer, sizeof(imgbuffer)); + } + + jpeg_read_header(&cinfo, TRUE); + + jpeg_start_decompress(&cinfo); + + unsigned long height_ = cinfo.output_height; + unsigned long width_ = cinfo.output_width; + unsigned long output_components_ = cinfo.output_components; + + unsigned char* rows[123]; + + while (cinfo.output_scanline < cinfo.output_height) + { + jpeg_read_scanlines(&cinfo, &rows[cinfo.output_scanline], 100); + } + + jpeg_finish_decompress(&cinfo); + jpeg_destroy_decompress(&cinfo); + + fclose( fp ); +} diff --git a/dlib/cmake_utils/test_for_libpng/CMakeLists.txt b/dlib/cmake_utils/test_for_libpng/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0be92061323718241a6727f8a993f357b28e5335 --- /dev/null +++ b/dlib/cmake_utils/test_for_libpng/CMakeLists.txt @@ -0,0 +1,11 @@ + +cmake_minimum_required(VERSION 3.8.0) +project(test_if_libpng_is_broken) + +find_package(PNG) + +include_directories(${PNG_INCLUDE_DIR}) +add_executable(libpng_test libpng_test.cpp) +target_link_libraries(libpng_test ${PNG_LIBRARIES}) + + diff --git a/dlib/cmake_utils/test_for_libpng/libpng_test.cpp b/dlib/cmake_utils/test_for_libpng/libpng_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..05f439ee94d0abb6e09e71ed3a49bcef1252bd9f --- /dev/null +++ b/dlib/cmake_utils/test_for_libpng/libpng_test.cpp @@ -0,0 +1,54 @@ +// Copyright (C) 2019 Davis E. King (davis@dlib.net), Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#include +#include +#include +#include +#include + +void png_loader_user_error_fn_silent(png_structp png_struct, png_const_charp ) +{ + longjmp(png_jmpbuf(png_struct),1); +} +void png_loader_user_warning_fn_silent(png_structp , png_const_charp ) +{ +} + +// This code doesn't really make a lot of sense. It's just calling all the libpng functions to make +// sure they can be compiled and linked. +int main() +{ + std::cerr << "This program is just for build system testing. Don't actually run it." << std::endl; + abort(); + + png_bytep* row_pointers_; + png_structp png_ptr_; + png_infop info_ptr_; + png_infop end_info_; + + FILE *fp = fopen( "whatever.png", "rb" ); + png_byte sig[8]; + fread( sig, 1, 8, fp ); + png_sig_cmp( sig, 0, 8 ); + png_ptr_ = png_create_read_struct( PNG_LIBPNG_VER_STRING, NULL, &png_loader_user_error_fn_silent, &png_loader_user_warning_fn_silent ); + + png_get_header_ver(NULL); + info_ptr_ = png_create_info_struct( png_ptr_ ); + end_info_ = png_create_info_struct( png_ptr_ ); + setjmp(png_jmpbuf(png_ptr_)); + png_set_palette_to_rgb(png_ptr_); + png_init_io( png_ptr_, fp ); + png_set_sig_bytes( png_ptr_, 8 ); + // flags force one byte per channel output + int png_transforms = PNG_TRANSFORM_PACKING; + png_read_png( png_ptr_, info_ptr_, png_transforms, NULL ); + png_get_image_height( png_ptr_, info_ptr_ ); + png_get_image_width( png_ptr_, info_ptr_ ); + png_get_bit_depth( png_ptr_, info_ptr_ ); + png_get_color_type( png_ptr_, info_ptr_ ); + + png_get_rows( png_ptr_, info_ptr_ ); + + fclose(fp); + png_destroy_read_struct(&png_ptr_, &info_ptr_, &end_info_); +} diff --git a/dlib/cmake_utils/test_for_libwebp/CMakeLists.txt b/dlib/cmake_utils/test_for_libwebp/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9ea3ed10eac178f9a4b3c046f3a6c95ffde344e9 --- /dev/null +++ b/dlib/cmake_utils/test_for_libwebp/CMakeLists.txt @@ -0,0 +1,7 @@ + +cmake_minimum_required(VERSION 3.8.0) +project(test_if_libwebp_is_broken) + +include_directories(${WEBP_INCLUDE_DIR}) +add_executable(libwebp_test libwebp_test.cpp) +target_link_libraries(libwebp_test ${WEBP_LIBRARY}) diff --git a/dlib/cmake_utils/test_for_libwebp/libwebp_test.cpp b/dlib/cmake_utils/test_for_libwebp/libwebp_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e561b4d4fb46035b6c1ae79343f90e21b7c6536e --- /dev/null +++ b/dlib/cmake_utils/test_for_libwebp/libwebp_test.cpp @@ -0,0 +1,22 @@ +// Copyright (C) 2019 Davis E. King (davis@dlib.net), Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include +#include + +// This code doesn't really make a lot of sense. It's just calling all the libjpeg functions to make +// sure they can be compiled and linked. +int main() +{ + std::cerr << "This program is just for build system testing. Don't actually run it." << std::endl; + std::abort(); + + uint8_t* data; + size_t output_size = 0; + int width, height, stride; + float quality; + output_size = WebPEncodeRGB(data, width, height, stride, quality, &data); + WebPDecodeRGBInto(data, output_size, data, output_size, stride); + WebPFree(data); +} diff --git a/dlib/cmake_utils/test_for_neon/CMakeLists.txt b/dlib/cmake_utils/test_for_neon/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b9c6e51e8ae5dcecb6f7b21028aa1e0b6872c42d --- /dev/null +++ b/dlib/cmake_utils/test_for_neon/CMakeLists.txt @@ -0,0 +1,6 @@ + +cmake_minimum_required(VERSION 3.8.0) +project(neon_test) + +add_library(neon_test STATIC neon_test.cpp ) + diff --git a/dlib/cmake_utils/test_for_neon/neon_test.cpp b/dlib/cmake_utils/test_for_neon/neon_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a4abdbadeee68bef850ec338b34b9a9c00289ed3 --- /dev/null +++ b/dlib/cmake_utils/test_for_neon/neon_test.cpp @@ -0,0 +1,9 @@ +#ifdef __ARM_NEON__ +#else +#error "No NEON" +#endif +int main(){} + + +// ------------------------------------------------------------------------------------ + diff --git a/dlib/cmake_utils/test_for_sse4/CMakeLists.txt b/dlib/cmake_utils/test_for_sse4/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4cd4e95cef131eff56e6ea2f072888a34e76cb9e --- /dev/null +++ b/dlib/cmake_utils/test_for_sse4/CMakeLists.txt @@ -0,0 +1,23 @@ + +cmake_minimum_required(VERSION 3.8.0) +project(sse4_test) + +set(USE_SSE4_INSTRUCTIONS ON CACHE BOOL "Use SSE4 instructions") + +# Pull this in since it sets the SSE4 compile options by putting that kind of stuff into the active_compile_opts list. +include(../set_compiler_specific_options.cmake) + + +try_run(run_result compile_result ${PROJECT_BINARY_DIR}/sse4_test_try_run_build ${CMAKE_CURRENT_LIST_DIR}/sse4_test.cpp + COMPILE_DEFINITIONS ${active_compile_opts}) + +message(STATUS "run_result = ${run_result}") +message(STATUS "compile_result = ${compile_result}") + +if ("${run_result}" EQUAL 0 AND compile_result) + message(STATUS "Ran SSE4 test program successfully, you have SSE4 available.") +else() + message(STATUS "Unable to run SSE4 test program, you don't seem to have SSE4 instructions available.") + # make this build fail so that calling try_compile statements will error in this case. + add_library(make_this_build_fail ${CMAKE_CURRENT_LIST_DIR}/this_file_doesnt_compile.cpp) +endif() diff --git a/dlib/cmake_utils/test_for_sse4/sse4_test.cpp b/dlib/cmake_utils/test_for_sse4/sse4_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8065dab952e36ad4674b3d21af66642882ecbc70 --- /dev/null +++ b/dlib/cmake_utils/test_for_sse4/sse4_test.cpp @@ -0,0 +1,18 @@ + +#include +#include +#include +#include // SSE3 +#include +#include // SSE4 + +int main() +{ + __m128 x; + x = _mm_set1_ps(1.23); + x = _mm_ceil_ps(x); + return 0; +} + +// ------------------------------------------------------------------------------------ + diff --git a/dlib/cmake_utils/test_for_sse4/this_file_doesnt_compile.cpp b/dlib/cmake_utils/test_for_sse4/this_file_doesnt_compile.cpp new file mode 100644 index 0000000000000000000000000000000000000000..83f89c0c753db34dbcf82babecfa61a5c88f0cd3 --- /dev/null +++ b/dlib/cmake_utils/test_for_sse4/this_file_doesnt_compile.cpp @@ -0,0 +1,3 @@ + +#error "This file doesn't compile!" + diff --git a/dlib/cmd_line_parser.h b/dlib/cmd_line_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..fd1148038dd4276579ad50036acaea27fc73bde1 --- /dev/null +++ b/dlib/cmd_line_parser.h @@ -0,0 +1,84 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSEr_ +#define DLIB_CMD_LINE_PARSEr_ + +#include "cmd_line_parser/cmd_line_parser_kernel_1.h" +#include "cmd_line_parser/cmd_line_parser_kernel_c.h" +#include "cmd_line_parser/cmd_line_parser_print_1.h" +#include "cmd_line_parser/cmd_line_parser_check_1.h" +#include "cmd_line_parser/cmd_line_parser_check_c.h" +#include +#include "cmd_line_parser/get_option.h" + +#include "map.h" +#include "sequence.h" + + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT + > + class impl_cmd_line_parser + { + /*! + This class is basically just a big templated typedef for building + a complete command line parser type out of all the parts it needs. + !*/ + + impl_cmd_line_parser() {} + + typedef typename sequence >::kernel_2a sequence_2a; + typedef typename sequence*>::kernel_2a psequence_2a; + typedef typename map,void*>::kernel_1a map_1a_string; + + public: + + typedef cmd_line_parser_kernel_1 kernel_1a; + typedef cmd_line_parser_kernel_c kernel_1a_c; + typedef cmd_line_parser_print_1 print_1a_c; + typedef cmd_line_parser_check_c > check_1a_c; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename charT + > + class cmd_line_parser : public impl_cmd_line_parser::check_1a_c + { + public: + + // These typedefs are here for backwards compatibility with previous versions of dlib. + typedef cmd_line_parser kernel_1a; + typedef cmd_line_parser kernel_1a_c; + typedef cmd_line_parser print_1a; + typedef cmd_line_parser print_1a_c; + typedef cmd_line_parser check_1a; + typedef cmd_line_parser check_1a_c; + }; + + template < + typename charT + > + inline void swap ( + cmd_line_parser& a, + cmd_line_parser& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + typedef cmd_line_parser command_line_parser; + typedef cmd_line_parser wcommand_line_parser; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSEr_ + diff --git a/dlib/cmd_line_parser/cmd_line_parser_check_1.h b/dlib/cmd_line_parser/cmd_line_parser_check_1.h new file mode 100644 index 0000000000000000000000000000000000000000..1736b4b56a03e0638428dbd26b5c24192dac4a15 --- /dev/null +++ b/dlib/cmd_line_parser/cmd_line_parser_check_1.h @@ -0,0 +1,580 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_CHECk_1_ +#define DLIB_CMD_LINE_PARSER_CHECk_1_ + +#include "cmd_line_parser_kernel_abstract.h" +#include +#include +#include "../string.h" +#include + +namespace dlib +{ + + template < + typename clp_base + > + class cmd_line_parser_check_1 : public clp_base + { + + /*! + This extension doesn't add any state. + + !*/ + + + public: + typedef typename clp_base::char_type char_type; + typedef typename clp_base::string_type string_type; + + // ------------------------------------------------------------------------------------ + + class cmd_line_check_error : public dlib::error + { + friend class cmd_line_parser_check_1; + + cmd_line_check_error( + error_type t, + const string_type& opt_, + const string_type& arg_ + ) : + dlib::error(t), + opt(opt_), + opt2(), + arg(arg_), + required_opts() + { set_info_string(); } + + cmd_line_check_error( + error_type t, + const string_type& opt_, + const string_type& opt2_, + int // this is just to make this constructor different from the one above + ) : + dlib::error(t), + opt(opt_), + opt2(opt2_), + arg(), + required_opts() + { set_info_string(); } + + cmd_line_check_error ( + error_type t, + const string_type& opt_, + const std::vector& vect + ) : + dlib::error(t), + opt(opt_), + opt2(), + arg(), + required_opts(vect) + { set_info_string(); } + + cmd_line_check_error( + error_type t, + const string_type& opt_ + ) : + dlib::error(t), + opt(opt_), + opt2(), + arg(), + required_opts() + { set_info_string(); } + + ~cmd_line_check_error() throw() {} + + void set_info_string ( + ) + { + std::ostringstream sout; + switch (type) + { + case EINVALID_OPTION_ARG: + sout << "Command line error: '" << narrow(arg) << "' is not a valid argument to " + << "the '" << narrow(opt) << "' option."; + break; + case EMISSING_REQUIRED_OPTION: + if (required_opts.size() == 1) + { + sout << "Command line error: The '" << narrow(opt) << "' option requires the presence of " + << "the '" << required_opts[0] << "' option."; + } + else + { + sout << "Command line error: The '" << narrow(opt) << "' option requires the presence of " + << "one of the following options: "; + for (unsigned long i = 0; i < required_opts.size(); ++i) + { + if (i == required_opts.size()-2) + sout << "'" << required_opts[i] << "' or "; + else if (i == required_opts.size()-1) + sout << "'" << required_opts[i] << "'."; + else + sout << "'" << required_opts[i] << "', "; + } + } + break; + case EINCOMPATIBLE_OPTIONS: + sout << "Command line error: The '" << narrow(opt) << "' and '" << narrow(opt2) + << "' options cannot be given together on the command line."; + break; + case EMULTIPLE_OCCURANCES: + sout << "Command line error: The '" << narrow(opt) << "' option can only " + << "be given on the command line once."; + break; + default: + sout << "Command line error."; + break; + } + const_cast(info) = wrap_string(sout.str(),0,0); + } + + public: + const string_type opt; + const string_type opt2; + const string_type arg; + const std::vector required_opts; + }; + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + void check_option_arg_type ( + const string_type& option_name + ) const; + + template < + typename T + > + void check_option_arg_range ( + const string_type& option_name, + const T& first, + const T& last + ) const; + + template < + typename T, + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const T (&arg_set)[length] + ) const; + + template < + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const char_type* (&arg_set)[length] + ) const; + + template < + size_t length + > + void check_incompatible_options ( + const char_type* (&option_set)[length] + ) const; + + template < + size_t length + > + void check_one_time_options ( + const char_type* (&option_set)[length] + ) const; + + void check_incompatible_options ( + const string_type& option_name1, + const string_type& option_name2 + ) const; + + void check_sub_option ( + const string_type& parent_option, + const string_type& sub_option + ) const; + + template < + size_t length + > + void check_sub_options ( + const string_type& parent_option, + const char_type* (&sub_option_set)[length] + ) const; + + template < + size_t length + > + void check_sub_options ( + const char_type* (&parent_option_set)[length], + const string_type& sub_option + ) const; + + template < + size_t parent_length, + size_t sub_length + > + void check_sub_options ( + const char_type* (&parent_option_set)[parent_length], + const char_type* (&sub_option_set)[sub_length] + ) const; + }; + + template < + typename clp_base + > + inline void swap ( + cmd_line_parser_check_1& a, + cmd_line_parser_check_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + template + void cmd_line_parser_check_1:: + check_option_arg_type ( + const string_type& option_name + ) const + { + try + { + const typename clp_base::option_type& opt = this->option(option_name); + const unsigned long number_of_arguments = opt.number_of_arguments(); + const unsigned long count = opt.count(); + for (unsigned long i = 0; i < number_of_arguments; ++i) + { + for (unsigned long j = 0; j < count; ++j) + { + string_cast(opt.argument(i,j)); + } + } + } + catch (string_cast_error& e) + { + throw cmd_line_check_error(EINVALID_OPTION_ARG,option_name,e.info); + } + } + +// ---------------------------------------------------------------------------------------- + + template + template + void cmd_line_parser_check_1:: + check_option_arg_range ( + const string_type& option_name, + const T& first, + const T& last + ) const + { + try + { + const typename clp_base::option_type& opt = this->option(option_name); + const unsigned long number_of_arguments = opt.number_of_arguments(); + const unsigned long count = opt.count(); + for (unsigned long i = 0; i < number_of_arguments; ++i) + { + for (unsigned long j = 0; j < count; ++j) + { + T temp(string_cast(opt.argument(i,j))); + if (temp < first || last < temp) + { + throw cmd_line_check_error( + EINVALID_OPTION_ARG, + option_name, + opt.argument(i,j) + ); + } + } + } + } + catch (string_cast_error& e) + { + throw cmd_line_check_error(EINVALID_OPTION_ARG,option_name,e.info); + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < typename T, size_t length > + void cmd_line_parser_check_1:: + check_option_arg_range ( + const string_type& option_name, + const T (&arg_set)[length] + ) const + { + try + { + const typename clp_base::option_type& opt = this->option(option_name); + const unsigned long number_of_arguments = opt.number_of_arguments(); + const unsigned long count = opt.count(); + for (unsigned long i = 0; i < number_of_arguments; ++i) + { + for (unsigned long j = 0; j < count; ++j) + { + T temp(string_cast(opt.argument(i,j))); + size_t k = 0; + for (; k < length; ++k) + { + if (arg_set[k] == temp) + break; + } + if (k == length) + { + throw cmd_line_check_error( + EINVALID_OPTION_ARG, + option_name, + opt.argument(i,j) + ); + } + } + } + } + catch (string_cast_error& e) + { + throw cmd_line_check_error(EINVALID_OPTION_ARG,option_name,e.info); + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_1:: + check_option_arg_range ( + const string_type& option_name, + const char_type* (&arg_set)[length] + ) const + { + const typename clp_base::option_type& opt = this->option(option_name); + const unsigned long number_of_arguments = opt.number_of_arguments(); + const unsigned long count = opt.count(); + for (unsigned long i = 0; i < number_of_arguments; ++i) + { + for (unsigned long j = 0; j < count; ++j) + { + size_t k = 0; + for (; k < length; ++k) + { + if (arg_set[k] == opt.argument(i,j)) + break; + } + if (k == length) + { + throw cmd_line_check_error( + EINVALID_OPTION_ARG, + option_name, + opt.argument(i,j) + ); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_1:: + check_incompatible_options ( + const char_type* (&option_set)[length] + ) const + { + for (size_t i = 0; i < length; ++i) + { + for (size_t j = i+1; j < length; ++j) + { + if (this->option(option_set[i]).count() > 0 && + this->option(option_set[j]).count() > 0 ) + { + throw cmd_line_check_error( + EINCOMPATIBLE_OPTIONS, + option_set[i], + option_set[j], + 0 // this argument has no meaning and is only here to make this + // call different from the other constructor + ); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void cmd_line_parser_check_1:: + check_incompatible_options ( + const string_type& option_name1, + const string_type& option_name2 + ) const + { + if (this->option(option_name1).count() > 0 && + this->option(option_name2).count() > 0 ) + { + throw cmd_line_check_error( + EINCOMPATIBLE_OPTIONS, + option_name1, + option_name2, + 0 // this argument has no meaning and is only here to make this + // call different from the other constructor + ); + } + } + +// ---------------------------------------------------------------------------------------- + + template + void cmd_line_parser_check_1:: + check_sub_option ( + const string_type& parent_option, + const string_type& sub_option + ) const + { + if (this->option(parent_option).count() == 0) + { + if (this->option(sub_option).count() != 0) + { + std::vector vect; + vect.resize(1); + vect[0] = parent_option; + throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option, vect); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_1:: + check_sub_options ( + const string_type& parent_option, + const char_type* (&sub_option_set)[length] + ) const + { + if (this->option(parent_option).count() == 0) + { + size_t i = 0; + for (; i < length; ++i) + { + if (this->option(sub_option_set[i]).count() > 0) + break; + } + if (i != length) + { + std::vector vect; + vect.resize(1); + vect[0] = parent_option; + throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option_set[i], vect); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_1:: + check_sub_options ( + const char_type* (&parent_option_set)[length], + const string_type& sub_option + ) const + { + // first check if the sub_option is present + if (this->option(sub_option).count() > 0) + { + // now check if any of the parents are present + bool parents_present = false; + for (size_t i = 0; i < length; ++i) + { + if (this->option(parent_option_set[i]).count() > 0) + { + parents_present = true; + break; + } + } + + if (!parents_present) + { + std::vector vect(parent_option_set, parent_option_set+length); + throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option, vect); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t parent_length, size_t sub_length > + void cmd_line_parser_check_1:: + check_sub_options ( + const char_type* (&parent_option_set)[parent_length], + const char_type* (&sub_option_set)[sub_length] + ) const + { + // first check if any of the parent options are present + bool parents_present = false; + for (size_t i = 0; i < parent_length; ++i) + { + if (this->option(parent_option_set[i]).count() > 0) + { + parents_present = true; + break; + } + } + + if (!parents_present) + { + // none of these sub options should be present + size_t i = 0; + for (; i < sub_length; ++i) + { + if (this->option(sub_option_set[i]).count() > 0) + break; + } + if (i != sub_length) + { + std::vector vect(parent_option_set, parent_option_set+parent_length); + throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option_set[i], vect); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_1:: + check_one_time_options ( + const char_type* (&option_set)[length] + ) const + { + size_t i = 0; + for (; i < length; ++i) + { + if (this->option(option_set[i]).count() > 1) + break; + } + if (i != length) + { + throw cmd_line_check_error( + EMULTIPLE_OCCURANCES, + option_set[i] + ); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_CHECk_1_ + + diff --git a/dlib/cmd_line_parser/cmd_line_parser_check_c.h b/dlib/cmd_line_parser/cmd_line_parser_check_c.h new file mode 100644 index 0000000000000000000000000000000000000000..7ff858e8985efc86526484df824b13f4dcfa2b94 --- /dev/null +++ b/dlib/cmd_line_parser/cmd_line_parser_check_c.h @@ -0,0 +1,453 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_CHECk_C_ +#define DLIB_CMD_LINE_PARSER_CHECk_C_ + +#include "cmd_line_parser_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include +#include "../interfaces/cmd_line_parser_option.h" +#include "../string.h" + +namespace dlib +{ + + template < + typename clp_check + > + class cmd_line_parser_check_c : public clp_check + { + public: + + typedef typename clp_check::char_type char_type; + typedef typename clp_check::string_type string_type; + + template < + typename T + > + void check_option_arg_type ( + const string_type& option_name + ) const; + + template < + typename T + > + void check_option_arg_range ( + const string_type& option_name, + const T& first, + const T& last + ) const; + + template < + typename T, + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const T (&arg_set)[length] + ) const; + + template < + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const char_type* (&arg_set)[length] + ) const; + + template < + size_t length + > + void check_incompatible_options ( + const char_type* (&option_set)[length] + ) const; + + template < + size_t length + > + void check_one_time_options ( + const char_type* (&option_set)[length] + ) const; + + void check_incompatible_options ( + const string_type& option_name1, + const string_type& option_name2 + ) const; + + void check_sub_option ( + const string_type& parent_option, + const string_type& sub_option + ) const; + + template < + size_t length + > + void check_sub_options ( + const string_type& parent_option, + const char_type* (&sub_option_set)[length] + ) const; + + template < + size_t length + > + void check_sub_options ( + const char_type* (&parent_option_set)[length], + const string_type& sub_option + ) const; + + template < + size_t parent_length, + size_t sub_length + > + void check_sub_options ( + const char_type* (&parent_option_set)[parent_length], + const char_type* (&sub_option_set)[sub_length] + ) const; + }; + + + template < + typename clp_check + > + inline void swap ( + cmd_line_parser_check_c& a, + cmd_line_parser_check_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + template + void cmd_line_parser_check_c:: + check_option_arg_type ( + const string_type& option_name + ) const + { + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name), + "\tvoid cmd_line_parser_check::check_option_arg_type()" + << "\n\tYou must have already parsed the command line and option_name must be valid." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_name: " << option_name + ); + + clp_check::template check_option_arg_type(option_name); + } + +// ---------------------------------------------------------------------------------------- + + template + template + void cmd_line_parser_check_c:: + check_option_arg_range ( + const string_type& option_name, + const T& first, + const T& last + ) const + { + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name) && + first <= last, + "\tvoid cmd_line_parser_check::check_option_arg_range()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_name: " << option_name + << "\n\tfirst: " << first + << "\n\tlast: " << last + ); + + clp_check::check_option_arg_range(option_name,first,last); + } + +// ---------------------------------------------------------------------------------------- + + template + template < typename T, size_t length > + void cmd_line_parser_check_c:: + check_option_arg_range ( + const string_type& option_name, + const T (&arg_set)[length] + ) const + { + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name), + "\tvoid cmd_line_parser_check::check_option_arg_range()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_name: " << option_name + ); + + clp_check::check_option_arg_range(option_name,arg_set); + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_c:: + check_option_arg_range ( + const string_type& option_name, + const char_type* (&arg_set)[length] + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name), + "\tvoid cmd_line_parser_check::check_option_arg_range()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_name: " << option_name + ); + + clp_check::check_option_arg_range(option_name,arg_set); + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_c:: + check_incompatible_options ( + const char_type* (&option_set)[length] + ) const + { + // make sure requires clause is not broken + for (size_t i = 0; i < length; ++i) + { + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_set[i]), + "\tvoid cmd_line_parser_check::check_incompatible_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_set[i]): " << ((this->option_is_defined(option_set[i]))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_set[i]: " << option_set[i] + << "\n\ti: " << static_cast(i) + ); + + } + clp_check::check_incompatible_options(option_set); + } + +// ---------------------------------------------------------------------------------------- + + template + void cmd_line_parser_check_c:: + check_incompatible_options ( + const string_type& option_name1, + const string_type& option_name2 + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name1) && + this->option_is_defined(option_name2), + "\tvoid cmd_line_parser_check::check_incompatible_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_name1): " << ((this->option_is_defined(option_name1))?"true":"false") + << "\n\toption_is_defined(option_name2): " << ((this->option_is_defined(option_name2))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_name1: " << option_name1 + << "\n\toption_name2: " << option_name2 + ); + + clp_check::check_incompatible_options(option_name1,option_name2); + } + +// ---------------------------------------------------------------------------------------- + + template + void cmd_line_parser_check_c:: + check_sub_option ( + const string_type& parent_option, + const string_type& sub_option + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(parent_option) && + this->option_is_defined(sub_option), + "\tvoid cmd_line_parser_check::check_sub_option()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\tparsed_line(): " << this->parsed_line() + << "\n\toption_is_defined(parent_option): " << this->option_is_defined(parent_option) + << "\n\toption_is_defined(sub_option): " << this->option_is_defined(sub_option) + << "\n\tparent_option: " << parent_option + << "\n\tsub_option: " << sub_option + ); + clp_check::check_sub_option(parent_option,sub_option); + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_c:: + check_sub_options ( + const string_type& parent_option, + const char_type* (&sub_option_set)[length] + ) const + { + // make sure requires clause is not broken + for (size_t i = 0; i < length; ++i) + { + DLIB_CASSERT( this->option_is_defined(sub_option_set[i]), + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(sub_option_set[i]): " + << ((this->option_is_defined(sub_option_set[i]))?"true":"false") + << "\n\tsub_option_set[i]: " << sub_option_set[i] + << "\n\ti: " << static_cast(i) + ); + + } + + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(parent_option), + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(parent_option): " << ((this->option_is_defined(parent_option))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\tparent_option: " << parent_option + ); + clp_check::check_sub_options(parent_option,sub_option_set); + + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_c:: + check_sub_options ( + const char_type* (&parent_option_set)[length], + const string_type& sub_option + ) const + { + // make sure requires clause is not broken + for (size_t i = 0; i < length; ++i) + { + DLIB_CASSERT( this->option_is_defined(parent_option_set[i]), + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(parent_option_set[i]): " + << ((this->option_is_defined(parent_option_set[i]))?"true":"false") + << "\n\tparent_option_set[i]: " << parent_option_set[i] + << "\n\ti: " << static_cast(i) + ); + + } + + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(sub_option), + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(sub_option): " << ((this->option_is_defined(sub_option))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\tsub_option: " << sub_option + ); + clp_check::check_sub_options(parent_option_set,sub_option); + + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t parent_length, size_t sub_length > + void cmd_line_parser_check_c:: + check_sub_options ( + const char_type* (&parent_option_set)[parent_length], + const char_type* (&sub_option_set)[sub_length] + ) const + { + // make sure requires clause is not broken + for (size_t i = 0; i < sub_length; ++i) + { + DLIB_CASSERT( this->option_is_defined(sub_option_set[i]), + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(sub_option_set[i]): " + << ((this->option_is_defined(sub_option_set[i]))?"true":"false") + << "\n\tsub_option_set[i]: " << sub_option_set[i] + << "\n\ti: " << static_cast(i) + ); + } + + for (size_t i = 0; i < parent_length; ++i) + { + DLIB_CASSERT( this->option_is_defined(parent_option_set[i]), + "\tvoid cmd_line_parser_check::check_parent_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(parent_option_set[i]): " + << ((this->option_is_defined(parent_option_set[i]))?"true":"false") + << "\n\tparent_option_set[i]: " << parent_option_set[i] + << "\n\ti: " << static_cast(i) + ); + } + + + + DLIB_CASSERT( this->parsed_line() == true , + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tYou must have parsed the command line before you call this function." + << "\n\tthis: " << this + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + ); + + clp_check::check_sub_options(parent_option_set,sub_option_set); + + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_c:: + check_one_time_options ( + const char_type* (&option_set)[length] + ) const + { + // make sure requires clause is not broken + for (size_t i = 0; i < length; ++i) + { + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_set[i]), + "\tvoid cmd_line_parser_check::check_one_time_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_set[i]): " << ((this->option_is_defined(option_set[i]))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_set[i]: " << option_set[i] + << "\n\ti: " << static_cast(i) + ); + + } + clp_check::check_one_time_options(option_set); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_CHECk_C_ + diff --git a/dlib/cmd_line_parser/cmd_line_parser_kernel_1.h b/dlib/cmd_line_parser/cmd_line_parser_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..68ea5a135078d8063b372100467d7e96691f617c --- /dev/null +++ b/dlib/cmd_line_parser/cmd_line_parser_kernel_1.h @@ -0,0 +1,799 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_KERNEl_1_ +#define DLIB_CMD_LINE_PARSER_KERNEl_1_ + +#include "cmd_line_parser_kernel_abstract.h" +#include "../algs.h" +#include +#include +#include "../interfaces/enumerable.h" +#include "../interfaces/cmd_line_parser_option.h" +#include "../assert.h" +#include "../string.h" + +namespace dlib +{ + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + class cmd_line_parser_kernel_1 : public enumerable > + { + /*! + REQUIREMENTS ON map + is an implementation of map/map_kernel_abstract.h + is instantiated to map items of type std::basic_string to void* + + REQUIREMENTS ON sequence + is an implementation of sequence/sequence_kernel_abstract.h and + is instantiated with std::basic_string + + REQUIREMENTS ON sequence2 + is an implementation of sequence/sequence_kernel_abstract.h and + is instantiated with std::basic_string* + + INITIAL VALUE + options.size() == 0 + argv.size() == 0 + have_parsed_line == false + + CONVENTION + have_parsed_line == parsed_line() + argv[index] == operator[](index) + argv.size() == number_of_arguments() + *((option_t*)options[name]) == option(name) + options.is_in_domain(name) == option_is_defined(name) + !*/ + + + + + public: + + typedef charT char_type; + typedef std::basic_string string_type; + typedef cmd_line_parser_option option_type; + + // exception class + class cmd_line_parse_error : public dlib::error + { + void set_info_string ( + ) + { + std::ostringstream sout; + switch (type) + { + case EINVALID_OPTION: + sout << "Command line error: '" << narrow(item) << "' is not a valid option."; + break; + case ETOO_FEW_ARGS: + if (num > 1) + { + sout << "Command line error: The '" << narrow(item) << "' option requires " << num + << " arguments."; + } + else + { + sout << "Command line error: The '" << narrow(item) << "' option requires " << num + << " argument."; + } + break; + case ETOO_MANY_ARGS: + sout << "Command line error: The '" << narrow(item) << "' option does not take any arguments.\n"; + break; + default: + sout << "Command line error."; + break; + } + const_cast(info) = wrap_string(sout.str(),0,0); + } + + public: + cmd_line_parse_error( + error_type t, + const std::basic_string& _item + ) : + dlib::error(t), + item(_item), + num(0) + { set_info_string();} + + cmd_line_parse_error( + error_type t, + const std::basic_string& _item, + unsigned long _num + ) : + dlib::error(t), + item(_item), + num(_num) + { set_info_string();} + + cmd_line_parse_error( + ) : + dlib::error(), + item(), + num(0) + { set_info_string();} + + ~cmd_line_parse_error() throw() {} + + const std::basic_string item; + const unsigned long num; + }; + + + private: + + class option_t : public cmd_line_parser_option + { + /*! + INITIAL VALUE + options.size() == 0 + + CONVENTION + name_ == name() + description_ == description() + number_of_arguments_ == number_of_arguments() + options[N][arg] == argument(arg,N) + num_present == count() + !*/ + + friend class cmd_line_parser_kernel_1; + + public: + + const std::basic_string& name ( + ) const { return name_; } + + const std::basic_string& group_name ( + ) const { return group_name_; } + + const std::basic_string& description ( + ) const { return description_; } + + unsigned long number_of_arguments( + ) const { return number_of_arguments_; } + + unsigned long count ( + ) const { return num_present; } + + const std::basic_string& argument ( + unsigned long arg, + unsigned long N + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( N < count() && arg < number_of_arguments(), + "\tconst string_type& cmd_line_parser_option::argument(unsigned long,unsigned long)" + << "\n\tInvalid arguments were given to this function." + << "\n\tthis: " << this + << "\n\tN: " << N + << "\n\targ: " << arg + << "\n\tname(): " << narrow(name()) + << "\n\tcount(): " << count() + << "\n\tnumber_of_arguments(): " << number_of_arguments() + ); + + return options[N][arg]; + } + + protected: + + option_t ( + ) : + num_present(0) + {} + + ~option_t() + { + clear(); + } + + private: + + void clear() + /*! + ensures + - #count() == 0 + - clears everything out of options and frees memory + !*/ + { + for (unsigned long i = 0; i < options.size(); ++i) + { + delete [] options[i]; + } + options.clear(); + num_present = 0; + } + + // data members + std::basic_string name_; + std::basic_string group_name_; + std::basic_string description_; + sequence2 options; + unsigned long number_of_arguments_; + unsigned long num_present; + + + + // restricted functions + option_t(option_t&); // copy constructor + option_t& operator=(option_t&); // assignment operator + }; + + // -------------------------- + + public: + + cmd_line_parser_kernel_1 ( + ); + + virtual ~cmd_line_parser_kernel_1 ( + ); + + void clear( + ); + + void parse ( + int argc, + const charT** argv + ); + + void parse ( + int argc, + charT** argv + ) + { + parse(argc, const_cast(argv)); + } + + bool parsed_line( + ) const; + + bool option_is_defined ( + const string_type& name + ) const; + + void add_option ( + const string_type& name, + const string_type& description, + unsigned long number_of_arguments = 0 + ); + + void set_group_name ( + const string_type& group_name + ); + + string_type get_group_name ( + ) const { return group_name; } + + const cmd_line_parser_option& option ( + const string_type& name + ) const; + + unsigned long number_of_arguments( + ) const; + + const string_type& operator[] ( + unsigned long index + ) const; + + void swap ( + cmd_line_parser_kernel_1& item + ); + + // functions from the enumerable interface + bool at_start ( + ) const { return options.at_start(); } + + void reset ( + ) const { options.reset(); } + + bool current_element_valid ( + ) const { return options.current_element_valid(); } + + const cmd_line_parser_option& element ( + ) const { return *static_cast*>(options.element().value()); } + + cmd_line_parser_option& element ( + ) { return *static_cast*>(options.element().value()); } + + bool move_next ( + ) const { return options.move_next(); } + + size_t size ( + ) const { return options.size(); } + + private: + + // data members + map options; + sequence argv; + bool have_parsed_line; + string_type group_name; + + // restricted functions + cmd_line_parser_kernel_1(cmd_line_parser_kernel_1&); // copy constructor + cmd_line_parser_kernel_1& operator=(cmd_line_parser_kernel_1&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + inline void swap ( + cmd_line_parser_kernel_1& a, + cmd_line_parser_kernel_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + cmd_line_parser_kernel_1:: + cmd_line_parser_kernel_1 ( + ) : + have_parsed_line(false) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + cmd_line_parser_kernel_1:: + ~cmd_line_parser_kernel_1 ( + ) + { + // delete all option_t objects in options + options.reset(); + while (options.move_next()) + { + delete static_cast(options.element().value()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + void cmd_line_parser_kernel_1:: + clear( + ) + { + have_parsed_line = false; + argv.clear(); + + + // delete all option_t objects in options + options.reset(); + while (options.move_next()) + { + delete static_cast(options.element().value()); + } + options.clear(); + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + void cmd_line_parser_kernel_1:: + parse ( + int argc_, + const charT** argv + ) + { + using namespace std; + + // make sure there aren't any arguments hanging around from the last time + // parse was called + this->argv.clear(); + + // make sure that the options have been cleared of any arguments since + // the last time parse() was called + if (have_parsed_line) + { + options.reset(); + while (options.move_next()) + { + static_cast(options.element().value())->clear(); + } + options.reset(); + } + + // this tells us if we have seen -- on the command line all by itself + // or not. + bool escape = false; + + const unsigned long argc = static_cast(argc_); + try + { + + for (unsigned long i = 1; i < argc; ++i) + { + if (argv[i][0] == _dT(charT,'-') && !escape) + { + // we are looking at the start of an option + + // -------------------------------------------------------------------- + if (argv[i][1] == _dT(charT,'-')) + { + // we are looking at the start of a "long named" option + string_type temp = &argv[i][2]; + string_type first_argument; + typename string_type::size_type pos = temp.find_first_of(_dT(charT,'=')); + // This variable will be 1 if there is an argument supplied via the = sign + // and 0 otherwise. + unsigned long extra_argument = 0; + if (pos != string_type::npos) + { + // there should be an extra argument + extra_argument = 1; + first_argument = temp.substr(pos+1); + temp = temp.substr(0,pos); + } + + // make sure this name is defined + if (!options.is_in_domain(temp)) + { + // the long name is not a valid option + if (argv[i][2] == _dT(charT,'\0')) + { + // there was nothing after the -- on the command line + escape = true; + continue; + } + else + { + // there was something after the command line but it + // wasn't a valid option + throw cmd_line_parse_error(EINVALID_OPTION,temp); + } + } + + + option_t* o = static_cast(options[temp]); + + // check the number of arguments after this option and make sure + // it is correct + if (argc + extra_argument <= o->number_of_arguments() + i) + { + // there are too few arguments + throw cmd_line_parse_error(ETOO_FEW_ARGS,temp,o->number_of_arguments()); + } + if (extra_argument && first_argument.size() == 0 ) + { + // if there would be exactly the right number of arguments if + // the first_argument wasn't empty + if (argc == o->number_of_arguments() + i) + throw cmd_line_parse_error(ETOO_FEW_ARGS,temp,o->number_of_arguments()); + else + { + // in this case we just ignore the trailing = and parse everything + // the same. + extra_argument = 0; + } + } + // you can't force an option that doesn't have any arguments to take + // one by using the --option=arg syntax + if (extra_argument == 1 && o->number_of_arguments() == 0) + { + throw cmd_line_parse_error(ETOO_MANY_ARGS,temp); + } + + + + + + + // at this point we know that the option is ok and we should + // populate its options object + if (o->number_of_arguments() > 0) + { + + string_type* stemp = new string_type[o->number_of_arguments()]; + unsigned long j = 0; + + // add the argument after the = sign if one is present + if (extra_argument) + { + stemp[0] = first_argument; + ++j; + } + + for (; j < o->number_of_arguments(); ++j) + { + stemp[j] = argv[i+j+1-extra_argument]; + } + o->options.add(o->options.size(),stemp); + } + o->num_present += 1; + + + // adjust the value of i to account for the arguments to + // this option + i += o->number_of_arguments() - extra_argument; + } + // -------------------------------------------------------------------- + else + { + // we are looking at the start of a list of a single char options + + // make sure there is something in this string other than - + if (argv[i][1] == _dT(charT,'\0')) + { + throw cmd_line_parse_error(); + } + + string_type temp = &argv[i][1]; + const typename string_type::size_type num = temp.size(); + for (unsigned long k = 0; k < num; ++k) + { + string_type name; + // Doing this instead of name = temp[k] seems to avoid a bug in g++ (Ubuntu/Linaro 4.5.2-8ubuntu4) 4.5.2 + // which results in name[0] having the wrong value. + name.resize(1); + name[0] = temp[k]; + + + // make sure this name is defined + if (!options.is_in_domain(name)) + { + // the name is not a valid option + throw cmd_line_parse_error(EINVALID_OPTION,name); + } + + option_t* o = static_cast(options[name]); + + // if there are chars immediately following this option + int delta = 0; + if (num != k+1) + { + delta = 1; + } + + // check the number of arguments after this option and make sure + // it is correct + if (argc + delta <= o->number_of_arguments() + i) + { + // there are too few arguments + std::ostringstream sout; + throw cmd_line_parse_error(ETOO_FEW_ARGS,name,o->number_of_arguments()); + } + + + o->num_present += 1; + + // at this point we know that the option is ok and we should + // populate its options object + if (o->number_of_arguments() > 0) + { + string_type* stemp = new string_type[o->number_of_arguments()]; + if (delta == 1) + { + temp = &argv[i][2+k]; + k = (unsigned long)num; // this ensures that the argument to this + // option isn't going to be treated as a + // list of options + + stemp[0] = temp; + } + for (unsigned long j = 0; j < o->number_of_arguments()-delta; ++j) + { + stemp[j+delta] = argv[i+j+1]; + } + o->options.add(o->options.size(),stemp); + + // adjust the value of i to account for the arguments to + // this option + i += o->number_of_arguments()-delta; + } + } // for (unsigned long k = 0; k < num; ++k) + } + // -------------------------------------------------------------------- + + } + else + { + // this is just a normal argument + string_type temp = argv[i]; + this->argv.add(this->argv.size(),temp); + } + + } + have_parsed_line = true; + + } + catch (...) + { + have_parsed_line = false; + + // clear all the option objects + options.reset(); + while (options.move_next()) + { + static_cast(options.element().value())->clear(); + } + options.reset(); + + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + bool cmd_line_parser_kernel_1:: + parsed_line( + ) const + { + return have_parsed_line; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + bool cmd_line_parser_kernel_1:: + option_is_defined ( + const string_type& name + ) const + { + return options.is_in_domain(name); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + void cmd_line_parser_kernel_1:: + set_group_name ( + const string_type& group_name_ + ) + { + group_name = group_name_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + void cmd_line_parser_kernel_1:: + add_option ( + const string_type& name, + const string_type& description, + unsigned long number_of_arguments + ) + { + option_t* temp = new option_t; + try + { + temp->name_ = name; + temp->group_name_ = group_name; + temp->description_ = description; + temp->number_of_arguments_ = number_of_arguments; + void* t = temp; + string_type n(name); + options.add(n,t); + }catch (...) { delete temp; throw;} + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + const cmd_line_parser_option& cmd_line_parser_kernel_1:: + option ( + const string_type& name + ) const + { + return *static_cast*>(options[name]); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + unsigned long cmd_line_parser_kernel_1:: + number_of_arguments( + ) const + { + return argv.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + const std::basic_string& cmd_line_parser_kernel_1:: + operator[] ( + unsigned long index + ) const + { + return argv[index]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + void cmd_line_parser_kernel_1:: + swap ( + cmd_line_parser_kernel_1& item + ) + { + options.swap(item.options); + argv.swap(item.argv); + exchange(have_parsed_line,item.have_parsed_line); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_KERNEl_1_ + diff --git a/dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h b/dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..ac9036132c578a2fae2de5b790660198864b37b5 --- /dev/null +++ b/dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h @@ -0,0 +1,673 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CMD_LINE_PARSER_KERNEl_ABSTRACT_ +#ifdef DLIB_CMD_LINE_PARSER_KERNEl_ABSTRACT_ + +#include "../algs.h" +#include +#include "../interfaces/enumerable.h" +#include "../interfaces/cmd_line_parser_option.h" +#include +#include + +namespace dlib +{ + + template < + typename charT + > + class cmd_line_parser : public enumerable > + { + /*! + REQUIREMENTS ON charT + Must be an integral type suitable for storing characters. (e.g. char + or wchar_t) + + INITIAL VALUE + - parsed_line() == false + - option_is_defined(x) == false, for all values of x + - get_group_name() == "" + + ENUMERATION ORDER + The enumerator will enumerate over all the options defined in *this + in alphabetical order according to the name of the option. + + POINTERS AND REFERENCES TO INTERNAL DATA + parsed_line(), option_is_defined(), option(), number_of_arguments(), + operator[](), and swap() functions do not invalidate pointers or + references to internal data. All other functions have no such guarantee. + + + WHAT THIS OBJECT REPRESENTS + This object represents a command line parser. + The command lines must match the following BNF. + + command_line ::= { | } [ -- {} ] + program_name ::= + arg ::= any that does not start with - + option_arg ::= + option_name ::= + long_option_name ::= { | - } + options ::= - {} {} | + -- [=] { } + char ::= any character other than - or = + word ::= any string from argv where argv is the second + parameter to main() + sword ::= any suffix of a string from argv where argv is the + second parameter to main() + bword ::= This is an empty string which denotes the beginning of a + . + + + Options with arguments: + An option with N arguments will consider the next N swords to be + its arguments. + + so for example, if we have an option o that expects 2 arguments + then the following are a few legal examples: + + program -o arg1 arg2 general_argument + program -oarg1 arg2 general_argument + + arg1 and arg2 are associated with the option o and general_argument + is not. + + Arguments not associated with an option: + An argument that is not associated with an option is considered a + general command line argument and is indexed by operator[] defined + by the cmd_line_parser object. Additionally, if the string + "--" appears in the command line all by itself then all words + following it are considered to be general command line arguments. + + + Consider the following two examples involving a command line and + a cmd_line_parser object called parser. + + Example 1: + command line: program general_arg1 -o arg1 arg2 general_arg2 + Then the following is true (assuming the o option is defined + and takes 2 arguments). + + parser[0] == "general_arg1" + parser[1] == "general_arg2" + parser.number_of_arguments() == 2 + parser.option("o").argument(0) == "arg1" + parser.option("o").argument(1) == "arg2" + parser.option("o").count() == 1 + + Example 2: + command line: program general_arg1 -- -o arg1 arg2 general_arg2 + Then the following is true (the -- causes everything following + it to be treated as a general argument). + + parser[0] == "general_arg1" + parser[1] == "-o" + parser[2] == "arg1" + parser[3] == "arg2" + parser[4] == "general_arg2" + parser.number_of_arguments() == 5 + parser.option("o").count() == 0 + !*/ + + public: + + typedef charT char_type; + typedef std::basic_string string_type; + typedef cmd_line_parser_option option_type; + + // exception class + class cmd_line_parse_error : public dlib::error + { + /*! + GENERAL + This exception is thrown if there is an error detected in a + command line while it is being parsed. You can consult this + object's type and item members to determine the nature of the + error. (note that the type member is inherited from dlib::error). + + INTERPRETING THIS EXCEPTION + - if (type == EINVALID_OPTION) then + - There was an undefined option on the command line + - item == The invalid option that was on the command line + - if (type == ETOO_FEW_ARGS) then + - An option was given on the command line but it was not + supplied with the required number of arguments. + - item == The name of this option. + - num == The number of arguments expected by this option. + - if (type == ETOO_MANY_ARGS) then + - An option was given on the command line such as --option=arg + but this option doesn't take any arguments. + - item == The name of this option. + !*/ + public: + const std::basic_string item; + const unsigned long num; + }; + + // -------------------------- + + cmd_line_parser ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~cmd_line_parser ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + void parse ( + int argc, + const charT** argv + ); + /*! + requires + - argv == an array of strings that was obtained from the second argument + of the function main(). + (i.e. argv[0] should be the token, argv[1] should be + an or token, etc.) + - argc == the number of strings in argv + ensures + - parses the command line given by argc and argv + - #parsed_line() == true + - #at_start() == true + throws + - std::bad_alloc + if this exception is thrown then #*this is unusable until clear() + is called successfully + - cmd_line_parse_error + This exception is thrown if there is an error parsing the command line. + If this exception is thrown then #parsed_line() == false and all + options will have their count() set to 0 but otherwise there will + be no effect (i.e. all registered options will remain registered). + !*/ + + void parse ( + int argc, + charT** argv + ); + /*! + This just calls this->parse(argc,argv) and performs the necessary const_cast + on argv. + !*/ + + bool parsed_line( + ) const; + /*! + ensures + - returns true if parse() has been called successfully + - returns false otherwise + !*/ + + bool option_is_defined ( + const string_type& name + ) const; + /*! + ensures + - returns true if the option has been added to the parser object + by calling add_option(name). + - returns false otherwise + !*/ + + void add_option ( + const string_type& name, + const string_type& description, + unsigned long number_of_arguments = 0 + ); + /*! + requires + - parsed_line() == false + - option_is_defined(name) == false + - name does not contain any ' ', '\t', '\n', or '=' characters + - name[0] != '-' + - name.size() > 0 + ensures + - #option_is_defined(name) == true + - #at_start() == true + - #option(name).count() == 0 + - #option(name).description() == description + - #option(name).number_of_arguments() == number_of_arguments + - #option(name).group_name() == get_group_name() + throws + - std::bad_alloc + if this exception is thrown then the add_option() function has no + effect + !*/ + + const option_type& option ( + const string_type& name + ) const; + /*! + requires + - option_is_defined(name) == true + ensures + - returns the option specified by name + !*/ + + unsigned long number_of_arguments( + ) const; + /*! + requires + - parsed_line() == true + ensures + - returns the number of arguments present in the command line. + This count does not include options or their arguments. Only + arguments unrelated to any option are counted. + !*/ + + const string_type& operator[] ( + unsigned long N + ) const; + /*! + requires + - parsed_line() == true + - N < number_of_arguments() + ensures + - returns the Nth command line argument + !*/ + + void swap ( + cmd_line_parser& item + ); + /*! + ensures + - swaps *this and item + !*/ + + void print_options ( + std::basic_ostream& out + ) const; + /*! + ensures + - prints all the command line options to out. + - #at_start() == true + throws + - any exception. + if an exception is thrown then #at_start() == true but otherwise + it will have no effect on the state of #*this. + !*/ + + void print_options ( + ) const; + /*! + ensures + - prints all the command line options to cout. + - #at_start() == true + throws + - any exception. + if an exception is thrown then #at_start() == true but otherwise + it will have no effect on the state of #*this. + !*/ + + string_type get_group_name ( + ) const; + /*! + ensures + - returns the current group name. This is the group new options will be + added into when added via add_option(). + - The group name of an option is used by print_options(). In particular, + it groups all options with the same group name together and displays them + under a title containing the text of the group name. This allows you to + group similar options together in the output of print_options(). + - A group name of "" (i.e. the empty string) means that no group name is + set. + !*/ + + void set_group_name ( + const string_type& group_name + ); + /*! + ensures + - #get_group_name() == group_name + !*/ + + // ------------------------------------------------------------- + // Input Validation Tools + // ------------------------------------------------------------- + + class cmd_line_check_error : public dlib::error + { + /*! + This is the exception thrown by the check_*() routines if they find a + command line error. The interpretation of the member variables is defined + below in each check_*() routine. + !*/ + + public: + const string_type opt; + const string_type opt2; + const string_type arg; + const std::vector required_opts; + }; + + template < + typename T + > + void check_option_arg_type ( + const string_type& option_name + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(option_name) == true + - T is not a pointer type + ensures + - all the arguments for the given option are convertible + by string_cast() to an object of type T. + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINVALID_OPTION_ARG + - opt == option_name + - arg == the text of the offending argument + !*/ + + template < + typename T + > + void check_option_arg_range ( + const string_type& option_name, + const T& first, + const T& last + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(option_name) == true + - first <= last + - T is not a pointer type + ensures + - all the arguments for the given option are convertible + by string_cast() to an object of type T and the resulting value is + in the range first to last inclusive. + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINVALID_OPTION_ARG + - opt == option_name + - arg == the text of the offending argument + !*/ + + template < + typename T, + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const T (&arg_set)[length] + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(option_name) == true + - T is not a pointer type + ensures + - for each argument to the given option: + - this argument is convertible by string_cast() to an object of + type T and the resulting value is equal to some element in the + arg_set array. + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINVALID_OPTION_ARG + - opt == option_name + - arg == the text of the offending argument + !*/ + + template < + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const char_type* (&arg_set)[length] + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(option_name) == true + ensures + - for each argument to the given option: + - there is a string in the arg_set array that is equal to this argument. + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINVALID_OPTION_ARG + - opt == option_name + - arg == the text of the offending argument + !*/ + + template < + size_t length + > + void check_one_time_options ( + const char_type* (&option_set)[length] + ) const; + /*! + requires + - parsed_line() == true + - for all valid i: + - option_is_defined(option_set[i]) == true + ensures + - all the options in the option_set array occur at most once on the + command line. + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EMULTIPLE_OCCURANCES + - opt == the option that occurred more than once on the command line. + !*/ + + void check_incompatible_options ( + const string_type& option_name1, + const string_type& option_name2 + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(option_name1) == true + - option_is_defined(option_name2) == true + ensures + - option(option_name1).count() == 0 || option(option_name2).count() == 0 + (i.e. at most, only one of the options is currently present) + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINCOMPATIBLE_OPTIONS + - opt == option_name1 + - opt2 == option_name2 + !*/ + + template < + size_t length + > + void check_incompatible_options ( + const char_type* (&option_set)[length] + ) const; + /*! + requires + - parsed_line() == true + - for all valid i: + - option_is_defined(option_set[i]) == true + ensures + - At most only one of the options in the array option_set has a count() + greater than 0. (i.e. at most, only one of the options is currently present) + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINCOMPATIBLE_OPTIONS + - opt == One of the incompatible options found. + - opt2 == The next incompatible option found. + !*/ + + void check_sub_option ( + const string_type& parent_option, + const string_type& sub_option + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(parent_option) == true + - option_is_defined(sub_option) == true + ensures + - if (option(parent_option).count() == 0) then + - option(sub_option).count() == 0 + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EMISSING_REQUIRED_OPTION + - opt == sub_option. + - required_opts == a vector that contains only parent_option. + !*/ + + template < + size_t length + > + void check_sub_options ( + const char_type* (&parent_option_set)[length], + const string_type& sub_option + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(sub_option) == true + - for all valid i: + - option_is_defined(parent_option_set[i] == true + ensures + - if (option(sub_option).count() > 0) then + - At least one of the options in the array parent_option_set has a count() + greater than 0. (i.e. at least one of the options in parent_option_set + is currently present) + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EMISSING_REQUIRED_OPTION + - opt == the first option from the sub_option that is present. + - required_opts == a vector containing everything from parent_option_set. + !*/ + + template < + size_t length + > + void check_sub_options ( + const string_type& parent_option, + const char_type* (&sub_option_set)[length] + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(parent_option) == true + - for all valid i: + - option_is_defined(sub_option_set[i]) == true + ensures + - if (option(parent_option).count() == 0) then + - for all valid i: + - option(sub_option_set[i]).count() == 0 + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EMISSING_REQUIRED_OPTION + - opt == the first option from the sub_option_set that is present. + - required_opts == a vector that contains only parent_option. + !*/ + + template < + size_t parent_length, + size_t sub_length + > + void check_sub_options ( + const char_type* (&parent_option_set)[parent_length], + const char_type* (&sub_option_set)[sub_length] + ) const; + /*! + requires + - parsed_line() == true + - for all valid i: + - option_is_defined(parent_option_set[i] == true + - for all valid j: + - option_is_defined(sub_option_set[j]) == true + ensures + - for all valid j: + - if (option(sub_option_set[j]).count() > 0) then + - At least one of the options in the array parent_option_set has a count() + greater than 0. (i.e. at least one of the options in parent_option_set + is currently present) + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EMISSING_REQUIRED_OPTION + - opt == the first option from the sub_option_set that is present. + - required_opts == a vector containing everything from parent_option_set. + !*/ + + + private: + + // restricted functions + cmd_line_parser(cmd_line_parser&); // copy constructor + cmd_line_parser& operator=(cmd_line_parser&); // assignment operator + + }; + +// ----------------------------------------------------------------------------------------- + + typedef cmd_line_parser command_line_parser; + typedef cmd_line_parser wcommand_line_parser; + +// ----------------------------------------------------------------------------------------- + + template < + typename charT + > + inline void swap ( + cmd_line_parser& a, + cmd_line_parser& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ----------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_KERNEl_ABSTRACT_ + diff --git a/dlib/cmd_line_parser/cmd_line_parser_kernel_c.h b/dlib/cmd_line_parser/cmd_line_parser_kernel_c.h new file mode 100644 index 0000000000000000000000000000000000000000..e80543018e17bc3700b3a824ad23954200be674d --- /dev/null +++ b/dlib/cmd_line_parser/cmd_line_parser_kernel_c.h @@ -0,0 +1,203 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_KERNEl_C_ +#define DLIB_CMD_LINE_PARSER_KERNEl_C_ + +#include "cmd_line_parser_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include +#include "../interfaces/cmd_line_parser_option.h" +#include "../string.h" + +namespace dlib +{ + + template < + typename clp_base + > + class cmd_line_parser_kernel_c : public clp_base + { + public: + + typedef typename clp_base::char_type char_type; + typedef typename clp_base::string_type string_type; + typedef typename clp_base::option_type option_type; + + void add_option ( + const string_type& name, + const string_type& description, + unsigned long number_of_arguments = 0 + ); + + const option_type& option ( + const string_type& name + ) const; + + unsigned long number_of_arguments( + ) const; + + const option_type& element ( + ) const; + + option_type& element ( + ); + + const string_type& operator[] ( + unsigned long N + ) const; + + }; + + + template < + typename clp_base + > + inline void swap ( + cmd_line_parser_kernel_c& a, + cmd_line_parser_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + const typename clp_base::string_type& cmd_line_parser_kernel_c:: + operator[] ( + unsigned long N + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && N < number_of_arguments(), + "\tvoid cmd_line_parser::operator[](unsigned long N)" + << "\n\tYou must specify a valid index N and the parser must have run already." + << "\n\tthis: " << this + << "\n\tN: " << N + << "\n\tparsed_line(): " << this->parsed_line() + << "\n\tnumber_of_arguments(): " << number_of_arguments() + ); + + return clp_base::operator[](N); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + void cmd_line_parser_kernel_c:: + add_option ( + const string_type& name, + const string_type& description, + unsigned long number_of_arguments + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == false && + name.size() > 0 && + this->option_is_defined(name) == false && + name.find_first_of(_dT(char_type," \t\n=")) == string_type::npos && + name[0] != '-', + "\tvoid cmd_line_parser::add_option(const string_type&,const string_type&,unsigned long)" + << "\n\tsee the requires clause of add_option()" + << "\n\tthis: " << this + << "\n\tname.size(): " << static_cast(name.size()) + << "\n\tname: \"" << narrow(name) << "\"" + << "\n\tparsed_line(): " << (this->parsed_line()? "true" : "false") + << "\n\tis_option_defined(\"" << narrow(name) << "\"): " << (this->option_is_defined(name)? "true" : "false") + ); + + clp_base::add_option(name,description,number_of_arguments); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + const typename clp_base::option_type& cmd_line_parser_kernel_c:: + option ( + const string_type& name + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->option_is_defined(name) == true, + "\toption cmd_line_parser::option(const string_type&)" + << "\n\tto get an option it must be defined by a call to add_option()" + << "\n\tthis: " << this + << "\n\tname: \"" << narrow(name) << "\"" + ); + + return clp_base::option(name); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + unsigned long cmd_line_parser_kernel_c:: + number_of_arguments( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true , + "\tunsigned long cmd_line_parser::number_of_arguments()" + << "\n\tyou must parse the command line before you can find out how many arguments it has" + << "\n\tthis: " << this + ); + + return clp_base::number_of_arguments(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + const typename clp_base::option_type& cmd_line_parser_kernel_c:: + element ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst cmd_line_parser_option& cmd_line_parser::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return clp_base::element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + typename clp_base::option_type& cmd_line_parser_kernel_c:: + element ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tcmd_line_parser_option& cmd_line_parser::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return clp_base::element(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_KERNEl_C_ + diff --git a/dlib/cmd_line_parser/cmd_line_parser_print_1.h b/dlib/cmd_line_parser/cmd_line_parser_print_1.h new file mode 100644 index 0000000000000000000000000000000000000000..3f52c842f23fe57655e93b33b807fe7a68b2d821 --- /dev/null +++ b/dlib/cmd_line_parser/cmd_line_parser_print_1.h @@ -0,0 +1,205 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_PRINt_1_ +#define DLIB_CMD_LINE_PARSER_PRINt_1_ + +#include "cmd_line_parser_kernel_abstract.h" +#include "../algs.h" +#include "../string.h" +#include +#include +#include +#include +#include + +namespace dlib +{ + + template < + typename clp_base + > + class cmd_line_parser_print_1 : public clp_base + { + + public: + + void print_options ( + std::basic_ostream& out + ) const; + + void print_options ( + ) const + { + print_options(std::cout); + } + + }; + + template < + typename clp_base + > + inline void swap ( + cmd_line_parser_print_1& a, + cmd_line_parser_print_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + void cmd_line_parser_print_1:: + print_options ( + std::basic_ostream& out + ) const + { + typedef typename clp_base::char_type ct; + typedef std::basic_string string; + typedef typename string::size_type size_type; + + typedef std::basic_ostringstream ostringstream; + + try + { + + + size_type max_len = 0; + this->reset(); + + // this loop here is just the bottom loop but without the print statements. + // I'm doing this to figure out what len should be. + while (this->move_next()) + { + size_type len = 0; + len += 3; + if (this->element().name().size() > 1) + { + ++len; + } + len += this->element().name().size(); + + if (this->element().number_of_arguments() == 1) + { + len += 6; + } + else + { + for (unsigned long i = 0; i < this->element().number_of_arguments(); ++i) + { + len += 7; + if (i+1 > 9) + ++len; + } + } + + len += 3; + if (len < 33) + max_len = std::max(max_len,len); + } + + + // Make a separate ostringstream for each option group. We are going to write + // the output for each group to a separate ostringstream so that we can keep + // them grouped together in the final output. + std::map > groups; + this->reset(); + while(this->move_next()) + { + if (!groups[this->element().group_name()]) + groups[this->element().group_name()].reset(new ostringstream); + } + + + + + this->reset(); + + while (this->move_next()) + { + ostringstream& sout = *groups[this->element().group_name()]; + + size_type len = 0; + sout << _dT(ct,"\n -"); + len += 3; + if (this->element().name().size() > 1) + { + sout << _dT(ct,"-"); + ++len; + } + sout << this->element().name(); + len += this->element().name().size(); + + if (this->element().number_of_arguments() == 1) + { + sout << _dT(ct," "); + len += 6; + } + else + { + for (unsigned long i = 0; i < this->element().number_of_arguments(); ++i) + { + sout << _dT(ct," "); + len += 7; + if (i+1 > 9) + ++len; + } + } + + sout << _dT(ct," "); + len += 3; + + while (len < max_len) + { + ++len; + sout << _dT(ct," "); + } + + const unsigned long ml = static_cast(max_len); + // now print the description but make it wrap around nicely if it + // is to long to fit on one line. + if (len <= max_len) + sout << wrap_string(this->element().description(),0,ml); + else + sout << _dT(ct,"\n") << wrap_string(this->element().description(),ml,ml); + } + + // Only print out a generic Options: group name if there is an unnamed option + // present. + if (groups.count(string()) == 1) + out << _dT(ct,"Options:"); + + // Now print everything out + typename std::map >::iterator i; + for (i = groups.begin(); i != groups.end(); ++i) + { + // print the group name if we have one + if (i->first.size() != 0) + { + if (i != groups.begin()) + out << _dT(ct,"\n\n"); + out << i->first << _dT(ct,":"); + } + + // print the options in the group + out << i->second->str(); + } + out << _dT(ct,"\n\n"); + this->reset(); + } + catch (...) + { + this->reset(); + throw; + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_PRINt_1_ + diff --git a/dlib/cmd_line_parser/get_option.h b/dlib/cmd_line_parser/get_option.h new file mode 100644 index 0000000000000000000000000000000000000000..2c8d1644f7c2b58905b69ae563a6d8169dcae5a1 --- /dev/null +++ b/dlib/cmd_line_parser/get_option.h @@ -0,0 +1,181 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GET_OPTiON_Hh_ +#define DLIB_GET_OPTiON_Hh_ + +#include "get_option_abstract.h" +#include "../string.h" +#include "../is_kind.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class option_parse_error : public error + { + public: + option_parse_error(const std::string& option_string, const std::string& str): + error(EOPTION_PARSE,"Error parsing argument for option '" + option_string + "', offending string is '" + str + "'.") {} + }; + +// ---------------------------------------------------------------------------------------- + + template + T impl_config_reader_get_option ( + const config_reader_type& cr, + const std::string& option_name, + const std::string& full_option_name, + T default_value + ) + { + std::string::size_type pos = option_name.find_first_of("."); + if (pos == std::string::npos) + { + if (cr.is_key_defined(option_name)) + { + try{ return string_cast(cr[option_name]); } + catch (string_cast_error&) { throw option_parse_error(full_option_name, cr[option_name]); } + } + } + else + { + std::string block_name = option_name.substr(0,pos); + if (cr.is_block_defined(block_name)) + { + return impl_config_reader_get_option(cr.block(block_name), + option_name.substr(pos+1), + full_option_name, + default_value); + } + } + + return default_value; + } + +// ---------------------------------------------------------------------------------------- + + template + typename enable_if,T>::type get_option ( + const cr_type& cr, + const std::string& option_name, + T default_value + ) + { + return impl_config_reader_get_option(cr, option_name, option_name, default_value); + } + +// ---------------------------------------------------------------------------------------- + + template + typename disable_if,T>::type get_option ( + const parser_type& parser, + const std::string& option_name, + T default_value + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( parser.option_is_defined(option_name) == true && + parser.option(option_name).number_of_arguments() == 1, + "\t T get_option()" + << "\n\t option_name: " << option_name + << "\n\t parser.option_is_defined(option_name): " << parser.option_is_defined(option_name) + << "\n\t parser.option(option_name).number_of_arguments(): " << parser.option(option_name).number_of_arguments() + ); + + if (parser.option(option_name)) + { + try + { + default_value = string_cast(parser.option(option_name).argument()); + } + catch (string_cast_error&) + { + throw option_parse_error(option_name, parser.option(option_name).argument()); + } + } + return default_value; + } + +// ---------------------------------------------------------------------------------------- + + template + typename disable_if,T>::type get_option ( + const parser_type& parser, + const cr_type& cr, + const std::string& option_name, + T default_value + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( parser.option_is_defined(option_name) == true && + parser.option(option_name).number_of_arguments() == 1, + "\t T get_option()" + << "\n\t option_name: " << option_name + << "\n\t parser.option_is_defined(option_name): " << parser.option_is_defined(option_name) + << "\n\t parser.option(option_name).number_of_arguments(): " << parser.option(option_name).number_of_arguments() + ); + + if (parser.option(option_name)) + return get_option(parser, option_name, default_value); + else + return get_option(cr, option_name, default_value); + } + +// ---------------------------------------------------------------------------------------- + + template + typename disable_if,T>::type get_option ( + const cr_type& cr, + const parser_type& parser, + const std::string& option_name, + T default_value + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( parser.option_is_defined(option_name) == true && + parser.option(option_name).number_of_arguments() == 1, + "\t T get_option()" + << "\n\t option_name: " << option_name + << "\n\t parser.option_is_defined(option_name): " << parser.option_is_defined(option_name) + << "\n\t parser.option(option_name).number_of_arguments(): " << parser.option(option_name).number_of_arguments() + ); + + if (parser.option(option_name)) + return get_option(parser, option_name, default_value); + else + return get_option(cr, option_name, default_value); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + inline std::string get_option ( + const T& cr, + const std::string& option_name, + const char* default_value + ) + { + return get_option(cr, option_name, std::string(default_value)); + } + +// ---------------------------------------------------------------------------------------- + + template + inline std::string get_option ( + const T& parser, + const U& cr, + const std::string& option_name, + const char* default_value + ) + { + return get_option(parser, cr, option_name, std::string(default_value)); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GET_OPTiON_Hh_ + diff --git a/dlib/cmd_line_parser/get_option_abstract.h b/dlib/cmd_line_parser/get_option_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..90dc16721791921d92ca1c2c76fc24bd1cdbc2b5 --- /dev/null +++ b/dlib/cmd_line_parser/get_option_abstract.h @@ -0,0 +1,146 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_GET_OPTiON_ABSTRACT_Hh_ +#ifdef DLIB_GET_OPTiON_ABSTRACT_Hh_ + +#inclue + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class option_parse_error : public error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception thrown by the get_option() functions. It is + thrown when the option string given by a command line parser or + config reader can't be converted into the type T. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_type, + typename T + > + T get_option ( + const config_reader_type& cr, + const std::string& option_name, + T default_value + ); + /*! + requires + - T is a type which can be read from an input stream + - config_reader_type == an implementation of config_reader/config_reader_kernel_abstract.h + ensures + - option_name is used to index into the given config_reader. + - if (cr contains an entry corresponding to option_name) then + - converts the string value in cr corresponding to option_name into + an object of type T and returns it. + - else + - returns default_value + - The scheme for indexing into cr based on option_name is best + understood by looking at a few examples: + - an option name of "name" corresponds to cr["name"] + - an option name of "block1.name" corresponds to cr.block("block1")["name"] + - an option name of "block1.block2.name" corresponds to cr.block("block1").block("block2")["name"] + throws + - option_parse_error + This exception is thrown if we attempt but fail to convert the string value + in cr into an object of type T. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename command_line_parser_type, + typename T + > + T get_option ( + const command_line_parser_type& parser, + const std::string& option_name, + T default_value + ); + /*! + requires + - parser.option_is_defined(option_name) == true + - parser.option(option_name).number_of_arguments() == 1 + - T is a type which can be read from an input stream + - command_line_parser_type == an implementation of cmd_line_parser/cmd_line_parser_kernel_abstract.h + ensures + - if (parser.option(option_name)) then + - converts parser.option(option_name).argument() into an object + of type T and returns it. That is, the string argument to this + command line option is converted into a T and returned. + - else + - returns default_value + throws + - option_parse_error + This exception is thrown if we attempt but fail to convert the string + argument into an object of type T. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename command_line_parser_type, + typename config_reader_type, + typename T + > + T get_option ( + const command_line_parser_type& parser, + const config_reader_type& cr, + const std::string& option_name, + T default_value + ); + /*! + requires + - parser.option_is_defined(option_name) == true + - parser.option(option_name).number_of_arguments() == 1 + - T is a type which can be read from an input stream + - command_line_parser_type == an implementation of cmd_line_parser/cmd_line_parser_kernel_abstract.h + - config_reader_type == an implementation of config_reader/config_reader_kernel_abstract.h + ensures + - if (parser.option(option_name)) then + - returns get_option(parser, option_name, default_value) + - else + - returns get_option(cr, option_name, default_value) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename command_line_parser_type, + typename config_reader_type, + typename T + > + T get_option ( + const config_reader_type& cr, + const command_line_parser_type& parser, + const std::string& option_name, + T default_value + ); + /*! + requires + - parser.option_is_defined(option_name) == true + - parser.option(option_name).number_of_arguments() == 1 + - T is a type which can be read from an input stream + - command_line_parser_type == an implementation of cmd_line_parser/cmd_line_parser_kernel_abstract.h + - config_reader_type == an implementation of config_reader/config_reader_kernel_abstract.h + ensures + - if (parser.option(option_name)) then + - returns get_option(parser, option_name, default_value) + - else + - returns get_option(cr, option_name, default_value) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GET_OPTiON_ABSTRACT_Hh_ + + diff --git a/dlib/compress_stream.h b/dlib/compress_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..8ccc1d52faf49b311a24b7b6740d3ec48e553771 --- /dev/null +++ b/dlib/compress_stream.h @@ -0,0 +1,133 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_COMPRESS_STREAm_ +#define DLIB_COMPRESS_STREAm_ + +#include "compress_stream/compress_stream_kernel_1.h" +#include "compress_stream/compress_stream_kernel_2.h" +#include "compress_stream/compress_stream_kernel_3.h" + +#include "conditioning_class.h" +#include "entropy_encoder.h" +#include "entropy_decoder.h" + +#include "entropy_encoder_model.h" +#include "entropy_decoder_model.h" +#include "lz77_buffer.h" +#include "sliding_buffer.h" +#include "lzp_buffer.h" +#include "crc32.h" + + +namespace dlib +{ + + class compress_stream + { + compress_stream() {} + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_1b fce1; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_1b fcd1; + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_2b fce2; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_2b fcd2; + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_3b fce3; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_3b fcd3; + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_4a fce4a; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_4a fcd4a; + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_4b fce4b; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_4b fcd4b; + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_5a fce5a; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_5a fcd5a; + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_5b fce5b; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_5b fcd5b; + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_5c fce5c; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_5c fcd5c; + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_6a fce6; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_6a fcd6; + + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_2d fce2d; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_2d fcd2d; + + typedef sliding_buffer::kernel_1a sliding_buffer1; + typedef lz77_buffer::kernel_2a lz77_buffer2a; + + + typedef lzp_buffer::kernel_1a lzp_buf_1; + typedef lzp_buffer::kernel_2a lzp_buf_2; + + + typedef entropy_encoder_model<513,entropy_encoder::kernel_2a>::kernel_1b fce_length; + typedef entropy_decoder_model<513,entropy_decoder::kernel_2a>::kernel_1b fcd_length; + + typedef entropy_encoder_model<65534,entropy_encoder::kernel_2a>::kernel_1b fce_length_2; + typedef entropy_decoder_model<65534,entropy_decoder::kernel_2a>::kernel_1b fcd_length_2; + + + typedef entropy_encoder_model<32257,entropy_encoder::kernel_2a>::kernel_1b fce_index; + typedef entropy_decoder_model<32257,entropy_decoder::kernel_2a>::kernel_1b fcd_index; + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef compress_stream_kernel_1 + kernel_1a; + + // kernel_1b + typedef compress_stream_kernel_1 + kernel_1b; + + // kernel_1c + typedef compress_stream_kernel_1 + kernel_1c; + + // kernel_1da + typedef compress_stream_kernel_1 + kernel_1da; + + // kernel_1ea + typedef compress_stream_kernel_1 + kernel_1ea; + + // kernel_1db + typedef compress_stream_kernel_1 + kernel_1db; + + // kernel_1eb + typedef compress_stream_kernel_1 + kernel_1eb; + + // kernel_1ec + typedef compress_stream_kernel_1 + kernel_1ec; + + + + + // kernel_2a + typedef compress_stream_kernel_2 + kernel_2a; + + + + + // kernel_3a + typedef compress_stream_kernel_3 + kernel_3a; + // kernel_3b + typedef compress_stream_kernel_3 + kernel_3b; + + + }; +} + +#endif // DLIB_COMPRESS_STREAm_ + diff --git a/dlib/compress_stream/compress_stream_kernel_1.h b/dlib/compress_stream/compress_stream_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..1a75ec6ced988ec9108379d47d29617e3930b511 --- /dev/null +++ b/dlib/compress_stream/compress_stream_kernel_1.h @@ -0,0 +1,252 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_COMPRESS_STREAM_KERNEl_1_ +#define DLIB_COMPRESS_STREAM_KERNEl_1_ + +#include "../algs.h" +#include +#include +#include +#include "compress_stream_kernel_abstract.h" + +namespace dlib +{ + + template < + typename fce, + typename fcd, + typename crc32 + > + class compress_stream_kernel_1 + { + /*! + REQUIREMENTS ON fce + is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h + the alphabet_size of fce must be 257. + fce and fcd share the same kernel number. + + REQUIREMENTS ON fcd + is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h + the alphabet_size of fcd must be 257. + fce and fcd share the same kernel number. + + REQUIREMENTS ON crc32 + is an implementation of crc32/crc32_kernel_abstract.h + + + + INITIAL VALUE + this object has no state + + CONVENTION + this object has no state + !*/ + + const static unsigned long eof_symbol = 256; + + public: + + class decompression_error : public dlib::error + { + public: + decompression_error( + const char* i + ) : + dlib::error(std::string(i)) + {} + + decompression_error( + const std::string& i + ) : + dlib::error(i) + {} + }; + + + compress_stream_kernel_1 ( + ) + {} + + ~compress_stream_kernel_1 ( + ) + {} + + void compress ( + std::istream& in, + std::ostream& out + ) const; + + void decompress ( + std::istream& in, + std::ostream& out + ) const; + + private: + + // restricted functions + compress_stream_kernel_1(compress_stream_kernel_1&); // copy constructor + compress_stream_kernel_1& operator=(compress_stream_kernel_1&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename fce, + typename fcd, + typename crc32 + > + void compress_stream_kernel_1:: + compress ( + std::istream& in_, + std::ostream& out_ + ) const + { + std::streambuf::int_type temp; + + std::streambuf& in = *in_.rdbuf(); + + typename fce::entropy_encoder_type coder; + coder.set_stream(out_); + + fce model(coder); + + crc32 crc; + + unsigned long count = 0; + + while (true) + { + // write out a known value every 20000 symbols + if (count == 20000) + { + count = 0; + coder.encode(1500,1501,8000); + } + ++count; + + // get the next character + temp = in.sbumpc(); + + // if we have hit EOF then encode the marker symbol + if (temp != EOF) + { + // encode the symbol + model.encode(static_cast(temp)); + crc.add(static_cast(temp)); + continue; + } + else + { + model.encode(eof_symbol); + + // now write the checksum + unsigned long checksum = crc.get_checksum(); + unsigned char byte1 = static_cast((checksum>>24)&0xFF); + unsigned char byte2 = static_cast((checksum>>16)&0xFF); + unsigned char byte3 = static_cast((checksum>>8)&0xFF); + unsigned char byte4 = static_cast((checksum)&0xFF); + + model.encode(byte1); + model.encode(byte2); + model.encode(byte3); + model.encode(byte4); + + break; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename fce, + typename fcd, + typename crc32 + > + void compress_stream_kernel_1:: + decompress ( + std::istream& in_, + std::ostream& out_ + ) const + { + + std::streambuf& out = *out_.rdbuf(); + + typename fcd::entropy_decoder_type coder; + coder.set_stream(in_); + + fcd model(coder); + + unsigned long symbol; + unsigned long count = 0; + + crc32 crc; + + // decode until we hit the marker symbol + while (true) + { + // make sure this is the value we expect + if (count == 20000) + { + if (coder.get_target(8000) != 1500) + { + throw decompression_error("Error detected in compressed data stream."); + } + count = 0; + coder.decode(1500,1501); + } + ++count; + + // decode the next symbol + model.decode(symbol); + if (symbol != eof_symbol) + { + crc.add(static_cast(symbol)); + // write this symbol to out + if (out.sputc(static_cast(symbol)) != static_cast(symbol)) + { + throw std::ios::failure("error occurred in compress_stream_kernel_1::decompress"); + } + continue; + } + else + { + // we read eof from the encoded data. now we just have to check the checksum and we are done. + unsigned char byte1; + unsigned char byte2; + unsigned char byte3; + unsigned char byte4; + + model.decode(symbol); byte1 = static_cast(symbol); + model.decode(symbol); byte2 = static_cast(symbol); + model.decode(symbol); byte3 = static_cast(symbol); + model.decode(symbol); byte4 = static_cast(symbol); + + unsigned long checksum = byte1; + checksum <<= 8; + checksum |= byte2; + checksum <<= 8; + checksum |= byte3; + checksum <<= 8; + checksum |= byte4; + + if (checksum != crc.get_checksum()) + throw decompression_error("Error detected in compressed data stream."); + + break; + } + } // while (true) + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_COMPRESS_STREAM_KERNEl_1_ + diff --git a/dlib/compress_stream/compress_stream_kernel_2.h b/dlib/compress_stream/compress_stream_kernel_2.h new file mode 100644 index 0000000000000000000000000000000000000000..e46b23fad874758b2553ac1db1a1006f8ec8b064 --- /dev/null +++ b/dlib/compress_stream/compress_stream_kernel_2.h @@ -0,0 +1,431 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_COMPRESS_STREAM_KERNEl_2_ +#define DLIB_COMPRESS_STREAM_KERNEl_2_ + +#include "../algs.h" +#include +#include +#include "compress_stream_kernel_abstract.h" + +namespace dlib +{ + + template < + typename fce, + typename fcd, + typename lz77_buffer, + typename sliding_buffer, + typename fce_length, + typename fcd_length, + typename fce_index, + typename fcd_index, + typename crc32 + > + class compress_stream_kernel_2 + { + /*! + REQUIREMENTS ON fce + is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h + the alphabet_size of fce must be 257. + fce and fcd share the same kernel number. + + REQUIREMENTS ON fcd + is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h + the alphabet_size of fcd must be 257. + fce and fcd share the same kernel number. + + REQUIREMENTS ON lz77_buffer + is an implementation of lz77_buffer/lz77_buffer_kernel_abstract.h + + REQUIREMENTS ON sliding_buffer + is an implementation of sliding_buffer/sliding_buffer_kernel_abstract.h + is instantiated with T = unsigned char + + REQUIREMENTS ON fce_length + is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h + the alphabet_size of fce must be 513. This will be used to encode the length of lz77 matches. + fce_length and fcd share the same kernel number. + + REQUIREMENTS ON fcd_length + is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h + the alphabet_size of fcd must be 513. This will be used to decode the length of lz77 matches. + fce_length and fcd share the same kernel number. + + REQUIREMENTS ON fce_index + is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h + the alphabet_size of fce must be 32257. This will be used to encode the index of lz77 matches. + fce_index and fcd share the same kernel number. + + REQUIREMENTS ON fcd_index + is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h + the alphabet_size of fcd must be 32257. This will be used to decode the index of lz77 matches. + fce_index and fcd share the same kernel number. + + REQUIREMENTS ON crc32 + is an implementation of crc32/crc32_kernel_abstract.h + + INITIAL VALUE + this object has no state + + CONVENTION + this object has no state + !*/ + + const static unsigned long eof_symbol = 256; + + public: + + class decompression_error : public dlib::error + { + public: + decompression_error( + const char* i + ) : + dlib::error(std::string(i)) + {} + + decompression_error( + const std::string& i + ) : + dlib::error(i) + {} + }; + + + compress_stream_kernel_2 ( + ) + {} + + ~compress_stream_kernel_2 ( + ) + {} + + void compress ( + std::istream& in, + std::ostream& out + ) const; + + void decompress ( + std::istream& in, + std::ostream& out + ) const; + + private: + + // restricted functions + compress_stream_kernel_2(compress_stream_kernel_2&); // copy constructor + compress_stream_kernel_2& operator=(compress_stream_kernel_2&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename fce, + typename fcd, + typename lz77_buffer, + typename sliding_buffer, + typename fce_length, + typename fcd_length, + typename fce_index, + typename fcd_index, + typename crc32 + > + void compress_stream_kernel_2:: + compress ( + std::istream& in_, + std::ostream& out_ + ) const + { + std::streambuf::int_type temp; + + std::streambuf& in = *in_.rdbuf(); + + typename fce::entropy_encoder_type coder; + coder.set_stream(out_); + + fce model(coder); + fce_length model_length(coder); + fce_index model_index(coder); + + const unsigned long LOOKAHEAD_LIMIT = 512; + lz77_buffer buffer(15,LOOKAHEAD_LIMIT); + + crc32 crc; + + + unsigned long count = 0; + + unsigned long lz77_count = 1; // number of times we used lz77 to encode + unsigned long ppm_count = 1; // number of times we used ppm to encode + + + while (true) + { + // write out a known value every 20000 symbols + if (count == 20000) + { + count = 0; + coder.encode(150,151,400); + } + ++count; + + // try to fill the lookahead buffer + if (buffer.get_lookahead_buffer_size() < buffer.get_lookahead_buffer_limit()) + { + temp = in.sbumpc(); + while (temp != EOF) + { + crc.add(static_cast(temp)); + buffer.add(static_cast(temp)); + if (buffer.get_lookahead_buffer_size() == buffer.get_lookahead_buffer_limit()) + break; + temp = in.sbumpc(); + } + } + + // compute the sum of ppm_count and lz77_count but make sure + // it is less than 65536 + unsigned long sum = ppm_count + lz77_count; + if (sum >= 65536) + { + ppm_count >>= 1; + lz77_count >>= 1; + ppm_count |= 1; + lz77_count |= 1; + sum = ppm_count+lz77_count; + } + + // if there are still more symbols in the lookahead buffer to encode + if (buffer.get_lookahead_buffer_size() > 0) + { + unsigned long match_index, match_length; + buffer.find_match(match_index,match_length,6); + if (match_length != 0) + { + + // signal the decoder that we are using lz77 + coder.encode(0,lz77_count,sum); + ++lz77_count; + + // encode the index and length pair + model_index.encode(match_index); + model_length.encode(match_length); + + } + else + { + + // signal the decoder that we are using ppm + coder.encode(lz77_count,sum,sum); + ++ppm_count; + + // encode the symbol using the ppm model + model.encode(buffer.lookahead_buffer(0)); + buffer.shift_buffers(1); + } + } + else + { + // signal the decoder that we are using ppm + coder.encode(lz77_count,sum,sum); + + + model.encode(eof_symbol); + // now write the checksum + unsigned long checksum = crc.get_checksum(); + unsigned char byte1 = static_cast((checksum>>24)&0xFF); + unsigned char byte2 = static_cast((checksum>>16)&0xFF); + unsigned char byte3 = static_cast((checksum>>8)&0xFF); + unsigned char byte4 = static_cast((checksum)&0xFF); + + model.encode(byte1); + model.encode(byte2); + model.encode(byte3); + model.encode(byte4); + + break; + } + } // while (true) + } + +// ---------------------------------------------------------------------------------------- + + template < + typename fce, + typename fcd, + typename lz77_buffer, + typename sliding_buffer, + typename fce_length, + typename fcd_length, + typename fce_index, + typename fcd_index, + typename crc32 + > + void compress_stream_kernel_2:: + decompress ( + std::istream& in_, + std::ostream& out_ + ) const + { + + std::streambuf& out = *out_.rdbuf(); + + typename fcd::entropy_decoder_type coder; + coder.set_stream(in_); + + fcd model(coder); + fcd_length model_length(coder); + fcd_index model_index(coder); + + unsigned long symbol; + unsigned long count = 0; + + sliding_buffer buffer; + buffer.set_size(15); + + // Initialize the buffer to all zeros. There is no algorithmic reason to + // do this. But doing so avoids a warning from valgrind so that is why + // I'm doing this. + for (unsigned long i = 0; i < buffer.size(); ++i) + buffer[i] = 0; + + crc32 crc; + + unsigned long lz77_count = 1; // number of times we used lz77 to encode + unsigned long ppm_count = 1; // number of times we used ppm to encode + bool next_block_lz77; + + + // decode until we hit the marker symbol + while (true) + { + // make sure this is the value we expect + if (count == 20000) + { + if (coder.get_target(400) != 150) + { + throw decompression_error("Error detected in compressed data stream."); + } + count = 0; + coder.decode(150,151); + } + ++count; + + + // compute the sum of ppm_count and lz77_count but make sure + // it is less than 65536 + unsigned long sum = ppm_count + lz77_count; + if (sum >= 65536) + { + ppm_count >>= 1; + lz77_count >>= 1; + ppm_count |= 1; + lz77_count |= 1; + sum = ppm_count+lz77_count; + } + + // check if we are decoding a lz77 or ppm block + if (coder.get_target(sum) < lz77_count) + { + coder.decode(0,lz77_count); + next_block_lz77 = true; + ++lz77_count; + } + else + { + coder.decode(lz77_count,sum); + next_block_lz77 = false; + ++ppm_count; + } + + + if (next_block_lz77) + { + + unsigned long match_length, match_index; + // decode the match index + model_index.decode(match_index); + + // decode the match length + model_length.decode(match_length); + + + match_index += match_length; + buffer.rotate_left(match_length); + for (unsigned long i = 0; i < match_length; ++i) + { + unsigned char ch = buffer[match_index-i]; + buffer[match_length-i-1] = ch; + + crc.add(ch); + // write this ch to out + if (out.sputc(static_cast(ch)) != static_cast(ch)) + { + throw std::ios::failure("error occurred in compress_stream_kernel_2::decompress"); + } + } + + } + else + { + + // decode the next symbol + model.decode(symbol); + if (symbol != eof_symbol) + { + buffer.rotate_left(1); + buffer[0] = static_cast(symbol); + + + crc.add(static_cast(symbol)); + // write this symbol to out + if (out.sputc(static_cast(symbol)) != static_cast(symbol)) + { + throw std::ios::failure("error occurred in compress_stream_kernel_2::decompress"); + } + } + else + { + // this was the eof marker symbol so we are done. now check the checksum + + // now get the checksum and make sure it matches + unsigned char byte1; + unsigned char byte2; + unsigned char byte3; + unsigned char byte4; + + model.decode(symbol); byte1 = static_cast(symbol); + model.decode(symbol); byte2 = static_cast(symbol); + model.decode(symbol); byte3 = static_cast(symbol); + model.decode(symbol); byte4 = static_cast(symbol); + + unsigned long checksum = byte1; + checksum <<= 8; + checksum |= byte2; + checksum <<= 8; + checksum |= byte3; + checksum <<= 8; + checksum |= byte4; + + if (checksum != crc.get_checksum()) + throw decompression_error("Error detected in compressed data stream."); + + break; + } + } + + } // while (true) + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_COMPRESS_STREAM_KERNEl_2_ + diff --git a/dlib/compress_stream/compress_stream_kernel_3.h b/dlib/compress_stream/compress_stream_kernel_3.h new file mode 100644 index 0000000000000000000000000000000000000000..ed4eee290d08a28036ca9a6a49ef1bd6454d7eb5 --- /dev/null +++ b/dlib/compress_stream/compress_stream_kernel_3.h @@ -0,0 +1,381 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_COMPRESS_STREAM_KERNEl_3_ +#define DLIB_COMPRESS_STREAM_KERNEl_3_ + +#include "../algs.h" +#include "compress_stream_kernel_abstract.h" +#include "../assert.h" + +namespace dlib +{ + + template < + typename lzp_buf, + typename crc32, + unsigned long buffer_size + > + class compress_stream_kernel_3 + { + /*! + REQUIREMENTS ON lzp_buf + is an implementation of lzp_buffer/lzp_buffer_kernel_abstract.h + + REQUIREMENTS ON buffer_size + 10 < buffer_size < 32 + + REQUIREMENTS ON crc32 + is an implementation of crc32/crc32_kernel_abstract.h + + + INITIAL VALUE + this object has no state + + CONVENTION + this object has no state + + + This implementation uses the lzp_buffer and writes out matches + in a byte aligned format. + + !*/ + + + public: + + class decompression_error : public dlib::error + { + public: + decompression_error( + const char* i + ) : + dlib::error(std::string(i)) + {} + + decompression_error( + const std::string& i + ) : + dlib::error(i) + {} + }; + + + compress_stream_kernel_3 ( + ) + { + COMPILE_TIME_ASSERT(10 < buffer_size && buffer_size < 32); + } + + ~compress_stream_kernel_3 ( + ) + {} + + void compress ( + std::istream& in, + std::ostream& out + ) const; + + void decompress ( + std::istream& in, + std::ostream& out + ) const; + + + + private: + + inline void write ( + unsigned char symbol + ) const + { + if (out->sputn(reinterpret_cast(&symbol),1)==0) + throw std::ios_base::failure("error writing to output stream in compress_stream_kernel_3"); + } + + inline void decode ( + unsigned char& symbol, + unsigned char& flag + ) const + { + if (count == 0) + { + if (((size_t)in->sgetn(reinterpret_cast(buffer),sizeof(buffer)))!=sizeof(buffer)) + throw decompression_error("Error detected in compressed data stream."); + count = 8; + } + --count; + symbol = buffer[8-count]; + flag = buffer[0] >> 7; + buffer[0] <<= 1; + } + + inline void encode ( + unsigned char symbol, + unsigned char flag + ) const + /*! + requires + - 0 <= flag <= 1 + ensures + - writes symbol with the given one bit flag + !*/ + { + // add this symbol and flag to the buffer + ++count; + buffer[0] <<= 1; + buffer[count] = symbol; + buffer[0] |= flag; + + if (count == 8) + { + if (((size_t)out->sputn(reinterpret_cast(buffer),sizeof(buffer)))!=sizeof(buffer)) + throw std::ios_base::failure("error writing to output stream in compress_stream_kernel_3"); + count = 0; + buffer[0] = 0; + } + } + + void clear ( + ) const + /*! + ensures + - resets the buffers + !*/ + { + count = 0; + } + + void flush ( + ) const + /*! + ensures + - flushes any data in the buffers to out + !*/ + { + if (count != 0) + { + buffer[0] <<= (8-count); + if (((size_t)out->sputn(reinterpret_cast(buffer),sizeof(buffer)))!=sizeof(buffer)) + throw std::ios_base::failure("error writing to output stream in compress_stream_kernel_3"); + } + } + + mutable unsigned int count; + // count tells us how many bytes are buffered in buffer and how many flag + // bit are currently in buffer[0] + mutable unsigned char buffer[9]; + // buffer[0] holds the flag bits to be writen. + // the rest of the buffer holds the bytes to be writen. + + mutable std::streambuf* in; + mutable std::streambuf* out; + + // restricted functions + compress_stream_kernel_3(compress_stream_kernel_3&); // copy constructor + compress_stream_kernel_3& operator=(compress_stream_kernel_3&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename lzp_buf, + typename crc32, + unsigned long buffer_size + > + void compress_stream_kernel_3:: + compress ( + std::istream& in_, + std::ostream& out_ + ) const + { + in = in_.rdbuf(); + out = out_.rdbuf(); + clear(); + + crc32 crc; + + lzp_buf buffer(buffer_size); + + std::streambuf::int_type temp = in->sbumpc(); + unsigned long index; + unsigned char symbol; + unsigned char length; + + while (temp != EOF) + { + symbol = static_cast(temp); + if (buffer.predict_match(index)) + { + if (buffer[index] == symbol) + { + // this is a match so we must find out how long it is + length = 1; + + buffer.add(symbol); + crc.add(symbol); + + temp = in->sbumpc(); + while (length < 255) + { + if (temp == EOF) + { + break; + } + else if (static_cast(length) >= index) + { + break; + } + else if (static_cast(temp) == buffer[index]) + { + ++length; + buffer.add(static_cast(temp)); + crc.add(static_cast(temp)); + temp = in->sbumpc(); + } + else + { + break; + } + } + + encode(length,1); + } + else + { + // this is also not a match + encode(symbol,0); + buffer.add(symbol); + crc.add(symbol); + + // get the next symbol + temp = in->sbumpc(); + } + } + else + { + // there wasn't a match so just write this symbol + encode(symbol,0); + buffer.add(symbol); + crc.add(symbol); + + // get the next symbol + temp = in->sbumpc(); + } + } + + // use a match of zero length to indicate EOF + encode(0,1); + + // now write the checksum + unsigned long checksum = crc.get_checksum(); + unsigned char byte1 = static_cast((checksum>>24)&0xFF); + unsigned char byte2 = static_cast((checksum>>16)&0xFF); + unsigned char byte3 = static_cast((checksum>>8)&0xFF); + unsigned char byte4 = static_cast((checksum)&0xFF); + + encode(byte1,0); + encode(byte2,0); + encode(byte3,0); + encode(byte4,0); + + flush(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename lzp_buf, + typename crc32, + unsigned long buffer_size + > + void compress_stream_kernel_3:: + decompress ( + std::istream& in_, + std::ostream& out_ + ) const + { + in = in_.rdbuf(); + out = out_.rdbuf(); + clear(); + + crc32 crc; + + lzp_buf buffer(buffer_size); + + + unsigned long index = 0; + unsigned char symbol; + unsigned char length; + unsigned char flag; + + decode(symbol,flag); + while (flag == 0 || symbol != 0) + { + buffer.predict_match(index); + + if (flag == 1) + { + length = symbol; + do + { + --length; + symbol = buffer[index]; + write(symbol); + buffer.add(symbol); + crc.add(symbol); + } while (length != 0); + } + else + { + // this is just a literal + write(symbol); + buffer.add(symbol); + crc.add(symbol); + } + decode(symbol,flag); + } + + + // now get the checksum and make sure it matches + unsigned char byte1; + unsigned char byte2; + unsigned char byte3; + unsigned char byte4; + + decode(byte1,flag); + if (flag != 0) + throw decompression_error("Error detected in compressed data stream."); + decode(byte2,flag); + if (flag != 0) + throw decompression_error("Error detected in compressed data stream."); + decode(byte3,flag); + if (flag != 0) + throw decompression_error("Error detected in compressed data stream."); + decode(byte4,flag); + if (flag != 0) + throw decompression_error("Error detected in compressed data stream."); + + unsigned long checksum = byte1; + checksum <<= 8; + checksum |= byte2; + checksum <<= 8; + checksum |= byte3; + checksum <<= 8; + checksum |= byte4; + + if (checksum != crc.get_checksum()) + throw decompression_error("Error detected in compressed data stream."); + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_COMPRESS_STREAM_KERNEl_3_ + diff --git a/dlib/compress_stream/compress_stream_kernel_abstract.h b/dlib/compress_stream/compress_stream_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..48f46d9e11afac0120c4c56274905f032a199865 --- /dev/null +++ b/dlib/compress_stream/compress_stream_kernel_abstract.h @@ -0,0 +1,94 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_COMPRESS_STREAM_KERNEl_ABSTRACT_ +#ifdef DLIB_COMPRESS_STREAM_KERNEl_ABSTRACT_ + +#include "../algs.h" +#include + +namespace dlib +{ + + class compress_stream + { + /*! + INITIAL VALUE + This object does not have any state associated with it. + + WHAT THIS OBJECT REPRESENTS + This object consists of the two functions compress and decompress. + These functions allow you to compress and decompress data. + !*/ + + public: + + class decompression_error : public dlib::error {}; + + compress_stream ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~compress_stream ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + + void compress ( + std::istream& in, + std::ostream& out + ) const; + /*! + ensures + - reads all data from in (until EOF is reached) and compresses it + and writes it to out + throws + - std::ios_base::failure + if there was a problem writing to out then this exception will + be thrown. + - any other exception + this exception may be thrown if there is any other problem + !*/ + + + void decompress ( + std::istream& in, + std::ostream& out + ) const; + /*! + ensures + - reads data from in, decompresses it and writes it to out. note that + it stops reading data from in when it encounters the end of the + compressed data, not when it encounters EOF. + throws + - std::ios_base::failure + if there was a problem writing to out then this exception will + be thrown. + - decompression_error + if an error was detected in the compressed data that prevented + it from being correctly decompressed then this exception is + thrown. + - any other exception + this exception may be thrown if there is any other problem + !*/ + + + private: + + // restricted functions + compress_stream(compress_stream&); // copy constructor + compress_stream& operator=(compress_stream&); // assignment operator + + }; + +} + +#endif // DLIB_COMPRESS_STREAM_KERNEl_ABSTRACT_ + diff --git a/dlib/conditioning_class.h b/dlib/conditioning_class.h new file mode 100644 index 0000000000000000000000000000000000000000..409b9871602ef2942dce9e7c88a5c81db789e9a6 --- /dev/null +++ b/dlib/conditioning_class.h @@ -0,0 +1,80 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASs_ +#define DLIB_CONDITIONING_CLASs_ + +#include "conditioning_class/conditioning_class_kernel_1.h" +#include "conditioning_class/conditioning_class_kernel_2.h" +#include "conditioning_class/conditioning_class_kernel_3.h" +#include "conditioning_class/conditioning_class_kernel_4.h" +#include "conditioning_class/conditioning_class_kernel_c.h" + + +#include "memory_manager.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size + > + class conditioning_class + { + conditioning_class() {} + + typedef memory_manager::kernel_2b mm; + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef conditioning_class_kernel_1 + kernel_1a; + typedef conditioning_class_kernel_c + kernel_1a_c; + + // kernel_2a + typedef conditioning_class_kernel_2 + kernel_2a; + typedef conditioning_class_kernel_c + kernel_2a_c; + + // kernel_3a + typedef conditioning_class_kernel_3 + kernel_3a; + typedef conditioning_class_kernel_c + kernel_3a_c; + + + // -------- kernel_4 --------- + + // kernel_4a + typedef conditioning_class_kernel_4 + kernel_4a; + typedef conditioning_class_kernel_c + kernel_4a_c; + + // kernel_4b + typedef conditioning_class_kernel_4 + kernel_4b; + typedef conditioning_class_kernel_c + kernel_4b_c; + + // kernel_4c + typedef conditioning_class_kernel_4 + kernel_4c; + typedef conditioning_class_kernel_c + kernel_4c_c; + + // kernel_4d + typedef conditioning_class_kernel_4 + kernel_4d; + typedef conditioning_class_kernel_c + kernel_4d_c; + + }; +} + +#endif // DLIB_CONDITIONING_CLASS_ + diff --git a/dlib/conditioning_class/conditioning_class_kernel_1.h b/dlib/conditioning_class/conditioning_class_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..d26d80244aa76acbe7f25bf6d52c8eaf53fdc047 --- /dev/null +++ b/dlib/conditioning_class/conditioning_class_kernel_1.h @@ -0,0 +1,333 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASS_KERNEl_1_ +#define DLIB_CONDITIONING_CLASS_KERNEl_1_ + +#include "conditioning_class_kernel_abstract.h" +#include "../assert.h" +#include "../algs.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size + > + class conditioning_class_kernel_1 + { + /*! + INITIAL VALUE + total == 1 + counts == pointer to an array of alphabet_size unsigned shorts + for all i except i == alphabet_size-1: counts[i] == 0 + counts[alphabet_size-1] == 1 + + CONVENTION + counts == pointer to an array of alphabet_size unsigned shorts + get_total() == total + get_count(symbol) == counts[symbol] + + LOW_COUNT(symbol) == sum of counts[0] though counts[symbol-1] + or 0 if symbol == 0 + + get_memory_usage() == global_state.memory_usage + !*/ + + public: + + class global_state_type + { + public: + global_state_type () : memory_usage(0) {} + private: + unsigned long memory_usage; + + friend class conditioning_class_kernel_1; + }; + + conditioning_class_kernel_1 ( + global_state_type& global_state_ + ); + + ~conditioning_class_kernel_1 ( + ); + + void clear( + ); + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + + unsigned long get_count ( + unsigned long symbol + ) const; + + unsigned long get_total ( + ) const; + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + + unsigned long get_memory_usage ( + ) const; + + global_state_type& get_global_state ( + ); + + static unsigned long get_alphabet_size ( + ); + + + private: + + // restricted functions + conditioning_class_kernel_1(conditioning_class_kernel_1&); // copy constructor + conditioning_class_kernel_1& operator=(conditioning_class_kernel_1&); // assignment operator + + // data members + unsigned short total; + unsigned short* counts; + global_state_type& global_state; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_1:: + conditioning_class_kernel_1 ( + global_state_type& global_state_ + ) : + total(1), + counts(new unsigned short[alphabet_size]), + global_state(global_state_) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); + + unsigned short* start = counts; + unsigned short* end = counts+alphabet_size-1; + while (start != end) + { + *start = 0; + ++start; + } + *start = 1; + + // update memory usage + global_state.memory_usage += sizeof(unsigned short)*alphabet_size + + sizeof(conditioning_class_kernel_1); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_1:: + ~conditioning_class_kernel_1 ( + ) + { + delete [] counts; + // update memory usage + global_state.memory_usage -= sizeof(unsigned short)*alphabet_size + + sizeof(conditioning_class_kernel_1); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_1:: + clear( + ) + { + total = 1; + unsigned short* start = counts; + unsigned short* end = counts+alphabet_size-1; + while (start != end) + { + *start = 0; + ++start; + } + *start = 1; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_1:: + get_memory_usage( + ) const + { + return global_state.memory_usage; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + typename conditioning_class_kernel_1::global_state_type& conditioning_class_kernel_1:: + get_global_state( + ) + { + return global_state; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + bool conditioning_class_kernel_1:: + increment_count ( + unsigned long symbol, + unsigned short amount + ) + { + // if we are going over a total of 65535 then scale down all counts by 2 + if (static_cast(total)+static_cast(amount) >= 65536) + { + total = 0; + unsigned short* start = counts; + unsigned short* end = counts+alphabet_size; + while (start != end) + { + *start >>= 1; + total += *start; + ++start; + } + // make sure it is at least one + if (counts[alphabet_size-1]==0) + { + ++total; + counts[alphabet_size-1] = 1; + } + } + counts[symbol] += amount; + total += amount; + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_1:: + get_count ( + unsigned long symbol + ) const + { + return counts[symbol]; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_1:: + get_alphabet_size ( + ) + { + return alphabet_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_1:: + get_total ( + ) const + { + return total; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_1:: + get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const + { + if (counts[symbol] == 0) + return 0; + + total_count = total; + + const unsigned short* start = counts; + const unsigned short* end = counts+symbol; + unsigned short high_count_temp = *start; + while (start != end) + { + ++start; + high_count_temp += *start; + } + low_count = high_count_temp - *start; + high_count = high_count_temp; + return *start; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_1:: + get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const + { + unsigned long high_count_temp = *counts; + const unsigned short* start = counts; + while (target >= high_count_temp) + { + ++start; + high_count_temp += *start; + } + + low_count = high_count_temp - *start; + high_count = high_count_temp; + symbol = static_cast(start-counts); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_1_ + diff --git a/dlib/conditioning_class/conditioning_class_kernel_2.h b/dlib/conditioning_class/conditioning_class_kernel_2.h new file mode 100644 index 0000000000000000000000000000000000000000..c9b38c8e3b95e6b98393ad511a92bb7a62423bc4 --- /dev/null +++ b/dlib/conditioning_class/conditioning_class_kernel_2.h @@ -0,0 +1,500 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASS_KERNEl_2_ +#define DLIB_CONDITIONING_CLASS_KERNEl_2_ + +#include "conditioning_class_kernel_abstract.h" +#include "../assert.h" +#include "../algs.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size + > + class conditioning_class_kernel_2 + { + /*! + INITIAL VALUE + total == 1 + symbols == pointer to array of alphabet_size data structs + for all i except i == alphabet_size-1: symbols[i].count == 0 + symbols[i].left_count == 0 + + symbols[alphabet_size-1].count == 1 + symbols[alpahbet_size-1].left_count == 0 + + CONVENTION + symbols == pointer to array of alphabet_size data structs + get_total() == total + get_count(symbol) == symbols[symbol].count + + symbols is organized as a tree with symbols[0] as the root. + + the left subchild of symbols[i] is symbols[i*2+1] and + the right subchild is symbols[i*2+2]. + the partent of symbols[i] == symbols[(i-1)/2] + + symbols[i].left_count == the sum of the counts of all the + symbols to the left of symbols[i] + + get_memory_usage() == global_state.memory_usage + !*/ + + public: + + class global_state_type + { + public: + global_state_type () : memory_usage(0) {} + private: + unsigned long memory_usage; + + friend class conditioning_class_kernel_2; + }; + + conditioning_class_kernel_2 ( + global_state_type& global_state_ + ); + + ~conditioning_class_kernel_2 ( + ); + + void clear( + ); + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + + unsigned long get_count ( + unsigned long symbol + ) const; + + inline unsigned long get_total ( + ) const; + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + + unsigned long get_memory_usage ( + ) const; + + global_state_type& get_global_state ( + ); + + static unsigned long get_alphabet_size ( + ); + + private: + + // restricted functions + conditioning_class_kernel_2(conditioning_class_kernel_2&); // copy constructor + conditioning_class_kernel_2& operator=(conditioning_class_kernel_2&); // assignment operator + + // data members + unsigned short total; + struct data + { + unsigned short count; + unsigned short left_count; + }; + + data* symbols; + global_state_type& global_state; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_2:: + conditioning_class_kernel_2 ( + global_state_type& global_state_ + ) : + total(1), + symbols(new data[alphabet_size]), + global_state(global_state_) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); + + data* start = symbols; + data* end = symbols + alphabet_size-1; + + while (start != end) + { + start->count = 0; + start->left_count = 0; + ++start; + } + + start->count = 1; + start->left_count = 0; + + + // update the left_counts for the symbol alphabet_size-1 + unsigned short temp; + unsigned long symbol = alphabet_size-1; + while (symbol != 0) + { + // temp will be 1 if symbol is odd, 0 if it is even + temp = static_cast(symbol&0x1); + + // set symbol to its parent + symbol = (symbol-1)>>1; + + // note that all left subchidren are odd and also that + // if symbol was a left subchild then we want to increment + // its parents left_count + if (temp) + ++symbols[symbol].left_count; + } + + global_state.memory_usage += sizeof(data)*alphabet_size + + sizeof(conditioning_class_kernel_2); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_2:: + ~conditioning_class_kernel_2 ( + ) + { + delete [] symbols; + global_state.memory_usage -= sizeof(data)*alphabet_size + + sizeof(conditioning_class_kernel_2); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_2:: + clear( + ) + { + data* start = symbols; + data* end = symbols + alphabet_size-1; + + total = 1; + + while (start != end) + { + start->count = 0; + start->left_count = 0; + ++start; + } + + start->count = 1; + start->left_count = 0; + + // update the left_counts + unsigned short temp; + unsigned long symbol = alphabet_size-1; + while (symbol != 0) + { + // temp will be 1 if symbol is odd, 0 if it is even + temp = static_cast(symbol&0x1); + + // set symbol to its parent + symbol = (symbol-1)>>1; + + // note that all left subchidren are odd and also that + // if symbol was a left subchild then we want to increment + // its parents left_count + symbols[symbol].left_count += temp; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_2:: + get_memory_usage( + ) const + { + return global_state.memory_usage; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + typename conditioning_class_kernel_2::global_state_type& conditioning_class_kernel_2:: + get_global_state( + ) + { + return global_state; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + bool conditioning_class_kernel_2:: + increment_count ( + unsigned long symbol, + unsigned short amount + ) + { + // if we need to renormalize then do so + if (static_cast(total)+static_cast(amount) >= 65536) + { + unsigned long s; + unsigned short temp; + for (unsigned short i = 0; i < alphabet_size-1; ++i) + { + s = i; + + // divide the count for this symbol by 2 + symbols[i].count >>= 1; + + symbols[i].left_count = 0; + + // bubble this change up though the tree + while (s != 0) + { + // temp will be 1 if symbol is odd, 0 if it is even + temp = static_cast(s&0x1); + + // set s to its parent + s = (s-1)>>1; + + // note that all left subchidren are odd and also that + // if s was a left subchild then we want to increment + // its parents left_count + if (temp) + symbols[s].left_count += symbols[i].count; + } + } + + // update symbols alphabet_size-1 + { + s = alphabet_size-1; + + // divide alphabet_size-1 symbol by 2 if it's > 1 + if (symbols[alphabet_size-1].count > 1) + symbols[alphabet_size-1].count >>= 1; + + // bubble this change up though the tree + while (s != 0) + { + // temp will be 1 if symbol is odd, 0 if it is even + temp = static_cast(s&0x1); + + // set s to its parent + s = (s-1)>>1; + + // note that all left subchidren are odd and also that + // if s was a left subchild then we want to increment + // its parents left_count + if (temp) + symbols[s].left_count += symbols[alphabet_size-1].count; + } + } + + + + + + + // calculate the new total + total = 0; + unsigned long m = 0; + while (m < alphabet_size) + { + total += symbols[m].count + symbols[m].left_count; + m = (m<<1) + 2; + } + + } + + + + + // increment the count for the specified symbol + symbols[symbol].count += amount;; + total += amount; + + + unsigned short temp; + while (symbol != 0) + { + // temp will be 1 if symbol is odd, 0 if it is even + temp = static_cast(symbol&0x1); + + // set symbol to its parent + symbol = (symbol-1)>>1; + + // note that all left subchidren are odd and also that + // if symbol was a left subchild then we want to increment + // its parents left_count + if (temp) + symbols[symbol].left_count += amount; + } + + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_2:: + get_count ( + unsigned long symbol + ) const + { + return symbols[symbol].count; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_2:: + get_alphabet_size ( + ) + { + return alphabet_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_2:: + get_total ( + ) const + { + return total; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_2:: + get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const + { + if (symbols[symbol].count == 0) + return 0; + + unsigned long current = symbol; + total_count = total; + unsigned long high_count_temp = 0; + bool came_from_right = true; + while (true) + { + if (came_from_right) + { + high_count_temp += symbols[current].count + symbols[current].left_count; + } + + // note that if current is even then it is a right child + came_from_right = !(current&0x1); + + if (current == 0) + break; + + // set current to its parent + current = (current-1)>>1 ; + } + + + low_count = high_count_temp - symbols[symbol].count; + high_count = high_count_temp; + + return symbols[symbol].count; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_2:: + get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const + { + unsigned long current = 0; + unsigned long low_count_temp = 0; + + while (true) + { + if (static_cast(target) < symbols[current].left_count) + { + // we should go left + current = (current<<1) + 1; + } + else + { + target -= symbols[current].left_count; + low_count_temp += symbols[current].left_count; + if (static_cast(target) < symbols[current].count) + { + // we have found our target + symbol = current; + high_count = low_count_temp + symbols[current].count; + low_count = low_count_temp; + break; + } + else + { + // go right + target -= symbols[current].count; + low_count_temp += symbols[current].count; + current = (current<<1) + 2; + } + } + + } + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_1_ + diff --git a/dlib/conditioning_class/conditioning_class_kernel_3.h b/dlib/conditioning_class/conditioning_class_kernel_3.h new file mode 100644 index 0000000000000000000000000000000000000000..b6de485558b47bb65f7b32e66bbeabfbbd5fae1d --- /dev/null +++ b/dlib/conditioning_class/conditioning_class_kernel_3.h @@ -0,0 +1,438 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASS_KERNEl_3_ +#define DLIB_CONDITIONING_CLASS_KERNEl_3_ + +#include "conditioning_class_kernel_abstract.h" +#include "../assert.h" +#include "../algs.h" + + +namespace dlib +{ + + template < + unsigned long alphabet_size + > + class conditioning_class_kernel_3 + { + /*! + INITIAL VALUE + total == 1 + counts == pointer to an array of alphabet_size data structs + for all i except i == 0: counts[i].count == 0 + counts[0].count == 1 + counts[0].symbol == alphabet_size-1 + for all i except i == alphabet_size-1: counts[i].present == false + counts[alphabet_size-1].present == true + + CONVENTION + counts == pointer to an array of alphabet_size data structs + get_total() == total + get_count(symbol) == counts[x].count where + counts[x].symbol == symbol + + + LOW_COUNT(symbol) == sum of counts[0].count though counts[x-1].count + where counts[x].symbol == symbol + if (counts[0].symbol == symbol) LOW_COUNT(symbol)==0 + + + if (counts[i].count == 0) then + counts[i].symbol == undefined value + + if (symbol has a nonzero count) then + counts[symbol].present == true + + get_memory_usage() == global_state.memory_usage + !*/ + + public: + + class global_state_type + { + public: + global_state_type () : memory_usage(0) {} + private: + unsigned long memory_usage; + + friend class conditioning_class_kernel_3; + }; + + conditioning_class_kernel_3 ( + global_state_type& global_state_ + ); + + ~conditioning_class_kernel_3 ( + ); + + void clear( + ); + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + + unsigned long get_count ( + unsigned long symbol + ) const; + + unsigned long get_total ( + ) const; + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + + unsigned long get_memory_usage ( + ) const; + + global_state_type& get_global_state ( + ); + + static unsigned long get_alphabet_size ( + ); + + private: + + // restricted functions + conditioning_class_kernel_3(conditioning_class_kernel_3&); // copy constructor + conditioning_class_kernel_3& operator=(conditioning_class_kernel_3&); // assignment operator + + struct data + { + unsigned short count; + unsigned short symbol; + bool present; + }; + + // data members + unsigned short total; + data* counts; + global_state_type& global_state; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_3:: + conditioning_class_kernel_3 ( + global_state_type& global_state_ + ) : + total(1), + counts(new data[alphabet_size]), + global_state(global_state_) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); + + data* start = counts; + data* end = counts+alphabet_size; + start->count = 1; + start->symbol = alphabet_size-1; + start->present = false; + ++start; + while (start != end) + { + start->count = 0; + start->present = false; + ++start; + } + counts[alphabet_size-1].present = true; + + // update memory usage + global_state.memory_usage += sizeof(data)*alphabet_size + + sizeof(conditioning_class_kernel_3); + + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_3:: + ~conditioning_class_kernel_3 ( + ) + { + delete [] counts; + // update memory usage + global_state.memory_usage -= sizeof(data)*alphabet_size + + sizeof(conditioning_class_kernel_3); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_3:: + clear( + ) + { + total = 1; + data* start = counts; + data* end = counts+alphabet_size; + start->count = 1; + start->symbol = alphabet_size-1; + start->present = false; + ++start; + while (start != end) + { + start->count = 0; + start->present = false; + ++start; + } + counts[alphabet_size-1].present = true; + + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + typename conditioning_class_kernel_3::global_state_type& conditioning_class_kernel_3:: + get_global_state( + ) + { + return global_state; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_3:: + get_memory_usage( + ) const + { + return global_state.memory_usage; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + bool conditioning_class_kernel_3:: + increment_count ( + unsigned long symbol, + unsigned short amount + ) + { + // if we are going over a total of 65535 then scale down all counts by 2 + if (static_cast(total)+static_cast(amount) >= 65536) + { + total = 0; + data* start = counts; + data* end = counts+alphabet_size; + + while (start != end) + { + if (start->count == 1) + { + if (start->symbol == alphabet_size-1) + { + // this symbol must never be zero so we will leave its count at 1 + ++total; + } + else + { + start->count = 0; + counts[start->symbol].present = false; + } + } + else + { + start->count >>= 1; + total += start->count; + } + + ++start; + } + } + + + data* start = counts; + data* swap_spot = counts; + + if (counts[symbol].present) + { + while (true) + { + if (start->symbol == symbol && start->count!=0) + { + unsigned short temp = start->count + amount; + + start->symbol = swap_spot->symbol; + start->count = swap_spot->count; + + swap_spot->symbol = static_cast(symbol); + swap_spot->count = temp; + break; + } + + if ( (start->count) < (swap_spot->count)) + { + swap_spot = start; + } + + + ++start; + } + } + else + { + counts[symbol].present = true; + while (true) + { + if (start->count == 0) + { + start->symbol = swap_spot->symbol; + start->count = swap_spot->count; + + swap_spot->symbol = static_cast(symbol); + swap_spot->count = amount; + break; + } + + if ((start->count) < (swap_spot->count)) + { + swap_spot = start; + } + + ++start; + } + } + + total += amount; + + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_3:: + get_count ( + unsigned long symbol + ) const + { + if (counts[symbol].present == false) + return 0; + + data* start = counts; + while (start->symbol != symbol) + { + ++start; + } + return start->count; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_3:: + get_alphabet_size ( + ) + { + return alphabet_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_3:: + get_total ( + ) const + { + return total; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_3:: + get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const + { + if (counts[symbol].present == false) + return 0; + + total_count = total; + unsigned long low_count_temp = 0; + data* start = counts; + while (start->symbol != symbol) + { + low_count_temp += start->count; + ++start; + } + + low_count = low_count_temp; + high_count = low_count_temp + start->count; + return start->count; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_3:: + get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const + { + unsigned long high_count_temp = counts->count; + const data* start = counts; + while (target >= high_count_temp) + { + ++start; + high_count_temp += start->count; + } + + low_count = high_count_temp - start->count; + high_count = high_count_temp; + symbol = static_cast(start->symbol); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_3_ + diff --git a/dlib/conditioning_class/conditioning_class_kernel_4.h b/dlib/conditioning_class/conditioning_class_kernel_4.h new file mode 100644 index 0000000000000000000000000000000000000000..cb48ac196267d23e2341102d0cac173e8ba5a877 --- /dev/null +++ b/dlib/conditioning_class/conditioning_class_kernel_4.h @@ -0,0 +1,533 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASS_KERNEl_4_ +#define DLIB_CONDITIONING_CLASS_KERNEl_4_ + +#include "conditioning_class_kernel_abstract.h" +#include "../assert.h" +#include "../algs.h" + +namespace dlib +{ + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + class conditioning_class_kernel_4 + { + /*! + REQUIREMENTS ON pool_size + pool_size > 0 + this will be the number of nodes contained in our memory pool + + REQUIREMENTS ON mem_manager + mem_manager is an implementation of memory_manager/memory_manager_kernel_abstract.h + + INITIAL VALUE + total == 1 + escapes == 1 + next == 0 + + CONVENTION + get_total() == total + get_count(alphabet_size-1) == escapes + + if (next != 0) then + next == pointer to the start of a linked list and the linked list + is terminated by a node with a next pointer of 0. + + get_count(symbol) == node::count for the node where node::symbol==symbol + or 0 if no such node currently exists. + + if (there is a node for the symbol) then + LOW_COUNT(symbol) == the sum of all node's counts in the linked list + up to but not including the node for the symbol. + + get_memory_usage() == global_state.memory_usage + !*/ + + + struct node + { + unsigned short symbol; + unsigned short count; + node* next; + }; + + public: + + class global_state_type + { + public: + global_state_type ( + ) : + memory_usage(pool_size*sizeof(node)+sizeof(global_state_type)) + {} + private: + unsigned long memory_usage; + + typename mem_manager::template rebind::other pool; + + friend class conditioning_class_kernel_4; + }; + + conditioning_class_kernel_4 ( + global_state_type& global_state_ + ); + + ~conditioning_class_kernel_4 ( + ); + + void clear( + ); + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + + unsigned long get_count ( + unsigned long symbol + ) const; + + inline unsigned long get_total ( + ) const; + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + + unsigned long get_memory_usage ( + ) const; + + global_state_type& get_global_state ( + ); + + static unsigned long get_alphabet_size ( + ); + + + private: + + void half_counts ( + ); + /*! + ensures + - divides all counts by 2 but ensures that escapes is always at least 1 + !*/ + + // restricted functions + conditioning_class_kernel_4(conditioning_class_kernel_4&); // copy constructor + conditioning_class_kernel_4& operator=(conditioning_class_kernel_4&); // assignment operator + + // data members + unsigned short total; + unsigned short escapes; + node* next; + global_state_type& global_state; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + conditioning_class_kernel_4:: + conditioning_class_kernel_4 ( + global_state_type& global_state_ + ) : + total(1), + escapes(1), + next(0), + global_state(global_state_) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); + + // update memory usage + global_state.memory_usage += sizeof(conditioning_class_kernel_4); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + conditioning_class_kernel_4:: + ~conditioning_class_kernel_4 ( + ) + { + clear(); + // update memory usage + global_state.memory_usage -= sizeof(conditioning_class_kernel_4); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + void conditioning_class_kernel_4:: + clear( + ) + { + total = 1; + escapes = 1; + while (next) + { + node* temp = next; + next = next->next; + global_state.pool.deallocate(temp); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + unsigned long conditioning_class_kernel_4:: + get_memory_usage( + ) const + { + return global_state.memory_usage; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + typename conditioning_class_kernel_4::global_state_type& conditioning_class_kernel_4:: + get_global_state( + ) + { + return global_state; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + bool conditioning_class_kernel_4:: + increment_count ( + unsigned long symbol, + unsigned short amount + ) + { + if (symbol == alphabet_size-1) + { + // make sure we won't cause any overflow + if (total >= 65536 - amount ) + half_counts(); + + escapes += amount; + total += amount; + return true; + } + + + // find the symbol and increment it or add a new node to the list + if (next) + { + node* temp = next; + node* previous = 0; + while (true) + { + if (temp->symbol == static_cast(symbol)) + { + // make sure we won't cause any overflow + if (total >= 65536 - amount ) + half_counts(); + + // we have found the symbol + total += amount; + temp->count += amount; + + // if this node now has a count greater than its parent node + if (previous && temp->count > previous->count) + { + // swap the nodes so that the nodes will be in semi-sorted order + swap(temp->count,previous->count); + swap(temp->symbol,previous->symbol); + } + return true; + } + else if (temp->next == 0) + { + // we did not find the symbol so try to add it to the list + if (global_state.pool.get_number_of_allocations() < pool_size) + { + // make sure we won't cause any overflow + if (total >= 65536 - amount ) + half_counts(); + + node* t = global_state.pool.allocate(); + t->next = 0; + t->symbol = static_cast(symbol); + t->count = amount; + temp->next = t; + total += amount; + return true; + } + else + { + // no memory left + return false; + } + } + else if (temp->count == 0) + { + // remove nodes that have a zero count + if (previous) + { + previous->next = temp->next; + node* t = temp; + temp = temp->next; + global_state.pool.deallocate(t); + } + else + { + next = temp->next; + node* t = temp; + temp = temp->next; + global_state.pool.deallocate(t); + } + } + else + { + previous = temp; + temp = temp->next; + } + } // while (true) + } + // if there aren't any nodes in the list yet then do this instead + else + { + if (global_state.pool.get_number_of_allocations() < pool_size) + { + // make sure we won't cause any overflow + if (total >= 65536 - amount ) + half_counts(); + + next = global_state.pool.allocate(); + next->next = 0; + next->symbol = static_cast(symbol); + next->count = amount; + total += amount; + return true; + } + else + { + // no memory left + return false; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + unsigned long conditioning_class_kernel_4:: + get_count ( + unsigned long symbol + ) const + { + if (symbol == alphabet_size-1) + { + return escapes; + } + else + { + node* temp = next; + while (temp) + { + if (temp->symbol == symbol) + return temp->count; + temp = temp->next; + } + return 0; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + unsigned long conditioning_class_kernel_4:: + get_alphabet_size ( + ) + { + return alphabet_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + unsigned long conditioning_class_kernel_4:: + get_total ( + ) const + { + return total; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + unsigned long conditioning_class_kernel_4:: + get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const + { + if (symbol != alphabet_size-1) + { + node* temp = next; + unsigned long low = 0; + while (temp) + { + if (temp->symbol == static_cast(symbol)) + { + high_count = temp->count + low; + low_count = low; + total_count = total; + return temp->count; + } + low += temp->count; + temp = temp->next; + } + return 0; + } + else + { + total_count = total; + high_count = total; + low_count = total-escapes; + return escapes; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + void conditioning_class_kernel_4:: + get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const + { + node* temp = next; + unsigned long high = 0; + while (true) + { + if (temp != 0) + { + high += temp->count; + if (target < high) + { + symbol = temp->symbol; + high_count = high; + low_count = high - temp->count; + return; + } + temp = temp->next; + } + else + { + // this must be the escape symbol + symbol = alphabet_size-1; + low_count = total-escapes; + high_count = total; + return; + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + void conditioning_class_kernel_4:: + half_counts ( + ) + { + total = 0; + if (escapes > 1) + escapes >>= 1; + + //divide all counts by 2 + node* temp = next; + while (temp) + { + temp->count >>= 1; + total += temp->count; + temp = temp->next; + } + total += escapes; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_4_ + diff --git a/dlib/conditioning_class/conditioning_class_kernel_abstract.h b/dlib/conditioning_class/conditioning_class_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..411aea56653a471dbc9226a718aa061e4078f78d --- /dev/null +++ b/dlib/conditioning_class/conditioning_class_kernel_abstract.h @@ -0,0 +1,228 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CONDITIONING_CLASS_KERNEl_ABSTRACT_ +#ifdef DLIB_CONDITIONING_CLASS_KERNEl_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size + > + class conditioning_class + { + /*! + REQUIREMENTS ON alphabet_size + 1 < alphabet_size < 65536 + + INITIAL VALUE + get_total() == 1 + get_count(X) == 0 : for all valid values of X except alphabet_size-1 + get_count(alphabet_size-1) == 1 + + WHAT THIS OBJECT REPRESENTS + This object represents a conditioning class used for arithmetic style + compression. It maintains the cumulative counts which are needed + by the entropy_coder and entropy_decoder objects. + + At any moment a conditioning_class object represents a set of + alphabet_size symbols. Each symbol is associated with an integer + called its count. + + All symbols start out with a count of zero except for alphabet_size-1. + This last symbol will always have a count of at least one. It is + intended to be used as an escape into a lower context when coding + and so it must never have a zero probability or the decoder won't + be able to identify the escape symbol. + + NOTATION: + Let MAP(i) be a function which maps integers to symbols. MAP(i) is + one to one and onto. Its domain is 1 to alphabet_size inclusive. + + Let RMAP(s) be the inverse of MAP(i). + ( i.e. RMAP(MAP(i)) == i and MAP(RMAP(s)) == s ) + + Let COUNT(i) give the count for the symbol MAP(i). + ( i.e. COUNT(i) == get_count(MAP(i)) ) + + + Let LOW_COUNT(s) == the sum of COUNT(x) for x == 1 to x == RMAP(s)-1 + (note that the sum of COUNT(x) for x == 1 to x == 0 is 0) + Let HIGH_COUNT(s) == LOW_COUNT(s) + get_count(s) + + + + Basically what this is saying is just that you shoudln't assume you know + what order the symbols are placed in when calculating the cumulative + sums. The specific mapping provided by the MAP() function is unspecified. + + THREAD SAFETY + This object can be used safely in a multithreaded program as long as the + global state is not shared between conditioning classes which run on + different threads. + + GLOBAL_STATE_TYPE + The global_state_type obejct allows instances of the conditioning_class + object to share any kind of global state the implementer desires. + However, the global_state_type object exists primarily to facilitate the + sharing of a memory pool between many instances of a conditioning_class + object. But note that it is not required that there be any kind of + memory pool at all, it is just a possibility. + !*/ + + public: + + class global_state_type + { + global_state_type ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + // my contents are implementation specific. + }; + + conditioning_class ( + global_state_type& global_state + ); + /*! + ensures + - #*this is properly initialized + - &#get_global_state() == &global_state + throws + - std::bad_alloc + !*/ + + ~conditioning_class ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + !*/ + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + /*! + requires + - 0 <= symbol < alphabet_size + - 0 < amount < 32768 + ensures + - if (sufficient memory is available to complete this operation) then + - returns true + - if (get_total()+amount < 65536) then + - #get_count(symbol) == get_count(symbol) + amount + - else + - #get_count(symbol) == get_count(symbol)/2 + amount + - if (get_count(alphabet_size-1) == 1) then + - #get_count(alphabet_size-1) == 1 + - else + - #get_count(alphabet_size-1) == get_count(alphabet_size-1)/2 + - for all X where (X != symbol)&&(X != alpahbet_size-1): + #get_count(X) == get_count(X)/2 + - else + - returns false + !*/ + + unsigned long get_count ( + unsigned long symbol + ) const; + /*! + requires + - 0 <= symbol < alphabet_size + ensures + - returns the count for the specified symbol + !*/ + + unsigned long get_total ( + ) const; + /*! + ensures + - returns the sum of get_count(X) for all valid values of X + (i.e. returns the sum of the counts for all the symbols) + !*/ + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + /*! + requires + - 0 <= symbol < alphabet_size + ensures + - returns get_count(symbol) + - if (get_count(symbol) != 0) then + - #total_count == get_total() + - #low_count == LOW_COUNT(symbol) + - #high_count == HIGH_COUNT(symbol) + - #low_count < #high_count <= #total_count + !*/ + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + /*! + requires + - 0 <= target < get_total() + ensures + - LOW_COUNT(#symbol) <= target < HIGH_COUNT(#symbol) + - #low_count == LOW_COUNT(#symbol) + - #high_count == HIGH_COUNT(#symbol) + - #low_count < #high_count <= get_total() + !*/ + + global_state_type& get_global_state ( + ); + /*! + ensures + - returns a reference to the global state used by *this + !*/ + + unsigned long get_memory_usage ( + ) const; + /*! + ensures + - returns the number of bytes of memory allocated by all conditioning_class + objects that share the global state given by get_global_state() + !*/ + + static unsigned long get_alphabet_size ( + ); + /*! + ensures + - returns alphabet_size + !*/ + + private: + + // restricted functions + conditioning_class(conditioning_class&); // copy constructor + conditioning_class& operator=(conditioning_class&); // assignment operator + + }; + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_ABSTRACT_ + diff --git a/dlib/conditioning_class/conditioning_class_kernel_c.h b/dlib/conditioning_class/conditioning_class_kernel_c.h new file mode 100644 index 0000000000000000000000000000000000000000..964240be862a3bc5b5ed769bd1cf0ce9f4288700 --- /dev/null +++ b/dlib/conditioning_class/conditioning_class_kernel_c.h @@ -0,0 +1,162 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASS_KERNEl_C_ +#define DLIB_CONDITIONING_CLASS_KERNEl_C_ + +#include "conditioning_class_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include + +namespace dlib +{ + + template < + typename cc_base + > + class conditioning_class_kernel_c : public cc_base + { + const unsigned long alphabet_size; + + public: + + conditioning_class_kernel_c ( + typename cc_base::global_state_type& global_state + ) : cc_base(global_state),alphabet_size(cc_base::get_alphabet_size()) {} + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + + unsigned long get_count ( + unsigned long symbol + ) const; + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename cc_base + > + bool conditioning_class_kernel_c:: + increment_count ( + unsigned long symbol, + unsigned short amount + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(symbol < alphabet_size && + 0 < amount && amount < 32768, + "\tvoid conditioning_class::increment_count()" + << "\n\tthe symbol must be in the range 0 to alphabet_size-1. and" + << "\n\tamount must be in the range 1 to 32767" + << "\n\talphabet_size: " << alphabet_size + << "\n\tsymbol: " << symbol + << "\n\tamount: " << amount + << "\n\tthis: " << this + ); + + // call the real function + return cc_base::increment_count(symbol,amount); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename cc_base + > + unsigned long conditioning_class_kernel_c:: + get_count ( + unsigned long symbol + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(symbol < alphabet_size, + "\tvoid conditioning_class::get_count()" + << "\n\tthe symbol must be in the range 0 to alphabet_size-1" + << "\n\talphabet_size: " << alphabet_size + << "\n\tsymbol: " << symbol + << "\n\tthis: " << this + ); + + // call the real function + return cc_base::get_count(symbol); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename cc_base + > + unsigned long conditioning_class_kernel_c:: + get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(symbol < alphabet_size, + "\tvoid conditioning_class::get_range()" + << "\n\tthe symbol must be in the range 0 to alphabet_size-1" + << "\n\talphabet_size: " << alphabet_size + << "\n\tsymbol: " << symbol + << "\n\tthis: " << this + ); + + // call the real function + return cc_base::get_range(symbol,low_count,high_count,total_count); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename cc_base + > + void conditioning_class_kernel_c:: + get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( target < this->get_total(), + "\tvoid conditioning_class::get_symbol()" + << "\n\tthe target must be in the range 0 to get_total()-1" + << "\n\tget_total(): " << this->get_total() + << "\n\ttarget: " << target + << "\n\tthis: " << this + ); + + // call the real function + cc_base::get_symbol(target,symbol,low_count,high_count); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_C_ + diff --git a/dlib/config.h b/dlib/config.h new file mode 100644 index 0000000000000000000000000000000000000000..a48ac415f6f08f6841e99e0003a3d8babaea9674 --- /dev/null +++ b/dlib/config.h @@ -0,0 +1,31 @@ + + +// If you are compiling dlib as a shared library and installing it somewhere on your system +// then it is important that any programs that use dlib agree on the state of the +// DLIB_ASSERT statements (i.e. they are either always on or always off). Therefore, +// uncomment one of the following lines to force all DLIB_ASSERTs to either always on or +// always off. If you don't define one of these two macros then DLIB_ASSERT will toggle +// automatically depending on the state of certain other macros, which is not what you want +// when creating a shared library. +//#define ENABLE_ASSERTS // asserts always enabled +//#define DLIB_DISABLE_ASSERTS // asserts always disabled + +//#define DLIB_ISO_CPP_ONLY +//#define DLIB_NO_GUI_SUPPORT +//#define DLIB_ENABLE_STACK_TRACE + +// You should also consider telling dlib to link against libjpeg, libpng, libgif, fftw, CUDA, +// and a BLAS and LAPACK library. To do this you need to uncomment the following #defines. +// #define DLIB_JPEG_SUPPORT +// #define DLIB_PNG_SUPPORT +// #define DLIB_GIF_SUPPORT +// #define DLIB_USE_FFTW +// #define DLIB_USE_BLAS +// #define DLIB_USE_LAPACK +// #define DLIB_USE_CUDA + + +// Define this so the code in dlib/test_for_odr_violations.h can detect ODR violations +// related to users doing bad things with config.h +#define DLIB_NOT_CONFIGURED + diff --git a/dlib/config.h.in b/dlib/config.h.in new file mode 100644 index 0000000000000000000000000000000000000000..abaa655c1ff62b9dc4d6369c04d55b8b415619ff --- /dev/null +++ b/dlib/config.h.in @@ -0,0 +1,36 @@ + + +// If you are compiling dlib as a shared library and installing it somewhere on your system +// then it is important that any programs that use dlib agree on the state of the +// DLIB_ASSERT statements (i.e. they are either always on or always off). Therefore, +// uncomment one of the following lines to force all DLIB_ASSERTs to either always on or +// always off. If you don't define one of these two macros then DLIB_ASSERT will toggle +// automatically depending on the state of certain other macros, which is not what you want +// when creating a shared library. +#cmakedefine ENABLE_ASSERTS // asserts always enabled +#cmakedefine DLIB_DISABLE_ASSERTS // asserts always disabled + +#cmakedefine DLIB_ISO_CPP_ONLY +#cmakedefine DLIB_NO_GUI_SUPPORT +#cmakedefine DLIB_ENABLE_STACK_TRACE + +#cmakedefine LAPACK_FORCE_UNDERSCORE +#cmakedefine LAPACK_FORCE_NOUNDERSCORE + +// You should also consider telling dlib to link against libjpeg, libpng, libgif, fftw, CUDA, +// and a BLAS and LAPACK library. To do this you need to uncomment the following #defines. +#cmakedefine DLIB_JPEG_SUPPORT +#cmakedefine DLIB_WEBP_SUPPORT +#cmakedefine DLIB_PNG_SUPPORT +#cmakedefine DLIB_GIF_SUPPORT +#cmakedefine DLIB_USE_FFTW +#cmakedefine DLIB_USE_BLAS +#cmakedefine DLIB_USE_LAPACK +#cmakedefine DLIB_USE_CUDA +#cmakedefine DLIB_USE_MKL_FFT +#cmakedefine DLIB_USE_FFMPEG + +// This variable allows dlib/test_for_odr_violations.h to catch people who mistakenly use +// headers from one version of dlib with a compiled dlib binary from a different dlib version. +#cmakedefine DLIB_CHECK_FOR_VERSION_MISMATCH @DLIB_CHECK_FOR_VERSION_MISMATCH@ + diff --git a/dlib/config_reader.h b/dlib/config_reader.h new file mode 100644 index 0000000000000000000000000000000000000000..d140a310ce2ddb9e7ace1c3a7abaf18d80ffb6e8 --- /dev/null +++ b/dlib/config_reader.h @@ -0,0 +1,39 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONFIG_READEr_ +#define DLIB_CONFIG_READEr_ + +#include "config_reader/config_reader_kernel_1.h" +#include "map.h" +#include "tokenizer.h" +#include "cmd_line_parser/get_option.h" + +#include "algs.h" +#include "is_kind.h" + + +namespace dlib +{ + + typedef config_reader_kernel_1< + map::kernel_1b, + map::kernel_1b, + tokenizer::kernel_1a + > config_reader; + + template <> struct is_config_reader { const static bool value = true; }; + +#ifndef DLIB_ISO_CPP_ONLY + typedef config_reader_thread_safe_1< + config_reader, + map::kernel_1b + > config_reader_thread_safe; + + template <> struct is_config_reader { const static bool value = true; }; +#endif // DLIB_ISO_CPP_ONLY + + +} + +#endif // DLIB_CONFIG_READEr_ + diff --git a/dlib/config_reader/config_reader_kernel_1.h b/dlib/config_reader/config_reader_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..c0f9e5a7110eb5743183642964d8849bda3d4b53 --- /dev/null +++ b/dlib/config_reader/config_reader_kernel_1.h @@ -0,0 +1,738 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONFIG_READER_KERNEl_1_ +#define DLIB_CONFIG_READER_KERNEl_1_ + +#include "config_reader_kernel_abstract.h" +#include +#include +#include +#include +#include "../algs.h" +#include "../stl_checked/std_vector_c.h" + +#ifndef DLIB_ISO_CPP_ONLY +#include "config_reader_thread_safe_1.h" +#endif + +namespace dlib +{ + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + class config_reader_kernel_1 + { + + /*! + REQUIREMENTS ON map_string_string + is an implementation of map/map_kernel_abstract.h that maps std::string to std::string + + REQUIREMENTS ON map_string_void + is an implementation of map/map_kernel_abstract.h that maps std::string to void* + + REQUIREMENTS ON tokenizer + is an implementation of tokenizer/tokenizer_kernel_abstract.h + + CONVENTION + key_table.is_in_domain(x) == is_key_defined(x) + block_table.is_in_domain(x) == is_block_defined(x) + + key_table[x] == operator[](x) + block_table[x] == (void*)&block(x) + !*/ + + public: + + // These two typedefs are defined for backwards compatibility with older versions of dlib. + typedef config_reader_kernel_1 kernel_1a; +#ifndef DLIB_ISO_CPP_ONLY + typedef config_reader_thread_safe_1< + config_reader_kernel_1, + map_string_void + > thread_safe_1a; +#endif // DLIB_ISO_CPP_ONLY + + + config_reader_kernel_1(); + + class config_reader_error : public dlib::error + { + friend class config_reader_kernel_1; + config_reader_error( + unsigned long ln, + bool r = false + ) : + dlib::error(ECONFIG_READER), + line_number(ln), + redefinition(r) + { + std::ostringstream sout; + sout << "Error in config_reader while parsing at line number " << line_number << "."; + if (redefinition) + sout << "\nThe identifier on this line has already been defined in this scope."; + const_cast(info) = sout.str(); + } + public: + const unsigned long line_number; + const bool redefinition; + }; + + class file_not_found : public dlib::error + { + friend class config_reader_kernel_1; + file_not_found( + const std::string& file_name_ + ) : + dlib::error(ECONFIG_READER, "Error in config_reader, unable to open file " + file_name_), + file_name(file_name_) + {} + + ~file_not_found() throw() {} + + public: + const std::string file_name; + }; + + class config_reader_access_error : public dlib::error + { + public: + config_reader_access_error( + const std::string& block_name_, + const std::string& key_name_ + ) : + dlib::error(ECONFIG_READER), + block_name(block_name_), + key_name(key_name_) + { + std::ostringstream sout; + sout << "Error in config_reader.\n"; + if (block_name.size() > 0) + sout << " A block with the name '" << block_name << "' was expected but not found."; + else if (key_name.size() > 0) + sout << " A key with the name '" << key_name << "' was expected but not found."; + + const_cast(info) = sout.str(); + } + + ~config_reader_access_error() throw() {} + const std::string block_name; + const std::string key_name; + }; + + config_reader_kernel_1( + const std::string& config_file + ); + + config_reader_kernel_1( + std::istream& in + ); + + virtual ~config_reader_kernel_1( + ); + + void clear ( + ); + + void load_from ( + std::istream& in + ); + + void load_from ( + const std::string& config_file + ); + + bool is_key_defined ( + const std::string& key + ) const; + + bool is_block_defined ( + const std::string& name + ) const; + + typedef config_reader_kernel_1 this_type; + const this_type& block ( + const std::string& name + ) const; + + const std::string& operator[] ( + const std::string& key + ) const; + + template < + typename queue_of_strings + > + void get_keys ( + queue_of_strings& keys + ) const; + + template < + typename alloc + > + void get_keys ( + std::vector& keys + ) const; + + template < + typename alloc + > + void get_keys ( + std_vector_c& keys + ) const; + + template < + typename queue_of_strings + > + void get_blocks ( + queue_of_strings& blocks + ) const; + + template < + typename alloc + > + void get_blocks ( + std::vector& blocks + ) const; + + template < + typename alloc + > + void get_blocks ( + std_vector_c& blocks + ) const; + + private: + + static void parse_config_file ( + config_reader_kernel_1& cr, + tokenizer& tok, + unsigned long& line_number, + const bool top_of_recursion = true + ); + /*! + requires + - line_number == 1 + - cr == *this + - top_of_recursion == true + ensures + - parses the data coming from tok and puts it into cr. + throws + - config_reader_error + !*/ + + map_string_string key_table; + map_string_void block_table; + + // restricted functions + config_reader_kernel_1(config_reader_kernel_1&); + config_reader_kernel_1& operator=(config_reader_kernel_1&); + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + config_reader_kernel_1:: + config_reader_kernel_1( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + void config_reader_kernel_1:: + clear( + ) + { + // free all our blocks + block_table.reset(); + while (block_table.move_next()) + { + delete static_cast(block_table.element().value()); + } + block_table.clear(); + key_table.clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + void config_reader_kernel_1:: + load_from( + std::istream& in + ) + { + clear(); + + tokenizer tok; + tok.set_stream(in); + tok.set_identifier_token( + tok.lowercase_letters() + tok.uppercase_letters(), + tok.lowercase_letters() + tok.uppercase_letters() + tok.numbers() + "_-." + ); + + unsigned long line_number = 1; + try + { + parse_config_file(*this,tok,line_number); + } + catch (...) + { + clear(); + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + void config_reader_kernel_1:: + load_from( + const std::string& config_file + ) + { + clear(); + std::ifstream fin(config_file.c_str()); + if (!fin) + throw file_not_found(config_file); + + load_from(fin); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + config_reader_kernel_1:: + config_reader_kernel_1( + std::istream& in + ) + { + load_from(in); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + config_reader_kernel_1:: + config_reader_kernel_1( + const std::string& config_file + ) + { + load_from(config_file); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + void config_reader_kernel_1:: + parse_config_file( + config_reader_kernel_1& cr, + tokenizer& tok, + unsigned long& line_number, + const bool top_of_recursion + ) + { + int type; + std::string token; + bool in_comment = false; + bool seen_identifier = false; + std::string identifier; + while (true) + { + tok.get_token(type,token); + // ignore white space + if (type == tokenizer::WHITE_SPACE) + continue; + + // basically ignore end of lines + if (type == tokenizer::END_OF_LINE) + { + ++line_number; + in_comment = false; + continue; + } + + // we are in a comment still so ignore this + if (in_comment) + continue; + + // if this is the start of a comment + if (type == tokenizer::CHAR && token[0] == '#') + { + in_comment = true; + continue; + } + + // if this is the case then we have just finished parsing a block so we should + // quit this function + if ( (type == tokenizer::CHAR && token[0] == '}' && !top_of_recursion) || + (type == tokenizer::END_OF_FILE && top_of_recursion) ) + { + break; + } + + if (seen_identifier) + { + seen_identifier = false; + // the next character should be either a '=' or a '{' + if (type != tokenizer::CHAR || (token[0] != '=' && token[0] != '{')) + throw config_reader_error(line_number); + + if (token[0] == '=') + { + // we should parse the value out now + // first discard any white space + if (tok.peek_type() == tokenizer::WHITE_SPACE) + tok.get_token(type,token); + + std::string value; + type = tok.peek_type(); + token = tok.peek_token(); + while (true) + { + if (type == tokenizer::END_OF_FILE || type == tokenizer::END_OF_LINE) + break; + + if (type == tokenizer::CHAR && token[0] == '\\') + { + tok.get_token(type,token); + if (tok.peek_type() == tokenizer::CHAR && + tok.peek_token()[0] == '#') + { + tok.get_token(type,token); + value += '#'; + } + else if (tok.peek_type() == tokenizer::CHAR && + tok.peek_token()[0] == '}') + { + tok.get_token(type,token); + value += '}'; + } + else + { + value += '\\'; + } + } + else if (type == tokenizer::CHAR && + (token[0] == '#' || token[0] == '}')) + { + break; + } + else + { + value += token; + tok.get_token(type,token); + } + type = tok.peek_type(); + token = tok.peek_token(); + } // while(true) + + // strip of any tailing white space from value + std::string::size_type pos = value.find_last_not_of(" \t\r\n"); + if (pos == std::string::npos) + value.clear(); + else + value.erase(pos+1); + + // make sure this key isn't already in the key_table + if (cr.key_table.is_in_domain(identifier)) + throw config_reader_error(line_number,true); + + // add this key/value pair to the key_table + cr.key_table.add(identifier,value); + + } + else // when token[0] == '{' + { + // make sure this identifier isn't already in the block_table + if (cr.block_table.is_in_domain(identifier)) + throw config_reader_error(line_number,true); + + config_reader_kernel_1* new_cr = new config_reader_kernel_1; + void* vtemp = new_cr; + try { cr.block_table.add(identifier,vtemp); } + catch (...) { delete new_cr; throw; } + + // now parse this block + parse_config_file(*new_cr,tok,line_number,false); + } + } + else + { + // the next thing should be an identifier but if it isn't this is an error + if (type != tokenizer::IDENTIFIER) + throw config_reader_error(line_number); + + seen_identifier = true; + identifier = token; + } + } // while (true) + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + config_reader_kernel_1:: + ~config_reader_kernel_1( + ) + { + clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + bool config_reader_kernel_1:: + is_key_defined ( + const std::string& key + ) const + { + return key_table.is_in_domain(key); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + bool config_reader_kernel_1:: + is_block_defined ( + const std::string& name + ) const + { + return block_table.is_in_domain(name); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename mss, + typename msv, + typename tokenizer + > + const config_reader_kernel_1& config_reader_kernel_1:: + block ( + const std::string& name + ) const + { + if (is_block_defined(name) == false) + { + throw config_reader_access_error(name,""); + } + + return *static_cast(block_table[name]); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + const std::string& config_reader_kernel_1:: + operator[] ( + const std::string& key + ) const + { + if (is_key_defined(key) == false) + { + throw config_reader_access_error("",key); + } + + return key_table[key]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename queue_of_strings + > + void config_reader_kernel_1:: + get_keys ( + queue_of_strings& keys + ) const + { + keys.clear(); + key_table.reset(); + std::string temp; + while (key_table.move_next()) + { + temp = key_table.element().key(); + keys.enqueue(temp); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename alloc + > + void config_reader_kernel_1:: + get_keys ( + std::vector& keys + ) const + { + keys.clear(); + key_table.reset(); + while (key_table.move_next()) + { + keys.push_back(key_table.element().key()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename alloc + > + void config_reader_kernel_1:: + get_keys ( + std_vector_c& keys + ) const + { + keys.clear(); + key_table.reset(); + while (key_table.move_next()) + { + keys.push_back(key_table.element().key()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename queue_of_strings + > + void config_reader_kernel_1:: + get_blocks ( + queue_of_strings& blocks + ) const + { + blocks.clear(); + block_table.reset(); + std::string temp; + while (block_table.move_next()) + { + temp = block_table.element().key(); + blocks.enqueue(temp); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename alloc + > + void config_reader_kernel_1:: + get_blocks ( + std::vector& blocks + ) const + { + blocks.clear(); + block_table.reset(); + while (block_table.move_next()) + { + blocks.push_back(block_table.element().key()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename alloc + > + void config_reader_kernel_1:: + get_blocks ( + std_vector_c& blocks + ) const + { + blocks.clear(); + block_table.reset(); + while (block_table.move_next()) + { + blocks.push_back(block_table.element().key()); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONFIG_READER_KERNEl_1_ + diff --git a/dlib/config_reader/config_reader_kernel_abstract.h b/dlib/config_reader/config_reader_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..e8c44c2b2f965249b382e2a3708096f69f4bc68a --- /dev/null +++ b/dlib/config_reader/config_reader_kernel_abstract.h @@ -0,0 +1,363 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CONFIG_READER_KERNEl_ABSTRACT_ +#ifdef DLIB_CONFIG_READER_KERNEl_ABSTRACT_ + +#include +#include + +namespace dlib +{ + + class config_reader + { + + /*! + INITIAL VALUE + - there aren't any keys defined for this object + - there aren't any blocks defined for this object + + POINTERS AND REFERENCES TO INTERNAL DATA + The destructor, clear(), and load_from() invalidate pointers + and references to internal data. All other functions are guaranteed + to NOT invalidate pointers or references to internal data. + + WHAT THIS OBJECT REPRESENTS + This object represents something which is intended to be used to read + text configuration files that are defined by the following EBNF (with + config_file as the starting symbol): + + config_file = block; + block = { key_value_pair | sub_block }; + key_value_pair = key_name, "=", value; + sub_block = block_name, "{", block, "}"; + + key_name = identifier; + block_name = identifier; + value = matches any string of text that ends with a newline character, # or }. + note that the trailing newline, # or } is not part of the value though. + identifier = Any string that matches the following regular expression: + [a-zA-Z][a-zA-Z0-9_-\.]* + i.e. Any string that starts with a letter and then is continued + with any number of letters, numbers, _ . or - characters. + + Whitespace and comments are ignored. A comment is text that starts with # (but not \# + since the \ escapes the # so that you can have a # symbol in a value if you want) and + ends in a new line. You can also escape a } (e.g. "\}") if you want to have one in a + value. + + Note that in a value the leading and trailing white spaces are stripped off but any + white space inside the value is preserved. + + Also note that all key_names and block_names within a block syntax group must be unique + but don't have to be globally unique. I.e. different blocks can reuse names. + + EXAMPLE CONFIG FILES: + + Example 1: + #comment. This line is ignored because it starts with # + + #here we have key1 which will have the value of "my value" + key1 = my value + + another_key= another value # this is another key called "another_key" with + # a value of "another value" + + # this key's value is the empty string. I.e. "" + key2= + + Example 2: + #this example illustrates the use of blocks + some_key = blah blah + + # now here is a block + our_block + { + # here we can define some keys and values that are local to this block. + a_key = something + foo = bar + some_key = more stuff # note that it is ok to name our key this even though + # there is a key called some_key above. This is because + # we are doing so inside a different block + } + + another_block { foo = bar2 } # this block has only one key and is all on a single line + !*/ + + public: + + // exception classes + class config_reader_error : public dlib::error + { + /*! + GENERAL + This exception is thrown if there is an error while parsing the + config file. The type member of this exception will be set + to ECONFIG_READER. + + INTERPRETING THIS EXCEPTION + - line_number == the line number the parser was at when the + error occurred. + - if (redefinition) then + - The key or block name on line line_number has already + been defined in this scope which is an error. + - else + - Some other general syntax error was detected + !*/ + public: + const unsigned long line_number; + const bool redefinition; + }; + + class file_not_found : public dlib::error + { + /*! + GENERAL + This exception is thrown if the config file can't be opened for + some reason. The type member of this exception will be set + to ECONFIG_READER. + + INTERPRETING THIS EXCEPTION + - file_name == the name of the config file which we failed to open + !*/ + public: + const std::string file_name; + }; + + + class config_reader_access_error : public dlib::error + { + /*! + GENERAL + This exception is thrown if you try to access a key or + block that doesn't exist inside a config reader. The type + member of this exception will be set to ECONFIG_READER. + !*/ + public: + config_reader_access_error( + const std::string& block_name_, + const std::string& key_name_ + ); + /*! + ensures + - #block_name == block_name_ + - #key_name == key_name_ + !*/ + + const std::string block_name; + const std::string key_name; + }; + + // -------------------------- + + config_reader( + ); + /*! + ensures + - #*this is properly initialized + - This object will not have any keys or blocks defined in it. + throws + - std::bad_alloc + - config_reader_error + !*/ + + config_reader( + std::istream& in + ); + /*! + ensures + - #*this is properly initialized + - reads the config file to parse from the given input stream, + parses it and loads this object up with all the sub blocks and + key/value pairs it finds. + - before the load is performed, the previous state of the config file + reader is erased. So after the load the config file reader will contain + only information from the given config file. + - This object will represent the top most block of the config file. + throws + - std::bad_alloc + - config_reader_error + !*/ + + config_reader( + const std::string& config_file + ); + /*! + ensures + - #*this is properly initialized + - parses the config file named by the config_file string. Specifically, + parses it and loads this object up with all the sub blocks and + key/value pairs it finds in the file. + - before the load is performed, the previous state of the config file + reader is erased. So after the load the config file reader will contain + only information from the given config file. + - This object will represent the top most block of the config file. + throws + - std::bad_alloc + - config_reader_error + - file_not_found + !*/ + + virtual ~config_reader( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + If this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + void load_from ( + std::istream& in + ); + /*! + ensures + - reads the config file to parse from the given input stream, + parses it and loads this object up with all the sub blocks and + key/value pairs it finds. + - before the load is performed, the previous state of the config file + reader is erased. So after the load the config file reader will contain + only information from the given config file. + - *this will represent the top most block of the config file contained + in the input stream in. + throws + - std::bad_alloc + If this exception is thrown then *this is unusable + until clear() is called and succeeds + - config_reader_error + If this exception is thrown then this object will + revert to its initial value. + !*/ + + void load_from ( + const std::string& config_file + ); + /*! + ensures + - parses the config file named by the config_file string. Specifically, + parses it and loads this object up with all the sub blocks and + key/value pairs it finds in the file. + - before the load is performed, the previous state of the config file + reader is erased. So after the load the config file reader will contain + only information from the given config file. + - This object will represent the top most block of the config file. + throws + - std::bad_alloc + If this exception is thrown then *this is unusable + until clear() is called and succeeds + - config_reader_error + If this exception is thrown then this object will + revert to its initial value. + - file_not_found + If this exception is thrown then this object will + revert to its initial value. + !*/ + + bool is_key_defined ( + const std::string& key_name + ) const; + /*! + ensures + - if (there is a key with the given name defined within this config_reader's block) then + - returns true + - else + - returns false + !*/ + + bool is_block_defined ( + const std::string& block_name + ) const; + /*! + ensures + - if (there is a sub block with the given name defined within this config_reader's block) then + - returns true + - else + - returns false + !*/ + + typedef config_reader this_type; + const this_type& block ( + const std::string& block_name + ) const; + /*! + ensures + - if (is_block_defined(block_name) == true) then + - returns a const reference to the config_reader that represents the given named sub block + - else + - throws config_reader_access_error + throws + - config_reader_access_error + if this exception is thrown then its block_name field will be set to the + given block_name string. + !*/ + + const std::string& operator[] ( + const std::string& key_name + ) const; + /*! + ensures + - if (is_key_defined(key_name) == true) then + - returns a const reference to the value string associated with the given key in + this config_reader's block. + - else + - throws config_reader_access_error + throws + - config_reader_access_error + if this exception is thrown then its key_name field will be set to the + given key_name string. + !*/ + + template < + typename queue_of_strings + > + void get_keys ( + queue_of_strings& keys + ) const; + /*! + requires + - queue_of_strings is an implementation of queue/queue_kernel_abstract.h + with T set to std::string, or std::vector, or + dlib::std_vector_c + ensures + - #keys == a collection containing all the keys defined in this config_reader's block. + (i.e. for all strings str in keys it is the case that is_key_defined(str) == true) + !*/ + + template < + typename queue_of_strings + > + void get_blocks ( + queue_of_strings& blocks + ) const; + /*! + requires + - queue_of_strings is an implementation of queue/queue_kernel_abstract.h + with T set to std::string, or std::vector, or + dlib::std_vector_c + ensures + - #blocks == a collection containing the names of all the blocks defined in this + config_reader's block. + (i.e. for all strings str in blocks it is the case that is_block_defined(str) == true) + !*/ + + private: + + // restricted functions + config_reader(config_reader&); // copy constructor + config_reader& operator=(config_reader&); // assignment operator + + }; + +} + +#endif // DLIB_CONFIG_READER_KERNEl_ABSTRACT_ + diff --git a/dlib/config_reader/config_reader_thread_safe_1.h b/dlib/config_reader/config_reader_thread_safe_1.h new file mode 100644 index 0000000000000000000000000000000000000000..1ad250c99a4a0c1e65d0790c2432a727543f1a3a --- /dev/null +++ b/dlib/config_reader/config_reader_thread_safe_1.h @@ -0,0 +1,456 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONFIG_READER_THREAD_SAFe_ +#define DLIB_CONFIG_READER_THREAD_SAFe_ + +#include "config_reader_kernel_abstract.h" +#include +#include +#include +#include "../algs.h" +#include "../interfaces/enumerable.h" +#include "../threads.h" +#include "config_reader_thread_safe_abstract.h" + +namespace dlib +{ + + template < + typename config_reader_base, + typename map_string_void + > + class config_reader_thread_safe_1 + { + + /*! + CONVENTION + - get_mutex() == *m + - *cr == the config reader being extended + - block_table[x] == (void*)&block(x) + - block_table.size() == the number of blocks in *cr + - block_table[key] == a config_reader_thread_safe_1 that contains &cr.block(key) + - if (own_pointers) then + - this object owns the m and cr pointers and should delete them when destructed + !*/ + + public: + + config_reader_thread_safe_1 ( + const config_reader_base* base, + rmutex* m_ + ); + + config_reader_thread_safe_1(); + + typedef typename config_reader_base::config_reader_error config_reader_error; + typedef typename config_reader_base::config_reader_access_error config_reader_access_error; + + config_reader_thread_safe_1( + std::istream& in + ); + + config_reader_thread_safe_1( + const std::string& config_file + ); + + virtual ~config_reader_thread_safe_1( + ); + + void clear ( + ); + + void load_from ( + std::istream& in + ); + + void load_from ( + const std::string& config_file + ); + + bool is_key_defined ( + const std::string& key + ) const; + + bool is_block_defined ( + const std::string& name + ) const; + + typedef config_reader_thread_safe_1 this_type; + const this_type& block ( + const std::string& name + ) const; + + const std::string& operator[] ( + const std::string& key + ) const; + + template < + typename queue_of_strings + > + void get_keys ( + queue_of_strings& keys + ) const; + + template < + typename queue_of_strings + > + void get_blocks ( + queue_of_strings& blocks + ) const; + + inline const rmutex& get_mutex ( + ) const; + + private: + + void fill_block_table ( + ); + /*! + ensures + - block_table.size() == the number of blocks in cr + - block_table[key] == a config_reader_thread_safe_1 that contains &cr.block(key) + !*/ + + rmutex* m; + config_reader_base* cr; + map_string_void block_table; + const bool own_pointers; + + // restricted functions + config_reader_thread_safe_1(config_reader_thread_safe_1&); + config_reader_thread_safe_1& operator=(config_reader_thread_safe_1&); + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + config_reader_thread_safe_1:: + config_reader_thread_safe_1( + const config_reader_base* base, + rmutex* m_ + ) : + m(m_), + cr(const_cast(base)), + own_pointers(false) + { + fill_block_table(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + config_reader_thread_safe_1:: + config_reader_thread_safe_1( + ) : + m(0), + cr(0), + own_pointers(true) + { + try + { + m = new rmutex; + cr = new config_reader_base; + } + catch (...) + { + if (m) delete m; + if (cr) delete cr; + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + void config_reader_thread_safe_1:: + clear( + ) + { + auto_mutex M(*m); + cr->clear(); + fill_block_table(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + void config_reader_thread_safe_1:: + load_from( + std::istream& in + ) + { + auto_mutex M(*m); + cr->load_from(in); + fill_block_table(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + void config_reader_thread_safe_1:: + load_from( + const std::string& config_file + ) + { + auto_mutex M(*m); + cr->load_from(config_file); + fill_block_table(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + config_reader_thread_safe_1:: + config_reader_thread_safe_1( + std::istream& in + ) : + m(0), + cr(0), + own_pointers(true) + { + try + { + m = new rmutex; + cr = new config_reader_base(in); + fill_block_table(); + } + catch (...) + { + if (m) delete m; + if (cr) delete cr; + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + config_reader_thread_safe_1:: + config_reader_thread_safe_1( + const std::string& config_file + ) : + m(0), + cr(0), + own_pointers(true) + { + try + { + m = new rmutex; + cr = new config_reader_base(config_file); + fill_block_table(); + } + catch (...) + { + if (m) delete m; + if (cr) delete cr; + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + config_reader_thread_safe_1:: + ~config_reader_thread_safe_1( + ) + { + if (own_pointers) + { + delete m; + delete cr; + } + + // clear out the block table + block_table.reset(); + while (block_table.move_next()) + { + delete static_cast(block_table.element().value()); + } + block_table.clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + bool config_reader_thread_safe_1:: + is_key_defined ( + const std::string& key + ) const + { + auto_mutex M(*m); + return cr->is_key_defined(key); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + bool config_reader_thread_safe_1:: + is_block_defined ( + const std::string& name + ) const + { + auto_mutex M(*m); + return cr->is_block_defined(name); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + const config_reader_thread_safe_1& config_reader_thread_safe_1:: + block ( + const std::string& name + ) const + { + auto_mutex M(*m); + if (block_table.is_in_domain(name) == false) + { + throw config_reader_access_error(name,""); + } + + return *static_cast(block_table[name]); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + const std::string& config_reader_thread_safe_1:: + operator[] ( + const std::string& key + ) const + { + auto_mutex M(*m); + return (*cr)[key]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + template < + typename queue_of_strings + > + void config_reader_thread_safe_1:: + get_keys ( + queue_of_strings& keys + ) const + { + auto_mutex M(*m); + cr->get_keys(keys); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + template < + typename queue_of_strings + > + void config_reader_thread_safe_1:: + get_blocks ( + queue_of_strings& blocks + ) const + { + auto_mutex M(*m); + cr->get_blocks(blocks); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + const rmutex& config_reader_thread_safe_1:: + get_mutex ( + ) const + { + return *m; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// private member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + void config_reader_thread_safe_1:: + fill_block_table ( + ) + { + using namespace std; + // first empty out the block table + block_table.reset(); + while (block_table.move_next()) + { + delete static_cast(block_table.element().value()); + } + block_table.clear(); + + std::vector blocks; + cr->get_blocks(blocks); + + // now fill the block table up to match what is in cr + for (unsigned long i = 0; i < blocks.size(); ++i) + { + config_reader_thread_safe_1* block = new config_reader_thread_safe_1(&cr->block(blocks[i]),m); + void* temp = block; + block_table.add(blocks[i],temp); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONFIG_READER_THREAD_SAFe_ + + diff --git a/dlib/config_reader/config_reader_thread_safe_abstract.h b/dlib/config_reader/config_reader_thread_safe_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..25bcbae4a9c2dd009f3772c6b4879a2620cf9433 --- /dev/null +++ b/dlib/config_reader/config_reader_thread_safe_abstract.h @@ -0,0 +1,45 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CONFIG_READER_THREAD_SAFe_ABSTRACT_ +#ifdef DLIB_CONFIG_READER_THREAD_SAFe_ABSTRACT_ + +#include +#include +#include "config_reader_kernel_abstract.h" +#include "../threads/threads_kernel_abstract.h" + +namespace dlib +{ + + class config_reader_thread_safe + { + + /*! + WHAT THIS EXTENSION DOES FOR config_reader + This object extends a normal config_reader by simply wrapping all + its member functions inside mutex locks to make it safe to use + in a threaded program. + + So this object provides an interface identical to the one defined + in the config_reader/config_reader_kernel_abstract.h file except that + the rmutex returned by get_mutex() is always locked when this + object's member functions are called. + !*/ + + public: + + const rmutex& get_mutex ( + ) const; + /*! + ensures + - returns the rmutex used to make this object thread safe. i.e. returns + the rmutex that is locked when this object's functions are called. + !*/ + + }; + +} + +#endif // DLIB_CONFIG_READER_THREAD_SAFe_ABSTRACT_ + + diff --git a/dlib/console_progress_indicator.h b/dlib/console_progress_indicator.h new file mode 100644 index 0000000000000000000000000000000000000000..42e6fa679070e2eaf799a878ef207fd6b4538f63 --- /dev/null +++ b/dlib/console_progress_indicator.h @@ -0,0 +1,256 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONSOLE_PROGRESS_INDiCATOR_Hh_ +#define DLIB_CONSOLE_PROGRESS_INDiCATOR_Hh_ + +#include +#include +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class console_progress_indicator + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for reporting how long a task will take + to complete. + + For example, consider the following bit of code: + + console_progress_indicator pbar(100) + for (int i = 1; i <= 100; ++i) + { + pbar.print_status(i); + long_running_operation(); + } + + The above code will print a message to the console each iteration + which shows the current progress and how much time is remaining until + the loop terminates. + !*/ + + public: + + inline explicit console_progress_indicator ( + double target_value + ); + /*! + ensures + - #target() == target_value + !*/ + + inline void reset ( + double target_value + ); + /*! + ensures + - #target() == target_value + - performs the equivalent of: + *this = console_progress_indicator(target_value) + (i.e. resets this object with a new target value) + + !*/ + + inline double target ( + ) const; + /*! + ensures + - This object attempts to measure how much time is + left until we reach a certain targeted value. This + function returns that targeted value. + !*/ + + inline bool print_status ( + double cur, + bool always_print = false, + std::ostream& out = std::clog + ); + /*! + ensures + - print_status() assumes it is called with values which are linearly + approaching target(). It will display the current progress and attempt + to predict how much time is remaining until cur becomes equal to target(). + - prints a status message to out which indicates how much more time is + left until cur is equal to target() + - if (always_print) then + - This function prints to the screen each time it is called. + - else + - This function throttles the printing so that at most 1 message is + printed each second. Note that it won't print anything to the screen + until about one second has elapsed. This means that the first call + to print_status() never prints to the screen. + - This function returns true if it prints to the screen and false + otherwise. + !*/ + + inline void finish ( + std::ostream& out = std::cout + ) const; + /*! + ensures + - This object prints the completed progress and the elapsed time to out. + It is meant to be called after the loop we are tracking the progress of. + !*/ + + private: + + double target_val; + std::chrono::time_point start_time; + double first_val; + double seen_first_val; + std::chrono::time_point last_time; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// IMPLEMENTATION DETAILS +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + console_progress_indicator:: + console_progress_indicator ( + double target_value + ) : + target_val(target_value), + start_time(std::chrono::steady_clock::now()), + first_val(0), + seen_first_val(false), + last_time(std::chrono::steady_clock::now()) + { + } + +// ---------------------------------------------------------------------------------------- + + bool console_progress_indicator:: + print_status ( + double cur, + bool always_print, + std::ostream& out + ) + { + const auto cur_time = std::chrono::steady_clock::now(); + + // if this is the first time print_status has been called + // then collect some information and exit. We will print status + // on the next call. + if (!seen_first_val) + { + start_time = cur_time; + last_time = cur_time; + first_val = cur; + seen_first_val = true; + return false; + } + + if ((cur_time - last_time) >= std::chrono::seconds(1) || always_print) + { + last_time = cur_time; + const auto delta_t = cur_time - start_time; + double delta_val = std::abs(cur - first_val); + + // don't do anything if cur is equal to first_val + if (delta_val < std::numeric_limits::epsilon()) + return false; + + const auto rem_time = delta_t / delta_val * std::abs(target_val - cur); + + const auto oldflags = out.flags(); + out.setf(std::ios::fixed,std::ios::floatfield); + std::streamsize ss; + + // adapt the precision based on whether the target val is an integer + if (std::trunc(target_val) == target_val) + ss = out.precision(0); + else + ss = out.precision(2); + + out << "Progress: " << cur << "/" << target_val; + ss = out.precision(2); + out << " (" << cur / target_val * 100. << "%). "; + + const auto hours = std::chrono::duration_cast(rem_time); + const auto minutes = std::chrono::duration_cast(rem_time) - hours; + const auto seconds = std::chrono::duration_cast(rem_time) - hours - minutes; + out << "Time remaining: "; + if (rem_time >= std::chrono::hours(1)) + out << hours.count() << "h "; + if (rem_time >= std::chrono::minutes(1)) + out << minutes.count() << "min "; + out << seconds.count() << "s. \r" << std::flush; + + // restore previous output flags and precision settings + out.flags(oldflags); + out.precision(ss); + + return true; + } + + return false; + } + +// ---------------------------------------------------------------------------------------- + + double console_progress_indicator:: + target ( + ) const + { + return target_val; + } + +// ---------------------------------------------------------------------------------------- + + void console_progress_indicator:: + reset ( + double target_value + ) + { + *this = console_progress_indicator(target_value); + } + +// ---------------------------------------------------------------------------------------- + + void console_progress_indicator:: + finish ( + std::ostream& out + ) const + { + const auto oldflags = out.flags(); + out.setf(std::ios::fixed,std::ios::floatfield); + std::streamsize ss; + + // adapt the precision based on whether the target val is an integer + if (std::trunc(target_val) == target_val) + ss = out.precision(0); + else + ss = out.precision(2); + + out << "Progress: " << target_val << "/" << target_val; + out << " (100.00%). "; + const auto delta_t = std::chrono::steady_clock::now() - start_time; + const auto hours = std::chrono::duration_cast(delta_t); + const auto minutes = std::chrono::duration_cast(delta_t) - hours; + const auto seconds = std::chrono::duration_cast(delta_t) - hours - minutes; + out << "Time elapsed: "; + if (delta_t >= std::chrono::hours(1)) + out << hours.count() << "h "; + if (delta_t >= std::chrono::minutes(1)) + out << minutes.count() << "min "; + out << seconds.count() << "s. " << std::endl; + + // restore previous output flags and precision settings + out.flags(oldflags); + out.precision(ss); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONSOLE_PROGRESS_INDiCATOR_Hh_ + diff --git a/dlib/constexpr_if.h b/dlib/constexpr_if.h new file mode 100644 index 0000000000000000000000000000000000000000..2c22ddf5e3e05b9f2c621a50d20d8092b2be3eae --- /dev/null +++ b/dlib/constexpr_if.h @@ -0,0 +1,136 @@ +// Copyright (C) 2022 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IF_CONSTEXPR_H +#define DLIB_IF_CONSTEXPR_H + +#include "overloaded.h" +#include "type_traits.h" + +namespace dlib +{ +// ---------------------------------------------------------------------------------------- + + namespace detail + { + const auto _ = [](auto&& arg) -> decltype(auto) { return std::forward(arg); }; + + template class Op, class... Args> + struct is_detected : std::false_type{}; + + template